Compare commits
41 Commits
462c44ce89
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
1b28f10ca8
|
|||
|
664e79fce6
|
|||
|
c74313c195
|
|||
|
9783ae5865
|
|||
|
62de222488
|
|||
| c9d143d84b | |||
|
d18a904ed5
|
|||
|
cb7be28f42
|
|||
|
0908b43724
|
|||
|
52310f588d
|
|||
|
b52216bd2f
|
|||
|
2bc83a17dd
|
|||
|
faf6e2abd7
|
|||
|
0a4eac188a
|
|||
|
7c90c9ed4a
|
|||
|
8a631af0d2
|
|||
|
40fda3420c
|
|||
|
c4801e3309
|
|||
|
4f10a8a422
|
|||
|
0b44d1c83f
|
|||
|
0133d956a5
|
|||
|
3c20e854aa
|
|||
|
090dbec390
|
|||
|
df860b3061
|
|||
|
9aecc7ce02
|
|||
|
94f1f1c266
|
|||
|
8fff893d25
|
|||
|
5ba62afec3
|
|||
|
058da51f86
|
|||
|
adfe372d13
|
|||
|
3163ea47dc
|
|||
|
ab07e6a8dc
|
|||
|
b8fcbc7e10
|
|||
|
aa569aac16
|
|||
|
1a407ad4c2
|
|||
|
5d0c8cc20c
|
|||
|
d226c32b9b
|
|||
|
86786c9d05
|
|||
| d78d461236 | |||
| 49425635ce | |||
| 8ff029fcb7 |
55
.claude/skills/bubbletea/SKILL.md
Normal file
55
.claude/skills/bubbletea/SKILL.md
Normal file
@@ -0,0 +1,55 @@
|
||||
---
|
||||
name: bubbletea
|
||||
description: Browse Bubbletea TUI framework documentation and examples. Use when working with Bubbletea components, models, commands, or building terminal user interfaces in Go.
|
||||
---
|
||||
|
||||
# Bubbletea Documentation
|
||||
|
||||
Bubbletea is a Go framework for building terminal user interfaces based on The Elm Architecture.
|
||||
|
||||
## Key Resources
|
||||
|
||||
When you need to understand Bubbletea patterns or find examples:
|
||||
|
||||
1. **Examples README** - Overview of all available examples:
|
||||
https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
|
||||
|
||||
2. **Examples Directory** - Full source code for all examples:
|
||||
https://github.com/charmbracelet/bubbletea/tree/main/examples
|
||||
|
||||
## How to Use
|
||||
|
||||
1. First, fetch the examples README to get an overview of available examples:
|
||||
```
|
||||
WebFetch https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
|
||||
```
|
||||
|
||||
2. Once you identify a relevant example, fetch its source code from the examples directory.
|
||||
|
||||
## Common Examples to Reference
|
||||
|
||||
- `list` - List component with filtering
|
||||
- `table` - Table component
|
||||
- `textinput` - Text input handling
|
||||
- `textarea` - Multi-line text input
|
||||
- `viewport` - Scrollable content
|
||||
- `paginator` - Pagination
|
||||
- `spinner` - Loading spinners
|
||||
- `progress` - Progress bars
|
||||
- `tabs` - Tab navigation
|
||||
- `help` - Help text/keybindings display
|
||||
|
||||
## Core Concepts
|
||||
|
||||
- **Model**: Application state
|
||||
- **Update**: Handles messages and returns updated model + commands
|
||||
- **View**: Renders the model to a string
|
||||
- **Cmd**: Side effects that produce messages
|
||||
- **Msg**: Events that trigger updates
|
||||
|
||||
## Related Charm Libraries
|
||||
|
||||
- **Bubbles**: Pre-built components (github.com/charmbracelet/bubbles)
|
||||
- **Lipgloss**: Styling and layout (github.com/charmbracelet/lipgloss)
|
||||
- **Glamour**: Markdown rendering (github.com/charmbracelet/glamour)
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,3 +4,5 @@ oubliette.toml
|
||||
*.db-wal
|
||||
*.db-shm
|
||||
/oubliette
|
||||
*.mmdb
|
||||
*.mmdb.gz
|
||||
|
||||
87
PLAN.md
87
PLAN.md
@@ -150,7 +150,7 @@ Goal: Add the entertaining shell implementations.
|
||||
- **Haunted:** commands gradually return stranger output, files appear/disappear, `whoami` returns different users
|
||||
- **Bread crumbs:** fake .bash_history, id_rsa files, database configs pointing to other honeypots
|
||||
|
||||
### 3.2 Cisco IOS Shell
|
||||
### 3.2 Cisco IOS Shell ✅
|
||||
- Realistic `>` and `#` prompts
|
||||
- Common commands: `show running-config`, `show interfaces`, `enable`, `configure terminal`
|
||||
- Fake device info that looks like a real router
|
||||
@@ -162,14 +162,29 @@ Goal: Add the entertaining shell implementations.
|
||||
- "WARNING: milk expires in 2 days"
|
||||
- Per-credential shell routing via `shell` field in static credentials
|
||||
|
||||
### 3.4 Text Adventure
|
||||
### 3.4 Text Adventure ✅
|
||||
- Zork-style dungeon crawler
|
||||
- "You are in a dimly lit server room."
|
||||
- Navigation, items, puzzles
|
||||
- The dungeon is the oubliette itself
|
||||
|
||||
### 3.5 Other Shell Ideas (Future)
|
||||
- **Banking TUI:** 80s-style green-on-black bank terminal
|
||||
### 3.5 Banking TUI Shell ✅
|
||||
- 80s-style green-on-black bank terminal
|
||||
|
||||
### 3.6 PostgreSQL psql Shell ✅
|
||||
- Simulates psql interactive terminal with `db_name` and `pg_version` config
|
||||
- Backslash meta-commands: `\q`, `\dt`, `\d <table>`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
|
||||
- SQL statement handling with multi-line buffering (semicolon-terminated)
|
||||
- Canned responses for common queries (SELECT version(), current_database(), etc.)
|
||||
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
|
||||
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
|
||||
|
||||
### 3.7 Roomba Shell ✅
|
||||
- iRobot Roomba j7+ vacuum robot interface
|
||||
- Status, cleaning, scheduling, diagnostics, floor map
|
||||
- Humorous history entries (cat encounters, sock tangles, sticky substances)
|
||||
|
||||
### 3.8 Other Shell Ideas (Future)
|
||||
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
|
||||
- **ELIZA therapist:** every response is a therapy question
|
||||
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
|
||||
@@ -181,19 +196,55 @@ Goal: Add the entertaining shell implementations.
|
||||
|
||||
Goal: Make the web UI great and add operational niceties.
|
||||
|
||||
### 4.1 Enhanced Web UI
|
||||
- GeoIP lookups and world map visualization of attack sources
|
||||
- Charts: attempts over time, hourly patterns, credential trends
|
||||
- Session detail view with full command log
|
||||
- Filtering and search
|
||||
### 4.1 Enhanced Web UI ✅
|
||||
- GeoIP lookups and world map visualization of attack sources ✅
|
||||
- Charts: attempts over time, hourly patterns, credential trends ✅
|
||||
- Session detail view with full command log ✅
|
||||
- Filtering and search ✅
|
||||
|
||||
### 4.2 Operational
|
||||
- Prometheus metrics endpoint
|
||||
- Structured logging (slog)
|
||||
- Graceful shutdown
|
||||
- Systemd unit file / deployment docs
|
||||
### 4.2 Operational ✅
|
||||
- Prometheus metrics endpoint ✅
|
||||
- Structured logging (slog) ✅
|
||||
- Graceful shutdown ✅
|
||||
- Docker image (nix dockerTools) ✅
|
||||
- Systemd unit file / deployment docs ✅
|
||||
|
||||
### 4.3 GeoIP
|
||||
- Embed a lightweight GeoIP database or use an API
|
||||
- Store country/city with each attempt
|
||||
- Aggregate stats by country
|
||||
### 4.3 GeoIP ✅
|
||||
- Embed a lightweight GeoIP database or use an API ✅
|
||||
- Store country/city with each attempt ✅
|
||||
- Aggregate stats by country ✅
|
||||
|
||||
### 4.4 Capture SSH Exec Commands ✅
|
||||
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
|
||||
|
||||
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`) ✅
|
||||
- Parse the command string from the exec payload ✅
|
||||
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration ✅
|
||||
- Store the command on the session record before closing the channel ✅
|
||||
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
|
||||
- Surface exec commands in the web UI (session detail view) ✅
|
||||
|
||||
#### 4.4.1 Fake Exec Output
|
||||
Return plausible fake output for exec commands to encourage bots to interact further.
|
||||
|
||||
**Approach: regex-based output assembly.** Bots typically send a single long command that chains recon commands and then echoes a summary (e.g. `echo "UNAME:$uname"`). Rather than interpreting arbitrary shell pipelines, we scan the command string for known patterns and assemble fake output.
|
||||
|
||||
Implementation:
|
||||
- A map of common command/variable patterns to fake output strings, e.g.:
|
||||
- `uname -a` / `uname -s -v -n -m` → `"Linux ubuntu-server 5.15.0-91-generic #101-Ubuntu SMP Tue Jan 2 15:13:10 UTC 2024 x86_64"`
|
||||
- `uname -m` / `arch` → `"x86_64"`
|
||||
- `cat /proc/uptime` → `"86432.71 172801.55"`
|
||||
- `nproc` / `grep -c "^processor" /proc/cpuinfo` → `"2"`
|
||||
- `cat /proc/cpuinfo` → fake cpuinfo block
|
||||
- `lspci` → empty (no GPU — discourages cryptominer targeting)
|
||||
- `id` → `"uid=0(root) gid=0(root) groups=0(root)"`
|
||||
- `cat /etc/passwd` → minimal fake passwd file
|
||||
- `last` → fake login entries
|
||||
- `cat --help`, `ls --help` → canned GNU coreutils help text
|
||||
- Scan the exec command for `echo "KEY:$var"` patterns; for each key, look up the corresponding fake value from the variable assignment earlier in the command
|
||||
- If we recognise echo patterns, assemble and return the expected output
|
||||
- If we don't recognise the command at all, return empty output with exit 0 (current behaviour)
|
||||
- Values should draw from the existing shell config where possible (hostname, fake_user) for consistency
|
||||
- New package `internal/execfake` or a file in `internal/server/` — keep it simple
|
||||
|
||||
Gather more real-world bot examples before implementing to ensure good coverage of common recon patterns.
|
||||
|
||||
27
README.md
27
README.md
@@ -34,7 +34,8 @@ Key settings:
|
||||
- `auth.accept_after` — accept login after N failures per IP (default `10`)
|
||||
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
|
||||
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
|
||||
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS)
|
||||
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
|
||||
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
|
||||
- `storage.db_path` — SQLite database path (default `oubliette.db`)
|
||||
- `storage.retention_days` — auto-prune records older than N days (default `90`)
|
||||
- `storage.retention_interval` — how often to run retention (default `1h`)
|
||||
@@ -43,12 +44,21 @@ Key settings:
|
||||
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
|
||||
- `web.enabled` — enable the web dashboard (default `false`)
|
||||
- `web.listen_addr` — web dashboard listen address (default `:8080`)
|
||||
- Dashboard includes Chart.js charts (attempts over time, hourly pattern), an SVG world map choropleth colored by attack origin, and filter controls for date range / IP / country / username
|
||||
- Session detail pages at `/sessions/{id}` include terminal replay via xterm.js
|
||||
- `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`)
|
||||
- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty)
|
||||
- `detection.enabled` — enable human detection scoring (default `false`)
|
||||
- `detection.threshold` — score threshold (0.0–1.0) for flagging sessions (default `0.6`)
|
||||
- `detection.update_interval` — how often to recompute scores (default `5s`)
|
||||
- `notify.webhooks` — list of webhook endpoints for notifications (see example config)
|
||||
|
||||
### GeoIP
|
||||
|
||||
Country-level GeoIP lookups are embedded in the binary using the [DB-IP Lite](https://db-ip.com/db/lite.php) database (CC-BY-4.0). The dashboard shows country alongside IPs and includes a "Top Countries" table.
|
||||
|
||||
For local development, run `scripts/fetch-geoip.sh` to download the MMDB file. The Nix build fetches it automatically.
|
||||
|
||||
### Run
|
||||
|
||||
```sh
|
||||
@@ -61,6 +71,9 @@ Test with:
|
||||
ssh -o StrictHostKeyChecking=no -p 2222 root@localhost
|
||||
```
|
||||
|
||||
SSH exec commands (`ssh user@host <command>`) are also captured and stored on the session record.
|
||||
|
||||
|
||||
### NixOS Module
|
||||
|
||||
Add the flake as an input and enable the service:
|
||||
@@ -82,3 +95,15 @@ Add the flake as an input and enable the service:
|
||||
```
|
||||
|
||||
Alternatively, use `configFile` to pass a pre-written TOML file instead of `settings`.
|
||||
|
||||
### Docker
|
||||
|
||||
Build a Docker image via nix:
|
||||
|
||||
```sh
|
||||
nix build .#dockerImage
|
||||
docker load < result
|
||||
docker run -v /path/to/data:/data -p 2222:2222 -p 8080:8080 oubliette:0.8.0
|
||||
```
|
||||
|
||||
Place your `oubliette.toml` in the data volume. The container exposes ports 2222 (SSH) and 8080 (web/metrics).
|
||||
|
||||
@@ -13,13 +13,14 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/server"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"git.t-juice.club/torjus/oubliette/internal/web"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/server"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/web"
|
||||
)
|
||||
|
||||
const Version = "0.4.0"
|
||||
const Version = "0.18.0"
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
@@ -65,12 +66,23 @@ func run() error {
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Clean up sessions left active by a previous unclean shutdown.
|
||||
if n, err := store.CloseActiveSessions(context.Background(), time.Now()); err != nil {
|
||||
return fmt.Errorf("close stale sessions: %w", err)
|
||||
} else if n > 0 {
|
||||
logger.Info("closed stale sessions from previous run", "count", n)
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
m := metrics.New(Version)
|
||||
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
|
||||
m.RegisterStoreCollector(instrumentedStore)
|
||||
|
||||
srv, err := server.New(*cfg, store, logger)
|
||||
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
|
||||
srv, err := server.New(*cfg, instrumentedStore, logger, m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create server: %w", err)
|
||||
}
|
||||
@@ -79,7 +91,12 @@ func run() error {
|
||||
|
||||
// Start web server if enabled.
|
||||
if cfg.Web.Enabled {
|
||||
webHandler, err := web.NewServer(store, logger.With("component", "web"))
|
||||
var metricsHandler http.Handler
|
||||
if *cfg.Web.MetricsEnabled {
|
||||
metricsHandler = m.Handler()
|
||||
}
|
||||
|
||||
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create web server: %w", err)
|
||||
}
|
||||
|
||||
27
flake.nix
27
flake.nix
@@ -18,19 +18,44 @@
|
||||
pkgs = nixpkgs.legacyPackages.${system};
|
||||
mainGo = builtins.readFile ./cmd/oubliette/main.go;
|
||||
version = builtins.head (builtins.match ''.*const Version = "([^"]+)".*'' mainGo);
|
||||
geoipDb = pkgs.fetchurl {
|
||||
url = "https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz";
|
||||
hash = "sha256-xmQZEJZ5WzE9uQww1Sdb8248l+liYw46tjbfJeu945Q=";
|
||||
};
|
||||
in
|
||||
{
|
||||
default = pkgs.buildGoModule {
|
||||
pname = "oubliette";
|
||||
inherit version;
|
||||
src = ./.;
|
||||
vendorHash = "sha256-EbJ90e4Jco7CvYYJLrewFLD5XF+Wv6TsT8RRLcj+ijU=";
|
||||
vendorHash = "sha256-/zxK6CABLYBNtuSOI8dIVgMNxKiDIcbZUS7bQR5TenA=";
|
||||
subPackages = [ "cmd/oubliette" ];
|
||||
nativeBuildInputs = [ pkgs.gzip ];
|
||||
preBuild = ''
|
||||
gunzip -c ${geoipDb} > internal/geoip/dbip-country-lite.mmdb
|
||||
'';
|
||||
meta = {
|
||||
description = "SSH honeypot";
|
||||
mainProgram = "oubliette";
|
||||
};
|
||||
};
|
||||
|
||||
dockerImage = pkgs.dockerTools.buildLayeredImage {
|
||||
name = "oubliette";
|
||||
tag = version;
|
||||
contents = [ self.packages.${system}.default ];
|
||||
config = {
|
||||
Entrypoint = [ "/bin/oubliette" ];
|
||||
Cmd = [ "-config" "/data/oubliette.toml" ];
|
||||
ExposedPorts = {
|
||||
"2222/tcp" = {};
|
||||
"8080/tcp" = {};
|
||||
};
|
||||
Volumes = {
|
||||
"/data" = {};
|
||||
};
|
||||
};
|
||||
};
|
||||
});
|
||||
|
||||
devShells = forAllSystems (system:
|
||||
|
||||
30
go.mod
30
go.mod
@@ -1,21 +1,49 @@
|
||||
module git.t-juice.club/torjus/oubliette
|
||||
module code.t-juice.club/torjus/oubliette
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.6.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/oschwald/maxminddb-golang v1.13.1
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.45.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
|
||||
94
go.sum
94
go.sum
@@ -1,34 +1,116 @@
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
|
||||
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {
|
||||
|
||||
@@ -21,15 +21,18 @@ type Config struct {
|
||||
}
|
||||
|
||||
type WebConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
Enabled bool `toml:"enabled"`
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
MetricsEnabled *bool `toml:"metrics_enabled"`
|
||||
MetricsToken string `toml:"metrics_token"`
|
||||
}
|
||||
|
||||
type ShellConfig struct {
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
UsernameRoutes map[string]string `toml:"username_routes"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
@@ -143,6 +146,10 @@ func applyDefaults(cfg *Config) {
|
||||
if cfg.Web.ListenAddr == "" {
|
||||
cfg.Web.ListenAddr = ":8080"
|
||||
}
|
||||
if cfg.Web.MetricsEnabled == nil {
|
||||
t := true
|
||||
cfg.Web.MetricsEnabled = &t
|
||||
}
|
||||
if cfg.Shell.Hostname == "" {
|
||||
cfg.Shell.Hostname = "ubuntu-server"
|
||||
}
|
||||
@@ -159,9 +166,10 @@ func applyDefaults(cfg *Config) {
|
||||
|
||||
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
|
||||
var knownShellKeys = map[string]bool{
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
"username_routes": true,
|
||||
}
|
||||
|
||||
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
|
||||
|
||||
@@ -282,6 +282,22 @@ password = "toor"
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMetricsToken(t *testing.T) {
|
||||
content := `
|
||||
[web]
|
||||
enabled = true
|
||||
metrics_token = "my-secret-token"
|
||||
`
|
||||
path := writeTemp(t, content)
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Web.MetricsToken != "my-secret-token" {
|
||||
t.Errorf("metrics_token = %q, want %q", cfg.Web.MetricsToken, "my-secret-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path/config.toml")
|
||||
if err == nil {
|
||||
@@ -297,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadUsernameRoutes(t *testing.T) {
|
||||
content := `
|
||||
[shell]
|
||||
hostname = "myhost"
|
||||
|
||||
[shell.username_routes]
|
||||
postgres = "psql"
|
||||
admin = "bash"
|
||||
|
||||
[shell.bash]
|
||||
custom_key = "value"
|
||||
`
|
||||
path := writeTemp(t, content)
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes == nil {
|
||||
t.Fatal("UsernameRoutes should not be nil")
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
|
||||
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
|
||||
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
|
||||
}
|
||||
// username_routes should NOT appear in the Shells map.
|
||||
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
|
||||
t.Error("username_routes should not appear in Shells map")
|
||||
}
|
||||
// bash should still appear in Shells map.
|
||||
if _, ok := cfg.Shell.Shells["bash"]; !ok {
|
||||
t.Error("Shells[\"bash\"] should still be present")
|
||||
}
|
||||
}
|
||||
|
||||
func writeTemp(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
path := filepath.Join(t.TempDir(), "config.toml")
|
||||
|
||||
51
internal/geoip/geoip.go
Normal file
51
internal/geoip/geoip.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package geoip
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
)
|
||||
|
||||
//go:embed dbip-country-lite.mmdb
|
||||
var mmdbData []byte
|
||||
|
||||
// Reader provides country-level GeoIP lookups using an embedded DB-IP Lite database.
|
||||
type Reader struct {
|
||||
db *maxminddb.Reader
|
||||
}
|
||||
|
||||
// New opens the embedded MMDB and returns a ready-to-use Reader.
|
||||
func New() (*Reader, error) {
|
||||
db, err := maxminddb.FromBytes(mmdbData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Reader{db: db}, nil
|
||||
}
|
||||
|
||||
type countryRecord struct {
|
||||
Country struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"country"`
|
||||
}
|
||||
|
||||
// Lookup returns the ISO 3166-1 alpha-2 country code for the given IP address,
|
||||
// or an empty string if the lookup fails or no result is found.
|
||||
func (r *Reader) Lookup(ipStr string) string {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var record countryRecord
|
||||
if err := r.db.Lookup(ip, &record); err != nil {
|
||||
return ""
|
||||
}
|
||||
return record.Country.ISOCode
|
||||
}
|
||||
|
||||
// Close releases resources held by the reader.
|
||||
func (r *Reader) Close() error {
|
||||
return r.db.Close()
|
||||
}
|
||||
44
internal/geoip/geoip_test.go
Normal file
44
internal/geoip/geoip_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package geoip
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLookup(t *testing.T) {
|
||||
reader, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
want string
|
||||
}{
|
||||
{"8.8.8.8", "US"},
|
||||
{"1.1.1.1", "AU"},
|
||||
{"invalid", ""},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
got := reader.Lookup(tt.ip)
|
||||
if got != tt.want {
|
||||
t.Errorf("Lookup(%q) = %q, want %q", tt.ip, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupPrivateIP(t *testing.T) {
|
||||
reader, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Private IPs should return empty string (no country).
|
||||
got := reader.Lookup("10.0.0.1")
|
||||
if got != "" {
|
||||
t.Errorf("Lookup(10.0.0.1) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
178
internal/metrics/metrics.go
Normal file
178
internal/metrics/metrics.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// Metrics holds all Prometheus collectors for the honeypot.
|
||||
type Metrics struct {
|
||||
registry *prometheus.Registry
|
||||
|
||||
SSHConnectionsTotal *prometheus.CounterVec
|
||||
SSHConnectionsActive prometheus.Gauge
|
||||
AuthAttemptsTotal *prometheus.CounterVec
|
||||
AuthAttemptsByCountry *prometheus.CounterVec
|
||||
CommandsExecuted *prometheus.CounterVec
|
||||
HumanScore prometheus.Histogram
|
||||
SessionsTotal *prometheus.CounterVec
|
||||
SessionsActive prometheus.Gauge
|
||||
SessionDuration prometheus.Histogram
|
||||
ExecCommandsTotal prometheus.Counter
|
||||
BuildInfo *prometheus.GaugeVec
|
||||
StorageQueryDuration *prometheus.HistogramVec
|
||||
StorageQueryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// New creates a new Metrics instance with all collectors registered.
|
||||
func New(version string) *Metrics {
|
||||
reg := prometheus.NewRegistry()
|
||||
|
||||
m := &Metrics{
|
||||
registry: reg,
|
||||
SSHConnectionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_ssh_connections_total",
|
||||
Help: "Total SSH connections received.",
|
||||
}, []string{"outcome"}),
|
||||
SSHConnectionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "oubliette_ssh_connections_active",
|
||||
Help: "Current active SSH connections.",
|
||||
}),
|
||||
AuthAttemptsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_auth_attempts_total",
|
||||
Help: "Total authentication attempts.",
|
||||
}, []string{"result", "reason"}),
|
||||
AuthAttemptsByCountry: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_auth_attempts_by_country_total",
|
||||
Help: "Total authentication attempts by country.",
|
||||
}, []string{"country"}),
|
||||
CommandsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_commands_executed_total",
|
||||
Help: "Total commands executed in shells.",
|
||||
}, []string{"shell"}),
|
||||
HumanScore: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "oubliette_human_score",
|
||||
Help: "Distribution of final human detection scores.",
|
||||
Buckets: prometheus.LinearBuckets(0, 0.1, 11), // 0.0, 0.1, ..., 1.0
|
||||
}),
|
||||
SessionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_sessions_total",
|
||||
Help: "Total sessions created.",
|
||||
}, []string{"shell"}),
|
||||
SessionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "oubliette_sessions_active",
|
||||
Help: "Current active sessions.",
|
||||
}),
|
||||
SessionDuration: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "oubliette_session_duration_seconds",
|
||||
Help: "Session duration in seconds.",
|
||||
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600},
|
||||
}),
|
||||
ExecCommandsTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "oubliette_exec_commands_total",
|
||||
Help: "Total SSH exec commands received.",
|
||||
}),
|
||||
BuildInfo: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "oubliette_build_info",
|
||||
Help: "Build information. Always 1.",
|
||||
}, []string{"version"}),
|
||||
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "oubliette_storage_query_duration_seconds",
|
||||
Help: "Duration of storage query calls in seconds.",
|
||||
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
}, []string{"method"}),
|
||||
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_storage_query_errors_total",
|
||||
Help: "Total storage query errors.",
|
||||
}, []string{"method"}),
|
||||
}
|
||||
|
||||
reg.MustRegister(
|
||||
collectors.NewGoCollector(),
|
||||
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
|
||||
m.SSHConnectionsTotal,
|
||||
m.SSHConnectionsActive,
|
||||
m.AuthAttemptsTotal,
|
||||
m.AuthAttemptsByCountry,
|
||||
m.CommandsExecuted,
|
||||
m.HumanScore,
|
||||
m.SessionsTotal,
|
||||
m.SessionsActive,
|
||||
m.SessionDuration,
|
||||
m.ExecCommandsTotal,
|
||||
m.BuildInfo,
|
||||
m.StorageQueryDuration,
|
||||
m.StorageQueryErrors,
|
||||
)
|
||||
|
||||
m.BuildInfo.WithLabelValues(version).Set(1)
|
||||
|
||||
// Initialize label combinations so they appear in Gather/output.
|
||||
for _, outcome := range []string{"accepted", "rejected_handshake", "rejected_max_connections"} {
|
||||
m.SSHConnectionsTotal.WithLabelValues(outcome)
|
||||
}
|
||||
for _, reason := range []string{"static_credential", "remembered_credential", "threshold_reached", "rejected"} {
|
||||
m.AuthAttemptsTotal.WithLabelValues("accepted", reason)
|
||||
m.AuthAttemptsTotal.WithLabelValues("rejected", reason)
|
||||
}
|
||||
for _, sh := range []string{"bash", "fridge", "banking", "adventure", "cisco"} {
|
||||
m.SessionsTotal.WithLabelValues(sh)
|
||||
m.CommandsExecuted.WithLabelValues(sh)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// RegisterStoreCollector registers a collector that queries storage stats on each scrape.
|
||||
func (m *Metrics) RegisterStoreCollector(store storage.Store) {
|
||||
m.registry.MustRegister(&storeCollector{store: store})
|
||||
}
|
||||
|
||||
// Handler returns an http.Handler that serves Prometheus metrics.
|
||||
func (m *Metrics) Handler() http.Handler {
|
||||
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
|
||||
}
|
||||
|
||||
// storeCollector implements prometheus.Collector, querying storage on each scrape.
|
||||
type storeCollector struct {
|
||||
store storage.Store
|
||||
}
|
||||
|
||||
var (
|
||||
storageLoginAttemptsDesc = prometheus.NewDesc(
|
||||
"oubliette_storage_login_attempts_total",
|
||||
"Total login attempts in storage.",
|
||||
nil, nil,
|
||||
)
|
||||
storageUniqueIPsDesc = prometheus.NewDesc(
|
||||
"oubliette_storage_unique_ips",
|
||||
"Unique IPs in storage.",
|
||||
nil, nil,
|
||||
)
|
||||
storageSessionsDesc = prometheus.NewDesc(
|
||||
"oubliette_storage_sessions_total",
|
||||
"Total sessions in storage.",
|
||||
nil, nil,
|
||||
)
|
||||
)
|
||||
|
||||
func (c *storeCollector) Describe(ch chan<- *prometheus.Desc) {
|
||||
ch <- storageLoginAttemptsDesc
|
||||
ch <- storageUniqueIPsDesc
|
||||
ch <- storageSessionsDesc
|
||||
}
|
||||
|
||||
func (c *storeCollector) Collect(ch chan<- prometheus.Metric) {
|
||||
stats, err := c.store.GetDashboardStats(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ch <- prometheus.MustNewConstMetric(storageLoginAttemptsDesc, prometheus.GaugeValue, float64(stats.TotalAttempts))
|
||||
ch <- prometheus.MustNewConstMetric(storageUniqueIPsDesc, prometheus.GaugeValue, float64(stats.UniqueIPs))
|
||||
ch <- prometheus.MustNewConstMetric(storageSessionsDesc, prometheus.GaugeValue, float64(stats.TotalSessions))
|
||||
}
|
||||
142
internal/metrics/metrics_test.go
Normal file
142
internal/metrics/metrics_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
m := New("1.2.3")
|
||||
|
||||
// Gather all metrics and check expected names exist.
|
||||
families, err := m.registry.Gather()
|
||||
if err != nil {
|
||||
t.Fatalf("gather: %v", err)
|
||||
}
|
||||
|
||||
want := map[string]bool{
|
||||
"oubliette_ssh_connections_total": false,
|
||||
"oubliette_ssh_connections_active": false,
|
||||
"oubliette_auth_attempts_total": false,
|
||||
"oubliette_commands_executed_total": false,
|
||||
"oubliette_human_score": false,
|
||||
"oubliette_sessions_total": false,
|
||||
"oubliette_sessions_active": false,
|
||||
"oubliette_session_duration_seconds": false,
|
||||
"oubliette_build_info": false,
|
||||
}
|
||||
|
||||
for _, f := range families {
|
||||
if _, ok := want[f.GetName()]; ok {
|
||||
want[f.GetName()] = true
|
||||
}
|
||||
}
|
||||
|
||||
for name, found := range want {
|
||||
if !found {
|
||||
t.Errorf("metric %q not registered", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAttemptsByCountry(t *testing.T) {
|
||||
m := New("1.0.0")
|
||||
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
|
||||
m.AuthAttemptsByCountry.WithLabelValues("DE").Inc()
|
||||
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
|
||||
|
||||
families, err := m.registry.Gather()
|
||||
if err != nil {
|
||||
t.Fatalf("gather: %v", err)
|
||||
}
|
||||
|
||||
var found bool
|
||||
for _, f := range families {
|
||||
if f.GetName() == "oubliette_auth_attempts_by_country_total" {
|
||||
found = true
|
||||
if len(f.GetMetric()) != 2 {
|
||||
t.Errorf("expected 2 label pairs (US, DE), got %d", len(f.GetMetric()))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("oubliette_auth_attempts_by_country_total not found after incrementing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
m := New("1.2.3")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
m.Handler().ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(w.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), `oubliette_build_info{version="1.2.3"} 1`) {
|
||||
t.Errorf("response should contain build_info metric, got:\n%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCollector(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed some data.
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
m := New("test")
|
||||
m.RegisterStoreCollector(store)
|
||||
|
||||
families, err := m.registry.Gather()
|
||||
if err != nil {
|
||||
t.Fatalf("gather: %v", err)
|
||||
}
|
||||
|
||||
wantMetrics := map[string]float64{
|
||||
"oubliette_storage_login_attempts_total": 2,
|
||||
"oubliette_storage_unique_ips": 2,
|
||||
"oubliette_storage_sessions_total": 1,
|
||||
}
|
||||
|
||||
for _, f := range families {
|
||||
expected, ok := wantMetrics[f.GetName()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if len(f.GetMetric()) == 0 {
|
||||
t.Errorf("metric %q has no samples", f.GetName())
|
||||
continue
|
||||
}
|
||||
got := f.GetMetric()[0].GetGauge().GetValue()
|
||||
if got != expected {
|
||||
t.Errorf("metric %q = %f, want %f", f.GetName(), got, expected)
|
||||
}
|
||||
delete(wantMetrics, f.GetName())
|
||||
}
|
||||
|
||||
for name := range wantMetrics {
|
||||
t.Errorf("metric %q not found in gather output", name)
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
// Event types.
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
func testSession() SessionInfo {
|
||||
|
||||
@@ -12,14 +12,22 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/detection"
|
||||
"git.t-juice.club/torjus/oubliette/internal/notify"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/detection"
|
||||
"code.t-juice.club/torjus/oubliette/internal/geoip"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/notify"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@@ -32,9 +40,11 @@ type Server struct {
|
||||
connSem chan struct{} // semaphore limiting concurrent connections
|
||||
shellRegistry *shell.Registry
|
||||
notifier notify.Sender
|
||||
metrics *metrics.Metrics
|
||||
geoip *geoip.Reader
|
||||
}
|
||||
|
||||
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
|
||||
registry := shell.NewRegistry()
|
||||
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering bash shell: %w", err)
|
||||
@@ -42,6 +52,29 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
|
||||
if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering fridge shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(banking.NewBankingShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering banking shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering adventure shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering cisco shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering psql shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering roomba shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering tetris shell: %w", err)
|
||||
}
|
||||
|
||||
geo, err := geoip.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening geoip database: %w", err)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
@@ -51,6 +84,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
|
||||
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
||||
shellRegistry: registry,
|
||||
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
|
||||
metrics: m,
|
||||
geoip: geo,
|
||||
}
|
||||
|
||||
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
||||
@@ -68,6 +103,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context) error {
|
||||
defer s.geoip.Close()
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen: %w", err)
|
||||
@@ -94,11 +131,16 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
|
||||
// Enforce max concurrent connections.
|
||||
select {
|
||||
case s.connSem <- struct{}{}:
|
||||
s.metrics.SSHConnectionsActive.Inc()
|
||||
go func() {
|
||||
defer func() { <-s.connSem }()
|
||||
defer func() {
|
||||
<-s.connSem
|
||||
s.metrics.SSHConnectionsActive.Dec()
|
||||
}()
|
||||
s.handleConn(conn)
|
||||
}()
|
||||
default:
|
||||
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
|
||||
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
}
|
||||
@@ -110,11 +152,13 @@ func (s *Server) handleConn(conn net.Conn) {
|
||||
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
||||
if err != nil {
|
||||
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
|
||||
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
|
||||
return
|
||||
}
|
||||
defer sshConn.Close()
|
||||
|
||||
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
|
||||
s.logger.Info("SSH connection established",
|
||||
"remote_addr", sshConn.RemoteAddr(),
|
||||
"user", sshConn.User(),
|
||||
@@ -153,6 +197,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
|
||||
}
|
||||
}
|
||||
// Second priority: username-based route.
|
||||
if selectedShell == nil {
|
||||
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
|
||||
sh, found := s.shellRegistry.Get(shellName)
|
||||
if found {
|
||||
selectedShell = sh
|
||||
} else {
|
||||
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
|
||||
}
|
||||
}
|
||||
}
|
||||
// Lowest priority: random selection.
|
||||
if selectedShell == nil {
|
||||
var err error
|
||||
selectedShell, err = s.shellRegistry.Select()
|
||||
@@ -163,11 +219,17 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
}
|
||||
|
||||
ip := extractIP(conn.RemoteAddr())
|
||||
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
|
||||
country := s.geoip.Lookup(ip)
|
||||
sessionStart := time.Now()
|
||||
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create session", "err", err)
|
||||
} else {
|
||||
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
|
||||
s.metrics.SessionsActive.Inc()
|
||||
defer func() {
|
||||
s.metrics.SessionsActive.Dec()
|
||||
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
|
||||
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
|
||||
s.logger.Error("failed to end session", "err", err)
|
||||
}
|
||||
@@ -193,14 +255,24 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
|
||||
defer s.notifier.CleanupSession(sessionID)
|
||||
|
||||
// Handle session requests (pty-req, shell, etc.)
|
||||
// Handle session requests (pty-req, shell, exec, etc.)
|
||||
execCh := make(chan string, 1)
|
||||
go func() {
|
||||
defer close(execCh)
|
||||
for req := range requests {
|
||||
switch req.Type {
|
||||
case "pty-req", "shell":
|
||||
if req.WantReply {
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
case "exec":
|
||||
if req.WantReply {
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
var payload struct{ Command string }
|
||||
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
|
||||
execCh <- payload.Command
|
||||
}
|
||||
default:
|
||||
if req.WantReply {
|
||||
req.Reply(false, nil)
|
||||
@@ -209,6 +281,29 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
}
|
||||
}()
|
||||
|
||||
// Check for exec request before proceeding to interactive shell.
|
||||
select {
|
||||
case cmd, ok := <-execCh:
|
||||
if ok && cmd != "" {
|
||||
s.logger.Info("exec command received",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"user", conn.User(),
|
||||
"session_id", sessionID,
|
||||
"command", cmd,
|
||||
)
|
||||
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
|
||||
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
|
||||
}
|
||||
s.metrics.ExecCommandsTotal.Inc()
|
||||
// Send exit-status 0 and close channel.
|
||||
exitPayload := make([]byte, 4) // uint32(0)
|
||||
_, _ = channel.SendRequest("exit-status", false, exitPayload)
|
||||
return
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
// No exec request within timeout — proceed with interactive shell.
|
||||
}
|
||||
|
||||
// Build session context.
|
||||
var shellCfg map[string]any
|
||||
if s.cfg.Shell.Shells != nil {
|
||||
@@ -226,6 +321,9 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
Banner: s.cfg.Shell.Banner,
|
||||
FakeUser: s.cfg.Shell.FakeUser,
|
||||
},
|
||||
OnCommand: func(sh string) {
|
||||
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
|
||||
},
|
||||
}
|
||||
|
||||
// Wrap channel in RecordingChannel.
|
||||
@@ -261,6 +359,7 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
}
|
||||
if scorer != nil {
|
||||
finalScore := scorer.Score()
|
||||
s.metrics.HumanScore.Observe(finalScore)
|
||||
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
|
||||
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
|
||||
}
|
||||
@@ -310,6 +409,12 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
|
||||
ip := extractIP(conn.RemoteAddr())
|
||||
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
|
||||
|
||||
if d.Accepted {
|
||||
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
|
||||
} else {
|
||||
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
|
||||
}
|
||||
|
||||
s.logger.Info("auth attempt",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"username", conn.User(),
|
||||
@@ -317,7 +422,11 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
|
||||
"reason", d.Reason,
|
||||
)
|
||||
|
||||
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil {
|
||||
country := s.geoip.Lookup(ip)
|
||||
if country != "" {
|
||||
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
|
||||
}
|
||||
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
|
||||
s.logger.Error("failed to record login attempt", "err", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@@ -120,7 +122,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := New(cfg, store, logger)
|
||||
srv, err := New(cfg, store, logger, metrics.New("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("creating server: %v", err)
|
||||
}
|
||||
@@ -251,6 +253,137 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Test exec command capture.
|
||||
t.Run("exec_command", func(t *testing.T) {
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "root",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SSH dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// Run a command via exec (no PTY, no shell).
|
||||
if err := session.Run("uname -a"); err != nil {
|
||||
// Run returns an error because the server closes the channel,
|
||||
// but that's expected.
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Give the server a moment to store the command.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify the exec command was captured.
|
||||
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
var foundExec bool
|
||||
for _, s := range sessions {
|
||||
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
|
||||
foundExec = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundExec {
|
||||
t.Error("expected a session with exec_command='uname -a'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test username route: add username_routes so that "postgres" gets psql shell.
|
||||
t.Run("username_route", func(t *testing.T) {
|
||||
// Reconfigure with username routes.
|
||||
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
|
||||
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
|
||||
|
||||
// Need to get the "postgres" user in via static creds or threshold.
|
||||
// Use static creds for simplicity.
|
||||
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
|
||||
config.Credential{Username: "postgres", Password: "postgres"},
|
||||
)
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
defer func() {
|
||||
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
}()
|
||||
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "postgres",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SSH dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||
t.Fatalf("request pty: %v", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("stdin pipe: %v", err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
session.Stdout = &output
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("shell: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the psql banner.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Send \q to quit.
|
||||
stdin.Write([]byte(`\q` + "\r"))
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
session.Wait()
|
||||
|
||||
out := output.String()
|
||||
if !strings.Contains(out, "psql") {
|
||||
t.Errorf("output should contain psql banner, got: %s", out)
|
||||
}
|
||||
|
||||
// Verify session was created with shell name "psql".
|
||||
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
var foundPsql bool
|
||||
for _, s := range sessions {
|
||||
if s.ShellName == "psql" && s.Username == "postgres" {
|
||||
foundPsql = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundPsql {
|
||||
t.Error("expected a session with shell_name='psql' for user 'postgres'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test threshold acceptance: after enough failed dials, a subsequent
|
||||
// dial with the same credentials should succeed via threshold or
|
||||
// remembered credential.
|
||||
|
||||
117
internal/shell/adventure/adventure.go
Normal file
117
internal/shell/adventure/adventure.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package adventure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
// AdventureShell implements a Zork-style text adventure set in a dungeon/data center.
|
||||
type AdventureShell struct{}
|
||||
|
||||
// NewAdventureShell returns a new AdventureShell instance.
|
||||
func NewAdventureShell() *AdventureShell {
|
||||
return &AdventureShell{}
|
||||
}
|
||||
|
||||
func (a *AdventureShell) Name() string { return "adventure" }
|
||||
func (a *AdventureShell) Description() string { return "Zork-style text adventure dungeon crawler" }
|
||||
|
||||
func (a *AdventureShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dungeonName := configString(sess.ShellConfig, "dungeon_name", "THE OUBLIETTE")
|
||||
game := newGame()
|
||||
|
||||
// Print banner and initial room.
|
||||
banner := strings.ReplaceAll(adventureBanner(dungeonName), "\n", "\r\n")
|
||||
fmt.Fprint(rw, banner)
|
||||
|
||||
// Show starting room.
|
||||
startDesc := game.describeRoom(game.rooms[game.currentRoom])
|
||||
startDesc = strings.ReplaceAll(startDesc, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n\r\n", startDesc)
|
||||
|
||||
for {
|
||||
if _, err := fmt.Fprint(rw, "> "); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := game.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
// Log command and output to store.
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("adventure")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func adventureBanner(dungeonName string) string {
|
||||
return fmt.Sprintf(`
|
||||
___ _ _ ____ _ ___ _____ _____ _____
|
||||
/ _ \| | | | __ )| | |_ _| ____|_ _|_ _| ___
|
||||
| | | | | | | _ \| | | || _| | | | | / _ \
|
||||
| |_| | |_| | |_) | |___ | || |___ | | | | | __/
|
||||
\___/ \___/|____/|_____|___|_____| |_| |_| \___|
|
||||
|
||||
Welcome to %s.
|
||||
|
||||
You wake up in the dark. The air is cold and hums with electricity.
|
||||
This is a place where things are put to be forgotten.
|
||||
|
||||
Type 'help' for commands. Type 'look' to examine your surroundings.
|
||||
|
||||
`, dungeonName)
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
383
internal/shell/adventure/adventure_test.go
Normal file
383
internal/shell/adventure/adventure_test.go
Normal file
@@ -0,0 +1,383 @@
|
||||
package adventure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (r *rwCloser) Close() error { return nil }
|
||||
|
||||
func runShell(t *testing.T, commands string) string {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "root",
|
||||
Store: store,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: "testhost",
|
||||
},
|
||||
}
|
||||
|
||||
rw := &rwCloser{
|
||||
Reader: bytes.NewBufferString(commands),
|
||||
Writer: &bytes.Buffer{},
|
||||
}
|
||||
|
||||
sh := NewAdventureShell()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := sh.Handle(ctx, sess, rw); err != nil {
|
||||
t.Fatalf("Handle: %v", err)
|
||||
}
|
||||
|
||||
return rw.Writer.(*bytes.Buffer).String()
|
||||
}
|
||||
|
||||
func TestAdventureShellName(t *testing.T) {
|
||||
sh := NewAdventureShell()
|
||||
if sh.Name() != "adventure" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "adventure")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanner(t *testing.T) {
|
||||
output := runShell(t, "quit\r")
|
||||
if !strings.Contains(output, "OUBLIETTE") {
|
||||
t.Error("output should contain OUBLIETTE in banner")
|
||||
}
|
||||
if !strings.Contains(output, ">") {
|
||||
t.Error("output should contain > prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartingRoom(t *testing.T) {
|
||||
output := runShell(t, "quit\r")
|
||||
if !strings.Contains(output, "The Oubliette") {
|
||||
t.Error("should start in The Oubliette")
|
||||
}
|
||||
if !strings.Contains(output, "narrow stone chamber") {
|
||||
t.Error("should show starting room description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelpCommand(t *testing.T) {
|
||||
output := runShell(t, "help\rquit\r")
|
||||
for _, keyword := range []string{"look", "go", "take", "drop", "use", "inventory", "help", "quit"} {
|
||||
if !strings.Contains(output, keyword) {
|
||||
t.Errorf("help output should mention %q", keyword)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookCommand(t *testing.T) {
|
||||
output := runShell(t, "look\rquit\r")
|
||||
if !strings.Contains(output, "The Oubliette") {
|
||||
t.Error("look should show current room")
|
||||
}
|
||||
if !strings.Contains(output, "Exits:") {
|
||||
t.Error("look should show exits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMovement(t *testing.T) {
|
||||
output := runShell(t, "go east\rquit\r")
|
||||
if !strings.Contains(output, "Stone Corridor") {
|
||||
t.Error("going east from oubliette should reach Stone Corridor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBareDirection(t *testing.T) {
|
||||
output := runShell(t, "e\rquit\r")
|
||||
if !strings.Contains(output, "Stone Corridor") {
|
||||
t.Error("bare 'e' should move east to Stone Corridor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectionAliases(t *testing.T) {
|
||||
output := runShell(t, "east\rquit\r")
|
||||
if !strings.Contains(output, "Stone Corridor") {
|
||||
t.Error("'east' should move to Stone Corridor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidDirection(t *testing.T) {
|
||||
output := runShell(t, "go north\rquit\r")
|
||||
if !strings.Contains(output, "can't go") {
|
||||
t.Error("should say you can't go that direction")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTakeItem(t *testing.T) {
|
||||
// Move to corridor where flashlight is.
|
||||
output := runShell(t, "e\rtake flashlight\rinventory\rquit\r")
|
||||
if !strings.Contains(output, "Taken") {
|
||||
t.Error("should confirm taking item")
|
||||
}
|
||||
if !strings.Contains(output, "flashlight") {
|
||||
t.Error("inventory should show flashlight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDropItem(t *testing.T) {
|
||||
output := runShell(t, "e\rtake flashlight\rdrop flashlight\rinventory\rquit\r")
|
||||
if !strings.Contains(output, "Dropped") {
|
||||
t.Error("should confirm dropping item")
|
||||
}
|
||||
if !strings.Contains(output, "not carrying anything") {
|
||||
t.Error("inventory should be empty after dropping")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyInventory(t *testing.T) {
|
||||
output := runShell(t, "inventory\rquit\r")
|
||||
if !strings.Contains(output, "not carrying anything") {
|
||||
t.Error("should say not carrying anything")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExamineItem(t *testing.T) {
|
||||
output := runShell(t, "e\rexamine flashlight\rquit\r")
|
||||
if !strings.Contains(output, "batteries") {
|
||||
t.Error("examining flashlight should describe it")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDarkRoom(t *testing.T) {
|
||||
// Go to the pit without flashlight.
|
||||
output := runShell(t, "down\rquit\r")
|
||||
if !strings.Contains(output, "darkness") {
|
||||
t.Error("pit should be dark without flashlight")
|
||||
}
|
||||
// Should NOT show the lit description.
|
||||
if strings.Contains(output, "skeleton") {
|
||||
t.Error("should not see skeleton without flashlight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDarkRoomWithFlashlight(t *testing.T) {
|
||||
// Get flashlight first, then go to pit.
|
||||
output := runShell(t, "e\rtake flashlight\rw\rdown\rquit\r")
|
||||
if !strings.Contains(output, "skeleton") {
|
||||
t.Error("should see skeleton with flashlight in pit")
|
||||
}
|
||||
if !strings.Contains(output, "rusty key") {
|
||||
t.Error("should see rusty key with flashlight in pit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDarkRoomCantTake(t *testing.T) {
|
||||
output := runShell(t, "down\rtake rusty_key\rquit\r")
|
||||
if !strings.Contains(output, "too dark") {
|
||||
t.Error("should not be able to take items in dark room")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockedDoor(t *testing.T) {
|
||||
// Navigate to archive without keycard.
|
||||
output := runShell(t, "e\rs\rs\re\rquit\r")
|
||||
if !strings.Contains(output, "keycard") || !strings.Contains(output, "red") {
|
||||
t.Error("should mention keycard reader when trying locked door")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlockWithKeycard(t *testing.T) {
|
||||
// Get keycard from server room, navigate to archive, use keycard.
|
||||
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\rquit\r")
|
||||
if !strings.Contains(output, "green") || !strings.Contains(output, "clicks open") {
|
||||
t.Error("should show unlock message")
|
||||
}
|
||||
if !strings.Contains(output, "Control Room") {
|
||||
t.Error("should be able to enter control room after unlocking")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRustyKeyPuzzle(t *testing.T) {
|
||||
// Get flashlight, get rusty key from pit, go to generator room, use key.
|
||||
output := runShell(t, "e\rtake flashlight\rw\rdown\rtake rusty_key\rup\re\re\rs\re\ruse rusty_key\rquit\r")
|
||||
if !strings.Contains(output, "maintenance panel") || !strings.Contains(output, "logbook") {
|
||||
t.Error("should show maintenance log when using rusty key in generator room")
|
||||
}
|
||||
if !strings.Contains(output, "Dunwich") {
|
||||
t.Error("logbook should mention Dunwich")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitEndsGame(t *testing.T) {
|
||||
// Full path to exit: get keycard, navigate to archive, unlock, go to control room, east to exit.
|
||||
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\re\r")
|
||||
if !strings.Contains(output, "SECURITY AUDIT") {
|
||||
t.Error("exit should show the ending message")
|
||||
}
|
||||
if !strings.Contains(output, "SESSION HAS BEEN LOGGED") {
|
||||
t.Error("exit should mention session logging")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownCommand(t *testing.T) {
|
||||
output := runShell(t, "xyzzy\rquit\r")
|
||||
if !strings.Contains(output, "don't understand") {
|
||||
t.Error("should show error for unknown command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitCommand(t *testing.T) {
|
||||
output := runShell(t, "quit\r")
|
||||
if !strings.Contains(output, "terminated") {
|
||||
t.Error("quit should show termination message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitAlias(t *testing.T) {
|
||||
output := runShell(t, "exit\r")
|
||||
if !strings.Contains(output, "terminated") {
|
||||
t.Error("exit alias should work like quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVentilationShaft(t *testing.T) {
|
||||
output := runShell(t, "e\rup\rquit\r")
|
||||
if !strings.Contains(output, "Ventilation Shaft") {
|
||||
t.Error("should reach ventilation shaft from corridor")
|
||||
}
|
||||
if !strings.Contains(output, "note") {
|
||||
t.Error("should see note in ventilation shaft")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNote(t *testing.T) {
|
||||
output := runShell(t, "e\rup\rtake note\rexamine note\rquit\r")
|
||||
if !strings.Contains(output, "GET OUT") {
|
||||
t.Error("note should contain warning text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseItemWrongRoom(t *testing.T) {
|
||||
// Get keycard from server room, try to use in wrong room.
|
||||
output := runShell(t, "e\re\rtake keycard\ruse keycard\rquit\r")
|
||||
if !strings.Contains(output, "nothing to use") {
|
||||
t.Error("should say nothing to use keycard on in wrong room")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEthernetCable(t *testing.T) {
|
||||
output := runShell(t, "e\rs\rtake ethernet_cable\ruse ethernet_cable\rquit\r")
|
||||
if !strings.Contains(output, "chewed") {
|
||||
t.Error("using ethernet cable should mention chewed ends")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "root",
|
||||
Store: store,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: "testhost",
|
||||
},
|
||||
}
|
||||
|
||||
rw := &rwCloser{
|
||||
Reader: bytes.NewBufferString("help\rquit\r"),
|
||||
Writer: &bytes.Buffer{},
|
||||
}
|
||||
|
||||
sh := NewAdventureShell()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sh.Handle(ctx, sess, rw)
|
||||
|
||||
if len(store.SessionLogs) < 2 {
|
||||
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserBasics(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
verb string
|
||||
object string
|
||||
}{
|
||||
{"look", "look", ""},
|
||||
{"l", "look", ""},
|
||||
{"examine flashlight", "look", "flashlight"},
|
||||
{"go north", "go", "north"},
|
||||
{"n", "go", "north"},
|
||||
{"south", "go", "south"},
|
||||
{"take the key", "take", "key"},
|
||||
{"get a flashlight", "take", "flashlight"},
|
||||
{"drop rusty key", "drop", "rusty key"},
|
||||
{"use keycard", "use", "keycard"},
|
||||
{"inventory", "inventory", ""},
|
||||
{"i", "inventory", ""},
|
||||
{"inv", "inventory", ""},
|
||||
{"help", "help", ""},
|
||||
{"?", "help", ""},
|
||||
{"quit", "quit", ""},
|
||||
{"exit", "quit", ""},
|
||||
{"q", "quit", ""},
|
||||
{"LOOK", "look", ""},
|
||||
{" go east ", "go", "east"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
cmd := parseCommand(tt.input)
|
||||
if cmd.verb != tt.verb {
|
||||
t.Errorf("parseCommand(%q).verb = %q, want %q", tt.input, cmd.verb, tt.verb)
|
||||
}
|
||||
if cmd.object != tt.object {
|
||||
t.Errorf("parseCommand(%q).object = %q, want %q", tt.input, cmd.object, tt.object)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserEmpty(t *testing.T) {
|
||||
cmd := parseCommand("")
|
||||
if cmd.verb != "" {
|
||||
t.Errorf("empty input should give empty verb, got %q", cmd.verb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserArticleStripping(t *testing.T) {
|
||||
cmd := parseCommand("take the an a flashlight")
|
||||
if cmd.verb != "take" || cmd.object != "flashlight" {
|
||||
t.Errorf("articles should be stripped, got verb=%q object=%q", cmd.verb, cmd.object)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"dungeon_name": "MY DUNGEON"}
|
||||
if got := configString(cfg, "dungeon_name", "DEFAULT"); got != "MY DUNGEON" {
|
||||
t.Errorf("configString() = %q, want %q", got, "MY DUNGEON")
|
||||
}
|
||||
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
|
||||
t.Errorf("configString() for missing key = %q, want %q", got, "DEFAULT")
|
||||
}
|
||||
if got := configString(nil, "key", "DEFAULT"); got != "DEFAULT" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
|
||||
}
|
||||
}
|
||||
358
internal/shell/adventure/game.go
Normal file
358
internal/shell/adventure/game.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package adventure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type gameState struct {
|
||||
currentRoom string
|
||||
inventory []string
|
||||
rooms map[string]*room
|
||||
items map[string]*item
|
||||
flags map[string]bool
|
||||
turns int
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
func newGame() *gameState {
|
||||
rooms, items := newWorld()
|
||||
return &gameState{
|
||||
currentRoom: "oubliette",
|
||||
rooms: rooms,
|
||||
items: items,
|
||||
flags: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gameState) dispatch(input string) commandResult {
|
||||
cmd := parseCommand(input)
|
||||
if cmd.verb == "" {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
g.turns++
|
||||
|
||||
switch cmd.verb {
|
||||
case "look":
|
||||
return g.cmdLook(cmd.object)
|
||||
case "go":
|
||||
return g.cmdGo(cmd.object)
|
||||
case "take":
|
||||
return g.cmdTake(cmd.object)
|
||||
case "drop":
|
||||
return g.cmdDrop(cmd.object)
|
||||
case "use":
|
||||
return g.cmdUse(cmd.object)
|
||||
case "inventory":
|
||||
return g.cmdInventory()
|
||||
case "help":
|
||||
return g.cmdHelp()
|
||||
case "quit":
|
||||
return commandResult{output: "The darkness closes in. Session terminated.", exit: true}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("I don't understand '%s'. Type 'help' for available commands.", input)}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdHelp() commandResult {
|
||||
help := `Available commands:
|
||||
look / examine [item] - Look around or examine something
|
||||
go <direction> - Move (north, south, east, west, up, down)
|
||||
take <item> - Pick up an item
|
||||
drop <item> - Drop an item
|
||||
use <item> - Use an item
|
||||
inventory - Check what you're carrying
|
||||
help - Show this help
|
||||
quit / exit - End session
|
||||
|
||||
You can also just type a direction (n, s, e, w, u, d) to move.`
|
||||
return commandResult{output: help}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdLook(object string) commandResult {
|
||||
r := g.rooms[g.currentRoom]
|
||||
|
||||
// Look at a specific item.
|
||||
if object != "" {
|
||||
return g.examineItem(object)
|
||||
}
|
||||
|
||||
// Look at the room.
|
||||
return commandResult{output: g.describeRoom(r)}
|
||||
}
|
||||
|
||||
func (g *gameState) describeRoom(r *room) string {
|
||||
var b strings.Builder
|
||||
|
||||
fmt.Fprintf(&b, "== %s ==\n", r.name)
|
||||
|
||||
// Dark room check.
|
||||
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||
b.WriteString(r.darkDesc)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
b.WriteString(r.description)
|
||||
|
||||
// List visible items.
|
||||
visibleItems := g.roomItems(r)
|
||||
if len(visibleItems) > 0 {
|
||||
b.WriteString("\n\nYou can see: ")
|
||||
names := make([]string, len(visibleItems))
|
||||
for i, id := range visibleItems {
|
||||
names[i] = g.items[id].name
|
||||
}
|
||||
b.WriteString(strings.Join(names, ", "))
|
||||
}
|
||||
|
||||
// List exits.
|
||||
if len(r.exits) > 0 {
|
||||
b.WriteString("\n\nExits: ")
|
||||
dirs := make([]string, 0, len(r.exits))
|
||||
for dir := range r.exits {
|
||||
dirs = append(dirs, dir)
|
||||
}
|
||||
b.WriteString(strings.Join(dirs, ", "))
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (g *gameState) roomItems(r *room) []string {
|
||||
// In dark rooms without flashlight, can't see items.
|
||||
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||
return nil
|
||||
}
|
||||
return r.items
|
||||
}
|
||||
|
||||
func (g *gameState) examineItem(name string) commandResult {
|
||||
id := g.resolveItem(name)
|
||||
if id == "" {
|
||||
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
|
||||
}
|
||||
return commandResult{output: g.items[id].description}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdGo(direction string) commandResult {
|
||||
if direction == "" {
|
||||
return commandResult{output: "Go where? Try a direction: north, south, east, west, up, down."}
|
||||
}
|
||||
|
||||
r := g.rooms[g.currentRoom]
|
||||
destID, ok := r.exits[direction]
|
||||
if !ok {
|
||||
return commandResult{output: fmt.Sprintf("You can't go %s from here.", direction)}
|
||||
}
|
||||
|
||||
// Check locked doors.
|
||||
if r.locked != nil {
|
||||
if flag, locked := r.locked[direction]; locked && !g.flags[flag] {
|
||||
return g.lockedMessage(direction)
|
||||
}
|
||||
}
|
||||
|
||||
g.currentRoom = destID
|
||||
dest := g.rooms[destID]
|
||||
|
||||
// Entering the exit room ends the game.
|
||||
if destID == "exit" {
|
||||
return commandResult{output: g.describeRoom(dest), exit: true}
|
||||
}
|
||||
|
||||
return commandResult{output: g.describeRoom(dest)}
|
||||
}
|
||||
|
||||
func (g *gameState) lockedMessage(direction string) commandResult {
|
||||
if direction == "east" && g.currentRoom == "archive" {
|
||||
return commandResult{output: "The steel door won't budge. The keycard reader blinks red, waiting."}
|
||||
}
|
||||
return commandResult{output: "The way is locked."}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdTake(name string) commandResult {
|
||||
if name == "" {
|
||||
return commandResult{output: "Take what?"}
|
||||
}
|
||||
|
||||
r := g.rooms[g.currentRoom]
|
||||
|
||||
// Can't take items in dark rooms.
|
||||
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||
return commandResult{output: "It's too dark to find anything."}
|
||||
}
|
||||
|
||||
id := g.resolveRoomItem(name)
|
||||
if id == "" {
|
||||
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
|
||||
}
|
||||
|
||||
it := g.items[id]
|
||||
if !it.takeable {
|
||||
return commandResult{output: fmt.Sprintf("You can't take the %s.", it.name)}
|
||||
}
|
||||
|
||||
// Remove from room, add to inventory.
|
||||
g.removeRoomItem(r, id)
|
||||
g.inventory = append(g.inventory, id)
|
||||
return commandResult{output: fmt.Sprintf("Taken: %s", it.name)}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdDrop(name string) commandResult {
|
||||
if name == "" {
|
||||
return commandResult{output: "Drop what?"}
|
||||
}
|
||||
|
||||
id := g.resolveInventoryItem(name)
|
||||
if id == "" {
|
||||
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
|
||||
}
|
||||
|
||||
// Remove from inventory, add to room.
|
||||
g.removeInventoryItem(id)
|
||||
r := g.rooms[g.currentRoom]
|
||||
r.items = append(r.items, id)
|
||||
return commandResult{output: fmt.Sprintf("Dropped: %s", g.items[id].name)}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdUse(name string) commandResult {
|
||||
if name == "" {
|
||||
return commandResult{output: "Use what?"}
|
||||
}
|
||||
|
||||
id := g.resolveInventoryItem(name)
|
||||
if id == "" {
|
||||
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
|
||||
}
|
||||
|
||||
switch id {
|
||||
case "keycard":
|
||||
return g.useKeycard()
|
||||
case "rusty_key":
|
||||
return g.useRustyKey()
|
||||
case "flashlight":
|
||||
return commandResult{output: "The flashlight is already on. Its beam cuts through the darkness."}
|
||||
case "ethernet_cable":
|
||||
return commandResult{output: "You wave the cable around hopefully. Nothing happens. Both ends are chewed through anyway."}
|
||||
case "floppy_disk":
|
||||
return commandResult{output: "You don't have anything to put it in. Floppy drives went extinct decades ago."}
|
||||
case "note":
|
||||
return commandResult{output: g.items[id].description}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("You can't figure out how to use the %s here.", g.items[id].name)}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gameState) useKeycard() commandResult {
|
||||
if g.currentRoom != "archive" {
|
||||
return commandResult{output: "There's nothing to use the keycard on here."}
|
||||
}
|
||||
if g.flags["archive_unlocked"] {
|
||||
return commandResult{output: "You already unlocked that door."}
|
||||
}
|
||||
g.flags["archive_unlocked"] = true
|
||||
return commandResult{output: "You swipe the keycard. The reader flashes green and the steel door clicks open with a heavy thunk."}
|
||||
}
|
||||
|
||||
func (g *gameState) useRustyKey() commandResult {
|
||||
if g.currentRoom != "generator" {
|
||||
return commandResult{output: "There's nothing to use the rusty key on here."}
|
||||
}
|
||||
if g.flags["panel_opened"] {
|
||||
return commandResult{output: "The maintenance panel is already open."}
|
||||
}
|
||||
g.flags["panel_opened"] = true
|
||||
return commandResult{output: `The key fits. The maintenance panel swings open with a screech, revealing a logbook:
|
||||
|
||||
MAINTENANCE LOG - GENERATOR B
|
||||
Last service: 2003-11-15
|
||||
Technician: J. Dunwich
|
||||
|
||||
Notes: "Generator A offline permanently. B running on fumes.
|
||||
Fuel delivery canceled - 'facility decommissioned' per management.
|
||||
But the servers are still running. WHO IS USING THEM?
|
||||
Filed ticket #4,271. No response. As usual."
|
||||
|
||||
The entries stop after that date.`}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdInventory() commandResult {
|
||||
if len(g.inventory) == 0 {
|
||||
return commandResult{output: "You're not carrying anything."}
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("You are carrying:\n")
|
||||
for _, id := range g.inventory {
|
||||
fmt.Fprintf(&b, " - %s\n", g.items[id].name)
|
||||
}
|
||||
return commandResult{output: strings.TrimRight(b.String(), "\n")}
|
||||
}
|
||||
|
||||
// hasItem checks if the player has an item in inventory.
|
||||
func (g *gameState) hasItem(id string) bool {
|
||||
return slices.Contains(g.inventory, id)
|
||||
}
|
||||
|
||||
// resolveItem finds an item by name in both room and inventory.
|
||||
func (g *gameState) resolveItem(name string) string {
|
||||
if id := g.resolveInventoryItem(name); id != "" {
|
||||
return id
|
||||
}
|
||||
return g.resolveRoomItem(name)
|
||||
}
|
||||
|
||||
// resolveRoomItem finds an item in the current room by partial name match.
|
||||
func (g *gameState) resolveRoomItem(name string) string {
|
||||
r := g.rooms[g.currentRoom]
|
||||
name = strings.ToLower(name)
|
||||
for _, id := range r.items {
|
||||
if matchesItem(id, g.items[id].name, name) {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveInventoryItem finds an item in inventory by partial name match.
|
||||
func (g *gameState) resolveInventoryItem(name string) string {
|
||||
name = strings.ToLower(name)
|
||||
for _, id := range g.inventory {
|
||||
if matchesItem(id, g.items[id].name, name) {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// matchesItem checks if a search term matches an item's ID or display name.
|
||||
func matchesItem(id, displayName, search string) bool {
|
||||
return id == search ||
|
||||
strings.ToLower(displayName) == search ||
|
||||
strings.Contains(id, search) ||
|
||||
strings.Contains(strings.ToLower(displayName), search)
|
||||
}
|
||||
|
||||
func (g *gameState) removeRoomItem(r *room, id string) {
|
||||
for i, itemID := range r.items {
|
||||
if itemID == id {
|
||||
r.items = append(r.items[:i], r.items[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gameState) removeInventoryItem(id string) {
|
||||
for i, invID := range g.inventory {
|
||||
if invID == id {
|
||||
g.inventory = append(g.inventory[:i], g.inventory[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
106
internal/shell/adventure/parser.go
Normal file
106
internal/shell/adventure/parser.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package adventure
|
||||
|
||||
import "strings"
|
||||
|
||||
type parsedCommand struct {
|
||||
verb string
|
||||
object string
|
||||
}
|
||||
|
||||
// directionAliases maps shorthand and full direction names to canonical forms.
|
||||
var directionAliases = map[string]string{
|
||||
"n": "north",
|
||||
"s": "south",
|
||||
"e": "east",
|
||||
"w": "west",
|
||||
"u": "up",
|
||||
"d": "down",
|
||||
"north": "north",
|
||||
"south": "south",
|
||||
"east": "east",
|
||||
"west": "west",
|
||||
"up": "up",
|
||||
"down": "down",
|
||||
}
|
||||
|
||||
// verbAliases maps aliases to canonical verbs.
|
||||
var verbAliases = map[string]string{
|
||||
"look": "look",
|
||||
"l": "look",
|
||||
"examine": "look",
|
||||
"inspect": "look",
|
||||
"x": "look",
|
||||
"go": "go",
|
||||
"move": "go",
|
||||
"walk": "go",
|
||||
"take": "take",
|
||||
"get": "take",
|
||||
"grab": "take",
|
||||
"pick": "take",
|
||||
"drop": "drop",
|
||||
"put": "drop",
|
||||
"use": "use",
|
||||
"apply": "use",
|
||||
"inventory": "inventory",
|
||||
"inv": "inventory",
|
||||
"i": "inventory",
|
||||
"help": "help",
|
||||
"?": "help",
|
||||
"quit": "quit",
|
||||
"exit": "quit",
|
||||
"logout": "quit",
|
||||
"q": "quit",
|
||||
}
|
||||
|
||||
// articles are stripped from input.
|
||||
var articles = map[string]bool{
|
||||
"the": true,
|
||||
"a": true,
|
||||
"an": true,
|
||||
}
|
||||
|
||||
// parseCommand parses raw input into a verb and object.
|
||||
func parseCommand(input string) parsedCommand {
|
||||
input = strings.TrimSpace(strings.ToLower(input))
|
||||
if input == "" {
|
||||
return parsedCommand{}
|
||||
}
|
||||
|
||||
words := strings.Fields(input)
|
||||
|
||||
// Strip articles.
|
||||
var filtered []string
|
||||
for _, w := range words {
|
||||
if !articles[w] {
|
||||
filtered = append(filtered, w)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return parsedCommand{}
|
||||
}
|
||||
|
||||
first := filtered[0]
|
||||
|
||||
// Bare direction → go <direction>.
|
||||
if dir, ok := directionAliases[first]; ok {
|
||||
return parsedCommand{verb: "go", object: dir}
|
||||
}
|
||||
|
||||
// Known verb alias.
|
||||
if verb, ok := verbAliases[first]; ok {
|
||||
object := ""
|
||||
if len(filtered) > 1 {
|
||||
object = strings.Join(filtered[1:], " ")
|
||||
// Resolve direction alias in object for "go north" etc.
|
||||
if verb == "go" {
|
||||
if dir, ok := directionAliases[object]; ok {
|
||||
object = dir
|
||||
}
|
||||
}
|
||||
}
|
||||
return parsedCommand{verb: verb, object: object}
|
||||
}
|
||||
|
||||
// Unknown verb — return as-is so game can give an error.
|
||||
return parsedCommand{verb: first, object: strings.Join(filtered[1:], " ")}
|
||||
}
|
||||
211
internal/shell/adventure/world.go
Normal file
211
internal/shell/adventure/world.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package adventure
|
||||
|
||||
// room represents a location in the game world.
|
||||
type room struct {
|
||||
name string
|
||||
description string
|
||||
darkDesc string // shown when room is dark (empty = not a dark room)
|
||||
exits map[string]string // direction → room ID
|
||||
items []string // item IDs present in the room
|
||||
locked map[string]string // direction → flag required to unlock
|
||||
}
|
||||
|
||||
// item represents an object that can be picked up and used.
|
||||
type item struct {
|
||||
name string
|
||||
description string
|
||||
takeable bool
|
||||
}
|
||||
|
||||
func newWorld() (map[string]*room, map[string]*item) {
|
||||
rooms := map[string]*room{
|
||||
"oubliette": {
|
||||
name: "The Oubliette",
|
||||
description: `You are in a narrow stone chamber. The walls are damp and slick with condensation.
|
||||
Far above, an iron grate lets in a faint green glow — not daylight, but the steady
|
||||
pulse of status LEDs. A frayed ethernet cable hangs from the grate like a vine.
|
||||
A passage leads east into darkness. Stone steps spiral downward.`,
|
||||
exits: map[string]string{
|
||||
"east": "corridor",
|
||||
"down": "pit",
|
||||
},
|
||||
},
|
||||
"corridor": {
|
||||
name: "Stone Corridor",
|
||||
description: `A long corridor carved from living rock. The walls transition from rough-hewn
|
||||
stone to poured concrete as you look east. Fluorescent tubes flicker overhead,
|
||||
half of them dead. Cable trays run along the ceiling, sagging under the weight
|
||||
of bundled Cat5. A draft comes from a ventilation shaft above.`,
|
||||
exits: map[string]string{
|
||||
"west": "oubliette",
|
||||
"east": "server_room",
|
||||
"south": "cable_crypt",
|
||||
"up": "ventilation",
|
||||
},
|
||||
items: []string{"flashlight"},
|
||||
},
|
||||
"ventilation": {
|
||||
name: "Ventilation Shaft",
|
||||
description: `You've squeezed into a narrow ventilation shaft. The aluminum walls vibrate with
|
||||
the hum of distant fans. It's barely wide enough to turn around in. Dust and
|
||||
cobwebs coat everything. Someone has scratched tally marks into the metal — you
|
||||
stop counting at forty-seven.`,
|
||||
exits: map[string]string{
|
||||
"down": "corridor",
|
||||
},
|
||||
items: []string{"note"},
|
||||
},
|
||||
"server_room": {
|
||||
name: "Server Room",
|
||||
description: `Rows of black server racks stretch into the gloom, their LEDs blinking in
|
||||
patterns that almost seem deliberate. The air is frigid and filled with the
|
||||
white noise of a thousand fans. A yellowed label on the nearest rack reads:
|
||||
"PRODUCTION - DO NOT TOUCH". Someone has added in marker: "OR ELSE".
|
||||
A laminated keycard sits on top of a powered-down blade server.`,
|
||||
exits: map[string]string{
|
||||
"west": "corridor",
|
||||
"south": "cold_storage",
|
||||
},
|
||||
items: []string{"keycard"},
|
||||
},
|
||||
"pit": {
|
||||
name: "The Pit",
|
||||
darkDesc: `You are in absolute darkness. You can hear water dripping somewhere far below
|
||||
and the faint hum of electronics. The air smells of rust and ozone. You can't
|
||||
see a thing without a light source. The stairs lead back up.`,
|
||||
description: `Your flashlight reveals a deep shaft — part medieval oubliette, part cable run.
|
||||
Rusty chains hang from iron rings set into the walls, intertwined with bundles
|
||||
of fiber optic cable that glow faintly orange. At the bottom, a skeleton in a
|
||||
lab coat slumps against the wall, still wearing an ID badge. A rusty key glints
|
||||
on a hook near the skeleton's hand.`,
|
||||
exits: map[string]string{
|
||||
"up": "oubliette",
|
||||
},
|
||||
items: []string{"rusty_key"},
|
||||
},
|
||||
"cable_crypt": {
|
||||
name: "Cable Crypt",
|
||||
description: `This vaulted chamber was clearly something else before the cables arrived.
|
||||
Stone sarcophagi line the walls, but their lids have been removed and they're
|
||||
now stuffed full of tangled ethernet runs and power strips. A faded sign reads
|
||||
"STRUCTURED CABLING" — someone has crossed out "STRUCTURED" and written
|
||||
"CHAOTIC" above it. The air smells of old stone and warm plastic.`,
|
||||
exits: map[string]string{
|
||||
"north": "corridor",
|
||||
"south": "archive",
|
||||
},
|
||||
items: []string{"ethernet_cable"},
|
||||
},
|
||||
"cold_storage": {
|
||||
name: "Cold Storage",
|
||||
description: `A cavernous room kept at near-freezing temperatures. Frost coats the walls.
|
||||
Rows of old tape drives and disk platters are stacked on industrial shelving,
|
||||
their labels faded beyond reading. A humming cryo-unit in the corner has a
|
||||
blinking amber light. A hand-written sign says "BACKUP STORAGE - CRITICAL".`,
|
||||
exits: map[string]string{
|
||||
"north": "server_room",
|
||||
"east": "generator",
|
||||
},
|
||||
items: []string{"floppy_disk"},
|
||||
},
|
||||
"archive": {
|
||||
name: "The Archive",
|
||||
description: `Floor-to-ceiling shelves hold thousands of manila folders, binders, and
|
||||
three-ring notebooks. The organization system, if there ever was one, has
|
||||
long since collapsed into entropy. A thick layer of dust covers everything.
|
||||
A heavy steel door to the east has an electronic keycard reader with a
|
||||
steady red light. The cable crypt lies back to the north.`,
|
||||
exits: map[string]string{
|
||||
"north": "cable_crypt",
|
||||
"east": "control_room",
|
||||
},
|
||||
locked: map[string]string{
|
||||
"east": "archive_unlocked",
|
||||
},
|
||||
},
|
||||
"generator": {
|
||||
name: "Generator Room",
|
||||
description: `Two massive diesel generators squat in this room like sleeping beasts. One is
|
||||
clearly dead — corrosion has eaten through the fuel lines. The other hums at
|
||||
a low idle, keeping the facility on life support. A maintenance panel on the
|
||||
wall is secured with an old-fashioned keyhole.`,
|
||||
exits: map[string]string{
|
||||
"west": "cold_storage",
|
||||
},
|
||||
},
|
||||
"control_room": {
|
||||
name: "Control Room",
|
||||
description: `Banks of CRT monitors cast a blue glow across the room. Most show static, but
|
||||
one displays a facility map with pulsing dots labeled "ACTIVE SESSIONS". You
|
||||
count the dots — there are more than there should be. A desk covered in coffee
|
||||
rings holds a battered keyboard. The main display reads:
|
||||
|
||||
FACILITY STATUS: NOMINAL
|
||||
CONTAINMENT: ACTIVE
|
||||
SUBJECTS: ███
|
||||
|
||||
A heavy blast door leads east. Above it, a faded exit sign flickers.`,
|
||||
exits: map[string]string{
|
||||
"west": "archive",
|
||||
"east": "exit",
|
||||
},
|
||||
},
|
||||
"exit": {
|
||||
name: "The Exit",
|
||||
description: `The blast door groans open to reveal... a corridor. Not the freedom you
|
||||
expected, but another corridor — identical to the one you started in. The
|
||||
same flickering fluorescents. The same sagging cable trays. The same damp
|
||||
stone walls.
|
||||
|
||||
As you step through, the door slams shut behind you. A speaker crackles:
|
||||
|
||||
"THANK YOU FOR PARTICIPATING IN TODAY'S SECURITY AUDIT.
|
||||
YOUR SESSION HAS BEEN LOGGED.
|
||||
HAVE A NICE DAY."
|
||||
|
||||
The lights go out.`,
|
||||
},
|
||||
}
|
||||
|
||||
items := map[string]*item{
|
||||
"flashlight": {
|
||||
name: "flashlight",
|
||||
description: "A heavy-duty flashlight. The batteries are low but it still works.",
|
||||
takeable: true,
|
||||
},
|
||||
"keycard": {
|
||||
name: "keycard",
|
||||
description: "A laminated keycard with a faded photo. The name reads 'J. DUNWICH, LEVEL 3'.",
|
||||
takeable: true,
|
||||
},
|
||||
"rusty_key": {
|
||||
name: "rusty key",
|
||||
description: "An old iron key, spotted with rust. It looks like it fits a maintenance panel.",
|
||||
takeable: true,
|
||||
},
|
||||
"note": {
|
||||
name: "note",
|
||||
description: `A crumpled note in shaky handwriting:
|
||||
|
||||
"If you're reading this, GET OUT. The facility is automated now.
|
||||
The sessions never end. I thought I was running tests but the
|
||||
tests were running me. Don't trust the prompts. Don't trust
|
||||
the exits. Don't trust
|
||||
|
||||
[the writing trails off into an illegible scrawl]"`,
|
||||
takeable: true,
|
||||
},
|
||||
"ethernet_cable": {
|
||||
name: "ethernet cable",
|
||||
description: "A tangled Cat5e cable, about 3 meters long. Both ends have been chewed by something.",
|
||||
takeable: true,
|
||||
},
|
||||
"floppy_disk": {
|
||||
name: "floppy disk",
|
||||
description: `A 3.5" floppy disk labeled "BACKUP - CRITICAL - DO NOT FORMAT". The label is dated 1997.`,
|
||||
takeable: true,
|
||||
},
|
||||
}
|
||||
|
||||
return rooms, items
|
||||
}
|
||||
74
internal/shell/banking/banking.go
Normal file
74
internal/shell/banking/banking.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
// BankingShell is an 80s-style green-on-black bank terminal TUI.
|
||||
type BankingShell struct{}
|
||||
|
||||
// NewBankingShell returns a new BankingShell instance.
|
||||
func NewBankingShell() *BankingShell {
|
||||
return &BankingShell{}
|
||||
}
|
||||
|
||||
func (b *BankingShell) Name() string { return "banking" }
|
||||
func (b *BankingShell) Description() string { return "80s-style banking terminal TUI" }
|
||||
|
||||
func (b *BankingShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
bankName := configString(sess.ShellConfig, "bank_name", "SECUREBANK")
|
||||
terminalID := configString(sess.ShellConfig, "terminal_id", "")
|
||||
region := configString(sess.ShellConfig, "region", "NORTHEAST")
|
||||
|
||||
if terminalID == "" {
|
||||
terminalID = fmt.Sprintf("SB-%04d", rand.IntN(10000))
|
||||
}
|
||||
|
||||
m := newModel(sess, bankName, terminalID, region)
|
||||
p := tea.NewProgram(m,
|
||||
tea.WithInput(rw),
|
||||
tea.WithOutput(rw),
|
||||
tea.WithAltScreen(),
|
||||
)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := p.Run()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
p.Quit()
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
560
internal/shell/banking/banking_test.go
Normal file
560
internal/shell/banking/banking_test.go
Normal file
@@ -0,0 +1,560 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// newTestModel creates a model with a test session context.
|
||||
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "banker", "banking", "")
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "banker",
|
||||
Store: store,
|
||||
}
|
||||
m := newModel(sess, "SECUREBANK", "SB-0001", "NORTHEAST")
|
||||
return m, store
|
||||
}
|
||||
|
||||
// sendKeys sends a string of characters as individual key messages to the model.
|
||||
func sendKeys(m *model, s string) {
|
||||
for _, ch := range s {
|
||||
var msg tea.KeyMsg
|
||||
switch ch {
|
||||
case '\r':
|
||||
msg = tea.KeyMsg{Type: tea.KeyEnter}
|
||||
case '\x1b':
|
||||
msg = tea.KeyMsg{Type: tea.KeyEscape}
|
||||
case '\x03':
|
||||
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
|
||||
default:
|
||||
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{ch}}
|
||||
}
|
||||
m.Update(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBankingShellName(t *testing.T) {
|
||||
sh := NewBankingShell()
|
||||
if sh.Name() != "banking" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "banking")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCurrency(t *testing.T) {
|
||||
tests := []struct {
|
||||
cents int64
|
||||
want string
|
||||
}{
|
||||
{0, "$0.00"},
|
||||
{100, "$1.00"},
|
||||
{4738291, "$47,382.91"},
|
||||
{18254100, "$182,541.00"},
|
||||
{52387450, "$523,874.50"},
|
||||
{25000000, "$250,000.00"},
|
||||
{-125000, "-$1,250.00"},
|
||||
{99, "$0.99"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := formatCurrency(tt.cents)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatCurrency(%d) = %q, want %q", tt.cents, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBankState(t *testing.T) {
|
||||
state := newBankState()
|
||||
if len(state.Accounts) != 4 {
|
||||
t.Errorf("expected 4 accounts, got %d", len(state.Accounts))
|
||||
}
|
||||
for _, acct := range state.Accounts {
|
||||
txns, ok := state.Transactions[acct.Number]
|
||||
if !ok {
|
||||
t.Errorf("no transactions for account %s", acct.Number)
|
||||
continue
|
||||
}
|
||||
if len(txns) == 0 {
|
||||
t.Errorf("account %s has no transactions", acct.Number)
|
||||
}
|
||||
}
|
||||
if len(state.Messages) != 4 {
|
||||
t.Errorf("expected 4 messages, got %d", len(state.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginScreenRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "SECUREBANK") {
|
||||
t.Error("login should show bank name")
|
||||
}
|
||||
if !strings.Contains(view, "AUTHORIZED ACCESS ONLY") {
|
||||
t.Error("login should show authorization warning")
|
||||
}
|
||||
if !strings.Contains(view, "ACCOUNT NUMBER") {
|
||||
t.Error("login should prompt for account number")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginFlow(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
|
||||
// Type account number.
|
||||
sendKeys(m, "12345678")
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "12345678") {
|
||||
t.Error("should show typed account number")
|
||||
}
|
||||
|
||||
// Press enter.
|
||||
sendKeys(m, "\r")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "PIN") {
|
||||
t.Error("should show PIN prompt after entering account number")
|
||||
}
|
||||
|
||||
// Type PIN and enter.
|
||||
sendKeys(m, "1234\r")
|
||||
|
||||
// Should be on menu now.
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("expected screenMenu, got %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMainMenuRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r")
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "MAIN MENU") {
|
||||
t.Error("should show MAIN MENU after login")
|
||||
}
|
||||
if !strings.Contains(view, "WIRE TRANSFER") {
|
||||
t.Error("menu should contain WIRE TRANSFER option")
|
||||
}
|
||||
if !strings.Contains(view, "SECURE MESSAGES") {
|
||||
t.Error("menu should contain SECURE MESSAGES option")
|
||||
}
|
||||
if !strings.Contains(view, "ACCOUNT SUMMARY") {
|
||||
t.Error("menu should contain ACCOUNT SUMMARY option")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSummary(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "1\r") // account summary
|
||||
|
||||
if m.screen != screenAccountSummary {
|
||||
t.Fatalf("expected screenAccountSummary, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "ACCOUNT SUMMARY") {
|
||||
t.Error("should show ACCOUNT SUMMARY")
|
||||
}
|
||||
if !strings.Contains(view, "CHECKING") {
|
||||
t.Error("should show CHECKING account")
|
||||
}
|
||||
if !strings.Contains(view, "SAVINGS") {
|
||||
t.Error("should show SAVINGS account")
|
||||
}
|
||||
if !strings.Contains(view, "TOTAL") {
|
||||
t.Error("should show TOTAL")
|
||||
}
|
||||
|
||||
// Press any key to return.
|
||||
sendKeys(m, " ")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireTransferFlow(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "3\r") // wire transfer
|
||||
|
||||
if m.screen != screenTransfer {
|
||||
t.Fatalf("expected screenTransfer, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "WIRE TRANSFER") {
|
||||
t.Error("should show WIRE TRANSFER header")
|
||||
}
|
||||
|
||||
// Fill all fields.
|
||||
sendKeys(m, "021000021\r") // routing
|
||||
sendKeys(m, "9876543210\r") // dest account
|
||||
sendKeys(m, "JOHN DOE\r") // beneficiary
|
||||
sendKeys(m, "FIRST NATIONAL BANK\r") // bank name
|
||||
sendKeys(m, "50000\r") // amount
|
||||
sendKeys(m, "INVOICE 12345\r") // memo
|
||||
|
||||
// Should be on confirm step.
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "TRANSFER SUMMARY") {
|
||||
t.Error("should show TRANSFER SUMMARY for confirmation")
|
||||
}
|
||||
if !strings.Contains(view, "021000021") {
|
||||
t.Error("summary should show routing number")
|
||||
}
|
||||
|
||||
// Confirm.
|
||||
sendKeys(m, "Y\r")
|
||||
|
||||
// Auth code.
|
||||
sendKeys(m, "AUTH99\r")
|
||||
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "TRANSFER QUEUED") {
|
||||
t.Error("should show TRANSFER QUEUED confirmation")
|
||||
}
|
||||
|
||||
// Check wire transfer was stored.
|
||||
if len(m.state.Transfers) != 1 {
|
||||
t.Fatalf("expected 1 transfer, got %d", len(m.state.Transfers))
|
||||
}
|
||||
wt := m.state.Transfers[0]
|
||||
if wt.RoutingNumber != "021000021" {
|
||||
t.Errorf("routing = %q, want %q", wt.RoutingNumber, "021000021")
|
||||
}
|
||||
if wt.DestAccount != "9876543210" {
|
||||
t.Errorf("dest = %q, want %q", wt.DestAccount, "9876543210")
|
||||
}
|
||||
if wt.Beneficiary != "JOHN DOE" {
|
||||
t.Errorf("beneficiary = %q, want %q", wt.Beneficiary, "JOHN DOE")
|
||||
}
|
||||
|
||||
// Press key to return to menu.
|
||||
sendKeys(m, " ")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu after transfer, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireTransferCancel(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "3\r") // wire transfer
|
||||
sendKeys(m, "021000021\r")
|
||||
sendKeys(m, "9876543210\r")
|
||||
sendKeys(m, "JOHN DOE\r")
|
||||
sendKeys(m, "FIRST NATIONAL BANK\r")
|
||||
sendKeys(m, "50000\r")
|
||||
sendKeys(m, "INVOICE 12345\r")
|
||||
|
||||
// Cancel at confirm step.
|
||||
sendKeys(m, "N\r")
|
||||
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu after cancel, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionHistory(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "4\r") // transaction history
|
||||
|
||||
if m.screen != screenHistory {
|
||||
t.Fatalf("expected screenHistory, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "TRANSACTION HISTORY") {
|
||||
t.Error("should show TRANSACTION HISTORY header")
|
||||
}
|
||||
|
||||
// Select first account.
|
||||
sendKeys(m, "1\r")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "DATE") {
|
||||
t.Error("should show transaction list with DATE column")
|
||||
}
|
||||
if !strings.Contains(view, "PAGE") {
|
||||
t.Error("should show page indicator")
|
||||
}
|
||||
|
||||
// Press B to go back to account list.
|
||||
sendKeys(m, "B")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "SELECT ACCOUNT") {
|
||||
t.Error("should return to account selection")
|
||||
}
|
||||
|
||||
// Press 0 to return to menu.
|
||||
sendKeys(m, "0\r")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureMessages(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "5\r") // secure messages
|
||||
|
||||
if m.screen != screenMessages {
|
||||
t.Fatalf("expected screenMessages, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "SECURE MESSAGES") {
|
||||
t.Error("should show SECURE MESSAGES header")
|
||||
}
|
||||
if !strings.Contains(view, "SCHEDULED MAINTENANCE") {
|
||||
t.Error("should show first message subject")
|
||||
}
|
||||
|
||||
// View first message.
|
||||
sendKeys(m, "1\r")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "10.48.2.100") {
|
||||
t.Error("message body should contain breadcrumb IP")
|
||||
}
|
||||
|
||||
// Press key to return to list.
|
||||
sendKeys(m, " ")
|
||||
|
||||
// Return to menu.
|
||||
sendKeys(m, "0\r")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminAccessDenied(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "99\r") // admin
|
||||
|
||||
if m.screen != screenAdmin {
|
||||
t.Fatalf("expected screenAdmin, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "SYSTEM ADMINISTRATION") {
|
||||
t.Error("should show SYSTEM ADMINISTRATION header")
|
||||
}
|
||||
|
||||
// Three failed PIN attempts.
|
||||
sendKeys(m, "secret1\r")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "INVALID CREDENTIALS") {
|
||||
t.Error("should show INVALID CREDENTIALS after first attempt")
|
||||
}
|
||||
|
||||
sendKeys(m, "secret2\r")
|
||||
sendKeys(m, "secret3\r")
|
||||
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "ACCESS DENIED") {
|
||||
t.Error("should show ACCESS DENIED after 3 attempts")
|
||||
}
|
||||
if !strings.Contains(view, "ABEND S0C4") {
|
||||
t.Error("should show COBOL-style error")
|
||||
}
|
||||
|
||||
// Press key to return to menu.
|
||||
sendKeys(m, " ")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu after lockout, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminEscapeReturns(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "99\r") // admin
|
||||
sendKeys(m, "\x1b") // ESC to return
|
||||
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("ESC should return to menu from admin, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePin(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "6\r") // change PIN
|
||||
|
||||
if m.screen != screenChangePin {
|
||||
t.Fatalf("expected screenChangePin, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "CHANGE PIN") {
|
||||
t.Error("should show CHANGE PIN header")
|
||||
}
|
||||
|
||||
// Old PIN.
|
||||
sendKeys(m, "1234\r")
|
||||
// New PIN.
|
||||
sendKeys(m, "5678\r")
|
||||
// Confirm PIN.
|
||||
sendKeys(m, "5678\r")
|
||||
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "PIN CHANGED SUCCESSFULLY") {
|
||||
t.Error("should show PIN CHANGED SUCCESSFULLY")
|
||||
}
|
||||
|
||||
// Press key to return.
|
||||
sendKeys(m, " ")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu after PIN change, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogoutExits(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "7\r") // logout
|
||||
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after logout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
m, store := newTestModel(t)
|
||||
|
||||
// Send login keys and manually run returned commands.
|
||||
// Type account number.
|
||||
sendKeys(m, "12345678")
|
||||
// Enter to advance to PIN.
|
||||
sendKeys(m, "\r")
|
||||
// Type PIN.
|
||||
sendKeys(m, "1234")
|
||||
// Enter to login — this returns a logAction cmd.
|
||||
_, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
if cmd != nil {
|
||||
// Execute the batch of commands (login log).
|
||||
execCmds(cmd)
|
||||
}
|
||||
|
||||
// Give async store writes a moment.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if len(store.SessionLogs) == 0 {
|
||||
t.Error("expected session logs to be recorded after login")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, log := range store.SessionLogs {
|
||||
if strings.Contains(log.Input, "LOGIN") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected a LOGIN entry in session logs")
|
||||
}
|
||||
|
||||
// Navigate to account summary — also returns a logAction cmd.
|
||||
sendKeys(m, "1")
|
||||
_, cmd = m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
if cmd != nil {
|
||||
execCmds(cmd)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
foundMenu := false
|
||||
for _, log := range store.SessionLogs {
|
||||
if strings.Contains(log.Input, "MENU") {
|
||||
foundMenu = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundMenu {
|
||||
t.Error("expected a MENU entry in session logs for account summary")
|
||||
}
|
||||
}
|
||||
|
||||
// execCmds recursively executes tea.Cmd functions (including batches).
|
||||
func execCmds(cmd tea.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
msg := cmd()
|
||||
// tea.BatchMsg is a slice of Cmds returned by tea.Batch.
|
||||
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||
for _, c := range batch {
|
||||
execCmds(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"bank_name": "TESTBANK",
|
||||
"region": "SOUTHWEST",
|
||||
}
|
||||
if got := configString(cfg, "bank_name", "DEFAULT"); got != "TESTBANK" {
|
||||
t.Errorf("configString() = %q, want %q", got, "TESTBANK")
|
||||
}
|
||||
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
|
||||
t.Errorf("configString() = %q, want %q", got, "DEFAULT")
|
||||
}
|
||||
if got := configString(nil, "bank_name", "DEFAULT"); got != "DEFAULT" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScreenFrame(t *testing.T) {
|
||||
frame := screenFrame("TESTBANK", "TB-0001", "NORTHEAST", "content here", 0)
|
||||
if !strings.Contains(frame, "TESTBANK FEDERAL RESERVE SYSTEM") {
|
||||
t.Error("frame should contain bank name in header")
|
||||
}
|
||||
if !strings.Contains(frame, "TB-0001") {
|
||||
t.Error("frame should contain terminal ID in footer")
|
||||
}
|
||||
if !strings.Contains(frame, "content here") {
|
||||
t.Error("frame should contain the content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScreenFramePadsLines(t *testing.T) {
|
||||
frame := screenFrame("TESTBANK", "TB-0001", "NE", "short\n", 0)
|
||||
for i, line := range strings.Split(frame, "\n") {
|
||||
w := lipgloss.Width(line)
|
||||
if w > 0 && w < termWidth {
|
||||
t.Errorf("line %d has visual width %d, want at least %d: %q", i, w, termWidth, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestScreenFramePadsToHeight(t *testing.T) {
|
||||
short := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 30)
|
||||
lines := strings.Count(short, "\n")
|
||||
// Total newlines should be at least height-1 (since the last line has no trailing newline).
|
||||
if lines < 29 {
|
||||
t.Errorf("padded frame has %d newlines, want at least 29 for height=30", lines)
|
||||
}
|
||||
|
||||
// Without height, no padding.
|
||||
noPad := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 0)
|
||||
noPadLines := strings.Count(noPad, "\n")
|
||||
if noPadLines >= 29 {
|
||||
t.Errorf("unpadded frame has %d newlines, should be much less than 29", noPadLines)
|
||||
}
|
||||
}
|
||||
258
internal/shell/banking/data.go
Normal file
258
internal/shell/banking/data.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Account types.
|
||||
const (
|
||||
AcctChecking = "CHECKING"
|
||||
AcctSavings = "SAVINGS"
|
||||
AcctMoneyMarket = "MONEY MARKET"
|
||||
AcctCertDeposit = "CERT OF DEPOSIT"
|
||||
)
|
||||
|
||||
// Account represents a fake bank account.
|
||||
type Account struct {
|
||||
Number string
|
||||
Type string
|
||||
Balance int64 // cents
|
||||
}
|
||||
|
||||
// Transaction represents a fake bank transaction.
|
||||
type Transaction struct {
|
||||
Date string
|
||||
Description string
|
||||
Amount int64 // cents (negative for debits)
|
||||
Balance int64 // running balance in cents
|
||||
}
|
||||
|
||||
// SecureMessage represents a fake internal message.
|
||||
type SecureMessage struct {
|
||||
ID int
|
||||
Date string
|
||||
From string
|
||||
Subj string
|
||||
Body string
|
||||
Unread bool
|
||||
}
|
||||
|
||||
// WireTransfer captures data entered during the wire transfer wizard.
|
||||
type WireTransfer struct {
|
||||
RoutingNumber string
|
||||
DestAccount string
|
||||
Beneficiary string
|
||||
BankName string
|
||||
Amount string
|
||||
Memo string
|
||||
AuthCode string
|
||||
}
|
||||
|
||||
// bankState holds all fake data for a session.
|
||||
type bankState struct {
|
||||
Accounts []Account
|
||||
Transactions map[string][]Transaction // keyed by account number
|
||||
Messages []SecureMessage
|
||||
Transfers []WireTransfer
|
||||
}
|
||||
|
||||
func newBankState() *bankState {
|
||||
now := time.Now()
|
||||
|
||||
accounts := []Account{
|
||||
{Number: "****4821", Type: AcctChecking, Balance: 4738291},
|
||||
{Number: "****7203", Type: AcctSavings, Balance: 18254100},
|
||||
{Number: "****9915", Type: AcctMoneyMarket, Balance: 52387450},
|
||||
{Number: "****1102", Type: AcctCertDeposit, Balance: 25000000},
|
||||
}
|
||||
|
||||
transactions := make(map[string][]Transaction)
|
||||
transactions["****4821"] = generateCheckingTxns(now, accounts[0].Balance)
|
||||
transactions["****7203"] = generateSavingsTxns(now, accounts[1].Balance)
|
||||
transactions["****9915"] = generateMoneyMarketTxns(now, accounts[2].Balance)
|
||||
transactions["****1102"] = generateCDTxns(now, accounts[3].Balance)
|
||||
|
||||
messages := []SecureMessage{
|
||||
{
|
||||
ID: 1,
|
||||
Date: now.Add(-2 * 24 * time.Hour).Format("01/02/2006"),
|
||||
From: "SYSTEM ADMINISTRATOR",
|
||||
Subj: "SCHEDULED MAINTENANCE WINDOW",
|
||||
Unread: true,
|
||||
Body: fmt.Sprintf(`FROM: SYSTEM ADMINISTRATOR <sysadmin@internal.securebank.local>
|
||||
DATE: %s
|
||||
RE: SCHEDULED MAINTENANCE WINDOW
|
||||
|
||||
ALL TERMINALS WILL BE OFFLINE FOR MAINTENANCE:
|
||||
DATE: %s
|
||||
TIME: 02:00 - 04:00 EST
|
||||
AFFECTED: ALL REGIONS
|
||||
|
||||
DURING THIS WINDOW, THE FOLLOWING SYSTEMS WILL BE UNAVAILABLE:
|
||||
- WIRE TRANSFER PROCESSING (10.48.2.100:8443)
|
||||
- ACCOUNT MANAGEMENT (10.48.2.101:8443)
|
||||
- ACH BATCH PROCESSOR (10.48.2.105:9090)
|
||||
|
||||
PLEASE ENSURE ALL PENDING TRANSACTIONS ARE SUBMITTED BEFORE 01:30 EST.
|
||||
|
||||
CONTACT: HELPDESK EXT 4400 OR ops-support@internal.securebank.local`,
|
||||
now.Add(-2*24*time.Hour).Format("01/02/2006 15:04"),
|
||||
now.Add(5*24*time.Hour).Format("01/02/2006")),
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Date: now.Add(-5 * 24 * time.Hour).Format("01/02/2006"),
|
||||
From: "COMPLIANCE DEPT",
|
||||
Subj: "QUARTERLY AUDIT REMINDER",
|
||||
Unread: true,
|
||||
Body: `FROM: COMPLIANCE DEPT <compliance@internal.securebank.local>
|
||||
RE: QUARTERLY AUDIT REMINDER
|
||||
|
||||
ALL BRANCH MANAGERS:
|
||||
|
||||
THE Q4 COMPLIANCE AUDIT IS SCHEDULED FOR NEXT WEEK.
|
||||
PLEASE ENSURE THE FOLLOWING ARE CURRENT:
|
||||
|
||||
1. TRANSACTION LOGS EXPORTED TO \\FILESERV01\AUDIT\Q4
|
||||
2. VAULT ACCESS CODES ROTATED (LAST ROTATION: SEE VAULT-MGMT PORTAL)
|
||||
3. EMPLOYEE ACCESS REVIEWS COMPLETED IN IAM PORTAL (https://iam.internal:8443)
|
||||
|
||||
NOTE: DEFAULT CREDENTIALS FOR THE AUDIT PORTAL HAVE BEEN RESET.
|
||||
NEW CREDENTIALS DISTRIBUTED VIA SECURE COURIER.
|
||||
REFERENCE: AUDIT-2024-Q4-0847
|
||||
|
||||
VAULT MASTER CODE HINT: FIRST 4 OF ROUTING + BRANCH ZIP (STANDARD FORMAT)`,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Date: now.Add(-8 * 24 * time.Hour).Format("01/02/2006"),
|
||||
From: "IT SECURITY",
|
||||
Subj: "PASSWORD POLICY UPDATE",
|
||||
Unread: false,
|
||||
Body: `FROM: IT SECURITY <itsec@internal.securebank.local>
|
||||
RE: PASSWORD POLICY UPDATE - EFFECTIVE IMMEDIATELY
|
||||
|
||||
ALL STAFF:
|
||||
|
||||
PER FEDERAL BANKING REGULATION 12 CFR 748, THE FOLLOWING
|
||||
PASSWORD POLICY IS NOW IN EFFECT:
|
||||
|
||||
- MINIMUM 12 CHARACTERS
|
||||
- MUST CONTAIN UPPERCASE, LOWERCASE, NUMBER, SPECIAL CHAR
|
||||
- 90-DAY ROTATION CYCLE
|
||||
- NO REUSE OF LAST 24 PASSWORDS
|
||||
|
||||
LEGACY SYSTEM ACCOUNTS (MAINFRAME, AS/400) ARE EXEMPT UNTIL
|
||||
MIGRATION IS COMPLETE. CURRENT LEGACY ACCESS:
|
||||
MAINFRAME: telnet://10.48.1.50:23 (CICS REGION PROD1)
|
||||
AS/400: tn5250://10.48.1.55 (SUBSYSTEM QINTER)
|
||||
|
||||
SERVICE ACCOUNT PASSWORDS ARE MANAGED VIA CYBERARK:
|
||||
https://pam.internal.securebank.local:8443
|
||||
|
||||
TICKET: SEC-2024-1847`,
|
||||
},
|
||||
{
|
||||
ID: 4,
|
||||
Date: now.Add(-12 * 24 * time.Hour).Format("01/02/2006"),
|
||||
From: "WIRE OPERATIONS",
|
||||
Subj: "FEDWIRE CUTOFF TIME CHANGE",
|
||||
Unread: false,
|
||||
Body: `FROM: WIRE OPERATIONS <wireops@internal.securebank.local>
|
||||
RE: FEDWIRE CUTOFF TIME CHANGE
|
||||
|
||||
EFFECTIVE NEXT MONDAY, FEDWIRE CUTOFF TIMES ARE:
|
||||
DOMESTIC WIRES: 16:30 EST (WAS 17:00)
|
||||
INTERNATIONAL WIRES: 14:00 EST (NO CHANGE)
|
||||
BOOK TRANSFERS: 17:30 EST (NO CHANGE)
|
||||
|
||||
WIRES SUBMITTED AFTER CUTOFF WILL BE QUEUED FOR NEXT
|
||||
BUSINESS DAY PROCESSING.
|
||||
|
||||
FOR EMERGENCY SAME-DAY PROCESSING AFTER CUTOFF:
|
||||
CONTACT WIRE ROOM: EXT 4450
|
||||
AUTH CODE REQUIRED (OBTAIN FROM BRANCH MANAGER)
|
||||
APPROVAL CHAIN: OPS-MGR -> VP-WIRE -> SVP-TREASURY
|
||||
|
||||
CORRESPONDENT BANK CONTACTS:
|
||||
JPMORGAN: wire.ops@jpmc.com / 212-555-0147
|
||||
CITI: fedwire@citi.com / 212-555-0283`,
|
||||
},
|
||||
}
|
||||
|
||||
return &bankState{
|
||||
Accounts: accounts,
|
||||
Transactions: transactions,
|
||||
Messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func generateCheckingTxns(now time.Time, endBalance int64) []Transaction {
|
||||
txns := []Transaction{
|
||||
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
|
||||
{Description: "CHECK #1847", Amount: -125000},
|
||||
{Description: "POS DEBIT - WHOLE FOODS #1284", Amount: -18743},
|
||||
{Description: "ATM WITHDRAWAL - MAIN ST BRANCH", Amount: -40000},
|
||||
{Description: "ACH DEBIT - MORTGAGE PMT", Amount: -215000},
|
||||
{Description: "WIRE TRANSFER IN - REF#8847201", Amount: 1250000},
|
||||
{Description: "POS DEBIT - SHELL OIL #4492", Amount: -6821},
|
||||
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
|
||||
{Description: "CHECK #1848", Amount: -75000},
|
||||
{Description: "ONLINE TRANSFER TO SAVINGS", Amount: -100000},
|
||||
{Description: "POS DEBIT - AMAZON.COM", Amount: -14599},
|
||||
{Description: "ACH DEBIT - ELECTRIC COMPANY", Amount: -18742},
|
||||
{Description: "ATM WITHDRAWAL - PARK AVE BRANCH", Amount: -20000},
|
||||
{Description: "WIRE TRANSFER OUT - REF#9014882", Amount: -500000},
|
||||
{Description: "POS DEBIT - COSTCO #0441", Amount: -28734},
|
||||
{Description: "ACH DEPOSIT - TAX REFUND", Amount: 342100},
|
||||
}
|
||||
return populateTransactions(txns, now, endBalance)
|
||||
}
|
||||
|
||||
func generateSavingsTxns(now time.Time, endBalance int64) []Transaction {
|
||||
txns := []Transaction{
|
||||
{Description: "INTEREST PAYMENT", Amount: 4521},
|
||||
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 100000},
|
||||
{Description: "INTEREST PAYMENT", Amount: 4633},
|
||||
{Description: "ACH DEPOSIT - DIVIDEND PMT", Amount: 125000},
|
||||
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 200000},
|
||||
{Description: "INTEREST PAYMENT", Amount: 4748},
|
||||
{Description: "WITHDRAWAL - TRANSFER TO MM", Amount: -500000},
|
||||
{Description: "INTEREST PAYMENT", Amount: 4812},
|
||||
}
|
||||
return populateTransactions(txns, now, endBalance)
|
||||
}
|
||||
|
||||
func generateMoneyMarketTxns(now time.Time, endBalance int64) []Transaction {
|
||||
txns := []Transaction{
|
||||
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 21847},
|
||||
{Description: "DEPOSIT - TRANSFER FROM SAVINGS", Amount: 500000},
|
||||
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22105},
|
||||
{Description: "WITHDRAWAL - WIRE TRANSFER", Amount: -1000000},
|
||||
{Description: "DEPOSIT - ACH TRANSFER", Amount: 750000},
|
||||
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22394},
|
||||
}
|
||||
return populateTransactions(txns, now, endBalance)
|
||||
}
|
||||
|
||||
func generateCDTxns(now time.Time, endBalance int64) []Transaction {
|
||||
txns := []Transaction{
|
||||
{Description: "CERTIFICATE OPENED - 12MO TERM", Amount: 25000000},
|
||||
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||
}
|
||||
return populateTransactions(txns, now, endBalance)
|
||||
}
|
||||
|
||||
func populateTransactions(txns []Transaction, now time.Time, endBalance int64) []Transaction {
|
||||
// Work backwards from end balance to assign dates and running balances.
|
||||
bal := endBalance
|
||||
for i := len(txns) - 1; i >= 0; i-- {
|
||||
txns[i].Balance = bal
|
||||
txns[i].Date = now.Add(time.Duration(-(len(txns) - i)) * 3 * 24 * time.Hour).Format("01/02/2006")
|
||||
bal -= txns[i].Amount
|
||||
}
|
||||
return txns
|
||||
}
|
||||
353
internal/shell/banking/model.go
Normal file
353
internal/shell/banking/model.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
type screen int
|
||||
|
||||
const (
|
||||
screenLogin screen = iota
|
||||
screenMenu
|
||||
screenAccountSummary
|
||||
screenAccountDetail
|
||||
screenTransfer
|
||||
screenHistory
|
||||
screenMessages
|
||||
screenChangePin
|
||||
screenAdmin
|
||||
)
|
||||
|
||||
type model struct {
|
||||
sess *shell.SessionContext
|
||||
bankName string
|
||||
terminalID string
|
||||
region string
|
||||
state *bankState
|
||||
screen screen
|
||||
quitting bool
|
||||
height int
|
||||
|
||||
login loginModel
|
||||
menu menuModel
|
||||
summary accountSummaryModel
|
||||
detail accountDetailModel
|
||||
transfer transferModel
|
||||
history historyModel
|
||||
messages messagesModel
|
||||
admin adminModel
|
||||
changePin changePinModel
|
||||
}
|
||||
|
||||
func newModel(sess *shell.SessionContext, bankName, terminalID, region string) *model {
|
||||
state := newBankState()
|
||||
unread := 0
|
||||
for _, msg := range state.Messages {
|
||||
if msg.Unread {
|
||||
unread++
|
||||
}
|
||||
}
|
||||
|
||||
return &model{
|
||||
sess: sess,
|
||||
bankName: bankName,
|
||||
terminalID: terminalID,
|
||||
region: region,
|
||||
state: state,
|
||||
screen: screenLogin,
|
||||
login: newLoginModel(bankName),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.quitting {
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.height = msg.Height
|
||||
return m, nil
|
||||
case tea.KeyMsg:
|
||||
if msg.Type == tea.KeyCtrlC {
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
switch m.screen {
|
||||
case screenLogin:
|
||||
return m.updateLogin(msg)
|
||||
case screenMenu:
|
||||
return m.updateMenu(msg)
|
||||
case screenAccountSummary:
|
||||
return m.updateAccountSummary(msg)
|
||||
case screenAccountDetail:
|
||||
return m.updateAccountDetail(msg)
|
||||
case screenTransfer:
|
||||
return m.updateTransfer(msg)
|
||||
case screenHistory:
|
||||
return m.updateHistory(msg)
|
||||
case screenMessages:
|
||||
return m.updateMessages(msg)
|
||||
case screenChangePin:
|
||||
return m.updateChangePin(msg)
|
||||
case screenAdmin:
|
||||
return m.updateAdmin(msg)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) View() string {
|
||||
var content string
|
||||
switch m.screen {
|
||||
case screenLogin:
|
||||
content = m.login.View()
|
||||
case screenMenu:
|
||||
content = m.menu.View()
|
||||
case screenAccountSummary:
|
||||
content = m.summary.View()
|
||||
case screenAccountDetail:
|
||||
content = m.detail.View()
|
||||
case screenTransfer:
|
||||
content = m.transfer.View()
|
||||
case screenHistory:
|
||||
content = m.history.View()
|
||||
case screenMessages:
|
||||
content = m.messages.View()
|
||||
case screenChangePin:
|
||||
content = m.changePin.View()
|
||||
case screenAdmin:
|
||||
content = m.admin.View()
|
||||
}
|
||||
|
||||
return screenFrame(m.bankName, m.terminalID, m.region, content, m.height)
|
||||
}
|
||||
|
||||
// --- Screen update handlers ---
|
||||
|
||||
func (m *model) updateLogin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.login, cmd = m.login.Update(msg)
|
||||
|
||||
if m.login.stage == 2 {
|
||||
// Login always succeeds — this is a honeypot.
|
||||
logCmd := logAction(m.sess, fmt.Sprintf("LOGIN acct=%s", m.login.accountNum), "ACCESS GRANTED")
|
||||
clearCmd := m.goToMenu()
|
||||
return m, tea.Batch(cmd, logCmd, clearCmd)
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateMenu(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.menu, cmd = m.menu.Update(msg)
|
||||
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEnter {
|
||||
choice := strings.TrimSpace(m.menu.choice)
|
||||
switch choice {
|
||||
case "1":
|
||||
m.screen = screenAccountSummary
|
||||
m.summary = newAccountSummaryModel(m.state.Accounts)
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 1", "ACCOUNT SUMMARY"))
|
||||
case "2":
|
||||
m.screen = screenAccountDetail
|
||||
m.detail = newAccountDetailModel(m.state.Accounts, m.state.Transactions)
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 2", "ACCOUNT DETAIL"))
|
||||
case "3":
|
||||
m.screen = screenTransfer
|
||||
m.transfer = newTransferModel()
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 3", "WIRE TRANSFER"))
|
||||
case "4":
|
||||
m.screen = screenHistory
|
||||
m.history = newHistoryModel(m.state.Accounts, m.state.Transactions)
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 4", "TRANSACTION HISTORY"))
|
||||
case "5":
|
||||
m.screen = screenMessages
|
||||
m.messages = newMessagesModel(m.state.Messages)
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 5", "SECURE MESSAGES"))
|
||||
case "6":
|
||||
m.screen = screenChangePin
|
||||
m.changePin = newChangePinModel()
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 6", "CHANGE PIN"))
|
||||
case "7":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(logAction(m.sess, "LOGOUT", "SESSION ENDED"), tea.Quit)
|
||||
case "99", "admin", "ADMIN":
|
||||
m.screen = screenAdmin
|
||||
m.admin = newAdminModel()
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "ADMIN ACCESS ATTEMPT", "ADMIN SCREEN SHOWN"))
|
||||
}
|
||||
// Invalid choice, reset.
|
||||
m.menu.choice = ""
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateAccountSummary(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if _, ok := msg.(tea.KeyMsg); ok {
|
||||
return m, m.goToMenu()
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) updateAccountDetail(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.detail, cmd = m.detail.Update(msg)
|
||||
if m.detail.choice == "back" {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateTransfer(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
prevStep := m.transfer.step
|
||||
var cmd tea.Cmd
|
||||
m.transfer, cmd = m.transfer.Update(msg)
|
||||
|
||||
// Transfer cancelled.
|
||||
if m.transfer.confirm == "cancelled" {
|
||||
clearCmd := m.goToMenu()
|
||||
return m, tea.Batch(clearCmd, logAction(m.sess, "WIRE TRANSFER CANCELLED", "USER CANCELLED"))
|
||||
}
|
||||
|
||||
// Clear screen when transfer steps change content height significantly
|
||||
// (e.g. confirm→authcode, fields→confirm, authcode→complete).
|
||||
if m.transfer.step != prevStep {
|
||||
cmd = tea.Batch(cmd, tea.ClearScreen)
|
||||
}
|
||||
|
||||
// Transfer completed — log it.
|
||||
if m.transfer.step == transferStepComplete && prevStep != transferStepComplete {
|
||||
t := m.transfer.transfer
|
||||
m.state.Transfers = append(m.state.Transfers, t)
|
||||
logMsg := fmt.Sprintf("WIRE TRANSFER: routing=%s dest=%s beneficiary=%s bank=%s amount=%s memo=%s auth=%s",
|
||||
t.RoutingNumber, t.DestAccount, t.Beneficiary, t.BankName, t.Amount, t.Memo, t.AuthCode)
|
||||
return m, tea.Batch(cmd, logAction(m.sess, logMsg, "TRANSFER QUEUED"))
|
||||
}
|
||||
|
||||
// Completed screen → any key goes back.
|
||||
if m.transfer.step == transferStepComplete {
|
||||
if _, ok := msg.(tea.KeyMsg); ok && prevStep == transferStepComplete {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
}
|
||||
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateHistory(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.history, cmd = m.history.Update(msg)
|
||||
if m.history.choice == "back" {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateMessages(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.messages, cmd = m.messages.Update(msg)
|
||||
if m.messages.choice == "back" {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
// Log when viewing a message.
|
||||
if m.messages.viewing >= 0 {
|
||||
idx := m.messages.viewing
|
||||
return m, tea.Batch(cmd, logAction(m.sess,
|
||||
fmt.Sprintf("VIEW MESSAGE #%d", idx+1),
|
||||
m.state.Messages[idx].Subj))
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateChangePin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
prevStage := m.changePin.stage
|
||||
var cmd tea.Cmd
|
||||
m.changePin, cmd = m.changePin.Update(msg)
|
||||
|
||||
// Log successful PIN change.
|
||||
if m.changePin.stage == 3 && prevStage != 3 {
|
||||
cmd = tea.Batch(cmd, logAction(m.sess, "CHANGE PIN", "PIN CHANGED SUCCESSFULLY"))
|
||||
}
|
||||
|
||||
if m.changePin.done {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateAdmin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
prevLocked := m.admin.locked
|
||||
prevAttempts := m.admin.attempts
|
||||
|
||||
// Check for ESC before delegating.
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEscape && !m.admin.locked {
|
||||
return m, m.goToMenu()
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.admin, cmd = m.admin.Update(msg)
|
||||
|
||||
// Log failed attempts.
|
||||
if m.admin.attempts > prevAttempts {
|
||||
cmd = tea.Batch(cmd, logAction(m.sess,
|
||||
fmt.Sprintf("ADMIN PIN ATTEMPT #%d", m.admin.attempts),
|
||||
"INVALID CREDENTIALS"))
|
||||
}
|
||||
|
||||
// Log lockout.
|
||||
if m.admin.locked && !prevLocked {
|
||||
cmd = tea.Batch(cmd, tea.ClearScreen, logAction(m.sess,
|
||||
"ADMIN LOCKOUT",
|
||||
"TERMINAL LOCKED - INCIDENT LOGGED"))
|
||||
}
|
||||
|
||||
// If locked and any key pressed, go back.
|
||||
if m.admin.locked {
|
||||
if _, ok := msg.(tea.KeyMsg); ok && prevLocked {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
}
|
||||
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) goToMenu() tea.Cmd {
|
||||
unread := 0
|
||||
for _, msg := range m.state.Messages {
|
||||
if msg.Unread {
|
||||
unread++
|
||||
}
|
||||
}
|
||||
m.screen = screenMenu
|
||||
m.menu = newMenuModel(m.bankName, unread)
|
||||
return tea.ClearScreen
|
||||
}
|
||||
|
||||
// logAction returns a tea.Cmd that logs an action to the session store.
|
||||
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if sess.Store != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("banking")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
213
internal/shell/banking/screen_accounts.go
Normal file
213
internal/shell/banking/screen_accounts.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// --- Account Summary ---
|
||||
|
||||
type accountSummaryModel struct {
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func newAccountSummaryModel(accounts []Account) accountSummaryModel {
|
||||
return accountSummaryModel{accounts: accounts}
|
||||
}
|
||||
|
||||
func (m accountSummaryModel) Update(_ tea.Msg) (accountSummaryModel, tea.Cmd) {
|
||||
// Any key returns to menu.
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m accountSummaryModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("ACCOUNT SUMMARY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Header.
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "ACCOUNT", "TYPE", "BALANCE")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
|
||||
b.WriteString("\n")
|
||||
|
||||
total := int64(0)
|
||||
for _, acct := range m.accounts {
|
||||
total += acct.Balance
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-18s %18s",
|
||||
acct.Number, acct.Type, formatCurrency(acct.Balance))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "", "TOTAL", formatCurrency(total))))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// --- Account Detail (with transactions) ---
|
||||
|
||||
type accountDetailModel struct {
|
||||
accounts []Account
|
||||
transactions map[string][]Transaction
|
||||
selected int
|
||||
page int
|
||||
pageSize int
|
||||
choosing bool
|
||||
choice string
|
||||
}
|
||||
|
||||
func newAccountDetailModel(accounts []Account, transactions map[string][]Transaction) accountDetailModel {
|
||||
return accountDetailModel{
|
||||
accounts: accounts,
|
||||
transactions: transactions,
|
||||
choosing: true,
|
||||
pageSize: 10,
|
||||
}
|
||||
}
|
||||
|
||||
func (m accountDetailModel) Update(msg tea.Msg) (accountDetailModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.choosing {
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
for i := range m.accounts {
|
||||
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||
m.selected = i
|
||||
m.choosing = false
|
||||
m.page = 0
|
||||
break
|
||||
}
|
||||
}
|
||||
if m.choice == "0" {
|
||||
m.choice = "back"
|
||||
}
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.choice) > 0 {
|
||||
m.choice = m.choice[:len(m.choice)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
|
||||
m.choice += ch
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch keyMsg.String() {
|
||||
case "n", "N":
|
||||
acctNum := m.accounts[m.selected].Number
|
||||
txns := m.transactions[acctNum]
|
||||
maxPage := (len(txns) - 1) / m.pageSize
|
||||
if m.page < maxPage {
|
||||
m.page++
|
||||
}
|
||||
case "p", "P":
|
||||
if m.page > 0 {
|
||||
m.page--
|
||||
}
|
||||
case "b", "B":
|
||||
m.choosing = true
|
||||
m.choice = ""
|
||||
default:
|
||||
m.choice = "back"
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m accountDetailModel) View() string {
|
||||
if m.choosing {
|
||||
return m.viewChooseAccount()
|
||||
}
|
||||
return m.viewDetail()
|
||||
}
|
||||
|
||||
func (m accountDetailModel) viewChooseAccount() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("ACCOUNT DETAIL"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
for i, acct := range m.accounts {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s %s",
|
||||
i+1, acct.Number, acct.Type, formatCurrency(acct.Balance))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||
b.WriteString(inputStyle.Render(m.choice))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m accountDetailModel) viewDetail() string {
|
||||
var b strings.Builder
|
||||
|
||||
acct := m.accounts[m.selected]
|
||||
txns := m.transactions[acct.Number]
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(fmt.Sprintf("ACCOUNT DETAIL - %s", acct.Number)))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" TYPE: %s", acct.Type)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" BALANCE: %s", formatCurrency(acct.Balance))))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Header.
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s", "DATE", "DESCRIPTION", "AMOUNT", "BALANCE")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 72)))
|
||||
b.WriteString("\n")
|
||||
|
||||
// Paginate.
|
||||
start := m.page * m.pageSize
|
||||
end := min(start+m.pageSize, len(txns))
|
||||
|
||||
for _, txn := range txns[start:end] {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s",
|
||||
txn.Date, txn.Description, formatCurrency(txn.Amount), formatCurrency(txn.Balance))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
|
||||
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
111
internal/shell/banking/screen_admin.go
Normal file
111
internal/shell/banking/screen_admin.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type adminModel struct {
|
||||
pin string
|
||||
attempts int
|
||||
locked bool
|
||||
}
|
||||
|
||||
func newAdminModel() adminModel {
|
||||
return adminModel{}
|
||||
}
|
||||
|
||||
func (m adminModel) Update(msg tea.Msg) (adminModel, tea.Cmd) {
|
||||
if m.locked {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
if m.pin != "" {
|
||||
m.attempts++
|
||||
if m.attempts >= 3 {
|
||||
m.locked = true
|
||||
}
|
||||
m.pin = ""
|
||||
}
|
||||
case tea.KeyBackspace:
|
||||
if len(m.pin) > 0 {
|
||||
m.pin = m.pin[:len(m.pin)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.pin) < 20 {
|
||||
m.pin += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m adminModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("SYSTEM ADMINISTRATION"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.locked {
|
||||
b.WriteString(errorStyle.Render(" *** ACCESS DENIED ***"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(errorStyle.Render(" MAXIMUM AUTHENTICATION ATTEMPTS EXCEEDED"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(errorStyle.Render(" TERMINAL LOCKED - INCIDENT LOGGED"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" SECURITY ALERT HAS BEEN DISPATCHED TO:"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" - INFORMATION SECURITY DEPT"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" - BRANCH SECURITY OFFICER"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" - FEDERAL RESERVE OVERSIGHT"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" INCIDENT REF: SEC-%d-ADMIN-BRUTE", 20240000+m.attempts)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" IEF4271I UNAUTHORIZED ACCESS ATTEMPT - ABEND S0C4"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" IEF4272I JOB SECADMIN STEP0001 - COND CODE 4088"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
} else {
|
||||
b.WriteString(titleStyle.Render(" RESTRICTED ACCESS - ADMINISTRATOR ONLY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" THIS FUNCTION REQUIRES LEVEL 5 SECURITY CLEARANCE."))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" ALL ACCESS ATTEMPTS ARE LOGGED AND AUDITED."))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.attempts > 0 {
|
||||
b.WriteString(errorStyle.Render(fmt.Sprintf(" INVALID CREDENTIALS (%d OF 3 ATTEMPTS)", m.attempts)))
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
|
||||
b.WriteString(titleStyle.Render(" ADMIN PIN: "))
|
||||
masked := strings.Repeat("*", len(m.pin))
|
||||
b.WriteString(inputStyle.Render(masked))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
111
internal/shell/banking/screen_changepin.go
Normal file
111
internal/shell/banking/screen_changepin.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type changePinModel struct {
|
||||
input string
|
||||
stage int // 0=old, 1=new, 2=confirm, 3=done
|
||||
newPin string
|
||||
done bool
|
||||
}
|
||||
|
||||
func newChangePinModel() changePinModel {
|
||||
return changePinModel{}
|
||||
}
|
||||
|
||||
func (m changePinModel) Update(msg tea.Msg) (changePinModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.stage == 3 {
|
||||
m.done = true
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
switch m.stage {
|
||||
case 0:
|
||||
if m.input != "" {
|
||||
m.stage = 1
|
||||
m.input = ""
|
||||
}
|
||||
case 1:
|
||||
if len(m.input) >= 4 {
|
||||
m.newPin = m.input
|
||||
m.stage = 2
|
||||
m.input = ""
|
||||
}
|
||||
case 2:
|
||||
if m.input == m.newPin {
|
||||
m.stage = 3
|
||||
} else {
|
||||
m.input = ""
|
||||
m.newPin = ""
|
||||
m.stage = 1
|
||||
}
|
||||
}
|
||||
case tea.KeyEscape:
|
||||
m.done = true
|
||||
case tea.KeyBackspace:
|
||||
if len(m.input) > 0 {
|
||||
m.input = m.input[:len(m.input)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.input) < 12 {
|
||||
m.input += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m changePinModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("CHANGE PIN"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.stage == 3 {
|
||||
b.WriteString(titleStyle.Render(" PIN CHANGED SUCCESSFULLY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" YOUR NEW PIN IS NOW ACTIVE."))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" PLEASE USE YOUR NEW PIN FOR ALL FUTURE TRANSACTIONS."))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||
} else {
|
||||
prompts := []string{" CURRENT PIN: ", " NEW PIN: ", " CONFIRM PIN: "}
|
||||
for i := 0; i < m.stage; i++ {
|
||||
b.WriteString(baseStyle.Render(prompts[i]))
|
||||
b.WriteString(baseStyle.Render(strings.Repeat("*", 4)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if m.stage < 3 {
|
||||
b.WriteString(titleStyle.Render(prompts[m.stage]))
|
||||
masked := strings.Repeat("*", len(m.input))
|
||||
b.WriteString(inputStyle.Render(masked))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
if m.stage == 1 {
|
||||
b.WriteString(dimStyle.Render(" PIN MUST BE AT LEAST 4 CHARACTERS"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
152
internal/shell/banking/screen_history.go
Normal file
152
internal/shell/banking/screen_history.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type historyModel struct {
|
||||
accounts []Account
|
||||
transactions map[string][]Transaction
|
||||
selected int
|
||||
page int
|
||||
pageSize int
|
||||
choosing bool
|
||||
choice string
|
||||
}
|
||||
|
||||
func newHistoryModel(accounts []Account, transactions map[string][]Transaction) historyModel {
|
||||
return historyModel{
|
||||
accounts: accounts,
|
||||
transactions: transactions,
|
||||
choosing: true,
|
||||
pageSize: 12,
|
||||
}
|
||||
}
|
||||
|
||||
func (m historyModel) Update(msg tea.Msg) (historyModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.choosing {
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
for i := range m.accounts {
|
||||
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||
m.selected = i
|
||||
m.choosing = false
|
||||
m.page = 0
|
||||
break
|
||||
}
|
||||
}
|
||||
if m.choice == "0" {
|
||||
m.choice = "back"
|
||||
}
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.choice) > 0 {
|
||||
m.choice = m.choice[:len(m.choice)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
|
||||
m.choice += ch
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch keyMsg.String() {
|
||||
case "n", "N":
|
||||
acctNum := m.accounts[m.selected].Number
|
||||
txns := m.transactions[acctNum]
|
||||
maxPage := (len(txns) - 1) / m.pageSize
|
||||
if m.page < maxPage {
|
||||
m.page++
|
||||
}
|
||||
case "p", "P":
|
||||
if m.page > 0 {
|
||||
m.page--
|
||||
}
|
||||
case "b", "B":
|
||||
m.choosing = true
|
||||
m.choice = ""
|
||||
default:
|
||||
m.choice = "back"
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m historyModel) View() string {
|
||||
if m.choosing {
|
||||
return m.viewChooseAccount()
|
||||
}
|
||||
return m.viewHistory()
|
||||
}
|
||||
|
||||
func (m historyModel) viewChooseAccount() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("TRANSACTION HISTORY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
for i, acct := range m.accounts {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s",
|
||||
i+1, acct.Number, acct.Type)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||
b.WriteString(inputStyle.Render(m.choice))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m historyModel) viewHistory() string {
|
||||
var b strings.Builder
|
||||
|
||||
acct := m.accounts[m.selected]
|
||||
txns := m.transactions[acct.Number]
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(fmt.Sprintf("TRANSACTION HISTORY - %s (%s)", acct.Number, acct.Type)))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-40s %14s", "DATE", "DESCRIPTION", "AMOUNT")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
|
||||
b.WriteString("\n")
|
||||
|
||||
start := m.page * m.pageSize
|
||||
end := min(start+m.pageSize, len(txns))
|
||||
|
||||
for _, txn := range txns[start:end] {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-40s %14s",
|
||||
txn.Date, txn.Description, formatCurrency(txn.Amount))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
|
||||
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
103
internal/shell/banking/screen_login.go
Normal file
103
internal/shell/banking/screen_login.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type loginModel struct {
|
||||
accountNum string
|
||||
pin string
|
||||
stage int // 0 = account, 1 = pin, 2 = authenticating
|
||||
bankName string
|
||||
}
|
||||
|
||||
func newLoginModel(bankName string) loginModel {
|
||||
return loginModel{bankName: bankName}
|
||||
}
|
||||
|
||||
func (m loginModel) Update(msg tea.Msg) (loginModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
if m.stage == 0 && m.accountNum != "" {
|
||||
m.stage = 1
|
||||
return m, nil
|
||||
}
|
||||
if m.stage == 1 && m.pin != "" {
|
||||
m.stage = 2
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyBackspace:
|
||||
if m.stage == 0 && len(m.accountNum) > 0 {
|
||||
m.accountNum = m.accountNum[:len(m.accountNum)-1]
|
||||
} else if m.stage == 1 && len(m.pin) > 0 {
|
||||
m.pin = m.pin[:len(m.pin)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
|
||||
if m.stage == 0 && len(m.accountNum) < 20 {
|
||||
m.accountNum += ch
|
||||
} else if m.stage == 1 && len(m.pin) < 12 {
|
||||
m.pin += ch
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m loginModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(m.bankName + " ONLINE BANKING"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("AUTHORIZED ACCESS ONLY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(titleStyle.Render(" ACCOUNT NUMBER: "))
|
||||
if m.stage == 0 {
|
||||
b.WriteString(inputStyle.Render(m.accountNum))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
} else {
|
||||
b.WriteString(baseStyle.Render(m.accountNum))
|
||||
}
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.stage >= 1 {
|
||||
b.WriteString(titleStyle.Render(" PIN: "))
|
||||
masked := strings.Repeat("*", len(m.pin))
|
||||
if m.stage == 1 {
|
||||
b.WriteString(inputStyle.Render(masked))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
} else {
|
||||
b.WriteString(baseStyle.Render(masked))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
if m.stage == 2 {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" AUTHENTICATING..."))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" WARNING: UNAUTHORIZED ACCESS TO THIS SYSTEM IS A FEDERAL CRIME"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(fmt.Sprintf(" UNDER 18 U.S.C. %s 1030. ALL ACTIVITY IS MONITORED AND LOGGED.", "\u00A7")))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
80
internal/shell/banking/screen_menu.go
Normal file
80
internal/shell/banking/screen_menu.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type menuModel struct {
|
||||
choice string
|
||||
unread int
|
||||
bankName string
|
||||
}
|
||||
|
||||
func newMenuModel(bankName string, unreadCount int) menuModel {
|
||||
return menuModel{bankName: bankName, unread: unreadCount}
|
||||
}
|
||||
|
||||
func (m menuModel) Update(msg tea.Msg) (menuModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.choice) > 0 {
|
||||
m.choice = m.choice[:len(m.choice)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
|
||||
if len(m.choice) < 10 {
|
||||
m.choice += ch
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m menuModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("MAIN MENU"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
items := []struct {
|
||||
num string
|
||||
desc string
|
||||
}{
|
||||
{"1", "ACCOUNT SUMMARY"},
|
||||
{"2", "ACCOUNT DETAIL / TRANSACTIONS"},
|
||||
{"3", "WIRE TRANSFER"},
|
||||
{"4", "TRANSACTION HISTORY"},
|
||||
{"5", fmt.Sprintf("SECURE MESSAGES (%d UNREAD)", m.unread)},
|
||||
{"6", "CHANGE PIN"},
|
||||
{"7", "LOGOUT"},
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%s] %s", item.num, item.desc)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||
b.WriteString(inputStyle.Render(m.choice))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
122
internal/shell/banking/screen_messages.go
Normal file
122
internal/shell/banking/screen_messages.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type messagesModel struct {
|
||||
messages []SecureMessage
|
||||
viewing int // -1 = list, >= 0 = detail
|
||||
choice string
|
||||
}
|
||||
|
||||
func newMessagesModel(messages []SecureMessage) messagesModel {
|
||||
return messagesModel{messages: messages, viewing: -1}
|
||||
}
|
||||
|
||||
func (m messagesModel) Update(msg tea.Msg) (messagesModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.viewing >= 0 {
|
||||
// In detail view, any key goes back to list.
|
||||
m.viewing = -1
|
||||
m.choice = ""
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
if m.choice == "0" {
|
||||
m.choice = "back"
|
||||
return m, nil
|
||||
}
|
||||
for i := range m.messages {
|
||||
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||
m.viewing = i
|
||||
m.messages[i].Unread = false
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
m.choice = ""
|
||||
case tea.KeyBackspace:
|
||||
if len(m.choice) > 0 {
|
||||
m.choice = m.choice[:len(m.choice)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' && len(m.choice) < 2 {
|
||||
m.choice += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m messagesModel) View() string {
|
||||
if m.viewing >= 0 {
|
||||
return m.viewDetail()
|
||||
}
|
||||
return m.viewList()
|
||||
}
|
||||
|
||||
func (m messagesModel) viewList() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("SECURE MESSAGES"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-4s %-3s %-12s %-22s %s", "#", "", "DATE", "FROM", "SUBJECT")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
|
||||
b.WriteString("\n")
|
||||
|
||||
for i, msg := range m.messages {
|
||||
marker := " "
|
||||
if msg.Unread {
|
||||
marker = " * "
|
||||
}
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d]%s%-12s %-22s %s",
|
||||
i+1, marker, msg.Date, msg.From, msg.Subj)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" SELECT MESSAGE: "))
|
||||
b.WriteString(inputStyle.Render(m.choice))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m messagesModel) viewDetail() string {
|
||||
var b strings.Builder
|
||||
|
||||
msg := m.messages[m.viewing]
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(fmt.Sprintf("MESSAGE #%d", m.viewing+1)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(msg.Body))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MESSAGE LIST"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
198
internal/shell/banking/screen_transfer.go
Normal file
198
internal/shell/banking/screen_transfer.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
const (
|
||||
transferStepRouting = iota
|
||||
transferStepDest
|
||||
transferStepBeneficiary
|
||||
transferStepBankName
|
||||
transferStepAmount
|
||||
transferStepMemo
|
||||
transferStepConfirm
|
||||
transferStepAuthCode
|
||||
transferStepComplete
|
||||
)
|
||||
|
||||
var transferPrompts = []string{
|
||||
" ROUTING NUMBER (ABA): ",
|
||||
" DESTINATION ACCOUNT: ",
|
||||
" BENEFICIARY NAME: ",
|
||||
" RECEIVING BANK NAME: ",
|
||||
" AMOUNT (USD): ",
|
||||
" MEMO / REFERENCE: ",
|
||||
"",
|
||||
" AUTHORIZATION CODE: ",
|
||||
}
|
||||
|
||||
type transferModel struct {
|
||||
step int
|
||||
fields [8]string // indexed by step
|
||||
transfer WireTransfer
|
||||
confirm string // y/n input for confirm step
|
||||
}
|
||||
|
||||
func newTransferModel() transferModel {
|
||||
return transferModel{}
|
||||
}
|
||||
|
||||
func (m transferModel) Update(msg tea.Msg) (transferModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.step == transferStepComplete {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.step == transferStepConfirm {
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
switch strings.ToUpper(m.confirm) {
|
||||
case "Y", "YES":
|
||||
m.step = transferStepAuthCode
|
||||
case "N", "NO":
|
||||
m.confirm = "cancelled"
|
||||
}
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.confirm) > 0 {
|
||||
m.confirm = m.confirm[:len(m.confirm)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.confirm) < 3 {
|
||||
m.confirm += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
val := strings.TrimSpace(m.fields[m.step])
|
||||
if val == "" {
|
||||
return m, nil
|
||||
}
|
||||
if m.step == transferStepAuthCode {
|
||||
m.transfer.AuthCode = val
|
||||
m.step = transferStepComplete
|
||||
return m, nil
|
||||
}
|
||||
switch m.step {
|
||||
case transferStepRouting:
|
||||
m.transfer.RoutingNumber = val
|
||||
case transferStepDest:
|
||||
m.transfer.DestAccount = val
|
||||
case transferStepBeneficiary:
|
||||
m.transfer.Beneficiary = val
|
||||
case transferStepBankName:
|
||||
m.transfer.BankName = val
|
||||
case transferStepAmount:
|
||||
m.transfer.Amount = val
|
||||
case transferStepMemo:
|
||||
m.transfer.Memo = val
|
||||
}
|
||||
m.step++
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.fields[m.step]) > 0 {
|
||||
m.fields[m.step] = m.fields[m.step][:len(m.fields[m.step])-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.fields[m.step]) < 40 {
|
||||
m.fields[m.step] += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m transferModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("WIRE TRANSFER"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Show completed fields.
|
||||
for i := 0; i < m.step && i < len(transferPrompts); i++ {
|
||||
if i == transferStepConfirm {
|
||||
continue
|
||||
}
|
||||
prompt := transferPrompts[i]
|
||||
if prompt == "" {
|
||||
continue
|
||||
}
|
||||
b.WriteString(baseStyle.Render(prompt))
|
||||
b.WriteString(baseStyle.Render(m.fields[i]))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Current field.
|
||||
switch {
|
||||
case m.step == transferStepConfirm:
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" === TRANSFER SUMMARY ==="))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" ROUTING: %s", m.transfer.RoutingNumber)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" ACCOUNT: %s", m.transfer.DestAccount)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" BENEFICIARY: %s", m.transfer.Beneficiary)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" BANK: %s", m.transfer.BankName)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" AMOUNT: $%s", m.transfer.Amount)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" MEMO: %s", m.transfer.Memo)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" CONFIRM TRANSFER? (Y/N): "))
|
||||
b.WriteString(inputStyle.Render(m.confirm))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
case m.step == transferStepComplete:
|
||||
b.WriteString("\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" TRANSFER QUEUED FOR PROCESSING"))
|
||||
b.WriteString("\n\n")
|
||||
routing := m.transfer.RoutingNumber
|
||||
if len(routing) > 4 {
|
||||
routing = routing[:4]
|
||||
}
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" CONFIRMATION #: WR-%s-%s",
|
||||
routing, "847291")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" STATUS: PENDING FEDWIRE SETTLEMENT"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" ESTIMATED COMPLETION: NEXT BUSINESS DAY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
case m.step < len(transferPrompts):
|
||||
prompt := transferPrompts[m.step]
|
||||
b.WriteString(titleStyle.Render(prompt))
|
||||
b.WriteString(inputStyle.Render(m.fields[m.step]))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
if m.step < transferStepConfirm {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(fmt.Sprintf(" STEP %d OF 6", m.step+1)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
168
internal/shell/banking/style.go
Normal file
168
internal/shell/banking/style.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const termWidth = 80
|
||||
|
||||
// Color palette — green-on-black retro terminal.
|
||||
var (
|
||||
colorGreen = lipgloss.Color("#00FF00")
|
||||
colorDim = lipgloss.Color("#007700")
|
||||
colorBlack = lipgloss.Color("#000000")
|
||||
colorBright = lipgloss.Color("#AAFFAA")
|
||||
colorRed = lipgloss.Color("#FF3333")
|
||||
)
|
||||
|
||||
// Reusable styles.
|
||||
var (
|
||||
baseStyle = lipgloss.NewStyle().
|
||||
Foreground(colorGreen).
|
||||
Background(colorBlack)
|
||||
|
||||
headerStyle = lipgloss.NewStyle().
|
||||
Foreground(colorBright).
|
||||
Background(colorBlack).
|
||||
Bold(true).
|
||||
Width(termWidth).
|
||||
Align(lipgloss.Center)
|
||||
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Foreground(colorGreen).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
|
||||
dimStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(colorRed).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
|
||||
inputStyle = lipgloss.NewStyle().
|
||||
Foreground(colorBright).
|
||||
Background(colorBlack)
|
||||
)
|
||||
|
||||
// divider returns an 80-column === line.
|
||||
func divider() string {
|
||||
return dimStyle.Render(strings.Repeat("=", termWidth))
|
||||
}
|
||||
|
||||
// thinDivider returns an 80-column --- line.
|
||||
func thinDivider() string {
|
||||
return dimStyle.Render(strings.Repeat("-", termWidth))
|
||||
}
|
||||
|
||||
// centerText centers text within 80 columns.
|
||||
func centerText(s string) string {
|
||||
return headerStyle.Render(s)
|
||||
}
|
||||
|
||||
// padRight pads a string to the given width.
|
||||
func padRight(s string, width int) string {
|
||||
if len(s) >= width {
|
||||
return s[:width]
|
||||
}
|
||||
return s + strings.Repeat(" ", width-len(s))
|
||||
}
|
||||
|
||||
// formatCurrency formats cents as $X,XXX.XX
|
||||
func formatCurrency(cents int64) string {
|
||||
negative := cents < 0
|
||||
if negative {
|
||||
cents = -cents
|
||||
}
|
||||
dollars := cents / 100
|
||||
remainder := cents % 100
|
||||
|
||||
// Add thousands separators.
|
||||
ds := fmt.Sprintf("%d", dollars)
|
||||
if len(ds) > 3 {
|
||||
var parts []string
|
||||
for len(ds) > 3 {
|
||||
parts = append([]string{ds[len(ds)-3:]}, parts...)
|
||||
ds = ds[:len(ds)-3]
|
||||
}
|
||||
parts = append([]string{ds}, parts...)
|
||||
ds = strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
if negative {
|
||||
return fmt.Sprintf("-$%s.%02d", ds, remainder)
|
||||
}
|
||||
return fmt.Sprintf("$%s.%02d", ds, remainder)
|
||||
}
|
||||
|
||||
// padLine pads a single line (which may contain ANSI codes) to termWidth
|
||||
// using its visual width. Padding uses a black background so the terminal's
|
||||
// default background doesn't bleed through.
|
||||
func padLine(line string) string {
|
||||
w := lipgloss.Width(line)
|
||||
if w >= termWidth {
|
||||
return line
|
||||
}
|
||||
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
|
||||
}
|
||||
|
||||
// padLines pads every line in a multi-line string to termWidth so that
|
||||
// shorter lines fully overwrite previous content in the terminal.
|
||||
func padLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = padLine(line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// screenFrame wraps content in the persistent header and footer.
|
||||
// The height parameter is used to pad the output to fill the terminal,
|
||||
// preventing leftover lines from previous renders bleeding through.
|
||||
func screenFrame(bankName, terminalID, region, content string, height int) string {
|
||||
var b strings.Builder
|
||||
|
||||
// Header (4 lines).
|
||||
b.WriteString(divider())
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(bankName + " FEDERAL RESERVE SYSTEM"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("SECURE BANKING TERMINAL"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(divider())
|
||||
b.WriteString("\n")
|
||||
|
||||
// Content.
|
||||
b.WriteString(content)
|
||||
|
||||
// Pad with blank lines between content and footer so the footer
|
||||
// stays at the bottom and the total output fills the terminal height.
|
||||
if height > 0 {
|
||||
const headerLines = 4
|
||||
const footerLines = 2
|
||||
// strings.Count gives newlines; add 1 for the line after the last \n.
|
||||
contentLines := strings.Count(content, "\n") + 1
|
||||
used := headerLines + contentLines + footerLines
|
||||
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
|
||||
for i := used; i < height; i++ {
|
||||
b.WriteString(blankLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer (2 lines).
|
||||
b.WriteString("\n")
|
||||
b.WriteString(divider())
|
||||
b.WriteString("\n")
|
||||
footer := fmt.Sprintf(" TERMINAL: %s | REGION: %s | ENCRYPTED SESSION ACTIVE", terminalID, region)
|
||||
b.WriteString(dimStyle.Render(padRight(footer, termWidth)))
|
||||
|
||||
// Pad every line to full terminal width so shorter lines overwrite
|
||||
// leftover content from previous renders.
|
||||
return padLines(b.String())
|
||||
}
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
@@ -86,6 +86,9 @@ func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw i
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("bash")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
@@ -116,7 +116,7 @@ func TestReadLineCtrlD(t *testing.T) {
|
||||
|
||||
func TestBashShellHandle(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash")
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
@@ -166,7 +166,7 @@ func TestBashShellHandle(t *testing.T) {
|
||||
|
||||
func TestBashShellFakeUser(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash")
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
|
||||
206
internal/shell/cisco/cisco.go
Normal file
206
internal/shell/cisco/cisco.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package cisco
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// CiscoShell emulates a Cisco IOS CLI.
|
||||
type CiscoShell struct{}
|
||||
|
||||
// NewCiscoShell returns a new CiscoShell instance.
|
||||
func NewCiscoShell() *CiscoShell {
|
||||
return &CiscoShell{}
|
||||
}
|
||||
|
||||
func (c *CiscoShell) Name() string { return "cisco" }
|
||||
func (c *CiscoShell) Description() string { return "Cisco IOS CLI emulator" }
|
||||
|
||||
func (c *CiscoShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
hostname := configString(sess.ShellConfig, "hostname", "Router")
|
||||
model := configString(sess.ShellConfig, "model", "C2960")
|
||||
iosVersion := configString(sess.ShellConfig, "ios_version", "15.0(2)SE11")
|
||||
enablePass := configString(sess.ShellConfig, "enable_password", "")
|
||||
|
||||
state := newIOSState(hostname, model, iosVersion, enablePass)
|
||||
|
||||
// IOS just shows a blank line then the prompt after SSH auth.
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
|
||||
for {
|
||||
prompt := state.prompt()
|
||||
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for Ctrl+Z (^Z) — return to privileged exec.
|
||||
if trimmed == "\x1a" || trimmed == "^Z" {
|
||||
if state.mode == modeGlobalConfig || state.mode == modeInterfaceConfig {
|
||||
state.mode = modePrivilegedExec
|
||||
state.currentIf = ""
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle "enable" specially — it needs password prompting.
|
||||
if state.mode == modeUserExec && isEnableCommand(trimmed) {
|
||||
output := handleEnable(ctx, state, rw)
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("cisco")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
result := state.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("cisco")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isEnableCommand checks if input resolves to "enable" in user exec mode.
|
||||
func isEnableCommand(input string) bool {
|
||||
words := strings.Fields(input)
|
||||
if len(words) != 1 {
|
||||
return false
|
||||
}
|
||||
w := strings.ToLower(words[0])
|
||||
enable := "enable"
|
||||
return len(w) >= 2 && len(w) <= len(enable) && enable[:len(w)] == w
|
||||
}
|
||||
|
||||
// handleEnable manages the enable password prompt flow.
|
||||
// Returns the output string (for logging).
|
||||
func handleEnable(ctx context.Context, state *iosState, rw io.ReadWriter) string {
|
||||
const maxAttempts = 3
|
||||
hadFailure := false
|
||||
|
||||
for range maxAttempts {
|
||||
fmt.Fprint(rw, "Password: ")
|
||||
password, err := readPassword(ctx, rw)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
|
||||
if state.enablePass == "" {
|
||||
// No password configured — accept after one failed attempt.
|
||||
if hadFailure {
|
||||
state.mode = modePrivilegedExec
|
||||
return ""
|
||||
}
|
||||
hadFailure = true
|
||||
} else if password == state.enablePass {
|
||||
state.mode = modePrivilegedExec
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
output := "% Bad passwords"
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
return output
|
||||
}
|
||||
|
||||
// readPassword reads a password without echoing characters.
|
||||
func readPassword(ctx context.Context, rw io.ReadWriter) (string, error) {
|
||||
var buf []byte
|
||||
b := make([]byte, 1)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := rw.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
ch := b[0]
|
||||
switch {
|
||||
case ch == '\r' || ch == '\n':
|
||||
return string(buf), nil
|
||||
case ch == 4: // Ctrl+D
|
||||
return string(buf), io.EOF
|
||||
case ch == 3: // Ctrl+C
|
||||
return "", io.EOF
|
||||
case ch == 127 || ch == 8: // Backspace/DEL
|
||||
if len(buf) > 0 {
|
||||
buf = buf[:len(buf)-1]
|
||||
}
|
||||
case ch == 27: // ESC sequence
|
||||
next := make([]byte, 1)
|
||||
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
|
||||
rw.Read(next)
|
||||
}
|
||||
case ch >= 32 && ch < 127:
|
||||
buf = append(buf, ch)
|
||||
// Don't echo.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
531
internal/shell/cisco/cisco_test.go
Normal file
531
internal/shell/cisco/cisco_test.go
Normal file
@@ -0,0 +1,531 @@
|
||||
package cisco
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Abbreviation resolution tests ---
|
||||
|
||||
func TestResolveAbbreviationExact(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "shutdown"},
|
||||
}
|
||||
got, err := resolveAbbreviation("show", entries)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "show" {
|
||||
t.Errorf("got %q, want %q", got, "show")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAbbreviationUnique(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "enable"},
|
||||
{name: "exit"},
|
||||
}
|
||||
got, err := resolveAbbreviation("sh", entries)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "show" {
|
||||
t.Errorf("got %q, want %q", got, "show")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAbbreviationAmbiguous(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "shutdown"},
|
||||
}
|
||||
_, err := resolveAbbreviation("sh", entries)
|
||||
if err == nil {
|
||||
t.Fatal("expected ambiguous error, got nil")
|
||||
}
|
||||
if err.Error() != "ambiguous" {
|
||||
t.Errorf("got error %q, want %q", err.Error(), "ambiguous")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAbbreviationUnknown(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "enable"},
|
||||
}
|
||||
_, err := resolveAbbreviation("xyz", entries)
|
||||
if err == nil {
|
||||
t.Fatal("expected unknown error, got nil")
|
||||
}
|
||||
if err.Error() != "unknown" {
|
||||
t.Errorf("got error %q, want %q", err.Error(), "unknown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAbbreviationCaseInsensitive(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "enable"},
|
||||
}
|
||||
got, err := resolveAbbreviation("SH", entries)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "show" {
|
||||
t.Errorf("got %q, want %q", got, "show")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-word command resolution tests ---
|
||||
|
||||
func TestResolveCommandShowRunningConfig(t *testing.T) {
|
||||
resolved, args, err := resolveCommand([]string{"sh", "run"}, privilegedExecCommands)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(args) != 0 {
|
||||
t.Errorf("unexpected args: %v", args)
|
||||
}
|
||||
want := []string{"show", "running-config"}
|
||||
if len(resolved) != len(want) {
|
||||
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||
}
|
||||
for i := range want {
|
||||
if resolved[i] != want[i] {
|
||||
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCommandConfigureTerminal(t *testing.T) {
|
||||
resolved, _, err := resolveCommand([]string{"conf", "t"}, privilegedExecCommands)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := []string{"configure", "terminal"}
|
||||
if len(resolved) != len(want) {
|
||||
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||
}
|
||||
for i := range want {
|
||||
if resolved[i] != want[i] {
|
||||
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCommandShowIPInterfaceBrief(t *testing.T) {
|
||||
resolved, _, err := resolveCommand([]string{"sh", "ip", "int", "br"}, privilegedExecCommands)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := []string{"show", "ip", "interface", "brief"}
|
||||
if len(resolved) != len(want) {
|
||||
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||
}
|
||||
for i := range want {
|
||||
if resolved[i] != want[i] {
|
||||
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCommandWithArgs(t *testing.T) {
|
||||
// "hostname MyRouter" → resolved=["hostname"], args=["MyRouter"]
|
||||
resolved, args, err := resolveCommand([]string{"hostname", "MyRouter"}, globalConfigCommands)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(resolved) != 1 || resolved[0] != "hostname" {
|
||||
t.Errorf("resolved = %v, want [hostname]", resolved)
|
||||
}
|
||||
if len(args) != 1 || args[0] != "MyRouter" {
|
||||
t.Errorf("args = %v, want [MyRouter]", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCommandAmbiguous(t *testing.T) {
|
||||
// In user exec, "e" matches "enable" and "exit" — ambiguous
|
||||
_, _, err := resolveCommand([]string{"e"}, userExecCommands)
|
||||
if err == nil {
|
||||
t.Fatal("expected ambiguous error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Mode state machine tests ---
|
||||
|
||||
func TestPromptGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
mode iosMode
|
||||
want string
|
||||
}{
|
||||
{modeUserExec, "Router>"},
|
||||
{modePrivilegedExec, "Router#"},
|
||||
{modeGlobalConfig, "Router(config)#"},
|
||||
{modeInterfaceConfig, "Router(config-if)#"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = tt.mode
|
||||
if got := s.prompt(); got != tt.want {
|
||||
t.Errorf("prompt(%d) = %q, want %q", tt.mode, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptAfterHostnameChange(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modeGlobalConfig
|
||||
s.dispatch("hostname Switch1")
|
||||
if s.hostname != "Switch1" {
|
||||
t.Fatalf("hostname = %q, want %q", s.hostname, "Switch1")
|
||||
}
|
||||
if got := s.prompt(); got != "Switch1(config)#" {
|
||||
t.Errorf("prompt = %q, want %q", got, "Switch1(config)#")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModeTransitions(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
|
||||
// Start in user exec.
|
||||
if s.mode != modeUserExec {
|
||||
t.Fatalf("initial mode = %d, want %d", s.mode, modeUserExec)
|
||||
}
|
||||
|
||||
// Can't skip to config mode directly from user exec.
|
||||
result := s.dispatch("configure terminal")
|
||||
if result.output == "" {
|
||||
t.Error("expected error for conf t in user exec mode")
|
||||
}
|
||||
|
||||
// Manually set privileged mode (enable tested separately).
|
||||
s.mode = modePrivilegedExec
|
||||
|
||||
// conf t → global config
|
||||
s.dispatch("configure terminal")
|
||||
if s.mode != modeGlobalConfig {
|
||||
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
|
||||
}
|
||||
|
||||
// interface Gi0/0 → interface config
|
||||
s.dispatch("interface GigabitEthernet0/0")
|
||||
if s.mode != modeInterfaceConfig {
|
||||
t.Errorf("mode after interface = %d, want %d", s.mode, modeInterfaceConfig)
|
||||
}
|
||||
|
||||
// exit → back to global config
|
||||
s.dispatch("exit")
|
||||
if s.mode != modeGlobalConfig {
|
||||
t.Errorf("mode after exit from if-config = %d, want %d", s.mode, modeGlobalConfig)
|
||||
}
|
||||
|
||||
// end → back to privileged exec
|
||||
s.dispatch("end")
|
||||
if s.mode != modePrivilegedExec {
|
||||
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
|
||||
}
|
||||
|
||||
// disable → back to user exec
|
||||
s.dispatch("disable")
|
||||
if s.mode != modeUserExec {
|
||||
t.Errorf("mode after disable = %d, want %d", s.mode, modeUserExec)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndFromInterfaceConfig(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modeInterfaceConfig
|
||||
s.currentIf = "GigabitEthernet0/0"
|
||||
|
||||
s.dispatch("end")
|
||||
if s.mode != modePrivilegedExec {
|
||||
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
|
||||
}
|
||||
if s.currentIf != "" {
|
||||
t.Errorf("currentIf = %q, want empty", s.currentIf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitFromPrivilegedExec(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modePrivilegedExec
|
||||
result := s.dispatch("exit")
|
||||
if !result.exit {
|
||||
t.Error("expected exit=true from privileged exec exit")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Show command output tests ---
|
||||
|
||||
func TestShowVersionContainsModel(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showVersion(s)
|
||||
if !contains(output, "C2960") {
|
||||
t.Error("show version missing model")
|
||||
}
|
||||
if !contains(output, "15.0(2)SE11") {
|
||||
t.Error("show version missing IOS version")
|
||||
}
|
||||
if !contains(output, "Router") {
|
||||
t.Error("show version missing hostname")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowRunningConfigContainsInterfaces(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showRunningConfig(s)
|
||||
if !contains(output, "hostname Router") {
|
||||
t.Error("running-config missing hostname")
|
||||
}
|
||||
if !contains(output, "interface GigabitEthernet0/0") {
|
||||
t.Error("running-config missing interface")
|
||||
}
|
||||
if !contains(output, "ip address 192.168.1.1") {
|
||||
t.Error("running-config missing IP address")
|
||||
}
|
||||
if !contains(output, "line vty") {
|
||||
t.Error("running-config missing VTY config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowRunningConfigWithEnableSecret(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "secret123")
|
||||
output := showRunningConfig(s)
|
||||
if !contains(output, "enable secret") {
|
||||
t.Error("running-config missing enable secret when password is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowRunningConfigWithoutEnableSecret(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showRunningConfig(s)
|
||||
if contains(output, "enable secret") {
|
||||
t.Error("running-config should not have enable secret when password is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowIPInterfaceBrief(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showIPInterfaceBrief(s)
|
||||
if !contains(output, "GigabitEthernet0/0") {
|
||||
t.Error("ip interface brief missing GigabitEthernet0/0")
|
||||
}
|
||||
if !contains(output, "192.168.1.1") {
|
||||
t.Error("ip interface brief missing 192.168.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowIPRoute(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showIPRoute(s)
|
||||
if !contains(output, "directly connected") {
|
||||
t.Error("ip route missing connected routes")
|
||||
}
|
||||
if !contains(output, "0.0.0.0/0") {
|
||||
t.Error("ip route missing default route")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowVLANBrief(t *testing.T) {
|
||||
output := showVLANBrief()
|
||||
if !contains(output, "default") {
|
||||
t.Error("vlan brief missing default vlan")
|
||||
}
|
||||
if !contains(output, "MGMT") {
|
||||
t.Error("vlan brief missing MGMT vlan")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Interface config tests ---
|
||||
|
||||
func TestInterfaceShutdownNoShutdown(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modeInterfaceConfig
|
||||
s.currentIf = "GigabitEthernet0/0"
|
||||
|
||||
s.dispatch("shutdown")
|
||||
iface := s.findInterface("GigabitEthernet0/0")
|
||||
if iface == nil {
|
||||
t.Fatal("interface not found")
|
||||
}
|
||||
if !iface.shutdown {
|
||||
t.Error("interface should be shutdown")
|
||||
}
|
||||
if iface.status != "administratively down" {
|
||||
t.Errorf("status = %q, want %q", iface.status, "administratively down")
|
||||
}
|
||||
|
||||
s.dispatch("no shutdown")
|
||||
if iface.shutdown {
|
||||
t.Error("interface should not be shutdown after no shutdown")
|
||||
}
|
||||
if iface.status != "up" {
|
||||
t.Errorf("status = %q, want %q", iface.status, "up")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaceIPAddress(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modeInterfaceConfig
|
||||
s.currentIf = "GigabitEthernet0/0"
|
||||
|
||||
s.dispatch("ip address 10.10.10.1 255.255.255.0")
|
||||
iface := s.findInterface("GigabitEthernet0/0")
|
||||
if iface == nil {
|
||||
t.Fatal("interface not found")
|
||||
}
|
||||
if iface.ip != "10.10.10.1" {
|
||||
t.Errorf("ip = %q, want %q", iface.ip, "10.10.10.1")
|
||||
}
|
||||
if iface.mask != "255.255.255.0" {
|
||||
t.Errorf("mask = %q, want %q", iface.mask, "255.255.255.0")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Dispatch / invalid command tests ---
|
||||
|
||||
func TestInvalidCommandInUserExec(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
result := s.dispatch("foobar")
|
||||
if !contains(result.output, "Invalid input") {
|
||||
t.Errorf("expected invalid input error, got %q", result.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmbiguousCommandOutput(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
// "e" in user exec is ambiguous (enable, exit)
|
||||
result := s.dispatch("e")
|
||||
if !contains(result.output, "Ambiguous") {
|
||||
t.Errorf("expected ambiguous error, got %q", result.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelpCommand(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
result := s.dispatch("?")
|
||||
if !contains(result.output, "show") {
|
||||
t.Error("help missing 'show'")
|
||||
}
|
||||
if !contains(result.output, "enable") {
|
||||
t.Error("help missing 'enable'")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Abbreviation integration tests ---
|
||||
|
||||
func TestShowAbbreviationInDispatch(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modePrivilegedExec
|
||||
result := s.dispatch("sh ver")
|
||||
if !contains(result.output, "Cisco IOS Software") {
|
||||
t.Error("'sh ver' should produce version output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfTAbbreviation(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modePrivilegedExec
|
||||
s.dispatch("conf t")
|
||||
if s.mode != modeGlobalConfig {
|
||||
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Enable command detection ---
|
||||
|
||||
func TestIsEnableCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"enable", true},
|
||||
{"en", true},
|
||||
{"ena", true},
|
||||
{"e", false}, // too short (single char could be other commands)
|
||||
{"enab", true},
|
||||
{"ENABLE", true},
|
||||
{"exit", false},
|
||||
{"enable 15", false}, // has extra argument
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isEnableCommand(tt.input); got != tt.want {
|
||||
t.Errorf("isEnableCommand(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- configString tests ---
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"hostname": "MySwitch"}
|
||||
if got := configString(cfg, "hostname", "Router"); got != "MySwitch" {
|
||||
t.Errorf("configString() = %q, want %q", got, "MySwitch")
|
||||
}
|
||||
if got := configString(cfg, "missing", "Default"); got != "Default" {
|
||||
t.Errorf("configString() for missing = %q, want %q", got, "Default")
|
||||
}
|
||||
if got := configString(nil, "key", "Default"); got != "Default" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "Default")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper ---
|
||||
|
||||
func TestMaskBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
mask string
|
||||
want int
|
||||
}{
|
||||
{"255.255.255.0", 24},
|
||||
{"255.255.255.252", 30},
|
||||
{"255.255.0.0", 16},
|
||||
{"255.0.0.0", 8},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := maskBits(tt.mask); got != tt.want {
|
||||
t.Errorf("maskBits(%q) = %d, want %d", tt.mask, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkFromIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip, mask, want string
|
||||
}{
|
||||
{"192.168.1.1", "255.255.255.0", "192.168.1.0"},
|
||||
{"10.0.0.1", "255.255.255.252", "10.0.0.0"},
|
||||
{"172.16.5.100", "255.255.0.0", "172.16.0.0"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := networkFromIP(tt.ip, tt.mask); got != tt.want {
|
||||
t.Errorf("networkFromIP(%q, %q) = %q, want %q", tt.ip, tt.mask, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shell metadata ---
|
||||
|
||||
func TestShellNameAndDescription(t *testing.T) {
|
||||
s := NewCiscoShell()
|
||||
if s.Name() != "cisco" {
|
||||
t.Errorf("Name() = %q, want %q", s.Name(), "cisco")
|
||||
}
|
||||
if s.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && containsHelper(s, substr)
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
414
internal/shell/cisco/commands.go
Normal file
414
internal/shell/cisco/commands.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package cisco
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// commandResult holds the output of a command and whether the session should end.
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
// commandEntry defines a single command with its name and optional sub-commands.
|
||||
type commandEntry struct {
|
||||
name string
|
||||
subs []commandEntry // nil for leaf commands
|
||||
}
|
||||
|
||||
// userExecCommands defines the command tree for user EXEC mode.
|
||||
var userExecCommands = []commandEntry{
|
||||
{name: "show", subs: []commandEntry{
|
||||
{name: "version"},
|
||||
{name: "clock"},
|
||||
{name: "ip", subs: []commandEntry{
|
||||
{name: "route"},
|
||||
{name: "interface", subs: []commandEntry{
|
||||
{name: "brief"},
|
||||
}},
|
||||
}},
|
||||
{name: "interfaces"},
|
||||
{name: "vlan", subs: []commandEntry{
|
||||
{name: "brief"},
|
||||
}},
|
||||
}},
|
||||
{name: "enable"},
|
||||
{name: "exit"},
|
||||
{name: "?"},
|
||||
}
|
||||
|
||||
// privilegedExecCommands extends user commands for privileged mode.
|
||||
var privilegedExecCommands = []commandEntry{
|
||||
{name: "show", subs: []commandEntry{
|
||||
{name: "version"},
|
||||
{name: "clock"},
|
||||
{name: "ip", subs: []commandEntry{
|
||||
{name: "route"},
|
||||
{name: "interface", subs: []commandEntry{
|
||||
{name: "brief"},
|
||||
}},
|
||||
}},
|
||||
{name: "interfaces"},
|
||||
{name: "running-config"},
|
||||
{name: "startup-config"},
|
||||
{name: "vlan", subs: []commandEntry{
|
||||
{name: "brief"},
|
||||
}},
|
||||
}},
|
||||
{name: "configure", subs: []commandEntry{
|
||||
{name: "terminal"},
|
||||
}},
|
||||
{name: "write", subs: []commandEntry{
|
||||
{name: "memory"},
|
||||
}},
|
||||
{name: "copy"},
|
||||
{name: "reload"},
|
||||
{name: "disable"},
|
||||
{name: "terminal", subs: []commandEntry{
|
||||
{name: "length"},
|
||||
}},
|
||||
{name: "exit"},
|
||||
{name: "?"},
|
||||
}
|
||||
|
||||
// globalConfigCommands defines the command tree for global config mode.
|
||||
var globalConfigCommands = []commandEntry{
|
||||
{name: "hostname"},
|
||||
{name: "interface"},
|
||||
{name: "ip", subs: []commandEntry{
|
||||
{name: "route"},
|
||||
}},
|
||||
{name: "no"},
|
||||
{name: "end"},
|
||||
{name: "exit"},
|
||||
{name: "?"},
|
||||
}
|
||||
|
||||
// interfaceConfigCommands defines the command tree for interface config mode.
|
||||
var interfaceConfigCommands = []commandEntry{
|
||||
{name: "ip", subs: []commandEntry{
|
||||
{name: "address"},
|
||||
}},
|
||||
{name: "description"},
|
||||
{name: "shutdown"},
|
||||
{name: "no", subs: []commandEntry{
|
||||
{name: "shutdown"},
|
||||
}},
|
||||
{name: "switchport", subs: []commandEntry{
|
||||
{name: "mode"},
|
||||
}},
|
||||
{name: "end"},
|
||||
{name: "exit"},
|
||||
{name: "?"},
|
||||
}
|
||||
|
||||
// commandsForMode returns the command tree for the given IOS mode.
|
||||
func commandsForMode(mode iosMode) []commandEntry {
|
||||
switch mode {
|
||||
case modeUserExec:
|
||||
return userExecCommands
|
||||
case modePrivilegedExec:
|
||||
return privilegedExecCommands
|
||||
case modeGlobalConfig:
|
||||
return globalConfigCommands
|
||||
case modeInterfaceConfig:
|
||||
return interfaceConfigCommands
|
||||
default:
|
||||
return userExecCommands
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAbbreviation attempts to match an abbreviated word against a list of
|
||||
// command entries. It returns the matched entry name, or an error string if
|
||||
// ambiguous or unknown.
|
||||
func resolveAbbreviation(word string, entries []commandEntry) (string, error) {
|
||||
word = strings.ToLower(word)
|
||||
var matches []string
|
||||
for _, e := range entries {
|
||||
if strings.ToLower(e.name) == word {
|
||||
return e.name, nil // exact match
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(e.name), word) {
|
||||
matches = append(matches, e.name)
|
||||
}
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return "", fmt.Errorf("unknown")
|
||||
case 1:
|
||||
return matches[0], nil
|
||||
default:
|
||||
return "", fmt.Errorf("ambiguous")
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCommand resolves a sequence of abbreviated words into the canonical
|
||||
// command path (e.g., ["sh", "run"] → ["show", "running-config"]).
|
||||
// It returns the resolved path, any remaining arguments, and an error if
|
||||
// resolution fails.
|
||||
func resolveCommand(words []string, entries []commandEntry) ([]string, []string, error) {
|
||||
var resolved []string
|
||||
current := entries
|
||||
|
||||
for i, w := range words {
|
||||
name, err := resolveAbbreviation(w, current)
|
||||
if err != nil {
|
||||
if err.Error() == "unknown" && len(resolved) > 0 {
|
||||
// Remaining words are arguments to the resolved command.
|
||||
return resolved, words[i:], nil
|
||||
}
|
||||
return resolved, words[i:], err
|
||||
}
|
||||
resolved = append(resolved, name)
|
||||
|
||||
// Find sub-commands for the matched entry.
|
||||
var nextLevel []commandEntry
|
||||
for _, e := range current {
|
||||
if e.name == name {
|
||||
nextLevel = e.subs
|
||||
break
|
||||
}
|
||||
}
|
||||
if nextLevel == nil {
|
||||
// Leaf command — rest are arguments.
|
||||
return resolved, words[i+1:], nil
|
||||
}
|
||||
current = nextLevel
|
||||
}
|
||||
return resolved, nil, nil
|
||||
}
|
||||
|
||||
// dispatch processes a command line in the context of the current IOS state.
|
||||
func (s *iosState) dispatch(input string) commandResult {
|
||||
words := strings.Fields(input)
|
||||
if len(words) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
// Handle "?" as a help request.
|
||||
if words[0] == "?" {
|
||||
return s.cmdHelp()
|
||||
}
|
||||
|
||||
cmds := commandsForMode(s.mode)
|
||||
resolved, args, err := resolveCommand(words, cmds)
|
||||
if err != nil {
|
||||
if err.Error() == "ambiguous" {
|
||||
return commandResult{output: fmt.Sprintf("%% Ambiguous command: \"%s\"", input)}
|
||||
}
|
||||
return commandResult{output: invalidInput(input)}
|
||||
}
|
||||
|
||||
if len(resolved) == 0 {
|
||||
return commandResult{output: invalidInput(input)}
|
||||
}
|
||||
|
||||
cmd := strings.Join(resolved, " ")
|
||||
|
||||
switch s.mode {
|
||||
case modeUserExec:
|
||||
return s.dispatchUserExec(cmd, args)
|
||||
case modePrivilegedExec:
|
||||
return s.dispatchPrivilegedExec(cmd, args)
|
||||
case modeGlobalConfig:
|
||||
return s.dispatchGlobalConfig(cmd, args)
|
||||
case modeInterfaceConfig:
|
||||
return s.dispatchInterfaceConfig(cmd, args)
|
||||
}
|
||||
return commandResult{output: invalidInput(input)}
|
||||
}
|
||||
|
||||
func (s *iosState) dispatchUserExec(cmd string, args []string) commandResult {
|
||||
switch cmd {
|
||||
case "show version":
|
||||
return commandResult{output: showVersion(s)}
|
||||
case "show clock":
|
||||
return commandResult{output: showClock()}
|
||||
case "show ip route":
|
||||
return commandResult{output: showIPRoute(s)}
|
||||
case "show ip interface brief":
|
||||
return commandResult{output: showIPInterfaceBrief(s)}
|
||||
case "show interfaces":
|
||||
return commandResult{output: showInterfaces(s)}
|
||||
case "show vlan brief":
|
||||
return commandResult{output: showVLANBrief()}
|
||||
case "enable":
|
||||
return commandResult{} // handled in Handle() loop
|
||||
case "exit":
|
||||
return commandResult{exit: true}
|
||||
}
|
||||
return commandResult{output: invalidInput(cmd)}
|
||||
}
|
||||
|
||||
func (s *iosState) dispatchPrivilegedExec(cmd string, args []string) commandResult {
|
||||
switch cmd {
|
||||
case "show version":
|
||||
return commandResult{output: showVersion(s)}
|
||||
case "show clock":
|
||||
return commandResult{output: showClock()}
|
||||
case "show ip route":
|
||||
return commandResult{output: showIPRoute(s)}
|
||||
case "show ip interface brief":
|
||||
return commandResult{output: showIPInterfaceBrief(s)}
|
||||
case "show interfaces":
|
||||
return commandResult{output: showInterfaces(s)}
|
||||
case "show running-config":
|
||||
return commandResult{output: showRunningConfig(s)}
|
||||
case "show startup-config":
|
||||
return commandResult{output: showRunningConfig(s)} // same as running
|
||||
case "show vlan brief":
|
||||
return commandResult{output: showVLANBrief()}
|
||||
case "configure terminal":
|
||||
s.mode = modeGlobalConfig
|
||||
return commandResult{output: "Enter configuration commands, one per line. End with CNTL/Z."}
|
||||
case "write memory":
|
||||
return commandResult{output: "[OK]"}
|
||||
case "copy":
|
||||
return commandResult{output: "[OK]"}
|
||||
case "reload":
|
||||
return commandResult{output: "System configuration has been modified. Save? [yes/no]: ", exit: true}
|
||||
case "disable":
|
||||
s.mode = modeUserExec
|
||||
return commandResult{}
|
||||
case "terminal length":
|
||||
return commandResult{} // accept silently
|
||||
case "exit":
|
||||
return commandResult{exit: true}
|
||||
}
|
||||
return commandResult{output: invalidInput(cmd)}
|
||||
}
|
||||
|
||||
func (s *iosState) dispatchGlobalConfig(cmd string, args []string) commandResult {
|
||||
switch cmd {
|
||||
case "hostname":
|
||||
if len(args) < 1 {
|
||||
return commandResult{output: "% Incomplete command."}
|
||||
}
|
||||
s.hostname = args[0]
|
||||
return commandResult{}
|
||||
case "interface":
|
||||
if len(args) < 1 {
|
||||
return commandResult{output: "% Incomplete command."}
|
||||
}
|
||||
ifName := strings.Join(args, "")
|
||||
s.currentIf = ifName
|
||||
s.mode = modeInterfaceConfig
|
||||
return commandResult{}
|
||||
case "ip route":
|
||||
return commandResult{} // accept silently
|
||||
case "no":
|
||||
return commandResult{} // accept silently
|
||||
case "end":
|
||||
s.mode = modePrivilegedExec
|
||||
return commandResult{}
|
||||
case "exit":
|
||||
s.mode = modePrivilegedExec
|
||||
return commandResult{}
|
||||
}
|
||||
return commandResult{output: invalidInput(cmd)}
|
||||
}
|
||||
|
||||
func (s *iosState) dispatchInterfaceConfig(cmd string, args []string) commandResult {
|
||||
switch cmd {
|
||||
case "ip address":
|
||||
if len(args) < 2 {
|
||||
return commandResult{output: "% Incomplete command."}
|
||||
}
|
||||
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||
iface.ip = args[0]
|
||||
iface.mask = args[1]
|
||||
}
|
||||
return commandResult{}
|
||||
case "description":
|
||||
if len(args) < 1 {
|
||||
return commandResult{output: "% Incomplete command."}
|
||||
}
|
||||
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||
iface.desc = strings.Join(args, " ")
|
||||
}
|
||||
return commandResult{}
|
||||
case "shutdown":
|
||||
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||
iface.shutdown = true
|
||||
iface.status = "administratively down"
|
||||
iface.protocol = "down"
|
||||
}
|
||||
return commandResult{}
|
||||
case "no shutdown":
|
||||
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||
iface.shutdown = false
|
||||
iface.status = "up"
|
||||
iface.protocol = "up"
|
||||
}
|
||||
return commandResult{}
|
||||
case "switchport mode":
|
||||
return commandResult{} // accept silently
|
||||
case "end":
|
||||
s.mode = modePrivilegedExec
|
||||
s.currentIf = ""
|
||||
return commandResult{}
|
||||
case "exit":
|
||||
s.mode = modeGlobalConfig
|
||||
s.currentIf = ""
|
||||
return commandResult{}
|
||||
}
|
||||
return commandResult{output: invalidInput(cmd)}
|
||||
}
|
||||
|
||||
func (s *iosState) cmdHelp() commandResult {
|
||||
cmds := commandsForMode(s.mode)
|
||||
var b strings.Builder
|
||||
for _, e := range cmds {
|
||||
if e.name == "?" {
|
||||
continue
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" %-20s %s\n", e.name, helpText(e.name)))
|
||||
}
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func helpText(name string) string {
|
||||
switch name {
|
||||
case "show":
|
||||
return "Show running system information"
|
||||
case "enable":
|
||||
return "Turn on privileged commands"
|
||||
case "disable":
|
||||
return "Turn off privileged commands"
|
||||
case "exit":
|
||||
return "Exit from the EXEC"
|
||||
case "configure":
|
||||
return "Enter configuration mode"
|
||||
case "write":
|
||||
return "Write running configuration to memory"
|
||||
case "copy":
|
||||
return "Copy from one file to another"
|
||||
case "reload":
|
||||
return "Halt and perform a cold restart"
|
||||
case "terminal":
|
||||
return "Set terminal line parameters"
|
||||
case "hostname":
|
||||
return "Set system's network name"
|
||||
case "interface":
|
||||
return "Select an interface to configure"
|
||||
case "ip":
|
||||
return "Global IP configuration subcommands"
|
||||
case "no":
|
||||
return "Negate a command or set its defaults"
|
||||
case "end":
|
||||
return "Exit from configure mode"
|
||||
case "description":
|
||||
return "Interface specific description"
|
||||
case "shutdown":
|
||||
return "Shutdown the selected interface"
|
||||
case "switchport":
|
||||
return "Set switching mode characteristics"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func invalidInput(input string) string {
|
||||
return fmt.Sprintf("%% Invalid input detected at '^' marker.\n\n%s\n^", input)
|
||||
}
|
||||
234
internal/shell/cisco/output.go
Normal file
234
internal/shell/cisco/output.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package cisco
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func showVersion(s *iosState) string {
|
||||
days := 14 + rand.Intn(350)
|
||||
hours := rand.Intn(24)
|
||||
mins := rand.Intn(60)
|
||||
|
||||
return fmt.Sprintf(`Cisco IOS Software, %s Software (%s-UNIVERSALK9-M), Version %s, RELEASE SOFTWARE (fc3)
|
||||
Technical Support: http://www.cisco.com/techsupport
|
||||
Copyright (c) 1986-2019 by Cisco Systems, Inc.
|
||||
Compiled Thu 30-Jan-19 10:08 by prod_rel_team
|
||||
|
||||
ROM: Bootstrap program is %s boot loader
|
||||
BOOTLDR: %s Boot Loader (C2960-HBOOT-M) Version 15.0(2r)SE, RELEASE SOFTWARE (fc1)
|
||||
|
||||
%s uptime is %d days, %d hours, %d minutes
|
||||
System returned to ROM by power-on
|
||||
System image file is "flash:/%s-universalk9-mz.SPA.%s.bin"
|
||||
|
||||
This product contains cryptographic features and is subject to United States
|
||||
and local country laws governing import, export, transfer and use.
|
||||
|
||||
cisco %s (%s) processor (revision K0) with 524288K bytes of memory.
|
||||
Processor board ID %s
|
||||
Last reset from power-on
|
||||
2 Gigabit Ethernet interfaces
|
||||
1 Virtual Ethernet interface
|
||||
64K bytes of flash-simulated non-volatile configuration memory.
|
||||
Total of 65536K bytes of APC System Flash (Read/Write)
|
||||
|
||||
Configuration register is 0x2102`,
|
||||
s.model, s.model, s.iosVersion,
|
||||
s.model, s.model,
|
||||
s.hostname, days, hours, mins,
|
||||
s.model, s.iosVersion,
|
||||
s.model, processorForModel(s.model),
|
||||
s.serial,
|
||||
)
|
||||
}
|
||||
|
||||
func processorForModel(model string) string {
|
||||
if strings.HasPrefix(model, "C29") {
|
||||
return "PowerPC405"
|
||||
}
|
||||
return "MIPS"
|
||||
}
|
||||
|
||||
func showClock() string {
|
||||
now := time.Now().UTC()
|
||||
return fmt.Sprintf("*%s UTC", now.Format("15:04:05.000 Mon Jan 2 2006"))
|
||||
}
|
||||
|
||||
func showIPRoute(s *iosState) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("Codes: C - connected, S - static, R - RIP, M - mobile, B - BGP\n")
|
||||
b.WriteString(" D - EIGRP, EX - EIGRP external, O - OSPF, IA - OSPF inter area\n")
|
||||
b.WriteString(" N1 - OSPF NSSA external type 1, N2 - OSPF NSSA external type 2\n")
|
||||
b.WriteString(" E1 - OSPF external type 1, E2 - OSPF external type 2\n")
|
||||
b.WriteString(" i - IS-IS, su - IS-IS summary, L1 - IS-IS level-1, L2 - IS-IS level-2\n")
|
||||
b.WriteString(" ia - IS-IS inter area, * - candidate default, U - per-user static route\n")
|
||||
b.WriteString(" o - ODR, P - periodic downloaded static route\n\n")
|
||||
b.WriteString("Gateway of last resort is 10.0.0.2 to network 0.0.0.0\n\n")
|
||||
|
||||
for _, iface := range s.interfaces {
|
||||
if iface.ip == "unassigned" || iface.status != "up" {
|
||||
continue
|
||||
}
|
||||
network := networkFromIP(iface.ip, iface.mask)
|
||||
maskBits := maskBits(iface.mask)
|
||||
fmt.Fprintf(&b, "C %s/%d is directly connected, %s\n", network, maskBits, iface.name)
|
||||
}
|
||||
b.WriteString("S* 0.0.0.0/0 [1/0] via 10.0.0.2")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func showIPInterfaceBrief(s *iosState) string {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%-25s %-15s %-4s %-7s %-22s %s\n",
|
||||
"Interface", "IP-Address", "OK?", "Method", "Status", "Protocol")
|
||||
for _, iface := range s.interfaces {
|
||||
ip := iface.ip
|
||||
if ip == "" {
|
||||
ip = "unassigned"
|
||||
}
|
||||
fmt.Fprintf(&b, "%-25s %-15s YES manual %-22s %s\n",
|
||||
iface.name, ip, iface.status, iface.protocol)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func showInterfaces(s *iosState) string {
|
||||
var b strings.Builder
|
||||
for i, iface := range s.interfaces {
|
||||
if i > 0 {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
upDown := "up"
|
||||
if iface.shutdown {
|
||||
upDown = "administratively down"
|
||||
}
|
||||
fmt.Fprintf(&b, "%s is %s, line protocol is %s\n", iface.name, upDown, iface.protocol)
|
||||
fmt.Fprintf(&b, " Hardware is Gigabit Ethernet, address is %s (bia %s)\n", iface.mac, iface.mac)
|
||||
if iface.ip != "unassigned" && iface.ip != "" {
|
||||
fmt.Fprintf(&b, " Internet address is %s/%d\n", iface.ip, maskBits(iface.mask))
|
||||
}
|
||||
fmt.Fprintf(&b, " MTU %d bytes, BW %s sec, DLY 10 usec,\n", iface.mtu, iface.bandwidth)
|
||||
b.WriteString(" reliability 255/255, txload 1/255, rxload 1/255\n")
|
||||
b.WriteString(" Encapsulation ARPA, loopback not set\n")
|
||||
fmt.Fprintf(&b, " %d packets input, %d bytes, 0 no buffer\n", iface.rxPackets, iface.rxBytes)
|
||||
fmt.Fprintf(&b, " %d packets output, %d bytes, 0 underruns", iface.txPackets, iface.txBytes)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func showRunningConfig(s *iosState) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("Building configuration...\n\n")
|
||||
b.WriteString("Current configuration : 1482 bytes\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("! Last configuration change at 14:32:22 UTC Mon Feb 10 2025\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("version 15.0\n")
|
||||
b.WriteString("service timestamps debug datetime msec\n")
|
||||
b.WriteString("service timestamps log datetime msec\n")
|
||||
b.WriteString("no service password-encryption\n")
|
||||
b.WriteString("!\n")
|
||||
fmt.Fprintf(&b, "hostname %s\n", s.hostname)
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("boot-start-marker\n")
|
||||
b.WriteString("boot-end-marker\n")
|
||||
b.WriteString("!\n")
|
||||
if s.enablePass != "" {
|
||||
b.WriteString("enable secret 5 $1$mERr$hx5rVt7rPNoS4wqbXKX7m0\n")
|
||||
}
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("no aaa new-model\n")
|
||||
b.WriteString("!\n")
|
||||
|
||||
for _, iface := range s.interfaces {
|
||||
b.WriteString("!\n")
|
||||
fmt.Fprintf(&b, "interface %s\n", iface.name)
|
||||
if iface.desc != "" {
|
||||
fmt.Fprintf(&b, " description %s\n", iface.desc)
|
||||
}
|
||||
if iface.ip != "unassigned" && iface.ip != "" {
|
||||
fmt.Fprintf(&b, " ip address %s %s\n", iface.ip, iface.mask)
|
||||
} else {
|
||||
b.WriteString(" no ip address\n")
|
||||
}
|
||||
if iface.shutdown {
|
||||
b.WriteString(" shutdown\n")
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("ip forward-protocol nd\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("ip route 0.0.0.0 0.0.0.0 10.0.0.2\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("access-list 10 permit 192.168.1.0 0.0.0.255\n")
|
||||
b.WriteString("access-list 10 deny any\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("line con 0\n")
|
||||
b.WriteString(" logging synchronous\n")
|
||||
b.WriteString("line vty 0 4\n")
|
||||
b.WriteString(" login local\n")
|
||||
b.WriteString(" transport input ssh\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("end")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func showVLANBrief() string {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "VLAN", "Name", "Status", "Ports")
|
||||
b.WriteString("---- -------------------------------- --------- -------------------------------\n")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1", "default", "active", "Gi0/0, Gi0/1, Gi0/2")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "10", "MGMT", "active", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "20", "USERS", "active", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "99", "NATIVE", "active", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1002", "fddi-default", "act/unsup", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1003", "token-ring-default", "act/unsup", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s", "1004", "fddinet-default", "act/unsup", "")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// networkFromIP derives the network address from an IP and mask.
|
||||
func networkFromIP(ip, mask string) string {
|
||||
ipParts := parseIPv4(ip)
|
||||
maskParts := parseIPv4(mask)
|
||||
if ipParts == nil || maskParts == nil {
|
||||
return ip
|
||||
}
|
||||
return fmt.Sprintf("%d.%d.%d.%d",
|
||||
ipParts[0]&maskParts[0],
|
||||
ipParts[1]&maskParts[1],
|
||||
ipParts[2]&maskParts[2],
|
||||
ipParts[3]&maskParts[3],
|
||||
)
|
||||
}
|
||||
|
||||
func maskBits(mask string) int {
|
||||
parts := parseIPv4(mask)
|
||||
if parts == nil {
|
||||
return 24
|
||||
}
|
||||
bits := 0
|
||||
for _, p := range parts {
|
||||
for i := 7; i >= 0; i-- {
|
||||
if p&(1<<uint(i)) != 0 {
|
||||
bits++
|
||||
} else {
|
||||
return bits
|
||||
}
|
||||
}
|
||||
}
|
||||
return bits
|
||||
}
|
||||
|
||||
func parseIPv4(s string) []int {
|
||||
var a, b, c, d int
|
||||
n, _ := fmt.Sscanf(s, "%d.%d.%d.%d", &a, &b, &c, &d)
|
||||
if n != 4 {
|
||||
return nil
|
||||
}
|
||||
return []int{a, b, c, d}
|
||||
}
|
||||
109
internal/shell/cisco/state.go
Normal file
109
internal/shell/cisco/state.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package cisco
|
||||
|
||||
import "fmt"
|
||||
|
||||
// iosMode represents the current CLI mode of the IOS state machine.
|
||||
type iosMode int
|
||||
|
||||
const (
|
||||
modeUserExec iosMode = iota // Router>
|
||||
modePrivilegedExec // Router#
|
||||
modeGlobalConfig // Router(config)#
|
||||
modeInterfaceConfig // Router(config-if)#
|
||||
)
|
||||
|
||||
// ifaceInfo holds interface metadata for show commands.
|
||||
type ifaceInfo struct {
|
||||
name string
|
||||
ip string
|
||||
mask string
|
||||
status string
|
||||
protocol string
|
||||
mac string
|
||||
bandwidth string
|
||||
mtu int
|
||||
rxPackets int
|
||||
txPackets int
|
||||
rxBytes int
|
||||
txBytes int
|
||||
shutdown bool
|
||||
desc string
|
||||
}
|
||||
|
||||
// iosState holds all mutable state for the Cisco IOS shell session.
|
||||
type iosState struct {
|
||||
mode iosMode
|
||||
hostname string
|
||||
model string
|
||||
iosVersion string
|
||||
serial string
|
||||
enablePass string
|
||||
interfaces []ifaceInfo
|
||||
currentIf string
|
||||
}
|
||||
|
||||
func newIOSState(hostname, model, iosVersion, enablePass string) *iosState {
|
||||
return &iosState{
|
||||
mode: modeUserExec,
|
||||
hostname: hostname,
|
||||
model: model,
|
||||
iosVersion: iosVersion,
|
||||
serial: "FTX1524Z0P3",
|
||||
enablePass: enablePass,
|
||||
interfaces: defaultInterfaces(),
|
||||
}
|
||||
}
|
||||
|
||||
func defaultInterfaces() []ifaceInfo {
|
||||
return []ifaceInfo{
|
||||
{
|
||||
name: "GigabitEthernet0/0", ip: "192.168.1.1", mask: "255.255.255.0",
|
||||
status: "up", protocol: "up", mac: "0050.7966.6800",
|
||||
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||
rxPackets: 148253, txPackets: 93127, rxBytes: 19284732, txBytes: 8291043,
|
||||
},
|
||||
{
|
||||
name: "GigabitEthernet0/1", ip: "10.0.0.1", mask: "255.255.255.252",
|
||||
status: "up", protocol: "up", mac: "0050.7966.6801",
|
||||
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||
rxPackets: 52104, txPackets: 48891, rxBytes: 4182934, txBytes: 3901284,
|
||||
},
|
||||
{
|
||||
name: "GigabitEthernet0/2", ip: "unassigned", mask: "",
|
||||
status: "administratively down", protocol: "down", mac: "0050.7966.6802",
|
||||
bandwidth: "1000000 Kbit", mtu: 1500, shutdown: true,
|
||||
},
|
||||
{
|
||||
name: "Vlan1", ip: "172.16.0.1", mask: "255.255.0.0",
|
||||
status: "up", protocol: "up", mac: "0050.7966.6810",
|
||||
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||
rxPackets: 8421, txPackets: 7103, rxBytes: 512384, txBytes: 423901,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// prompt returns the IOS prompt string for the current mode.
|
||||
func (s *iosState) prompt() string {
|
||||
switch s.mode {
|
||||
case modeUserExec:
|
||||
return fmt.Sprintf("%s>", s.hostname)
|
||||
case modePrivilegedExec:
|
||||
return fmt.Sprintf("%s#", s.hostname)
|
||||
case modeGlobalConfig:
|
||||
return fmt.Sprintf("%s(config)#", s.hostname)
|
||||
case modeInterfaceConfig:
|
||||
return fmt.Sprintf("%s(config-if)#", s.hostname)
|
||||
default:
|
||||
return fmt.Sprintf("%s>", s.hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// findInterface returns a pointer to the interface with the given name, or nil.
|
||||
func (s *iosState) findInterface(name string) *ifaceInfo {
|
||||
for i := range s.interfaces {
|
||||
if s.interfaces[i].name == name {
|
||||
return &s.interfaces[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// EventRecorder buffers I/O events in memory and periodically flushes them to
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func TestEventRecorderFlush(t *testing.T) {
|
||||
@@ -14,7 +14,7 @@ func TestEventRecorderFlush(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session so events have a valid session ID.
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
@@ -55,7 +55,7 @@ func TestEventRecorderPeriodicFlush(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
@@ -69,6 +69,9 @@ func (f *FridgeShell) Handle(ctx context.Context, sess *shell.SessionContext, rw
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("fridge")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
@@ -22,7 +22,7 @@ func (r *rwCloser) Close() error { return nil }
|
||||
func runShell(t *testing.T, commands string) string {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge")
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
@@ -205,7 +205,7 @@ func TestLogoutCommand(t *testing.T) {
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge")
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
|
||||
123
internal/shell/psql/commands.go
Normal file
123
internal/shell/psql/commands.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// commandResult holds the output of a command and whether the session should end.
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
// dispatchBackslash handles psql backslash meta-commands.
|
||||
func dispatchBackslash(cmd, dbName string) commandResult {
|
||||
// Normalize: trim spaces after the backslash command word.
|
||||
parts := strings.Fields(cmd)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{output: "Invalid command \\. Try \\? for help."}
|
||||
}
|
||||
|
||||
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
|
||||
args := parts[1:]
|
||||
|
||||
switch verb {
|
||||
case `\q`:
|
||||
return commandResult{exit: true}
|
||||
case `\dt`:
|
||||
return commandResult{output: listTables()}
|
||||
case `\d`:
|
||||
if len(args) == 0 {
|
||||
return commandResult{output: listTables()}
|
||||
}
|
||||
return commandResult{output: describeTable(args[0])}
|
||||
case `\l`:
|
||||
return commandResult{output: listDatabases()}
|
||||
case `\du`:
|
||||
return commandResult{output: listRoles()}
|
||||
case `\conninfo`:
|
||||
return commandResult{output: connInfo(dbName)}
|
||||
case `\?`:
|
||||
return commandResult{output: backslashHelp()}
|
||||
case `\h`:
|
||||
return commandResult{output: sqlHelp()}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
|
||||
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
|
||||
// Strip trailing semicolon and whitespace for matching.
|
||||
trimmed := strings.TrimRight(sql, "; \t")
|
||||
trimmed = strings.TrimSpace(trimmed)
|
||||
upper := strings.ToUpper(trimmed)
|
||||
|
||||
switch {
|
||||
case upper == "SELECT VERSION()":
|
||||
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
|
||||
return commandResult{output: formatSingleValue("version", ver)}
|
||||
case upper == "SELECT CURRENT_DATABASE()":
|
||||
return commandResult{output: formatSingleValue("current_database", dbName)}
|
||||
case upper == "SELECT CURRENT_USER":
|
||||
return commandResult{output: formatSingleValue("current_user", "postgres")}
|
||||
case upper == "SELECT NOW()":
|
||||
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
|
||||
return commandResult{output: formatSingleValue("now", now)}
|
||||
case upper == "SELECT 1":
|
||||
return commandResult{output: formatSingleValue("?column?", "1")}
|
||||
case strings.HasPrefix(upper, "INSERT"):
|
||||
return commandResult{output: "INSERT 0 1"}
|
||||
case strings.HasPrefix(upper, "UPDATE"):
|
||||
return commandResult{output: "UPDATE 1"}
|
||||
case strings.HasPrefix(upper, "DELETE"):
|
||||
return commandResult{output: "DELETE 1"}
|
||||
case strings.HasPrefix(upper, "CREATE TABLE"):
|
||||
return commandResult{output: "CREATE TABLE"}
|
||||
case strings.HasPrefix(upper, "CREATE DATABASE"):
|
||||
return commandResult{output: "CREATE DATABASE"}
|
||||
case strings.HasPrefix(upper, "DROP TABLE"):
|
||||
return commandResult{output: "DROP TABLE"}
|
||||
case strings.HasPrefix(upper, "ALTER TABLE"):
|
||||
return commandResult{output: "ALTER TABLE"}
|
||||
case upper == "BEGIN":
|
||||
return commandResult{output: "BEGIN"}
|
||||
case upper == "COMMIT":
|
||||
return commandResult{output: "COMMIT"}
|
||||
case upper == "ROLLBACK":
|
||||
return commandResult{output: "ROLLBACK"}
|
||||
case upper == "SHOW SERVER_VERSION":
|
||||
return commandResult{output: formatSingleValue("server_version", pgVersion)}
|
||||
case upper == "SHOW SEARCH_PATH":
|
||||
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
|
||||
case strings.HasPrefix(upper, "SET "):
|
||||
return commandResult{output: "SET"}
|
||||
default:
|
||||
// Extract the first token for the error message.
|
||||
firstToken := strings.Fields(trimmed)
|
||||
token := trimmed
|
||||
if len(firstToken) > 0 {
|
||||
token = firstToken[0]
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
|
||||
}
|
||||
}
|
||||
|
||||
// formatSingleValue formats a single-row, single-column psql result.
|
||||
func formatSingleValue(colName, value string) string {
|
||||
width := max(len(colName), len(value))
|
||||
|
||||
var b strings.Builder
|
||||
// Header
|
||||
fmt.Fprintf(&b, " %-*s \n", width, colName)
|
||||
// Separator
|
||||
b.WriteString(strings.Repeat("-", width+2))
|
||||
b.WriteString("\n")
|
||||
// Value
|
||||
fmt.Fprintf(&b, " %-*s\n", width, value)
|
||||
// Row count
|
||||
b.WriteString("(1 row)")
|
||||
return b.String()
|
||||
}
|
||||
155
internal/shell/psql/output.go
Normal file
155
internal/shell/psql/output.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package psql
|
||||
|
||||
import "fmt"
|
||||
|
||||
func startupBanner(version string) string {
|
||||
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
|
||||
}
|
||||
|
||||
func listTables() string {
|
||||
return ` List of relations
|
||||
Schema | Name | Type | Owner
|
||||
--------+---------------+-------+----------
|
||||
public | audit_log | table | postgres
|
||||
public | credentials | table | postgres
|
||||
public | sessions | table | postgres
|
||||
public | users | table | postgres
|
||||
(4 rows)`
|
||||
}
|
||||
|
||||
func listDatabases() string {
|
||||
return ` List of databases
|
||||
Name | Owner | Encoding | Collate | Ctype | Access privileges
|
||||
-----------+----------+----------+-------------+-------------+-----------------------
|
||||
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||
| | | | | postgres=CTc/postgres
|
||||
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||
| | | | | postgres=CTc/postgres
|
||||
(4 rows)`
|
||||
}
|
||||
|
||||
func listRoles() string {
|
||||
return ` List of roles
|
||||
Role name | Attributes | Member of
|
||||
-----------+------------------------------------------------------------+-----------
|
||||
app_user | | {}
|
||||
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
|
||||
readonly | Cannot login | {}`
|
||||
}
|
||||
|
||||
func describeTable(name string) string {
|
||||
switch name {
|
||||
case "users":
|
||||
return ` Table "public.users"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+-----------------------------------
|
||||
id | integer | | not null | nextval('users_id_seq'::regclass)
|
||||
username | character varying(255) | | not null |
|
||||
email | character varying(255) | | not null |
|
||||
password | character varying(255) | | not null |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
updated_at | timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"users_pkey" PRIMARY KEY, btree (id)
|
||||
"users_email_key" UNIQUE, btree (email)
|
||||
"users_username_key" UNIQUE, btree (username)`
|
||||
case "sessions":
|
||||
return ` Table "public.sessions"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+--------------------------------------
|
||||
id | integer | | not null | nextval('sessions_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
token | character varying(255) | | not null |
|
||||
ip_address | inet | | |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
expires_at | timestamp without time zone | | not null |
|
||||
Indexes:
|
||||
"sessions_pkey" PRIMARY KEY, btree (id)
|
||||
"sessions_token_key" UNIQUE, btree (token)
|
||||
Foreign-key constraints:
|
||||
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
case "credentials":
|
||||
return ` Table "public.credentials"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
-----------+-----------------------------+-----------+----------+-----------------------------------------
|
||||
id | integer | | not null | nextval('credentials_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
type | character varying(50) | | not null |
|
||||
value | text | | not null |
|
||||
created_at| timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"credentials_pkey" PRIMARY KEY, btree (id)
|
||||
Foreign-key constraints:
|
||||
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
case "audit_log":
|
||||
return ` Table "public.audit_log"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+---------------------------------------
|
||||
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
action | character varying(100) | | not null |
|
||||
details | text | | |
|
||||
ip_address | inet | | |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"audit_log_pkey" PRIMARY KEY, btree (id)
|
||||
Foreign-key constraints:
|
||||
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
default:
|
||||
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
|
||||
}
|
||||
}
|
||||
|
||||
func connInfo(dbName string) string {
|
||||
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
|
||||
}
|
||||
|
||||
func backslashHelp() string {
|
||||
return `General
|
||||
\copyright show PostgreSQL usage and distribution terms
|
||||
\crosstabview [COLUMNS] execute query and display result in crosstab
|
||||
\errverbose show most recent error message at maximum verbosity
|
||||
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
|
||||
\gdesc describe result of query, without executing it
|
||||
\gexec execute query, then execute each value in its result
|
||||
\gset [PREFIX] execute query and store result in psql variables
|
||||
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
|
||||
\q quit psql
|
||||
\watch [SEC] execute query every SEC seconds
|
||||
|
||||
Informational
|
||||
(options: S = show system objects, + = additional detail)
|
||||
\d[S+] list tables, views, and sequences
|
||||
\d[S+] NAME describe table, view, sequence, or index
|
||||
\da[S] [PATTERN] list aggregates
|
||||
\dA[+] [PATTERN] list access methods
|
||||
\dt[S+] [PATTERN] list tables
|
||||
\du[S+] [PATTERN] list roles
|
||||
\l[+] [PATTERN] list databases`
|
||||
}
|
||||
|
||||
func sqlHelp() string {
|
||||
return `Available help:
|
||||
ABORT CREATE LANGUAGE
|
||||
ALTER AGGREGATE CREATE MATERIALIZED VIEW
|
||||
ALTER COLLATION CREATE OPERATOR
|
||||
ALTER CONVERSION CREATE POLICY
|
||||
ALTER DATABASE CREATE PROCEDURE
|
||||
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
|
||||
ALTER DOMAIN CREATE ROLE
|
||||
ALTER EVENT TRIGGER CREATE RULE
|
||||
ALTER EXTENSION CREATE SCHEMA
|
||||
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
|
||||
ALTER FOREIGN TABLE CREATE SERVER
|
||||
ALTER FUNCTION CREATE STATISTICS
|
||||
ALTER GROUP CREATE SUBSCRIPTION
|
||||
ALTER INDEX CREATE TABLE
|
||||
ALTER LANGUAGE CREATE TABLESPACE
|
||||
BEGIN DELETE
|
||||
COMMIT DROP TABLE
|
||||
CREATE DATABASE INSERT
|
||||
CREATE INDEX ROLLBACK
|
||||
SELECT UPDATE`
|
||||
}
|
||||
137
internal/shell/psql/psql.go
Normal file
137
internal/shell/psql/psql.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// PsqlShell emulates a PostgreSQL psql interactive terminal.
|
||||
type PsqlShell struct{}
|
||||
|
||||
// NewPsqlShell returns a new PsqlShell instance.
|
||||
func NewPsqlShell() *PsqlShell {
|
||||
return &PsqlShell{}
|
||||
}
|
||||
|
||||
func (p *PsqlShell) Name() string { return "psql" }
|
||||
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
|
||||
|
||||
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dbName := configString(sess.ShellConfig, "db_name", "postgres")
|
||||
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
|
||||
|
||||
// Print startup banner.
|
||||
fmt.Fprint(rw, startupBanner(pgVersion))
|
||||
|
||||
var sqlBuf []string // accumulates multi-line SQL
|
||||
|
||||
for {
|
||||
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
|
||||
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Empty line in non-buffering state: just re-prompt.
|
||||
if trimmed == "" && len(sqlBuf) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
|
||||
if strings.HasPrefix(trimmed, `\`) {
|
||||
sqlBuf = nil // discard any partial SQL
|
||||
|
||||
result := dispatchBackslash(trimmed, dbName)
|
||||
if result.output != "" {
|
||||
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("psql")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate SQL lines.
|
||||
sqlBuf = append(sqlBuf, line)
|
||||
|
||||
// Check if the statement is terminated by a semicolon.
|
||||
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Full statement ready — join and dispatch.
|
||||
fullSQL := strings.Join(sqlBuf, " ")
|
||||
sqlBuf = nil
|
||||
|
||||
result := dispatchSQL(fullSQL, dbName, pgVersion)
|
||||
if result.output != "" {
|
||||
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("psql")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
|
||||
func buildPrompt(dbName string, continuation bool) string {
|
||||
if continuation {
|
||||
return dbName + "-# "
|
||||
}
|
||||
return dbName + "=# "
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
330
internal/shell/psql/psql_test.go
Normal file
330
internal/shell/psql/psql_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Prompt tests ---
|
||||
|
||||
func TestBuildPromptNormal(t *testing.T) {
|
||||
got := buildPrompt("postgres", false)
|
||||
if got != "postgres=# " {
|
||||
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptContinuation(t *testing.T) {
|
||||
got := buildPrompt("postgres", true)
|
||||
if got != "postgres-# " {
|
||||
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptCustomDB(t *testing.T) {
|
||||
got := buildPrompt("mydb", false)
|
||||
if got != "mydb=# " {
|
||||
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Backslash command dispatch tests ---
|
||||
|
||||
func TestBackslashQuit(t *testing.T) {
|
||||
result := dispatchBackslash(`\q`, "postgres")
|
||||
if !result.exit {
|
||||
t.Error("\\q should set exit=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListTables(t *testing.T) {
|
||||
result := dispatchBackslash(`\dt`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\dt should list tables including 'users'")
|
||||
}
|
||||
if !strings.Contains(result.output, "sessions") {
|
||||
t.Error("\\dt should list tables including 'sessions'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d users`, "postgres")
|
||||
if !strings.Contains(result.output, "username") {
|
||||
t.Error("\\d users should describe users table with 'username' column")
|
||||
}
|
||||
if !strings.Contains(result.output, "PRIMARY KEY") {
|
||||
t.Error("\\d users should include index info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeUnknownTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d nonexistent`, "postgres")
|
||||
if !strings.Contains(result.output, "Did not find") {
|
||||
t.Error("\\d nonexistent should return not found message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListDatabases(t *testing.T) {
|
||||
result := dispatchBackslash(`\l`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\l should list databases including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "template0") {
|
||||
t.Error("\\l should list databases including 'template0'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListRoles(t *testing.T) {
|
||||
result := dispatchBackslash(`\du`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\du should list roles including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "Superuser") {
|
||||
t.Error("\\du should show Superuser attribute for postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashConnInfo(t *testing.T) {
|
||||
result := dispatchBackslash(`\conninfo`, "mydb")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("\\conninfo should include database name")
|
||||
}
|
||||
if !strings.Contains(result.output, "5432") {
|
||||
t.Error("\\conninfo should include port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\?`, "postgres")
|
||||
if !strings.Contains(result.output, `\q`) {
|
||||
t.Error("\\? should include \\q in help output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashSQLHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\h`, "postgres")
|
||||
if !strings.Contains(result.output, "SELECT") {
|
||||
t.Error("\\h should include SQL commands like SELECT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashUnknown(t *testing.T) {
|
||||
result := dispatchBackslash(`\xyz`, "postgres")
|
||||
if !strings.Contains(result.output, "Invalid command") {
|
||||
t.Error("unknown backslash command should return error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SQL dispatch tests ---
|
||||
|
||||
func TestSQLSelectVersion(t *testing.T) {
|
||||
result := dispatchSQL("SELECT version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SELECT version() should contain pg version")
|
||||
}
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT version() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentDatabase(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("SELECT current_database() should return db name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentUser(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("SELECT current_user should return postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectNow(t *testing.T) {
|
||||
result := dispatchSQL("SELECT now();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT now() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectOne(t *testing.T) {
|
||||
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "1") {
|
||||
t.Error("SELECT 1 should return 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLInsert(t *testing.T) {
|
||||
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
|
||||
if result.output != "INSERT 0 1" {
|
||||
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUpdate(t *testing.T) {
|
||||
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
|
||||
if result.output != "UPDATE 1" {
|
||||
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDelete(t *testing.T) {
|
||||
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
|
||||
if result.output != "DELETE 1" {
|
||||
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateTable(t *testing.T) {
|
||||
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
|
||||
if result.output != "CREATE TABLE" {
|
||||
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateDatabase(t *testing.T) {
|
||||
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
|
||||
if result.output != "CREATE DATABASE" {
|
||||
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDropTable(t *testing.T) {
|
||||
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
|
||||
if result.output != "DROP TABLE" {
|
||||
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLAlterTable(t *testing.T) {
|
||||
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
|
||||
if result.output != "ALTER TABLE" {
|
||||
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLBeginCommitRollback(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
want string
|
||||
}{
|
||||
{"BEGIN;", "BEGIN"},
|
||||
{"COMMIT;", "COMMIT"},
|
||||
{"ROLLBACK;", "ROLLBACK"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := dispatchSQL(tt.sql, "postgres", "15.4")
|
||||
if result.output != tt.want {
|
||||
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowServerVersion(t *testing.T) {
|
||||
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SHOW server_version should contain version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowSearchPath(t *testing.T) {
|
||||
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "public") {
|
||||
t.Error("SHOW search_path should contain public")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSet(t *testing.T) {
|
||||
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
|
||||
if result.output != "SET" {
|
||||
t.Errorf("SET output = %q, want %q", result.output, "SET")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUnrecognized(t *testing.T) {
|
||||
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "ERROR") {
|
||||
t.Error("unrecognized SQL should return error")
|
||||
}
|
||||
if !strings.Contains(result.output, "FOOBAR") {
|
||||
t.Error("error should reference the offending token")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Case insensitivity ---
|
||||
|
||||
func TestSQLCaseInsensitive(t *testing.T) {
|
||||
result := dispatchSQL("select version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("select version() (lowercase) should work")
|
||||
}
|
||||
|
||||
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("mixed case SELECT should work")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Startup banner ---
|
||||
|
||||
func TestStartupBanner(t *testing.T) {
|
||||
banner := startupBanner("15.4")
|
||||
if !strings.Contains(banner, "psql (15.4)") {
|
||||
t.Errorf("banner should contain version, got: %s", banner)
|
||||
}
|
||||
if !strings.Contains(banner, "help") {
|
||||
t.Error("banner should mention help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- configString ---
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"db_name": "mydb"}
|
||||
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
|
||||
t.Errorf("configString() = %q, want %q", got, "mydb")
|
||||
}
|
||||
if got := configString(cfg, "missing", "default"); got != "default" {
|
||||
t.Errorf("configString() for missing = %q, want %q", got, "default")
|
||||
}
|
||||
if got := configString(nil, "key", "default"); got != "default" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shell metadata ---
|
||||
|
||||
func TestShellNameAndDescription(t *testing.T) {
|
||||
s := NewPsqlShell()
|
||||
if s.Name() != "psql" {
|
||||
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
|
||||
}
|
||||
if s.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// --- formatSingleValue ---
|
||||
|
||||
func TestFormatSingleValue(t *testing.T) {
|
||||
out := formatSingleValue("?column?", "1")
|
||||
if !strings.Contains(out, "?column?") {
|
||||
t.Error("should contain column name")
|
||||
}
|
||||
if !strings.Contains(out, "1") {
|
||||
t.Error("should contain value")
|
||||
}
|
||||
if !strings.Contains(out, "(1 row)") {
|
||||
t.Error("should contain row count")
|
||||
}
|
||||
}
|
||||
|
||||
// --- \d with no args ---
|
||||
|
||||
func TestBackslashDescribeNoArgs(t *testing.T) {
|
||||
result := dispatchBackslash(`\d`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\d with no args should list tables")
|
||||
}
|
||||
}
|
||||
463
internal/shell/roomba/roomba.go
Normal file
463
internal/shell/roomba/roomba.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package roomba
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
|
||||
type RoombaShell struct{}
|
||||
|
||||
// NewRoombaShell returns a new RoombaShell instance.
|
||||
func NewRoombaShell() *RoombaShell {
|
||||
return &RoombaShell{}
|
||||
}
|
||||
|
||||
func (r *RoombaShell) Name() string { return "roomba" }
|
||||
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
|
||||
|
||||
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
state := newRoombaState()
|
||||
|
||||
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
|
||||
fmt.Fprint(rw, banner)
|
||||
|
||||
for {
|
||||
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Fprint(rw, "logout\r\n")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := state.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("roomba")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bootBanner() string {
|
||||
return `
|
||||
____ _ ___ ____
|
||||
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
|
||||
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
|
||||
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|
||||
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
|
||||
|
||||
iRobot Roomba j7+ | RoombaOS v4.3.7
|
||||
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
|
||||
Battery: 73% | WiFi: Connected (SmartHome-5G)
|
||||
|
||||
Type 'help' for available commands.
|
||||
|
||||
`
|
||||
}
|
||||
|
||||
type room struct {
|
||||
name string
|
||||
areaSqFt int
|
||||
lastCleaned time.Time
|
||||
}
|
||||
|
||||
type scheduleEntry struct {
|
||||
day string
|
||||
time string
|
||||
}
|
||||
|
||||
type historyEntry struct {
|
||||
timestamp time.Time
|
||||
room string
|
||||
duration string
|
||||
note string
|
||||
}
|
||||
|
||||
type roombaState struct {
|
||||
battery int
|
||||
dustbin int
|
||||
status string
|
||||
rooms []room
|
||||
schedule []scheduleEntry
|
||||
cleanHistory []historyEntry
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
func newRoombaState() *roombaState {
|
||||
now := time.Now()
|
||||
return &roombaState{
|
||||
battery: 73,
|
||||
dustbin: 61,
|
||||
status: "Docked",
|
||||
rooms: []room{
|
||||
{"Kitchen", 180, now.Add(-2 * time.Hour)},
|
||||
{"Living Room", 320, now.Add(-5 * time.Hour)},
|
||||
{"Bedroom", 200, now.Add(-26 * time.Hour)},
|
||||
{"Hallway", 60, now.Add(-5 * time.Hour)},
|
||||
{"Bathroom", 75, now.Add(-50 * time.Hour)},
|
||||
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
|
||||
},
|
||||
schedule: []scheduleEntry{
|
||||
{"Monday", "09:00"},
|
||||
{"Wednesday", "09:00"},
|
||||
{"Friday", "09:00"},
|
||||
{"Saturday", "14:00"},
|
||||
},
|
||||
cleanHistory: []historyEntry{
|
||||
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
|
||||
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
|
||||
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
|
||||
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
|
||||
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
|
||||
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) dispatch(input string) commandResult {
|
||||
parts := strings.Fields(input)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
cmd := strings.ToLower(parts[0])
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "help":
|
||||
return s.cmdHelp()
|
||||
case "status":
|
||||
return s.cmdStatus()
|
||||
case "clean":
|
||||
return s.cmdClean(args)
|
||||
case "dock":
|
||||
return s.cmdDock()
|
||||
case "map":
|
||||
return s.cmdMap()
|
||||
case "schedule":
|
||||
return s.cmdSchedule(args)
|
||||
case "history":
|
||||
return s.cmdHistory()
|
||||
case "diagnostics":
|
||||
return s.cmdDiagnostics()
|
||||
case "alerts":
|
||||
return s.cmdAlerts()
|
||||
case "reboot":
|
||||
return s.cmdReboot()
|
||||
case "exit", "logout":
|
||||
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdHelp() commandResult {
|
||||
help := `Available commands:
|
||||
help - Show this help message
|
||||
status - Show robot status
|
||||
clean - Start full cleaning job
|
||||
clean room <name> - Clean a specific room
|
||||
dock - Return to dock
|
||||
map - Show floor plan and room list
|
||||
schedule - List cleaning schedule
|
||||
schedule add <day> <time> - Add scheduled cleaning
|
||||
schedule remove <day> - Remove scheduled cleaning
|
||||
history - Show recent cleaning history
|
||||
diagnostics - Run system diagnostics
|
||||
alerts - Show active alerts
|
||||
reboot - Reboot RoombaOS
|
||||
exit / logout - Disconnect`
|
||||
return commandResult{output: help}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdStatus() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== RoombaOS System Status ===\n")
|
||||
b.WriteString("Model: iRobot Roomba j7+\n")
|
||||
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
|
||||
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
|
||||
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
|
||||
b.WriteString("Side brush: OK (142 hrs)\n")
|
||||
b.WriteString("Main brush: OK (98 hrs)\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
|
||||
b.WriteString("Signal: -38 dBm\n")
|
||||
b.WriteString("Alexa: Linked\n")
|
||||
b.WriteString("Google Home: Linked\n")
|
||||
b.WriteString("iRobot Home App: Connected\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("Firmware: v4.3.7-stable\n")
|
||||
b.WriteString("LIDAR: Operational\n")
|
||||
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdClean(args []string) commandResult {
|
||||
if s.status == "Cleaning" {
|
||||
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
|
||||
}
|
||||
|
||||
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
|
||||
roomName := strings.Join(args[1:], " ")
|
||||
for _, r := range s.rooms {
|
||||
if strings.EqualFold(r.name, roomName) {
|
||||
s.status = "Cleaning"
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
|
||||
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
|
||||
)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
return commandResult{output: "Usage: clean [room <name>]"}
|
||||
}
|
||||
|
||||
s.status = "Cleaning"
|
||||
var totalArea int
|
||||
for _, r := range s.rooms {
|
||||
totalArea += r.areaSqFt
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
|
||||
totalArea, len(s.rooms), totalArea/8,
|
||||
)}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdDock() commandResult {
|
||||
if s.status == "Docked" {
|
||||
return commandResult{output: "Already docked."}
|
||||
}
|
||||
if s.status == "Returning to dock" {
|
||||
return commandResult{output: "Already returning to dock."}
|
||||
}
|
||||
s.status = "Returning to dock"
|
||||
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdMap() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Floor Plan ===\n\n")
|
||||
b.WriteString(" +------------+----------+\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" | Kitchen | Bathroom |\n")
|
||||
b.WriteString(" | 180sqft | 75sqft |\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" +------+-----+----+-----+\n")
|
||||
b.WriteString(" | | | |\n")
|
||||
b.WriteString(" | Hall | Living | Cat |\n")
|
||||
b.WriteString(" | 60sf | Room | Rm |\n")
|
||||
b.WriteString(" | | 320sqft |110sf|\n")
|
||||
b.WriteString(" +------+ +-----+\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" | Bed +----------+\n")
|
||||
b.WriteString(" | room | [DOCK]\n")
|
||||
b.WriteString(" |200sf |\n")
|
||||
b.WriteString(" +------+\n")
|
||||
b.WriteString("\nRoom Details:\n")
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
|
||||
for _, r := range s.rooms {
|
||||
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
|
||||
}
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdSchedule(args []string) commandResult {
|
||||
if len(args) == 0 {
|
||||
return s.scheduleList()
|
||||
}
|
||||
|
||||
sub := strings.ToLower(args[0])
|
||||
switch sub {
|
||||
case "add":
|
||||
if len(args) < 3 {
|
||||
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
|
||||
}
|
||||
return s.scheduleAdd(args[1], args[2])
|
||||
case "remove":
|
||||
if len(args) < 2 {
|
||||
return commandResult{output: "Usage: schedule remove <day>"}
|
||||
}
|
||||
return s.scheduleRemove(args[1])
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleList() commandResult {
|
||||
if len(s.schedule) == 0 {
|
||||
return commandResult{output: "No cleaning schedule configured."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Cleaning Schedule ===\n")
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
|
||||
for _, e := range s.schedule {
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleAdd(day, t string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
|
||||
if !slices.Contains(validDays, day) {
|
||||
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
|
||||
}
|
||||
|
||||
for _, e := range s.schedule {
|
||||
if strings.EqualFold(e.day, day) {
|
||||
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
|
||||
}
|
||||
}
|
||||
|
||||
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
|
||||
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleRemove(day string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
for i, e := range s.schedule {
|
||||
if strings.EqualFold(e.day, day) {
|
||||
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
|
||||
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdHistory() commandResult {
|
||||
if len(s.cleanHistory) == 0 {
|
||||
return commandResult{output: "No cleaning history."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Cleaning History ===\n")
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
|
||||
for _, h := range s.cleanHistory {
|
||||
ts := h.timestamp.Format("2006-01-02 15:04")
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdDiagnostics() commandResult {
|
||||
diag := `Running RoombaOS diagnostics...
|
||||
|
||||
[1/8] Cliff sensors........... OK
|
||||
[2/8] Bumper sensor........... OK
|
||||
[3/8] Side brush motor........ OK (142 hrs until replacement)
|
||||
[4/8] Main brush motor........ OK (98 hrs until replacement)
|
||||
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
|
||||
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
|
||||
[7/8] Dustbin sensor.......... OK
|
||||
[8/8] WiFi module............. OK (signal: -38 dBm)
|
||||
|
||||
ALL SYSTEMS NOMINAL`
|
||||
return commandResult{output: diag}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdAlerts() commandResult {
|
||||
var alerts []string
|
||||
if s.dustbin >= 60 {
|
||||
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
|
||||
}
|
||||
alerts = append(alerts,
|
||||
"WARNING: Side brush replacement due in 12 hours",
|
||||
"INFO: Unidentified sticky substance detected in Kitchen",
|
||||
"INFO: Cat frequently blocking cleaning path in Cat's Room",
|
||||
"INFO: Firmware update available: v4.4.0-beta",
|
||||
"INFO: Filter replacement recommended in 14 days",
|
||||
)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Active Alerts ===\n")
|
||||
for _, a := range alerts {
|
||||
b.WriteString(a + "\n")
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdReboot() commandResult {
|
||||
reboot := `RoombaOS is rebooting...
|
||||
|
||||
Stopping navigation engine..... done
|
||||
Saving room map data........... done
|
||||
Flushing cleaning logs......... done
|
||||
Disconnecting from WiFi........ done
|
||||
|
||||
Rebooting now. Goodbye!`
|
||||
return commandResult{output: reboot, exit: true}
|
||||
}
|
||||
|
||||
func capitalizeFirst(s string) string {
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
return strings.ToUpper(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
hours := int(d.Hours())
|
||||
minutes := int(d.Minutes()) % 60
|
||||
if hours >= 24 {
|
||||
days := hours / 24
|
||||
hours %= 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%dm", minutes)
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// Shell is the interface that all honeypot shell implementations must satisfy.
|
||||
@@ -24,6 +24,7 @@ type SessionContext struct {
|
||||
Store storage.Store
|
||||
ShellConfig map[string]any
|
||||
CommonConfig ShellCommonConfig
|
||||
OnCommand func(shell string) // called when a command is executed; may be nil
|
||||
}
|
||||
|
||||
// ShellCommonConfig holds settings shared across all shell types.
|
||||
|
||||
101
internal/shell/tetris/data.go
Normal file
101
internal/shell/tetris/data.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package tetris
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
// pieceType identifies a tetromino (0–6).
|
||||
type pieceType int
|
||||
|
||||
const (
|
||||
pieceI pieceType = iota
|
||||
pieceO
|
||||
pieceT
|
||||
pieceS
|
||||
pieceZ
|
||||
pieceJ
|
||||
pieceL
|
||||
)
|
||||
|
||||
const numPieceTypes = 7
|
||||
|
||||
// Standard Tetris colors.
|
||||
var pieceColors = [numPieceTypes]lipgloss.Color{
|
||||
lipgloss.Color("#00FFFF"), // I — cyan
|
||||
lipgloss.Color("#FFFF00"), // O — yellow
|
||||
lipgloss.Color("#AA00FF"), // T — purple
|
||||
lipgloss.Color("#00FF00"), // S — green
|
||||
lipgloss.Color("#FF0000"), // Z — red
|
||||
lipgloss.Color("#0000FF"), // J — blue
|
||||
lipgloss.Color("#FF8800"), // L — orange
|
||||
}
|
||||
|
||||
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
|
||||
// relative to the piece origin.
|
||||
type rotation [4][2]int
|
||||
|
||||
var pieces = [numPieceTypes][4]rotation{
|
||||
// I
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||
},
|
||||
// O
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
},
|
||||
// T
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
// S
|
||||
{
|
||||
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
// Z
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||
},
|
||||
// J
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||
},
|
||||
// L
|
||||
{
|
||||
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
// spawnCol returns the starting column for a piece, centering it on the board.
|
||||
func spawnCol(pt pieceType, rot int) int {
|
||||
shape := pieces[pt][rot]
|
||||
minC, maxC := shape[0][1], shape[0][1]
|
||||
for _, off := range shape {
|
||||
if off[1] < minC {
|
||||
minC = off[1]
|
||||
}
|
||||
if off[1] > maxC {
|
||||
maxC = off[1]
|
||||
}
|
||||
}
|
||||
width := maxC - minC + 1
|
||||
return (boardCols - width) / 2
|
||||
}
|
||||
210
internal/shell/tetris/game.go
Normal file
210
internal/shell/tetris/game.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package tetris
|
||||
|
||||
import "math/rand/v2"
|
||||
|
||||
const (
|
||||
boardRows = 20
|
||||
boardCols = 10
|
||||
)
|
||||
|
||||
// cell represents a single board cell. Zero value is empty.
|
||||
type cell struct {
|
||||
filled bool
|
||||
piece pieceType // which piece type filled this cell (for color)
|
||||
}
|
||||
|
||||
// gameState holds all mutable state for a Tetris game.
|
||||
type gameState struct {
|
||||
board [boardRows][boardCols]cell
|
||||
current pieceType
|
||||
currentRot int
|
||||
currentRow int
|
||||
currentCol int
|
||||
next pieceType
|
||||
score int
|
||||
level int
|
||||
lines int
|
||||
gameOver bool
|
||||
}
|
||||
|
||||
// newGame creates a new game state, optionally starting at a given level.
|
||||
func newGame(startLevel int) *gameState {
|
||||
g := &gameState{
|
||||
level: startLevel,
|
||||
next: pieceType(rand.IntN(numPieceTypes)),
|
||||
}
|
||||
g.spawnPiece()
|
||||
return g
|
||||
}
|
||||
|
||||
// spawnPiece pulls the next piece and generates a new next.
|
||||
func (g *gameState) spawnPiece() {
|
||||
g.current = g.next
|
||||
g.next = pieceType(rand.IntN(numPieceTypes))
|
||||
g.currentRot = 0
|
||||
g.currentRow = 0
|
||||
g.currentCol = spawnCol(g.current, 0)
|
||||
|
||||
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||
g.gameOver = true
|
||||
}
|
||||
}
|
||||
|
||||
// canPlace checks whether the piece fits at the given position.
|
||||
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
|
||||
shape := pieces[pt][rot]
|
||||
for _, off := range shape {
|
||||
r, c := row+off[0], col+off[1]
|
||||
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
|
||||
return false
|
||||
}
|
||||
if g.board[r][c].filled {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// moveLeft moves the current piece left if possible.
|
||||
func (g *gameState) moveLeft() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
|
||||
g.currentCol--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// moveRight moves the current piece right if possible.
|
||||
func (g *gameState) moveRight() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
|
||||
g.currentCol++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// moveDown moves the current piece down one row. Returns false if it cannot.
|
||||
func (g *gameState) moveDown() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
|
||||
g.currentRow++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// rotate rotates the current piece clockwise with wall kick attempts.
|
||||
func (g *gameState) rotate() bool {
|
||||
newRot := (g.currentRot + 1) % 4
|
||||
|
||||
// Try in-place first.
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
|
||||
g.currentRot = newRot
|
||||
return true
|
||||
}
|
||||
|
||||
// Wall kick: try +-1 column offset.
|
||||
for _, offset := range []int{-1, 1} {
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||
g.currentRot = newRot
|
||||
g.currentCol += offset
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// I piece: try +-2.
|
||||
if g.current == pieceI {
|
||||
for _, offset := range []int{-2, 2} {
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||
g.currentRot = newRot
|
||||
g.currentCol += offset
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ghostRow returns the row where the piece would land.
|
||||
func (g *gameState) ghostRow() int {
|
||||
row := g.currentRow
|
||||
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
|
||||
row++
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
|
||||
func (g *gameState) hardDrop() int {
|
||||
ghost := g.ghostRow()
|
||||
dropped := ghost - g.currentRow
|
||||
g.currentRow = ghost
|
||||
return dropped
|
||||
}
|
||||
|
||||
// lockPiece writes the current piece into the board.
|
||||
func (g *gameState) lockPiece() {
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||
g.board[r][c] = cell{filled: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clearLines removes completed rows and returns how many were cleared.
|
||||
func (g *gameState) clearLines() int {
|
||||
cleared := 0
|
||||
for r := boardRows - 1; r >= 0; r-- {
|
||||
full := true
|
||||
for c := range boardCols {
|
||||
if !g.board[r][c].filled {
|
||||
full = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if full {
|
||||
cleared++
|
||||
// Shift everything above down.
|
||||
for rr := r; rr > 0; rr-- {
|
||||
g.board[rr] = g.board[rr-1]
|
||||
}
|
||||
g.board[0] = [boardCols]cell{}
|
||||
r++ // re-check this row since we shifted
|
||||
}
|
||||
}
|
||||
return cleared
|
||||
}
|
||||
|
||||
// NES-style scoring multipliers per lines cleared.
|
||||
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
|
||||
|
||||
// addScore updates score, lines, and level after clearing rows.
|
||||
func (g *gameState) addScore(linesCleared int) {
|
||||
if linesCleared > 0 && linesCleared <= 4 {
|
||||
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
|
||||
}
|
||||
g.lines += linesCleared
|
||||
|
||||
// Level up every 10 lines.
|
||||
newLevel := g.lines / 10
|
||||
if newLevel > g.level {
|
||||
g.level = newLevel
|
||||
}
|
||||
}
|
||||
|
||||
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
|
||||
// Returns the number of lines cleared.
|
||||
func (g *gameState) afterLock() int {
|
||||
g.lockPiece()
|
||||
cleared := g.clearLines()
|
||||
g.addScore(cleared)
|
||||
g.spawnPiece()
|
||||
return cleared
|
||||
}
|
||||
|
||||
// tickInterval returns the gravity interval in milliseconds for the current level.
|
||||
func tickInterval(level int) int {
|
||||
return max(800-level*60, 100)
|
||||
}
|
||||
331
internal/shell/tetris/model.go
Normal file
331
internal/shell/tetris/model.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
type screen int
|
||||
|
||||
const (
|
||||
screenTitle screen = iota
|
||||
screenGame
|
||||
screenGameOver
|
||||
)
|
||||
|
||||
type tickMsg time.Time
|
||||
type lockMsg time.Time
|
||||
|
||||
const lockDelay = 500 * time.Millisecond
|
||||
|
||||
type model struct {
|
||||
sess *shell.SessionContext
|
||||
difficulty string
|
||||
screen screen
|
||||
game *gameState
|
||||
quitting bool
|
||||
height int
|
||||
keypresses int
|
||||
locking bool // true when piece has landed and lock delay is active
|
||||
}
|
||||
|
||||
func newModel(sess *shell.SessionContext, difficulty string) *model {
|
||||
return &model{
|
||||
sess: sess,
|
||||
difficulty: difficulty,
|
||||
screen: screenTitle,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.quitting {
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.height = msg.Height
|
||||
return m, nil
|
||||
case tea.KeyMsg:
|
||||
m.keypresses++
|
||||
if msg.Type == tea.KeyCtrlC {
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch m.screen {
|
||||
case screenTitle:
|
||||
return m.updateTitle(msg)
|
||||
case screenGame:
|
||||
return m.updateGame(msg)
|
||||
case screenGameOver:
|
||||
return m.updateGameOver(msg)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) View() string {
|
||||
var content string
|
||||
switch m.screen {
|
||||
case screenTitle:
|
||||
content = m.titleView()
|
||||
case screenGame:
|
||||
content = gameView(m.game)
|
||||
case screenGameOver:
|
||||
content = m.gameOverView()
|
||||
}
|
||||
|
||||
return gameFrame(content, m.height)
|
||||
}
|
||||
|
||||
// --- Title screen ---
|
||||
|
||||
func (m *model) titleView() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" Press any key to start"))
|
||||
b.WriteString("\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if _, ok := msg.(tea.KeyMsg); ok {
|
||||
m.screen = screenGame
|
||||
var startLevel int
|
||||
if m.difficulty == "hard" {
|
||||
startLevel = 5
|
||||
}
|
||||
m.game = newGame(startLevel)
|
||||
return m, tea.Batch(
|
||||
tea.ClearScreen,
|
||||
m.scheduleTick(),
|
||||
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||
)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// --- Game screen ---
|
||||
|
||||
func (m *model) scheduleTick() tea.Cmd {
|
||||
ms := tickInterval(m.game.level)
|
||||
if m.difficulty == "easy" {
|
||||
ms = max(1000-m.game.level*60, 150)
|
||||
}
|
||||
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *model) scheduleLock() tea.Cmd {
|
||||
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
|
||||
return lockMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
// performLock locks the piece, clears lines, and returns commands for logging
|
||||
// and scheduling the next tick. Returns nil if game over (goToGameOver is
|
||||
// included in the returned batch).
|
||||
func (m *model) performLock() tea.Cmd {
|
||||
m.locking = false
|
||||
cleared := m.game.afterLock()
|
||||
if m.game.gameOver {
|
||||
return tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
|
||||
m.goToGameOver(),
|
||||
)
|
||||
}
|
||||
var cmds []tea.Cmd
|
||||
cmds = append(cmds, m.scheduleTick())
|
||||
if cleared > 0 {
|
||||
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
|
||||
prevLevel := (m.game.lines - cleared) / 10
|
||||
if m.game.level > prevLevel {
|
||||
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
|
||||
}
|
||||
}
|
||||
return tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case lockMsg:
|
||||
if m.game.gameOver || !m.locking {
|
||||
return m, nil
|
||||
}
|
||||
// Lock delay expired — lock the piece now.
|
||||
return m, m.performLock()
|
||||
|
||||
case tickMsg:
|
||||
if m.game.gameOver || m.locking {
|
||||
return m, nil
|
||||
}
|
||||
if !m.game.moveDown() {
|
||||
// Piece landed — start lock delay instead of locking immediately.
|
||||
m.locking = true
|
||||
return m, m.scheduleLock()
|
||||
}
|
||||
return m, m.scheduleTick()
|
||||
|
||||
case tea.KeyMsg:
|
||||
if m.game.gameOver {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "left":
|
||||
m.game.moveLeft()
|
||||
// If piece can now drop further, cancel lock delay.
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case "right":
|
||||
m.game.moveRight()
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case "down":
|
||||
if m.game.moveDown() {
|
||||
m.game.score++ // soft drop bonus
|
||||
if m.locking {
|
||||
m.locking = false
|
||||
}
|
||||
}
|
||||
case "up", "z":
|
||||
m.game.rotate()
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case " ":
|
||||
m.locking = false
|
||||
dropped := m.game.hardDrop()
|
||||
m.game.score += dropped * 2
|
||||
return m, m.performLock()
|
||||
case "q":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// --- Game over screen ---
|
||||
|
||||
func (m *model) goToGameOver() tea.Cmd {
|
||||
m.screen = screenGameOver
|
||||
return tea.ClearScreen
|
||||
}
|
||||
|
||||
func (m *model) gameOverView() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" GAME OVER"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" R - Play again"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" Q - Quit"))
|
||||
b.WriteString("\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok {
|
||||
switch keyMsg.String() {
|
||||
case "r":
|
||||
startLevel := 0
|
||||
if m.difficulty == "hard" {
|
||||
startLevel = 5
|
||||
}
|
||||
m.game = newGame(startLevel)
|
||||
m.screen = screenGame
|
||||
m.keypresses = 0
|
||||
return m, tea.Batch(
|
||||
tea.ClearScreen,
|
||||
m.scheduleTick(),
|
||||
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||
)
|
||||
case "q":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Helper methods for safe access when game may be nil.
|
||||
func (m *model) gameScore() int {
|
||||
if m.game != nil {
|
||||
return m.game.score
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *model) gameLevel() int {
|
||||
if m.game != nil {
|
||||
return m.game.level
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *model) gameLines() int {
|
||||
if m.game != nil {
|
||||
return m.game.lines
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// logAction returns a tea.Cmd that logs an action to the session store.
|
||||
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if sess.Store != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("tetris")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
286
internal/shell/tetris/style.go
Normal file
286
internal/shell/tetris/style.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const termWidth = 80
|
||||
|
||||
var (
|
||||
colorWhite = lipgloss.Color("#FFFFFF")
|
||||
colorDim = lipgloss.Color("#555555")
|
||||
colorBlack = lipgloss.Color("#000000")
|
||||
colorGhost = lipgloss.Color("#333333")
|
||||
)
|
||||
|
||||
var (
|
||||
baseStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWhite).
|
||||
Background(colorBlack)
|
||||
|
||||
dimStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#00FFFF")).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
|
||||
sidebarLabelStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
sidebarValueStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWhite).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// cellStyle returns a style for a filled cell of a given piece type.
|
||||
func cellStyle(pt pieceType) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(pieceColors[pt]).
|
||||
Background(colorBlack)
|
||||
}
|
||||
|
||||
// ghostStyle returns a dimmed style for the ghost piece.
|
||||
func ghostCellStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(colorGhost).
|
||||
Background(colorBlack)
|
||||
}
|
||||
|
||||
// renderBoard renders the board, current piece, and ghost piece as a string.
|
||||
func renderBoard(g *gameState) string {
|
||||
// Build a display grid that includes the current piece and ghost.
|
||||
type displayCell struct {
|
||||
filled bool
|
||||
ghost bool
|
||||
piece pieceType
|
||||
}
|
||||
var grid [boardRows][boardCols]displayCell
|
||||
|
||||
// Copy locked cells.
|
||||
for r := range boardRows {
|
||||
for c := range boardCols {
|
||||
if g.board[r][c].filled {
|
||||
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ghost piece.
|
||||
ghostR := g.ghostRow()
|
||||
if ghostR != g.currentRow {
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := ghostR+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
|
||||
grid[r][c] = displayCell{ghost: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Current piece.
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||
grid[r][c] = displayCell{filled: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
|
||||
// Render grid.
|
||||
var b strings.Builder
|
||||
borderStyle := dimStyle
|
||||
|
||||
for _, row := range grid {
|
||||
b.WriteString(borderStyle.Render("|"))
|
||||
for _, dc := range row {
|
||||
switch {
|
||||
case dc.filled:
|
||||
b.WriteString(cellStyle(dc.piece).Render("[]"))
|
||||
case dc.ghost:
|
||||
b.WriteString(ghostCellStyle().Render("::"))
|
||||
default:
|
||||
b.WriteString(baseStyle.Render(" "))
|
||||
}
|
||||
}
|
||||
b.WriteString(borderStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// renderNextPiece renders the "next piece" preview box.
|
||||
func renderNextPiece(pt pieceType) string {
|
||||
shape := pieces[pt][0]
|
||||
// Determine bounding box.
|
||||
minR, maxR := shape[0][0], shape[0][0]
|
||||
minC, maxC := shape[0][1], shape[0][1]
|
||||
for _, off := range shape {
|
||||
if off[0] < minR {
|
||||
minR = off[0]
|
||||
}
|
||||
if off[0] > maxR {
|
||||
maxR = off[0]
|
||||
}
|
||||
if off[1] < minC {
|
||||
minC = off[1]
|
||||
}
|
||||
if off[1] > maxC {
|
||||
maxC = off[1]
|
||||
}
|
||||
}
|
||||
|
||||
rows := maxR - minR + 1
|
||||
cols := maxC - minC + 1
|
||||
|
||||
// Build a small grid.
|
||||
grid := make([][]bool, rows)
|
||||
for i := range grid {
|
||||
grid[i] = make([]bool, cols)
|
||||
}
|
||||
for _, off := range shape {
|
||||
grid[off[0]-minR][off[1]-minC] = true
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
boxWidth := 8 // chars for the box interior
|
||||
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||
b.WriteString("\n")
|
||||
|
||||
for r := range rows {
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
// Center the piece in the box.
|
||||
pieceWidth := cols * 2
|
||||
leftPad := (boxWidth - pieceWidth) / 2
|
||||
rightPad := boxWidth - pieceWidth - leftPad
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
|
||||
for c := range cols {
|
||||
if grid[r][c] {
|
||||
b.WriteString(cellStyle(pt).Render("[]"))
|
||||
} else {
|
||||
b.WriteString(baseStyle.Render(" "))
|
||||
}
|
||||
}
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Fill remaining rows in the box (max 4 rows for I piece).
|
||||
for r := rows; r < 2; r++ {
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// formatScore formats a score with comma separators.
|
||||
func formatScore(n int) string {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
if len(s) <= 3 {
|
||||
return s
|
||||
}
|
||||
var parts []string
|
||||
for len(s) > 3 {
|
||||
parts = append([]string{s[len(s)-3:]}, parts...)
|
||||
s = s[:len(s)-3]
|
||||
}
|
||||
parts = append([]string{s}, parts...)
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// gameView combines the board and sidebar into the game screen.
|
||||
func gameView(g *gameState) string {
|
||||
boardStr := renderBoard(g)
|
||||
boardLines := strings.Split(boardStr, "\n")
|
||||
|
||||
nextStr := renderNextPiece(g.next)
|
||||
nextLines := strings.Split(nextStr, "\n")
|
||||
|
||||
// Build sidebar lines.
|
||||
var sidebar []string
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
|
||||
sidebar = append(sidebar, nextLines...)
|
||||
sidebar = append(sidebar, "")
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
|
||||
sidebar = append(sidebar, "")
|
||||
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
|
||||
|
||||
// Combine board and sidebar side by side.
|
||||
var b strings.Builder
|
||||
maxLines := max(len(boardLines), len(sidebar))
|
||||
|
||||
for i := range maxLines {
|
||||
boardLine := ""
|
||||
if i < len(boardLines) {
|
||||
boardLine = boardLines[i]
|
||||
}
|
||||
sidebarLine := ""
|
||||
if i < len(sidebar) {
|
||||
sidebarLine = sidebar[i]
|
||||
}
|
||||
|
||||
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
|
||||
b.WriteString(boardLine)
|
||||
b.WriteString(sidebarLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// padLine pads a single line to termWidth.
|
||||
func padLine(line string) string {
|
||||
w := lipgloss.Width(line)
|
||||
if w >= termWidth {
|
||||
return line
|
||||
}
|
||||
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
|
||||
}
|
||||
|
||||
// padLines pads every line in a multi-line string to termWidth.
|
||||
func padLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = padLine(line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// gameFrame wraps content with padding to fill the terminal.
|
||||
func gameFrame(content string, height int) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(content)
|
||||
|
||||
// Pad with blank lines to fill terminal height.
|
||||
if height > 0 {
|
||||
contentLines := strings.Count(content, "\n") + 1
|
||||
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
|
||||
for i := contentLines; i < height; i++ {
|
||||
b.WriteString(blankLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return padLines(b.String())
|
||||
}
|
||||
66
internal/shell/tetris/tetris.go
Normal file
66
internal/shell/tetris/tetris.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
// TetrisShell is a Tetris game TUI for the honeypot.
|
||||
type TetrisShell struct{}
|
||||
|
||||
// NewTetrisShell returns a new TetrisShell instance.
|
||||
func NewTetrisShell() *TetrisShell {
|
||||
return &TetrisShell{}
|
||||
}
|
||||
|
||||
func (t *TetrisShell) Name() string { return "tetris" }
|
||||
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
|
||||
|
||||
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
|
||||
|
||||
m := newModel(sess, difficulty)
|
||||
p := tea.NewProgram(m,
|
||||
tea.WithInput(rw),
|
||||
tea.WithOutput(rw),
|
||||
tea.WithAltScreen(),
|
||||
)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := p.Run()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
p.Quit()
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
582
internal/shell/tetris/tetris_test.go
Normal file
582
internal/shell/tetris/tetris_test.go
Normal file
@@ -0,0 +1,582 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// newTestModel creates a model with a test session context.
|
||||
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "player",
|
||||
Store: store,
|
||||
}
|
||||
m := newModel(sess, "normal")
|
||||
return m, store
|
||||
}
|
||||
|
||||
// sendKey sends a single key message to the model and returns the command.
|
||||
func sendKey(m *model, key string) tea.Cmd {
|
||||
var msg tea.KeyMsg
|
||||
switch key {
|
||||
case "enter":
|
||||
msg = tea.KeyMsg{Type: tea.KeyEnter}
|
||||
case "up":
|
||||
msg = tea.KeyMsg{Type: tea.KeyUp}
|
||||
case "down":
|
||||
msg = tea.KeyMsg{Type: tea.KeyDown}
|
||||
case "left":
|
||||
msg = tea.KeyMsg{Type: tea.KeyLeft}
|
||||
case "right":
|
||||
msg = tea.KeyMsg{Type: tea.KeyRight}
|
||||
case "space":
|
||||
msg = tea.KeyMsg{Type: tea.KeySpace}
|
||||
case "ctrl+c":
|
||||
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
|
||||
default:
|
||||
if len(key) == 1 {
|
||||
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
|
||||
}
|
||||
}
|
||||
_, cmd := m.Update(msg)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// sendTick sends a tick message to the model.
|
||||
func sendTick(m *model) tea.Cmd {
|
||||
_, cmd := m.Update(tickMsg(time.Now()))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// execCmds recursively executes tea.Cmd functions (including batches).
|
||||
func execCmds(cmd tea.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
msg := cmd()
|
||||
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||
for _, c := range batch {
|
||||
execCmds(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTetrisShellName(t *testing.T) {
|
||||
sh := NewTetrisShell()
|
||||
if sh.Name() != "tetris" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"difficulty": "hard",
|
||||
}
|
||||
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
|
||||
t.Errorf("configString() = %q, want %q", got, "hard")
|
||||
}
|
||||
if got := configString(cfg, "missing", "normal"); got != "normal" {
|
||||
t.Errorf("configString() = %q, want %q", got, "normal")
|
||||
}
|
||||
if got := configString(nil, "difficulty", "normal"); got != "normal" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "normal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTitleScreenRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "████") {
|
||||
t.Error("title screen should show TETRIS logo")
|
||||
}
|
||||
if !strings.Contains(view, "Press any key") {
|
||||
t.Error("title screen should show 'Press any key'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTitleToGame(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
if m.screen != screenTitle {
|
||||
t.Fatalf("expected screenTitle, got %d", m.screen)
|
||||
}
|
||||
|
||||
sendKey(m, "enter")
|
||||
if m.screen != screenGame {
|
||||
t.Errorf("expected screenGame after keypress, got %d", m.screen)
|
||||
}
|
||||
if m.game == nil {
|
||||
t.Fatal("game should be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGameRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "|") {
|
||||
t.Error("game view should contain board borders")
|
||||
}
|
||||
if !strings.Contains(view, "SCORE") {
|
||||
t.Error("game view should show SCORE")
|
||||
}
|
||||
if !strings.Contains(view, "LEVEL") {
|
||||
t.Error("game view should show LEVEL")
|
||||
}
|
||||
if !strings.Contains(view, "LINES") {
|
||||
t.Error("game view should show LINES")
|
||||
}
|
||||
if !strings.Contains(view, "NEXT") {
|
||||
t.Error("game view should show NEXT")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Pure game logic tests ---
|
||||
|
||||
func TestNewGame(t *testing.T) {
|
||||
g := newGame(0)
|
||||
if g.gameOver {
|
||||
t.Error("new game should not be game over")
|
||||
}
|
||||
if g.score != 0 {
|
||||
t.Errorf("initial score = %d, want 0", g.score)
|
||||
}
|
||||
if g.level != 0 {
|
||||
t.Errorf("initial level = %d, want 0", g.level)
|
||||
}
|
||||
if g.lines != 0 {
|
||||
t.Errorf("initial lines = %d, want 0", g.lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGameHardLevel(t *testing.T) {
|
||||
g := newGame(5)
|
||||
if g.level != 5 {
|
||||
t.Errorf("hard start level = %d, want 5", g.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveLeft(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startCol := g.currentCol
|
||||
g.moveLeft()
|
||||
if g.currentCol != startCol-1 {
|
||||
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveRight(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startCol := g.currentCol
|
||||
g.moveRight()
|
||||
if g.currentCol != startCol+1 {
|
||||
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveDown(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRow := g.currentRow
|
||||
moved := g.moveDown()
|
||||
if !moved {
|
||||
t.Error("moveDown should succeed from starting position")
|
||||
}
|
||||
if g.currentRow != startRow+1 {
|
||||
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotMoveLeftBeyondWall(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Move all the way left.
|
||||
for range boardCols {
|
||||
g.moveLeft()
|
||||
}
|
||||
col := g.currentCol
|
||||
g.moveLeft() // should not move further
|
||||
if g.currentCol != col {
|
||||
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotMoveRightBeyondWall(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Move all the way right.
|
||||
for range boardCols {
|
||||
g.moveRight()
|
||||
}
|
||||
col := g.currentCol
|
||||
g.moveRight() // should not move further
|
||||
if g.currentCol != col {
|
||||
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotate(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRot := g.currentRot
|
||||
g.rotate()
|
||||
// Rotation should change (possibly with wall kick).
|
||||
if g.currentRot == startRot {
|
||||
// Rotation might legitimately fail in some edge cases, so just check
|
||||
// that the game state is valid.
|
||||
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||
t.Error("piece should be in a valid position after rotate attempt")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDrop(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRow := g.currentRow
|
||||
dropped := g.hardDrop()
|
||||
if dropped == 0 {
|
||||
t.Error("hard drop should move piece down at least some rows from top")
|
||||
}
|
||||
if g.currentRow <= startRow {
|
||||
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGhostRow(t *testing.T) {
|
||||
g := newGame(0)
|
||||
ghost := g.ghostRow()
|
||||
if ghost < g.currentRow {
|
||||
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
|
||||
}
|
||||
// Ghost should be at a position where moving down one more is impossible.
|
||||
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
|
||||
t.Error("ghost row should be the lowest valid position")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockPiece(t *testing.T) {
|
||||
g := newGame(0)
|
||||
g.hardDrop()
|
||||
pt := g.current
|
||||
row, col, rot := g.currentRow, g.currentCol, g.currentRot
|
||||
g.lockPiece()
|
||||
|
||||
// Verify that the piece's cells are now filled.
|
||||
shape := pieces[pt][rot]
|
||||
for _, off := range shape {
|
||||
r, c := row+off[0], col+off[1]
|
||||
if !g.board[r][c].filled {
|
||||
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearLines(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Fill the bottom row completely.
|
||||
for c := range boardCols {
|
||||
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
cleared := g.clearLines()
|
||||
if cleared != 1 {
|
||||
t.Errorf("clearLines() = %d, want 1", cleared)
|
||||
}
|
||||
// Bottom row should now be empty (shifted from above).
|
||||
for c := range boardCols {
|
||||
if g.board[boardRows-1][c].filled {
|
||||
t.Errorf("bottom row col %d should be empty after clearing", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearMultipleLines(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Fill the bottom 4 rows.
|
||||
for r := boardRows - 4; r < boardRows; r++ {
|
||||
for c := range boardCols {
|
||||
g.board[r][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
}
|
||||
cleared := g.clearLines()
|
||||
if cleared != 4 {
|
||||
t.Errorf("clearLines() = %d, want 4", cleared)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoring(t *testing.T) {
|
||||
tests := []struct {
|
||||
lines int
|
||||
level int
|
||||
want int
|
||||
}{
|
||||
{1, 0, 40},
|
||||
{2, 0, 100},
|
||||
{3, 0, 300},
|
||||
{4, 0, 1200},
|
||||
{1, 1, 80},
|
||||
{4, 2, 3600},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
g := newGame(tt.level)
|
||||
g.addScore(tt.lines)
|
||||
if g.score != tt.want {
|
||||
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLevelUp(t *testing.T) {
|
||||
g := newGame(0)
|
||||
g.lines = 9
|
||||
g.addScore(1) // This should push lines to 10, triggering level 1.
|
||||
if g.level != 1 {
|
||||
t.Errorf("level = %d, want 1 after 10 lines", g.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickInterval(t *testing.T) {
|
||||
if got := tickInterval(0); got != 800 {
|
||||
t.Errorf("tickInterval(0) = %d, want 800", got)
|
||||
}
|
||||
if got := tickInterval(5); got != 500 {
|
||||
t.Errorf("tickInterval(5) = %d, want 500", got)
|
||||
}
|
||||
// Floor at 100ms.
|
||||
if got := tickInterval(20); got != 100 {
|
||||
t.Errorf("tickInterval(20) = %d, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatScore(t *testing.T) {
|
||||
tests := []struct {
|
||||
n int
|
||||
want string
|
||||
}{
|
||||
{0, "0"},
|
||||
{100, "100"},
|
||||
{1250, "1,250"},
|
||||
{1000000, "1,000,000"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := formatScore(tt.n); got != tt.want {
|
||||
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGameOverScreen(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Force game over.
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "GAME OVER") {
|
||||
t.Error("game over screen should show GAME OVER")
|
||||
}
|
||||
if !strings.Contains(view, "Score") {
|
||||
t.Error("game over screen should show score")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestartFromGameOver(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
sendKey(m, "r")
|
||||
if m.screen != screenGame {
|
||||
t.Errorf("expected screenGame after restart, got %d", m.screen)
|
||||
}
|
||||
if m.game.gameOver {
|
||||
t.Error("game should not be over after restart")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitFromGame(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
sendKey(m, "q")
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after pressing q")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitFromGameOver(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
sendKey(m, "q")
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after pressing q in game over")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftDropScoring(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
scoreBefore := m.game.score
|
||||
sendKey(m, "down")
|
||||
if m.game.score != scoreBefore+1 {
|
||||
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDropScoring(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Hard drop gives 2 points per row dropped.
|
||||
sendKey(m, "space")
|
||||
if m.game.score < 2 {
|
||||
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickMovesDown(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
rowBefore := m.game.currentRow
|
||||
sendTick(m)
|
||||
// Piece should either move down by 1, or lock and spawn a new piece at top.
|
||||
movedDown := m.game.currentRow == rowBefore+1
|
||||
respawned := m.game.currentRow < rowBefore
|
||||
if !movedDown && !respawned && !m.game.gameOver {
|
||||
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
m, store := newTestModel(t)
|
||||
|
||||
// Press key to start game — returns a logAction cmd.
|
||||
cmd := sendKey(m, "enter")
|
||||
if cmd != nil {
|
||||
execCmds(cmd)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
found := false
|
||||
for _, log := range store.SessionLogs {
|
||||
if strings.Contains(log.Input, "GAME START") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected GAME START in session logs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeypressCounter(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
sendKey(m, "left")
|
||||
sendKey(m, "right")
|
||||
sendKey(m, "down")
|
||||
|
||||
if m.keypresses != 4 { // enter + 3 game keys
|
||||
t.Errorf("keypresses = %d, want 4", m.keypresses)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDelay(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Drop piece to the bottom via ticks until it can't move down.
|
||||
for range boardRows + 5 {
|
||||
if m.locking {
|
||||
break
|
||||
}
|
||||
sendTick(m)
|
||||
}
|
||||
|
||||
if !m.locking {
|
||||
t.Fatal("piece should be in locking state after hitting bottom")
|
||||
}
|
||||
|
||||
// During lock delay, we should still be able to move left/right.
|
||||
colBefore := m.game.currentCol
|
||||
sendKey(m, "left")
|
||||
if m.game.currentCol >= colBefore {
|
||||
// Might not have moved if against wall, try right.
|
||||
sendKey(m, "right")
|
||||
}
|
||||
|
||||
// Sending a lockMsg should finalize the piece.
|
||||
m.Update(lockMsg(time.Now()))
|
||||
// After lock, a new piece should have spawned (row near top).
|
||||
if m.game.currentRow > 1 && !m.game.gameOver {
|
||||
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDelayCancelledByDrop(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Build a ledge: fill rows 18-19 but leave column 0 empty.
|
||||
for r := boardRows - 2; r < boardRows; r++ {
|
||||
for c := 1; c < boardCols; c++ {
|
||||
m.game.board[r][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
}
|
||||
|
||||
// Move piece to column 0 area and drop it onto the ledge.
|
||||
for range boardCols {
|
||||
m.game.moveLeft()
|
||||
}
|
||||
// Tick down until locking.
|
||||
for range boardRows + 5 {
|
||||
if m.locking {
|
||||
break
|
||||
}
|
||||
sendTick(m)
|
||||
}
|
||||
|
||||
// If piece is on the ledge and we slide it to col 0 (open column),
|
||||
// the lock delay should cancel since it can fall further.
|
||||
// This test just validates the locking flag logic works.
|
||||
if m.locking {
|
||||
// Try moving — if piece can drop further, locking should cancel.
|
||||
sendKey(m, "left")
|
||||
// Whether locking cancels depends on the board state; just verify no crash.
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnCol(t *testing.T) {
|
||||
// All pieces should spawn roughly centered.
|
||||
for pt := range pieceType(numPieceTypes) {
|
||||
col := spawnCol(pt, 0)
|
||||
if col < 0 || col > boardCols-1 {
|
||||
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
|
||||
}
|
||||
// Verify piece fits at spawn position.
|
||||
shape := pieces[pt][0]
|
||||
for _, off := range shape {
|
||||
c := col + off[1]
|
||||
if c < 0 || c >= boardCols {
|
||||
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
217
internal/storage/instrumented.go
Normal file
217
internal/storage/instrumented.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// InstrumentedStore wraps a Store and records query duration and errors
|
||||
// as Prometheus metrics for each method call.
|
||||
type InstrumentedStore struct {
|
||||
store Store
|
||||
queryDuration *prometheus.HistogramVec
|
||||
queryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
|
||||
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
|
||||
return &InstrumentedStore{
|
||||
store: store,
|
||||
queryDuration: queryDuration,
|
||||
queryErrors: queryErrors,
|
||||
}
|
||||
}
|
||||
|
||||
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
v, err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||
return observeErr(s, "RecordLoginAttempt", func() error {
|
||||
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||
return observe(s, "CreateSession", func() (string, error) {
|
||||
return s.store.CreateSession(ctx, ip, username, shellName, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
|
||||
return observeErr(s, "EndSession", func() error {
|
||||
return s.store.EndSession(ctx, sessionID, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
|
||||
return observeErr(s, "UpdateHumanScore", func() error {
|
||||
return s.store.UpdateHumanScore(ctx, sessionID, score)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||
return observeErr(s, "SetExecCommand", func() error {
|
||||
return s.store.SetExecCommand(ctx, sessionID, command)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||
return observeErr(s, "AppendSessionLog", func() error {
|
||||
return s.store.AppendSessionLog(ctx, sessionID, input, output)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
|
||||
return s.store.DeleteRecordsBefore(ctx, cutoff)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetDashboardStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopUsernames(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopPasswords(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopIPs(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopCountries(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopExecCommands(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
return observe(s, "GetRecentSessions", func() ([]Session, error) {
|
||||
return s.store.GetRecentSessions(ctx, limit, activeOnly)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
|
||||
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||
return observe(s, "GetSession", func() (*Session, error) {
|
||||
return s.store.GetSession(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
||||
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
|
||||
return s.store.GetSessionLogs(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
||||
return observeErr(s, "AppendSessionEvents", func() error {
|
||||
return s.store.AppendSessionEvents(ctx, events)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
||||
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
|
||||
return s.store.GetSessionEvents(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
return observe(s, "CloseActiveSessions", func() (int64, error) {
|
||||
return s.store.CloseActiveSessions(ctx, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
|
||||
return s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
|
||||
return s.store.GetHourlyPattern(ctx, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
|
||||
return s.store.GetCountryStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetFilteredDashboardStats(ctx, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopUsernames(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopPasswords(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopIPs(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopCountries(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) Close() error {
|
||||
return s.store.Close()
|
||||
}
|
||||
163
internal/storage/instrumented_test.go
Normal file
163
internal/storage/instrumented_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
)
|
||||
|
||||
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
store := NewMemoryStore()
|
||||
return NewInstrumentedStore(store, dur, errs), dur, errs
|
||||
}
|
||||
|
||||
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
|
||||
m := &dto.Metric{}
|
||||
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
|
||||
return m.GetHistogram().GetSampleCount()
|
||||
}
|
||||
|
||||
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
|
||||
m := &dto.Metric{}
|
||||
c.WithLabelValues(method).Write(m)
|
||||
return m.GetCounter().GetValue()
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreDelegation(t *testing.T) {
|
||||
s, dur, _ := newTestInstrumented()
|
||||
ctx := context.Background()
|
||||
|
||||
// RecordLoginAttempt should delegate and record duration.
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// CreateSession should delegate and return a valid ID.
|
||||
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Fatal("CreateSession returned empty ID")
|
||||
}
|
||||
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// GetDashboardStats should delegate.
|
||||
stats, err := s.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetDashboardStats: %v", err)
|
||||
}
|
||||
if stats == nil {
|
||||
t.Fatal("GetDashboardStats returned nil")
|
||||
}
|
||||
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreErrorCounting(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_ec_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_ec_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
// Error should be counted.
|
||||
err := s.EndSession(ctx, "nonexistent", time.Now())
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
if c := getCounterValue(errs, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
|
||||
// Successful call should not increment error counter.
|
||||
s2, _, errs2 := newTestInstrumented()
|
||||
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
|
||||
t.Fatalf("expected error count 0, got %f", c)
|
||||
}
|
||||
}
|
||||
|
||||
// errorStore is a Store that returns errors for all methods.
|
||||
type errorStore struct {
|
||||
MemoryStore
|
||||
}
|
||||
|
||||
var errFake = errors.New("fake error")
|
||||
|
||||
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreObserveErr(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test2_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test2_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreClose(t *testing.T) {
|
||||
s, _, _ := newTestInstrumented()
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -25,7 +25,7 @@ func NewMemoryStore() *MemoryStore {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip string) error {
|
||||
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip, country string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -35,6 +35,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
||||
if a.Username == username && a.Password == password && a.IP == ip {
|
||||
a.Count++
|
||||
a.LastSeen = now
|
||||
a.Country = country
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -44,6 +45,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
||||
Username: username,
|
||||
Password: password,
|
||||
IP: ip,
|
||||
Country: country,
|
||||
Count: 1,
|
||||
FirstSeen: now,
|
||||
LastSeen: now,
|
||||
@@ -51,7 +53,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName string) (string, error) {
|
||||
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName, country string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -60,6 +62,7 @@ func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName s
|
||||
m.Sessions[id] = &Session{
|
||||
ID: id,
|
||||
IP: ip,
|
||||
Country: country,
|
||||
Username: username,
|
||||
ShellName: shellName,
|
||||
ConnectedAt: now,
|
||||
@@ -88,6 +91,16 @@ func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, scor
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) SetExecCommand(_ context.Context, sessionID string, command string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if s, ok := m.Sessions[sessionID]; ok {
|
||||
s.ExecCommand = &command
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -234,7 +247,60 @@ func (m *MemoryStore) GetTopPasswords(_ context.Context, limit int) ([]TopEntry,
|
||||
func (m *MemoryStore) GetTopIPs(_ context.Context, limit int) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.topN("ip", limit), nil
|
||||
|
||||
type ipInfo struct {
|
||||
count int64
|
||||
country string
|
||||
}
|
||||
agg := make(map[string]*ipInfo)
|
||||
for _, a := range m.LoginAttempts {
|
||||
info, ok := agg[a.IP]
|
||||
if !ok {
|
||||
info = &ipInfo{}
|
||||
agg[a.IP] = info
|
||||
}
|
||||
info.count += int64(a.Count)
|
||||
if a.Country != "" {
|
||||
info.country = a.Country
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(agg))
|
||||
for ip, info := range agg {
|
||||
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetTopCountries(_ context.Context, limit int) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if a.Country == "" {
|
||||
continue
|
||||
}
|
||||
counts[a.Country] += int64(a.Count)
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// topN aggregates login attempts by the given field and returns the top N. Must be called with m.mu held.
|
||||
@@ -270,20 +336,372 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.collectSessions(limit, activeOnly, f), nil
|
||||
}
|
||||
|
||||
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
|
||||
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
|
||||
// Compute event counts and input bytes per session.
|
||||
eventCounts := make(map[string]int)
|
||||
inputBytes := make(map[string]int64)
|
||||
for _, e := range m.SessionEvents {
|
||||
eventCounts[e.SessionID]++
|
||||
if e.Direction == 0 {
|
||||
inputBytes[e.SessionID] += int64(len(e.Data))
|
||||
}
|
||||
}
|
||||
|
||||
var sessions []Session
|
||||
for _, s := range m.Sessions {
|
||||
if activeOnly && s.DisconnectedAt != nil {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *s)
|
||||
if !matchesSessionFilter(s, f) {
|
||||
continue
|
||||
}
|
||||
sess := *s
|
||||
sess.EventCount = eventCounts[s.ID]
|
||||
sess.InputBytes = inputBytes[s.ID]
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||
})
|
||||
|
||||
if f.SortBy == "input_bytes" {
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].InputBytes > sessions[j].InputBytes
|
||||
})
|
||||
} else {
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
if limit > 0 && len(sessions) > limit {
|
||||
sessions = sessions[:limit]
|
||||
}
|
||||
return sessions, nil
|
||||
return sessions
|
||||
}
|
||||
|
||||
// matchesSessionFilter returns true if the session matches the given filter.
|
||||
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
|
||||
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||
return false
|
||||
}
|
||||
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||
return false
|
||||
}
|
||||
if f.IP != "" && s.IP != f.IP {
|
||||
return false
|
||||
}
|
||||
if f.Country != "" && s.Country != f.Country {
|
||||
return false
|
||||
}
|
||||
if f.Username != "" && s.Username != f.Username {
|
||||
return false
|
||||
}
|
||||
if f.HumanScoreAboveZero {
|
||||
if s.HumanScore == nil || *s.HumanScore <= 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, s := range m.Sessions {
|
||||
if s.ExecCommand != nil {
|
||||
counts[*s.ExecCommand]++
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var count int64
|
||||
t := disconnectedAt.UTC()
|
||||
for _, s := range m.Sessions {
|
||||
if s.DisconnectedAt == nil {
|
||||
s.DisconnectedAt = &t
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetAttemptsOverTime(_ context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var cutoff time.Time
|
||||
if since != nil {
|
||||
cutoff = *since
|
||||
} else {
|
||||
cutoff = time.Now().UTC().AddDate(0, 0, -days)
|
||||
}
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if a.LastSeen.Before(cutoff) {
|
||||
continue
|
||||
}
|
||||
if until != nil && a.LastSeen.After(*until) {
|
||||
continue
|
||||
}
|
||||
day := a.LastSeen.Format("2006-01-02")
|
||||
counts[day] += int64(a.Count)
|
||||
}
|
||||
|
||||
points := make([]TimeSeriesPoint, 0, len(counts))
|
||||
for day, count := range counts {
|
||||
t, _ := time.Parse("2006-01-02", day)
|
||||
points = append(points, TimeSeriesPoint{Timestamp: t, Count: count})
|
||||
}
|
||||
sort.Slice(points, func(i, j int) bool {
|
||||
return points[i].Timestamp.Before(points[j].Timestamp)
|
||||
})
|
||||
return points, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetHourlyPattern(_ context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
hourCounts := make(map[int]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if since != nil && a.LastSeen.Before(*since) {
|
||||
continue
|
||||
}
|
||||
if until != nil && a.LastSeen.After(*until) {
|
||||
continue
|
||||
}
|
||||
hour := a.LastSeen.Hour()
|
||||
hourCounts[hour] += int64(a.Count)
|
||||
}
|
||||
|
||||
counts := make([]HourlyCount, 0, len(hourCounts))
|
||||
for h, c := range hourCounts {
|
||||
counts = append(counts, HourlyCount{Hour: h, Count: c})
|
||||
}
|
||||
sort.Slice(counts, func(i, j int) bool {
|
||||
return counts[i].Hour < counts[j].Hour
|
||||
})
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetCountryStats(_ context.Context) ([]CountryCount, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if a.Country == "" {
|
||||
continue
|
||||
}
|
||||
counts[a.Country] += int64(a.Count)
|
||||
}
|
||||
|
||||
result := make([]CountryCount, 0, len(counts))
|
||||
for country, count := range counts {
|
||||
result = append(result, CountryCount{Country: country, Count: count})
|
||||
}
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].Count > result[j].Count
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// matchesFilter returns true if the login attempt matches the given filter. Must be called with m.mu held.
|
||||
func matchesFilter(a *LoginAttempt, f DashboardFilter) bool {
|
||||
if f.Since != nil && a.LastSeen.Before(*f.Since) {
|
||||
return false
|
||||
}
|
||||
if f.Until != nil && a.LastSeen.After(*f.Until) {
|
||||
return false
|
||||
}
|
||||
if f.IP != "" && a.IP != f.IP {
|
||||
return false
|
||||
}
|
||||
if f.Country != "" && a.Country != f.Country {
|
||||
return false
|
||||
}
|
||||
if f.Username != "" && a.Username != f.Username {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredDashboardStats(_ context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
stats := &DashboardStats{}
|
||||
ips := make(map[string]struct{})
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
stats.TotalAttempts += int64(a.Count)
|
||||
ips[a.IP] = struct{}{}
|
||||
}
|
||||
stats.UniqueIPs = int64(len(ips))
|
||||
|
||||
for _, s := range m.Sessions {
|
||||
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||
continue
|
||||
}
|
||||
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||
continue
|
||||
}
|
||||
if f.IP != "" && s.IP != f.IP {
|
||||
continue
|
||||
}
|
||||
if f.Country != "" && s.Country != f.Country {
|
||||
continue
|
||||
}
|
||||
stats.TotalSessions++
|
||||
if s.DisconnectedAt == nil {
|
||||
stats.ActiveSessions++
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopUsernames(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.filteredTopN("username", limit, f), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopPasswords(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.filteredTopN("password", limit, f), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopIPs(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
type ipInfo struct {
|
||||
count int64
|
||||
country string
|
||||
}
|
||||
agg := make(map[string]*ipInfo)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
info, ok := agg[a.IP]
|
||||
if !ok {
|
||||
info = &ipInfo{}
|
||||
agg[a.IP] = info
|
||||
}
|
||||
info.count += int64(a.Count)
|
||||
if a.Country != "" {
|
||||
info.country = a.Country
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(agg))
|
||||
for ip, info := range agg {
|
||||
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopCountries(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if a.Country == "" {
|
||||
continue
|
||||
}
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
counts[a.Country] += int64(a.Count)
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// filteredTopN aggregates login attempts by the given field with filter applied and returns the top N. Must be called with m.mu held.
|
||||
func (m *MemoryStore) filteredTopN(field string, limit int, f DashboardFilter) []TopEntry {
|
||||
counts := make(map[string]int64)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
var key string
|
||||
switch field {
|
||||
case "username":
|
||||
key = a.Username
|
||||
case "password":
|
||||
key = a.Password
|
||||
case "ip":
|
||||
key = a.IP
|
||||
}
|
||||
counts[key] += int64(a.Count)
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Close() error {
|
||||
|
||||
3
internal/storage/migrations/003_add_country.sql
Normal file
3
internal/storage/migrations/003_add_country.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE login_attempts ADD COLUMN country TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE sessions ADD COLUMN country TEXT NOT NULL DEFAULT '';
|
||||
CREATE INDEX idx_login_attempts_country ON login_attempts(country);
|
||||
1
internal/storage/migrations/004_add_exec_command.sql
Normal file
1
internal/storage/migrations/004_add_exec_command.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE sessions ADD COLUMN exec_command TEXT;
|
||||
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
|
||||
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
|
||||
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);
|
||||
@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 2 {
|
||||
t.Errorf("version = %d, want 2", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d, want 5", version)
|
||||
}
|
||||
|
||||
// Verify tables exist by inserting into them.
|
||||
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 2 {
|
||||
t.Errorf("version = %d after double migrate, want 2", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d after double migrate, want 5", version)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestRunRetentionDeletesOldRecords(t *testing.T) {
|
||||
}
|
||||
|
||||
// Insert a recent login attempt.
|
||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
|
||||
t.Fatalf("insert recent attempt: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -34,28 +35,29 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
|
||||
return &SQLiteStore{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip string) error {
|
||||
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO login_attempts (username, password, ip, count, first_seen, last_seen)
|
||||
VALUES (?, ?, ?, 1, ?, ?)
|
||||
INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
|
||||
VALUES (?, ?, ?, ?, 1, ?, ?)
|
||||
ON CONFLICT(username, password, ip) DO UPDATE SET
|
||||
count = count + 1,
|
||||
last_seen = ?`,
|
||||
username, password, ip, now, now, now)
|
||||
last_seen = ?,
|
||||
country = ?`,
|
||||
username, password, ip, country, now, now, now, country)
|
||||
if err != nil {
|
||||
return fmt.Errorf("recording login attempt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName string) (string, error) {
|
||||
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO sessions (id, ip, username, shell_name, connected_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
id, ip, username, shellName, now)
|
||||
INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
id, ip, username, shellName, country, now)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating session: %w", err)
|
||||
}
|
||||
@@ -82,6 +84,16 @@ func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, sc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET exec_command = ? WHERE id = ?`,
|
||||
command, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting exec command: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
@@ -99,12 +111,13 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
|
||||
var connectedAt string
|
||||
var disconnectedAt sql.NullString
|
||||
var humanScore sql.NullFloat64
|
||||
var execCommand sql.NullString
|
||||
|
||||
err := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, ip, username, shell_name, connected_at, disconnected_at, human_score
|
||||
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
|
||||
FROM sessions WHERE id = ?`, sessionID).Scan(
|
||||
&sess.ID, &sess.IP, &sess.Username, &sess.ShellName,
|
||||
&connectedAt, &disconnectedAt, &humanScore,
|
||||
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
|
||||
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
@@ -121,6 +134,9 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
|
||||
if humanScore.Valid {
|
||||
sess.HumanScore = &humanScore.Float64
|
||||
}
|
||||
if execCommand.Valid {
|
||||
sess.ExecCommand = &execCommand.String
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
@@ -288,10 +304,60 @@ func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntr
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return s.queryTopN(ctx, "ip", limit)
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT ip, country, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
GROUP BY ip
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying top IPs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning top IPs: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT country, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
WHERE country != ''
|
||||
GROUP BY country
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying top countries: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning top countries: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
|
||||
switch column {
|
||||
case "username", "password", "ip":
|
||||
// valid columns
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid column: %s", column)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT %s, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
@@ -317,40 +383,401 @@ func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) (
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
query := `SELECT id, ip, username, shell_name, connected_at, disconnected_at, human_score FROM sessions`
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
|
||||
if activeOnly {
|
||||
query += ` WHERE disconnected_at IS NULL`
|
||||
query += ` WHERE s.disconnected_at IS NULL`
|
||||
}
|
||||
query += ` ORDER BY connected_at DESC LIMIT ?`
|
||||
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, limit)
|
||||
return s.scanSessions(ctx, query, limit)
|
||||
}
|
||||
|
||||
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
|
||||
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
|
||||
var clauses []string
|
||||
var args []any
|
||||
|
||||
if activeOnly {
|
||||
clauses = append(clauses, "s.disconnected_at IS NULL")
|
||||
}
|
||||
if f.Since != nil {
|
||||
clauses = append(clauses, "s.connected_at >= ?")
|
||||
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
clauses = append(clauses, "s.connected_at <= ?")
|
||||
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
clauses = append(clauses, "s.ip = ?")
|
||||
args = append(args, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
clauses = append(clauses, "s.country = ?")
|
||||
args = append(args, f.Country)
|
||||
}
|
||||
if f.Username != "" {
|
||||
clauses = append(clauses, "s.username = ?")
|
||||
args = append(args, f.Username)
|
||||
}
|
||||
if f.HumanScoreAboveZero {
|
||||
clauses = append(clauses, "s.human_score > 0")
|
||||
}
|
||||
|
||||
if len(clauses) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
|
||||
var validSessionSorts = map[string]string{
|
||||
"connected_at": "s.connected_at DESC",
|
||||
"input_bytes": "input_bytes DESC",
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
where, args := buildSessionWhereClause(f, activeOnly)
|
||||
args = append(args, limit)
|
||||
|
||||
orderBy := validSessionSorts["connected_at"]
|
||||
if mapped, ok := validSessionSorts[f.SortBy]; ok {
|
||||
orderBy = mapped
|
||||
}
|
||||
|
||||
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
|
||||
|
||||
return s.scanSessions(ctx, query, args...)
|
||||
}
|
||||
|
||||
// scanSessions executes a session query and scans the results.
|
||||
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying recent sessions: %w", err)
|
||||
return nil, fmt.Errorf("querying sessions: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var sessions []Session
|
||||
for rows.Next() {
|
||||
var s Session
|
||||
var sess Session
|
||||
var connectedAt string
|
||||
var disconnectedAt sql.NullString
|
||||
var humanScore sql.NullFloat64
|
||||
if err := rows.Scan(&s.ID, &s.IP, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore); err != nil {
|
||||
var execCommand sql.NullString
|
||||
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
|
||||
return nil, fmt.Errorf("scanning session: %w", err)
|
||||
}
|
||||
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
if disconnectedAt.Valid {
|
||||
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
||||
s.DisconnectedAt = &t
|
||||
sess.DisconnectedAt = &t
|
||||
}
|
||||
if humanScore.Valid {
|
||||
s.HumanScore = &humanScore.Float64
|
||||
sess.HumanScore = &humanScore.Float64
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
if execCommand.Valid {
|
||||
sess.ExecCommand = &execCommand.String
|
||||
}
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT exec_command, COUNT(*) as total
|
||||
FROM sessions
|
||||
WHERE exec_command IS NOT NULL
|
||||
GROUP BY exec_command
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying top exec commands: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning top exec commands: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
res, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
|
||||
disconnectedAt.UTC().Format(time.RFC3339))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("closing active sessions: %w", err)
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
query := `SELECT DATE(last_seen) AS d, SUM(count) FROM login_attempts WHERE 1=1`
|
||||
var args []any
|
||||
|
||||
if since != nil {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, since.UTC().Format(time.RFC3339))
|
||||
} else {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, time.Now().UTC().AddDate(0, 0, -days).Format("2006-01-02"))
|
||||
}
|
||||
if until != nil {
|
||||
query += ` AND last_seen <= ?`
|
||||
args = append(args, until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
query += ` GROUP BY d ORDER BY d`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying attempts over time: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var points []TimeSeriesPoint
|
||||
for rows.Next() {
|
||||
var dateStr string
|
||||
var p TimeSeriesPoint
|
||||
if err := rows.Scan(&dateStr, &p.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning time series point: %w", err)
|
||||
}
|
||||
p.Timestamp, _ = time.Parse("2006-01-02", dateStr)
|
||||
points = append(points, p)
|
||||
}
|
||||
return points, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
query := `SELECT CAST(STRFTIME('%H', last_seen) AS INTEGER) AS h, SUM(count) FROM login_attempts WHERE 1=1`
|
||||
var args []any
|
||||
|
||||
if since != nil {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if until != nil {
|
||||
query += ` AND last_seen <= ?`
|
||||
args = append(args, until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
query += ` GROUP BY h ORDER BY h`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying hourly pattern: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var counts []HourlyCount
|
||||
for rows.Next() {
|
||||
var c HourlyCount
|
||||
if err := rows.Scan(&c.Hour, &c.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning hourly count: %w", err)
|
||||
}
|
||||
counts = append(counts, c)
|
||||
}
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT country, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
WHERE country != ''
|
||||
GROUP BY country
|
||||
ORDER BY total DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying country stats: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var counts []CountryCount
|
||||
for rows.Next() {
|
||||
var c CountryCount
|
||||
if err := rows.Scan(&c.Country, &c.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning country count: %w", err)
|
||||
}
|
||||
counts = append(counts, c)
|
||||
}
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
// buildAttemptWhereClause builds a dynamic WHERE clause for login_attempts filtering.
|
||||
func buildAttemptWhereClause(f DashboardFilter) (string, []any) {
|
||||
var clauses []string
|
||||
var args []any
|
||||
|
||||
if f.Since != nil {
|
||||
clauses = append(clauses, "last_seen >= ?")
|
||||
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
clauses = append(clauses, "last_seen <= ?")
|
||||
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
clauses = append(clauses, "ip = ?")
|
||||
args = append(args, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
clauses = append(clauses, "country = ?")
|
||||
args = append(args, f.Country)
|
||||
}
|
||||
if f.Username != "" {
|
||||
clauses = append(clauses, "username = ?")
|
||||
args = append(args, f.Username)
|
||||
}
|
||||
|
||||
if len(clauses) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
stats := &DashboardStats{}
|
||||
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
`SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`+where, args...).
|
||||
Scan(&stats.TotalAttempts, &stats.UniqueIPs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered attempt stats: %w", err)
|
||||
}
|
||||
|
||||
// Sessions don't have username/password, so only filter by time, IP, country.
|
||||
sessQuery := `SELECT COUNT(*) FROM sessions WHERE 1=1`
|
||||
var sessArgs []any
|
||||
if f.Since != nil {
|
||||
sessQuery += ` AND connected_at >= ?`
|
||||
sessArgs = append(sessArgs, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
sessQuery += ` AND connected_at <= ?`
|
||||
sessArgs = append(sessArgs, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
sessQuery += ` AND ip = ?`
|
||||
sessArgs = append(sessArgs, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
sessQuery += ` AND country = ?`
|
||||
sessArgs = append(sessArgs, f.Country)
|
||||
}
|
||||
|
||||
err = s.db.QueryRowContext(ctx, sessQuery, sessArgs...).Scan(&stats.TotalSessions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered total sessions: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.QueryRowContext(ctx, sessQuery+` AND disconnected_at IS NULL`, sessArgs...).Scan(&stats.ActiveSessions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered active sessions: %w", err)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return s.queryFilteredTopN(ctx, "username", limit, f)
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return s.queryFilteredTopN(ctx, "password", limit, f)
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
args = append(args, limit)
|
||||
//nolint:gosec // where clause built from trusted constants, not user input
|
||||
query := `SELECT ip, country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY ip ORDER BY total DESC LIMIT ?`
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top IPs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top IPs: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
countryClause := "country != ''"
|
||||
if where == "" {
|
||||
where = " WHERE " + countryClause
|
||||
} else {
|
||||
where += " AND " + countryClause
|
||||
}
|
||||
args = append(args, limit)
|
||||
//nolint:gosec // where clause built from trusted constants, not user input
|
||||
query := `SELECT country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY country ORDER BY total DESC LIMIT ?`
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top countries: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top countries: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) queryFilteredTopN(ctx context.Context, column string, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
switch column {
|
||||
case "username", "password":
|
||||
// valid columns
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid column: %s", column)
|
||||
}
|
||||
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
args = append(args, limit)
|
||||
query := fmt.Sprintf(`
|
||||
SELECT %s, SUM(count) AS total
|
||||
FROM login_attempts%s
|
||||
GROUP BY %s
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, column, where, column)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top %s: %w", column, err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top %s: %w", column, err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
@@ -23,17 +23,17 @@ func TestRecordLoginAttempt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// First attempt creates a new record.
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("first attempt: %v", err)
|
||||
}
|
||||
|
||||
// Second attempt with same credentials increments count.
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("second attempt: %v", err)
|
||||
}
|
||||
|
||||
// Different IP is a separate record.
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
|
||||
t.Fatalf("different IP: %v", err)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestCreateAndEndSession(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func TestUpdateHumanScore(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func TestAppendSessionLog(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
@@ -159,7 +159,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
||||
}
|
||||
|
||||
// Insert a recent login attempt.
|
||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
|
||||
t.Fatalf("insert recent attempt: %v", err)
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
||||
}
|
||||
|
||||
// Insert a recent session.
|
||||
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", ""); err != nil {
|
||||
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", "", ""); err != nil {
|
||||
t.Fatalf("insert recent session: %v", err)
|
||||
}
|
||||
|
||||
@@ -204,6 +204,79 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTopExecCommands(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create sessions with exec commands.
|
||||
for range 3 {
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
}
|
||||
for range 2 {
|
||||
id, err := store.CreateSession(ctx, "10.0.0.2", "admin", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "cat /etc/passwd"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
}
|
||||
// Session without exec command — should not appear.
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.3", "test", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
|
||||
entries, err := store.GetTopExecCommands(ctx, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTopExecCommands: %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(entries))
|
||||
}
|
||||
if entries[0].Value != "uname -a" || entries[0].Count != 3 {
|
||||
t.Errorf("entries[0] = %+v, want uname -a:3", entries[0])
|
||||
}
|
||||
if entries[1].Value != "cat /etc/passwd" || entries[1].Count != 2 {
|
||||
t.Errorf("entries[1] = %+v, want cat /etc/passwd:2", entries[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRecentSessionsEventCount(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
|
||||
// Add some events.
|
||||
events := []SessionEvent{
|
||||
{SessionID: id, Timestamp: time.Now(), Direction: 0, Data: []byte("ls\n")},
|
||||
{SessionID: id, Timestamp: time.Now(), Direction: 1, Data: []byte("file1\n")},
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("appending events: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].EventCount != 2 {
|
||||
t.Errorf("EventCount = %d, want 2", sessions[0].EventCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSQLiteStoreCreatesFile(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||
store, err := NewSQLiteStore(dbPath)
|
||||
@@ -214,7 +287,7 @@ func TestNewSQLiteStoreCreatesFile(t *testing.T) {
|
||||
|
||||
// Verify we can use the store.
|
||||
ctx := context.Background()
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1", ""); err != nil {
|
||||
t.Fatalf("recording attempt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ type LoginAttempt struct {
|
||||
Username string
|
||||
Password string
|
||||
IP string
|
||||
Country string
|
||||
Count int
|
||||
FirstSeen time.Time
|
||||
LastSeen time.Time
|
||||
@@ -20,11 +21,15 @@ type LoginAttempt struct {
|
||||
type Session struct {
|
||||
ID string
|
||||
IP string
|
||||
Country string
|
||||
Username string
|
||||
ShellName string
|
||||
ConnectedAt time.Time
|
||||
DisconnectedAt *time.Time
|
||||
HumanScore *float64
|
||||
ExecCommand *string
|
||||
EventCount int
|
||||
InputBytes int64
|
||||
}
|
||||
|
||||
// SessionLog represents a single log entry for a session.
|
||||
@@ -52,20 +57,50 @@ type DashboardStats struct {
|
||||
ActiveSessions int64
|
||||
}
|
||||
|
||||
// TimeSeriesPoint represents a single data point in a time series.
|
||||
type TimeSeriesPoint struct {
|
||||
Timestamp time.Time
|
||||
Count int64
|
||||
}
|
||||
|
||||
// HourlyCount represents the total attempts for a given hour of day.
|
||||
type HourlyCount struct {
|
||||
Hour int // 0-23
|
||||
Count int64
|
||||
}
|
||||
|
||||
// CountryCount represents the total attempts from a given country.
|
||||
type CountryCount struct {
|
||||
Country string
|
||||
Count int64
|
||||
}
|
||||
|
||||
// DashboardFilter contains optional filters for dashboard queries.
|
||||
type DashboardFilter struct {
|
||||
Since *time.Time
|
||||
Until *time.Time
|
||||
IP string
|
||||
Country string
|
||||
Username string
|
||||
HumanScoreAboveZero bool
|
||||
SortBy string
|
||||
}
|
||||
|
||||
// TopEntry represents a value and its count for top-N queries.
|
||||
type TopEntry struct {
|
||||
Value string
|
||||
Count int64
|
||||
Value string
|
||||
Country string // populated by GetTopIPs
|
||||
Count int64
|
||||
}
|
||||
|
||||
// Store is the interface for persistent storage of honeypot data.
|
||||
type Store interface {
|
||||
// RecordLoginAttempt upserts a login attempt, incrementing the count
|
||||
// for existing (username, password, ip) combinations.
|
||||
RecordLoginAttempt(ctx context.Context, username, password, ip string) error
|
||||
RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error
|
||||
|
||||
// CreateSession creates a new session record and returns its UUID.
|
||||
CreateSession(ctx context.Context, ip, username, shellName string) (string, error)
|
||||
CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error)
|
||||
|
||||
// EndSession sets the disconnected_at timestamp for a session.
|
||||
EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error
|
||||
@@ -73,6 +108,9 @@ type Store interface {
|
||||
// UpdateHumanScore sets the human detection score for a session.
|
||||
UpdateHumanScore(ctx context.Context, sessionID string, score float64) error
|
||||
|
||||
// SetExecCommand sets the exec command for a session.
|
||||
SetExecCommand(ctx context.Context, sessionID string, command string) error
|
||||
|
||||
// AppendSessionLog adds a log entry to a session.
|
||||
AppendSessionLog(ctx context.Context, sessionID, input, output string) error
|
||||
|
||||
@@ -92,10 +130,20 @@ type Store interface {
|
||||
// GetTopIPs returns the top N IPs by total attempt count.
|
||||
GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error)
|
||||
|
||||
// GetTopCountries returns the top N countries by total attempt count.
|
||||
GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error)
|
||||
|
||||
// GetTopExecCommands returns the top N exec commands by session count.
|
||||
GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error)
|
||||
|
||||
// GetRecentSessions returns the most recent sessions ordered by connected_at DESC.
|
||||
// If activeOnly is true, only sessions with no disconnected_at are returned.
|
||||
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
|
||||
|
||||
// GetFilteredSessions returns sessions matching the given filter, ordered
|
||||
// by the filter's SortBy field (default: connected_at DESC).
|
||||
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
|
||||
|
||||
// GetSession returns a single session by ID.
|
||||
GetSession(ctx context.Context, sessionID string) (*Session, error)
|
||||
|
||||
@@ -108,6 +156,35 @@ type Store interface {
|
||||
// GetSessionEvents returns all events for a session ordered by id.
|
||||
GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error)
|
||||
|
||||
// CloseActiveSessions sets disconnected_at for all sessions that are
|
||||
// still marked as active. This should be called at startup to clean up
|
||||
// sessions left over from a previous unclean shutdown.
|
||||
CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error)
|
||||
|
||||
// GetAttemptsOverTime returns daily attempt counts for the last N days.
|
||||
GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error)
|
||||
|
||||
// GetHourlyPattern returns total attempts grouped by hour of day (0-23).
|
||||
GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error)
|
||||
|
||||
// GetCountryStats returns total attempts per country, ordered by count DESC.
|
||||
GetCountryStats(ctx context.Context) ([]CountryCount, error)
|
||||
|
||||
// GetFilteredDashboardStats returns aggregate counts with optional filters applied.
|
||||
GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error)
|
||||
|
||||
// GetFilteredTopUsernames returns top usernames with optional filters applied.
|
||||
GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopPasswords returns top passwords with optional filters applied.
|
||||
GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopIPs returns top IPs with optional filters applied.
|
||||
GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopCountries returns top countries with optional filters applied.
|
||||
GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// Close releases any resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -38,23 +38,23 @@ func seedData(t *testing.T, store Store) {
|
||||
|
||||
// Login attempts: root/toor from two IPs, admin/admin from one IP.
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for range 3 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for range 2 {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sessions: one active, one ended.
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func seedData(t *testing.T, store Store) {
|
||||
t.Fatalf("ending session: %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash"); err != nil {
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -210,7 +210,7 @@ func TestGetSession(t *testing.T) {
|
||||
t.Run("found", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
@@ -233,7 +233,7 @@ func TestGetSessionLogs(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
@@ -277,7 +277,7 @@ func TestSessionEvents(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
@@ -316,6 +316,334 @@ func TestSessionEvents(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseActiveSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("no active sessions", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
n, err := store.CloseActiveSessions(ctx, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("CloseActiveSessions: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("closed %d, want 0", n)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("closes only active sessions", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 3 sessions: end one, leave two active.
|
||||
id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
|
||||
store.CreateSession(ctx, "10.0.0.3", "test", "bash", "")
|
||||
store.EndSession(ctx, id1, time.Now())
|
||||
|
||||
n, err := store.CloseActiveSessions(ctx, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("CloseActiveSessions: %v", err)
|
||||
}
|
||||
if n != 2 {
|
||||
t.Errorf("closed %d, want 2", n)
|
||||
}
|
||||
|
||||
// Verify no active sessions remain.
|
||||
active, err := store.GetRecentSessions(ctx, 10, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(active) != 0 {
|
||||
t.Errorf("active sessions = %d, want 0", len(active))
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetExecCommand(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("set and retrieve", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Initially nil.
|
||||
s, err := store.GetSession(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if s.ExecCommand != nil {
|
||||
t.Errorf("expected nil ExecCommand, got %q", *s.ExecCommand)
|
||||
}
|
||||
|
||||
// Set exec command.
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("SetExecCommand: %v", err)
|
||||
}
|
||||
|
||||
s, err = store.GetSession(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if s.ExecCommand == nil {
|
||||
t.Fatal("expected non-nil ExecCommand")
|
||||
}
|
||||
if *s.ExecCommand != "uname -a" {
|
||||
t.Errorf("ExecCommand = %q, want %q", *s.ExecCommand, "uname -a")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("appears in recent sessions", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "id"); err != nil {
|
||||
t.Fatalf("SetExecCommand: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ExecCommand == nil || *sessions[0].ExecCommand != "id" {
|
||||
t.Errorf("ExecCommand = %v, want \"id\"", sessions[0].ExecCommand)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func seedChartData(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
// Record attempts with country data from different IPs.
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for range 3 {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for range 2 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "123456", "10.0.0.3", "CN"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAttemptsOverTime(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAttemptsOverTime: %v", err)
|
||||
}
|
||||
if len(points) != 0 {
|
||||
t.Errorf("expected empty, got %v", points)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAttemptsOverTime: %v", err)
|
||||
}
|
||||
// All data was inserted today, so should be one point.
|
||||
if len(points) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(points))
|
||||
}
|
||||
// 5 + 3 + 2 = 10 total.
|
||||
if points[0].Count != 10 {
|
||||
t.Errorf("count = %d, want 10", points[0].Count)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetHourlyPattern(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetHourlyPattern: %v", err)
|
||||
}
|
||||
if len(counts) != 0 {
|
||||
t.Errorf("expected empty, got %v", counts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetHourlyPattern: %v", err)
|
||||
}
|
||||
// All data was inserted at the same hour.
|
||||
if len(counts) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(counts))
|
||||
}
|
||||
if counts[0].Count != 10 {
|
||||
t.Errorf("count = %d, want 10", counts[0].Count)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCountryStats(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
counts, err := store.GetCountryStats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 0 {
|
||||
t.Errorf("expected empty, got %v", counts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
counts, err := store.GetCountryStats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(counts))
|
||||
}
|
||||
// CN: 5 + 2 = 7, RU: 3 - ordered by count DESC.
|
||||
if counts[0].Country != "CN" || counts[0].Count != 7 {
|
||||
t.Errorf("counts[0] = %+v, want CN/7", counts[0])
|
||||
}
|
||||
if counts[1].Country != "RU" || counts[1].Count != 3 {
|
||||
t.Errorf("counts[1] = %+v, want RU/3", counts[1])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("excludes empty country", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.2", "US"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
|
||||
counts, err := store.GetCountryStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(counts))
|
||||
}
|
||||
if counts[0].Country != "US" {
|
||||
t.Errorf("country = %q, want US", counts[0].Country)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredDashboardStats(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("no filter", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 10 {
|
||||
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by country", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Country: "CN"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
// CN: 5 + 2 = 7
|
||||
if stats.TotalAttempts != 7 {
|
||||
t.Errorf("TotalAttempts = %d, want 7", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by IP", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{IP: "10.0.0.1"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 5 {
|
||||
t.Errorf("TotalAttempts = %d, want 5", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by username", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Username: "admin"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 3 {
|
||||
t.Errorf("TotalAttempts = %d, want 3", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredTopUsernames(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
// Filter by country CN should only show root.
|
||||
entries, err := store.GetFilteredTopUsernames(context.Background(), 10, DashboardFilter{Country: "CN"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredTopUsernames: %v", err)
|
||||
}
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(entries))
|
||||
}
|
||||
if entries[0].Value != "root" || entries[0].Count != 7 {
|
||||
t.Errorf("entries[0] = %+v, want root/7", entries[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetRecentSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
@@ -372,3 +700,192 @@ func TestGetRecentSessions(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestInputBytes(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("counts only input direction", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
events := []SessionEvent{
|
||||
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
|
||||
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
|
||||
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
|
||||
if sessions[0].InputBytes != 7 {
|
||||
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero when no events", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].InputBytes != 0 {
|
||||
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("filter by human score", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create two sessions, one with human score > 0.
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort by input bytes", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Session with more input (created first).
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
|
||||
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
|
||||
}); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
// Session with less input (created after id1, so would be first by connected_at).
|
||||
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
|
||||
}); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
// Default sort (connected_at DESC) should show id2 first.
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id2 {
|
||||
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
|
||||
}
|
||||
|
||||
// Sort by input_bytes should show id1 first (more input).
|
||||
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("combined filters", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
// Different country, also has score.
|
||||
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
// Same country CN but no score.
|
||||
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Filter: CN + human score > 0 -> only id1.
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
|
||||
Country: "CN",
|
||||
HumanScoreAboveZero: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,24 +1,37 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// dbContext returns a context detached from the HTTP request lifecycle with a
|
||||
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
|
||||
// queries when the browser aborts the previous XHR.
|
||||
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
|
||||
}
|
||||
|
||||
type dashboardData struct {
|
||||
Stats *storage.DashboardStats
|
||||
TopUsernames []storage.TopEntry
|
||||
TopPasswords []storage.TopEntry
|
||||
TopIPs []storage.TopEntry
|
||||
ActiveSessions []storage.Session
|
||||
RecentSessions []storage.Session
|
||||
Stats *storage.DashboardStats
|
||||
TopUsernames []storage.TopEntry
|
||||
TopPasswords []storage.TopEntry
|
||||
TopIPs []storage.TopEntry
|
||||
TopCountries []storage.TopEntry
|
||||
TopExecCommands []storage.TopEntry
|
||||
ActiveSessions []storage.Session
|
||||
RecentSessions []storage.Session
|
||||
}
|
||||
|
||||
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
@@ -48,6 +61,20 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
topCountries, err := s.store.GetTopCountries(ctx, 10)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get top countries", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topExecCommands, err := s.store.GetTopExecCommands(ctx, 10)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get top exec commands", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
activeSessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get active sessions", "err", err)
|
||||
@@ -63,12 +90,14 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
data := dashboardData{
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
ActiveSessions: activeSessions,
|
||||
RecentSessions: recentSessions,
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
TopCountries: topCountries,
|
||||
TopExecCommands: topExecCommands,
|
||||
ActiveSessions: activeSessions,
|
||||
RecentSessions: recentSessions,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
@@ -78,7 +107,10 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := s.store.GetDashboardStats(r.Context())
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get dashboard stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -92,7 +124,10 @@ func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
|
||||
sessions, err := s.store.GetRecentSessions(r.Context(), 50, true)
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get active sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -105,6 +140,24 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
f := parseDashboardFilter(r)
|
||||
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
|
||||
s.logger.Error("failed to render recent sessions fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type sessionDetailData struct {
|
||||
Session *storage.Session
|
||||
Logs []storage.SessionLog
|
||||
@@ -112,7 +165,8 @@ type sessionDetailData struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
session, err := s.store.GetSession(ctx, sessionID)
|
||||
@@ -162,8 +216,201 @@ type apiEventsResponse struct {
|
||||
Events []apiEvent `json:"events"`
|
||||
}
|
||||
|
||||
// parseDateParam parses a "YYYY-MM-DD" query parameter into a *time.Time.
|
||||
func parseDateParam(r *http.Request, name string) *time.Time {
|
||||
v := r.URL.Query().Get(name)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
t, err := time.Parse("2006-01-02", v)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// For "until" dates, set to end of day.
|
||||
if name == "until" {
|
||||
t = t.Add(24*time.Hour - time.Second)
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
|
||||
return storage.DashboardFilter{
|
||||
Since: parseDateParam(r, "since"),
|
||||
Until: parseDateParam(r, "until"),
|
||||
IP: r.URL.Query().Get("ip"),
|
||||
Country: r.URL.Query().Get("country"),
|
||||
Username: r.URL.Query().Get("username"),
|
||||
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
|
||||
SortBy: r.URL.Query().Get("sort"),
|
||||
}
|
||||
}
|
||||
|
||||
type apiTimeSeriesPoint struct {
|
||||
Date string `json:"date"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiAttemptsOverTimeResponse struct {
|
||||
Points []apiTimeSeriesPoint `json:"points"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Request) {
|
||||
days := 30
|
||||
if v := r.URL.Query().Get("days"); v != "" {
|
||||
if d, err := strconv.Atoi(v); err == nil && d > 0 && d <= 365 {
|
||||
days = d
|
||||
}
|
||||
}
|
||||
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get attempts over time", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiAttemptsOverTimeResponse{Points: make([]apiTimeSeriesPoint, len(points))}
|
||||
for i, p := range points {
|
||||
resp.Points[i] = apiTimeSeriesPoint{
|
||||
Date: p.Timestamp.Format("2006-01-02"),
|
||||
Count: p.Count,
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode attempts over time", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type apiHourlyCount struct {
|
||||
Hour int `json:"hour"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiHourlyPatternResponse struct {
|
||||
Hours []apiHourlyCount `json:"hours"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
counts, err := s.store.GetHourlyPattern(ctx, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get hourly pattern", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiHourlyPatternResponse{Hours: make([]apiHourlyCount, len(counts))}
|
||||
for i, c := range counts {
|
||||
resp.Hours[i] = apiHourlyCount{Hour: c.Hour, Count: c.Count}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode hourly pattern", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type apiCountryCount struct {
|
||||
Country string `json:"country"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiCountryStatsResponse struct {
|
||||
Countries []apiCountryCount `json:"countries"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
counts, err := s.store.GetCountryStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get country stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiCountryStatsResponse{Countries: make([]apiCountryCount, len(counts))}
|
||||
for i, c := range counts {
|
||||
resp.Countries[i] = apiCountryCount{Country: c.Country, Count: c.Count}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode country stats", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
f := parseDashboardFilter(r)
|
||||
|
||||
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topUsernames, err := s.store.GetFilteredTopUsernames(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top usernames", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topPasswords, err := s.store.GetFilteredTopPasswords(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top passwords", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topIPs, err := s.store.GetFilteredTopIPs(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top IPs", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topCountries, err := s.store.GetFilteredTopCountries(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top countries", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := dashboardData{
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
TopCountries: topCountries,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "dashboard_content", data); err != nil {
|
||||
s.logger.Error("failed to render dashboard content fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
events, err := s.store.GetSessionEvents(ctx, sessionID)
|
||||
|
||||
14
internal/web/static/chart.min.js
vendored
Normal file
14
internal/web/static/chart.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
275
internal/web/static/dashboard.js
Normal file
275
internal/web/static/dashboard.js
Normal file
@@ -0,0 +1,275 @@
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
// Chart.js theme for Pico dark mode
|
||||
Chart.defaults.color = '#b0b0b8';
|
||||
Chart.defaults.borderColor = '#3a3a4a';
|
||||
|
||||
var attemptsChart = null;
|
||||
var hourlyChart = null;
|
||||
|
||||
function getFilterParams() {
|
||||
var form = document.getElementById('filter-form');
|
||||
if (!form) return '';
|
||||
var params = new URLSearchParams();
|
||||
var since = form.elements['since'].value;
|
||||
var until = form.elements['until'].value;
|
||||
if (since) params.set('since', since);
|
||||
if (until) params.set('until', until);
|
||||
var humanScore = form.elements['human_score'];
|
||||
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||
var sortBy = form.elements['sort'];
|
||||
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||
return params.toString();
|
||||
}
|
||||
|
||||
function initAttemptsChart() {
|
||||
var canvas = document.getElementById('chart-attempts');
|
||||
if (!canvas) return;
|
||||
var ctx = canvas.getContext('2d');
|
||||
|
||||
var qs = getFilterParams();
|
||||
var url = '/api/charts/attempts-over-time' + (qs ? '?' + qs : '');
|
||||
|
||||
fetch(url)
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
var labels = data.points.map(function(p) { return p.date; });
|
||||
var values = data.points.map(function(p) { return p.count; });
|
||||
|
||||
if (attemptsChart) {
|
||||
attemptsChart.data.labels = labels;
|
||||
attemptsChart.data.datasets[0].data = values;
|
||||
attemptsChart.update();
|
||||
return;
|
||||
}
|
||||
|
||||
attemptsChart = new Chart(ctx, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: labels,
|
||||
datasets: [{
|
||||
label: 'Attempts',
|
||||
data: values,
|
||||
borderColor: '#6366f1',
|
||||
backgroundColor: 'rgba(99, 102, 241, 0.1)',
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
pointRadius: 2
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: true,
|
||||
plugins: { legend: { display: false } },
|
||||
scales: {
|
||||
x: { grid: { display: false } },
|
||||
y: { beginAtZero: true }
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function initHourlyChart() {
|
||||
var canvas = document.getElementById('chart-hourly');
|
||||
if (!canvas) return;
|
||||
var ctx = canvas.getContext('2d');
|
||||
|
||||
var qs = getFilterParams();
|
||||
var url = '/api/charts/hourly-pattern' + (qs ? '?' + qs : '');
|
||||
|
||||
fetch(url)
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
// Fill all 24 hours, defaulting to 0
|
||||
var hourMap = {};
|
||||
data.hours.forEach(function(h) { hourMap[h.hour] = h.count; });
|
||||
var labels = [];
|
||||
var values = [];
|
||||
for (var i = 0; i < 24; i++) {
|
||||
labels.push(i + ':00');
|
||||
values.push(hourMap[i] || 0);
|
||||
}
|
||||
|
||||
if (hourlyChart) {
|
||||
hourlyChart.data.labels = labels;
|
||||
hourlyChart.data.datasets[0].data = values;
|
||||
hourlyChart.update();
|
||||
return;
|
||||
}
|
||||
|
||||
hourlyChart = new Chart(ctx, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
labels: labels,
|
||||
datasets: [{
|
||||
label: 'Attempts',
|
||||
data: values,
|
||||
backgroundColor: 'rgba(99, 102, 241, 0.6)',
|
||||
borderColor: '#6366f1',
|
||||
borderWidth: 1
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: true,
|
||||
plugins: { legend: { display: false } },
|
||||
scales: {
|
||||
x: { grid: { display: false } },
|
||||
y: { beginAtZero: true }
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function initWorldMap() {
|
||||
var container = document.getElementById('world-map');
|
||||
if (!container) return;
|
||||
|
||||
fetch('/static/world.svg')
|
||||
.then(function(r) { return r.text(); })
|
||||
.then(function(svgText) {
|
||||
container.innerHTML = svgText;
|
||||
|
||||
fetch('/api/charts/country-stats')
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
colorMap(container, data.countries);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function colorMap(container, countries) {
|
||||
if (!countries || countries.length === 0) return;
|
||||
|
||||
var maxCount = countries[0].count; // already sorted DESC
|
||||
var logMax = Math.log(maxCount + 1);
|
||||
|
||||
// Build lookup
|
||||
var lookup = {};
|
||||
countries.forEach(function(c) {
|
||||
lookup[c.country.toLowerCase()] = c.count;
|
||||
});
|
||||
|
||||
// Create tooltip element
|
||||
var tooltip = document.createElement('div');
|
||||
tooltip.id = 'map-tooltip';
|
||||
tooltip.style.cssText = 'position:fixed;display:none;background:#1a1a2e;color:#e0e0e8;padding:4px 8px;border-radius:4px;font-size:13px;pointer-events:none;z-index:1000;border:1px solid #3a3a4a;';
|
||||
document.body.appendChild(tooltip);
|
||||
|
||||
var svg = container.querySelector('svg');
|
||||
if (!svg) return;
|
||||
|
||||
// Remove SVG title to prevent browser native tooltip
|
||||
var svgTitle = svg.querySelector('title');
|
||||
if (svgTitle) svgTitle.remove();
|
||||
|
||||
// Select both <path id="xx"> and <g id="xx"> country elements
|
||||
var elements = svg.querySelectorAll('path[id], g[id]');
|
||||
elements.forEach(function(el) {
|
||||
var id = el.id.toLowerCase();
|
||||
if (id.charAt(0) === '_') return; // skip non-country paths
|
||||
|
||||
var count = lookup[id];
|
||||
if (count) {
|
||||
var intensity = Math.log(count + 1) / logMax;
|
||||
var r = Math.round(30 + intensity * 69); // 30 -> 99
|
||||
var g = Math.round(30 + intensity * 72); // 30 -> 102
|
||||
var b = Math.round(62 + intensity * 179); // 62 -> 241
|
||||
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
|
||||
// For <g> elements, color child paths; for <path>, color directly
|
||||
if (el.tagName.toLowerCase() === 'g') {
|
||||
el.querySelectorAll('path').forEach(function(p) {
|
||||
p.style.fill = color;
|
||||
});
|
||||
} else {
|
||||
el.style.fill = color;
|
||||
}
|
||||
}
|
||||
|
||||
el.addEventListener('mouseenter', function(e) {
|
||||
var cc = id.toUpperCase();
|
||||
var n = lookup[id] || 0;
|
||||
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
|
||||
tooltip.style.display = 'block';
|
||||
});
|
||||
|
||||
el.addEventListener('mousemove', function(e) {
|
||||
tooltip.style.left = (e.clientX + 12) + 'px';
|
||||
tooltip.style.top = (e.clientY - 10) + 'px';
|
||||
});
|
||||
|
||||
el.addEventListener('mouseleave', function() {
|
||||
tooltip.style.display = 'none';
|
||||
});
|
||||
|
||||
el.addEventListener('click', function() {
|
||||
var input = document.querySelector('#filter-form input[name="country"]');
|
||||
if (input) {
|
||||
input.value = id.toUpperCase();
|
||||
applyFilters();
|
||||
}
|
||||
});
|
||||
|
||||
el.style.cursor = 'pointer';
|
||||
});
|
||||
}
|
||||
|
||||
function applyFilters() {
|
||||
// Re-fetch charts with filter params
|
||||
initAttemptsChart();
|
||||
initHourlyChart();
|
||||
|
||||
// Re-fetch dashboard content via htmx
|
||||
var form = document.getElementById('filter-form');
|
||||
if (!form) return;
|
||||
|
||||
var params = new URLSearchParams();
|
||||
['since', 'until', 'ip', 'country', 'username'].forEach(function(name) {
|
||||
var val = form.elements[name].value;
|
||||
if (val) params.set(name, val);
|
||||
});
|
||||
|
||||
var humanScore = form.elements['human_score'];
|
||||
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||
var sortBy = form.elements['sort'];
|
||||
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||
|
||||
var target = document.getElementById('dashboard-content');
|
||||
if (target) {
|
||||
var url = '/fragments/dashboard-content?' + params.toString();
|
||||
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
|
||||
}
|
||||
|
||||
// Server-side filter for recent sessions table
|
||||
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
|
||||
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
|
||||
}
|
||||
|
||||
window.clearFilters = function() {
|
||||
var form = document.getElementById('filter-form');
|
||||
if (form) {
|
||||
form.reset();
|
||||
applyFilters();
|
||||
}
|
||||
};
|
||||
|
||||
window.applyFilters = applyFilters;
|
||||
|
||||
// Initialize on DOM ready
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
initAttemptsChart();
|
||||
initHourlyChart();
|
||||
initWorldMap();
|
||||
|
||||
var form = document.getElementById('filter-form');
|
||||
if (form) {
|
||||
form.addEventListener('submit', function(e) {
|
||||
e.preventDefault();
|
||||
applyFilters();
|
||||
});
|
||||
}
|
||||
});
|
||||
})();
|
||||
1
internal/web/static/world.svg
Normal file
1
internal/web/static/world.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 55 KiB |
@@ -44,6 +44,32 @@ func templateFuncMap() template.FuncMap {
|
||||
}
|
||||
return fmt.Sprintf("%.0f%%", *f*100)
|
||||
},
|
||||
"derefString": func(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
},
|
||||
"truncateCommand": func(s string) string {
|
||||
if len(s) > 50 {
|
||||
return s[:50] + "..."
|
||||
}
|
||||
return s
|
||||
},
|
||||
"formatBytes": func(b int64) string {
|
||||
const (
|
||||
kb = 1024
|
||||
mb = 1024 * kb
|
||||
)
|
||||
switch {
|
||||
case b >= mb:
|
||||
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
|
||||
case b >= kb:
|
||||
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +81,7 @@ func loadTemplates() (*templateSet, error) {
|
||||
"templates/dashboard.html",
|
||||
"templates/fragments/stats.html",
|
||||
"templates/fragments/active_sessions.html",
|
||||
"templates/fragments/recent_sessions.html",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing dashboard templates: %w", err)
|
||||
|
||||
@@ -3,6 +3,86 @@
|
||||
{{template "stats" .Stats}}
|
||||
</section>
|
||||
|
||||
<details>
|
||||
<summary>Filters</summary>
|
||||
<form id="filter-form">
|
||||
<div class="grid">
|
||||
<label>Since <input type="date" name="since"></label>
|
||||
<label>Until <input type="date" name="until"></label>
|
||||
<label>IP <input type="text" name="ip" placeholder="10.0.0.1"></label>
|
||||
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
|
||||
<label>Username <input type="text" name="username" placeholder="root"></label>
|
||||
</div>
|
||||
<div class="grid">
|
||||
<label><input type="checkbox" name="human_score" value="1"> Human score > 0</label>
|
||||
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
|
||||
</div>
|
||||
<button type="submit">Apply</button>
|
||||
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
|
||||
</form>
|
||||
</details>
|
||||
|
||||
<section>
|
||||
<h3>Attack Trends</h3>
|
||||
<div class="grid">
|
||||
<article>
|
||||
<header>Attempts Over Time</header>
|
||||
<canvas id="chart-attempts"></canvas>
|
||||
</article>
|
||||
<article>
|
||||
<header>Hourly Pattern (UTC)</header>
|
||||
<canvas id="chart-hourly"></canvas>
|
||||
</article>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Attack Origins</h3>
|
||||
<article>
|
||||
<div id="world-map"></div>
|
||||
</article>
|
||||
</section>
|
||||
|
||||
<div id="dashboard-content">
|
||||
{{template "dashboard_content" .}}
|
||||
</div>
|
||||
|
||||
<section>
|
||||
<h3>Active Sessions</h3>
|
||||
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
|
||||
{{template "active_sessions" .ActiveSessions}}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Recent Sessions</h3>
|
||||
<table id="recent-sessions-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>IP</th>
|
||||
<th>Country</th>
|
||||
<th>Username</th>
|
||||
<th>Type</th>
|
||||
<th>Score</th>
|
||||
<th>Input</th>
|
||||
<th>Connected</th>
|
||||
<th>Disconnected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{template "recent_sessions" .RecentSessions}}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
{{end}}
|
||||
|
||||
{{define "scripts"}}
|
||||
<script src="/static/chart.min.js"></script>
|
||||
<script src="/static/dashboard.js"></script>
|
||||
{{end}}
|
||||
|
||||
{{define "dashboard_content"}}
|
||||
<section>
|
||||
<h3>Top Credentials & IPs</h3>
|
||||
<div class="top-grid">
|
||||
@@ -40,10 +120,25 @@
|
||||
<header>Top IPs</header>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>IP</th><th>Attempts</th></tr>
|
||||
<tr><th>IP</th><th>Country</th><th>Attempts</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .TopIPs}}
|
||||
<tr><td>{{.Value}}</td><td>{{.Country}}</td><td>{{.Count}}</td></tr>
|
||||
{{else}}
|
||||
<tr><td colspan="3">No data</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
<article>
|
||||
<header>Top Countries</header>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Country</th><th>Attempts</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .TopCountries}}
|
||||
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
|
||||
{{else}}
|
||||
<tr><td colspan="2">No data</td></tr>
|
||||
@@ -51,45 +146,21 @@
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
<article>
|
||||
<header>Top Exec Commands</header>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Command</th><th>Count</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .TopExecCommands}}
|
||||
<tr><td><code>{{truncateCommand .Value}}</code></td><td>{{.Count}}</td></tr>
|
||||
{{else}}
|
||||
<tr><td colspan="2">No data</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Active Sessions</h3>
|
||||
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
|
||||
{{template "active_sessions" .ActiveSessions}}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Recent Sessions</h3>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>IP</th>
|
||||
<th>Username</th>
|
||||
<th>Shell</th>
|
||||
<th>Score</th>
|
||||
<th>Connected</th>
|
||||
<th>Disconnected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .RecentSessions}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a></td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{.ShellName}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="7">No sessions</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
{{end}}
|
||||
|
||||
@@ -4,24 +4,28 @@
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>IP</th>
|
||||
<th>Country</th>
|
||||
<th>Username</th>
|
||||
<th>Shell</th>
|
||||
<th>Type</th>
|
||||
<th>Score</th>
|
||||
<th>Input</th>
|
||||
<th>Connected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a></td>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{.ShellName}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatBytes .InputBytes}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="6">No active sessions</td></tr>
|
||||
<tr><td colspan="8">No active sessions</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
17
internal/web/templates/fragments/recent_sessions.html
Normal file
17
internal/web/templates/fragments/recent_sessions.html
Normal file
@@ -0,0 +1,17 @@
|
||||
{{define "recent_sessions"}}
|
||||
{{range .}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatBytes .InputBytes}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="9">No sessions</td></tr>
|
||||
{{end}}
|
||||
{{end}}
|
||||
@@ -29,9 +29,16 @@
|
||||
}
|
||||
.top-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
|
||||
grid-template-columns: repeat(auto-fit, minmax(380px, 1fr));
|
||||
gap: 1rem;
|
||||
}
|
||||
.top-grid article {
|
||||
overflow: hidden;
|
||||
min-width: 0;
|
||||
}
|
||||
#world-map svg { width: 100%; height: auto; }
|
||||
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
|
||||
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
|
||||
nav h1 {
|
||||
margin: 0;
|
||||
}
|
||||
@@ -52,5 +59,6 @@
|
||||
<main class="container">
|
||||
{{block "content" .}}{{end}}
|
||||
</main>
|
||||
{{block "scripts" .}}{{end}}
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -7,8 +7,10 @@
|
||||
<table>
|
||||
<tbody>
|
||||
<tr><td><strong>IP</strong></td><td>{{.Session.IP}}</td></tr>
|
||||
<tr><td><strong>Country</strong></td><td>{{.Session.Country}}</td></tr>
|
||||
<tr><td><strong>Username</strong></td><td>{{.Session.Username}}</td></tr>
|
||||
<tr><td><strong>Shell</strong></td><td>{{.Session.ShellName}}</td></tr>
|
||||
{{if .Session.ExecCommand}}<tr><td><strong>Exec Command</strong></td><td><code>{{derefString .Session.ExecCommand}}</code></td></tr>{{end}}
|
||||
<tr><td><strong>Score</strong></td><td>{{formatScore .Session.HumanScore}}</td></tr>
|
||||
<tr><td><strong>Connected</strong></td><td>{{formatTime .Session.ConnectedAt}}</td></tr>
|
||||
<tr>
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"embed"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
@@ -20,7 +22,9 @@ type Server struct {
|
||||
}
|
||||
|
||||
// NewServer creates a new web Server with routes registered.
|
||||
func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
// If metricsHandler is non-nil, a /metrics endpoint is registered.
|
||||
// If metricsToken is non-empty, the metrics endpoint requires Bearer token auth.
|
||||
func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Handler, metricsToken string) (*Server, error) {
|
||||
tmpl, err := loadTemplates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -36,9 +40,22 @@ func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
s.mux.Handle("GET /static/", http.FileServerFS(staticFS))
|
||||
s.mux.HandleFunc("GET /sessions/{id}", s.handleSessionDetail)
|
||||
s.mux.HandleFunc("GET /api/sessions/{id}/events", s.handleAPISessionEvents)
|
||||
s.mux.HandleFunc("GET /api/charts/attempts-over-time", s.handleAPIAttemptsOverTime)
|
||||
s.mux.HandleFunc("GET /api/charts/hourly-pattern", s.handleAPIHourlyPattern)
|
||||
s.mux.HandleFunc("GET /api/charts/country-stats", s.handleAPICountryStats)
|
||||
s.mux.HandleFunc("GET /", s.handleDashboard)
|
||||
s.mux.HandleFunc("GET /fragments/stats", s.handleFragmentStats)
|
||||
s.mux.HandleFunc("GET /fragments/active-sessions", s.handleFragmentActiveSessions)
|
||||
s.mux.HandleFunc("GET /fragments/dashboard-content", s.handleFragmentDashboardContent)
|
||||
s.mux.HandleFunc("GET /fragments/recent-sessions", s.handleFragmentRecentSessions)
|
||||
|
||||
if metricsHandler != nil {
|
||||
h := metricsHandler
|
||||
if metricsToken != "" {
|
||||
h = requireBearerToken(metricsToken, h)
|
||||
}
|
||||
s.mux.Handle("GET /metrics", h)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@@ -47,3 +64,20 @@ func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// requireBearerToken wraps a handler to require a valid Bearer token.
|
||||
func requireBearerToken(token string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(auth, "Bearer ") {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
provided := auth[len("Bearer "):]
|
||||
if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -10,14 +10,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T) *Server {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
logger := slog.Default()
|
||||
srv, err := NewServer(store, logger)
|
||||
srv, err := NewServer(store, logger, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating server: %v", err)
|
||||
}
|
||||
@@ -30,29 +31,53 @@ func newSeededTestServer(t *testing.T) *Server {
|
||||
ctx := context.Background()
|
||||
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash"); err != nil {
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash"); err != nil {
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
|
||||
logger := slog.Default()
|
||||
srv, err := NewServer(store, logger)
|
||||
srv, err := NewServer(store, logger, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating server: %v", err)
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestDbContextNotCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
dbCtx, dbCancel := dbContext(req)
|
||||
defer dbCancel()
|
||||
|
||||
// Cancel the original request context.
|
||||
cancel()
|
||||
|
||||
// The DB context should still be usable.
|
||||
select {
|
||||
case <-dbCtx.Done():
|
||||
t.Fatal("dbContext should not be canceled when request context is canceled")
|
||||
default:
|
||||
}
|
||||
|
||||
// Verify the DB context has a deadline (from the timeout).
|
||||
if _, ok := dbCtx.Deadline(); !ok {
|
||||
t.Error("dbContext should have a deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardHandler(t *testing.T) {
|
||||
t.Run("empty store", func(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
@@ -149,12 +174,12 @@ func TestSessionDetailHandler(t *testing.T) {
|
||||
t.Run("found", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default())
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
@@ -180,7 +205,7 @@ func TestSessionDetailHandler(t *testing.T) {
|
||||
func TestAPISessionEvents(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
@@ -194,7 +219,7 @@ func TestAPISessionEvents(t *testing.T) {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default())
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
@@ -236,6 +261,293 @@ func TestAPISessionEvents(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsEndpoint(t *testing.T) {
|
||||
t.Run("enabled", func(t *testing.T) {
|
||||
m := metrics.New("test")
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `oubliette_build_info{version="test"} 1`) {
|
||||
t.Errorf("response should contain build_info metric, got:\n%s", body)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disabled", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
// Without a metrics handler, /metrics falls through to the dashboard.
|
||||
body := w.Body.String()
|
||||
if strings.Contains(body, "oubliette_build_info") {
|
||||
t.Error("response should not contain prometheus metrics when disabled")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsBearerToken(t *testing.T) {
|
||||
m := metrics.New("test")
|
||||
|
||||
t.Run("valid token", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
req.Header.Set("Authorization", "Bearer secret")
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong token", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
req.Header.Set("Authorization", "Bearer wrong")
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing header", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no token configured", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTruncateCommand(t *testing.T) {
|
||||
funcMap := templateFuncMap()
|
||||
fn := funcMap["truncateCommand"].(func(string) string)
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"short", "short"},
|
||||
{"exactly fifty characters long! that is what it i.", "exactly fifty characters long! that is what it i."},
|
||||
{"this string is definitely longer than fifty characters and should be truncated", "this string is definitely longer than fifty charac..."},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := fn(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("truncateCommand(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardExecCommands(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "Top Exec Commands") {
|
||||
t.Error("response should contain 'Top Exec Commands'")
|
||||
}
|
||||
if !strings.Contains(body, "uname -a") {
|
||||
t.Error("response should contain exec command 'uname -a'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIAttemptsOverTime(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/attempts-over-time", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
ct := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(ct, "application/json") {
|
||||
t.Errorf("Content-Type = %q, want application/json", ct)
|
||||
}
|
||||
|
||||
var resp apiAttemptsOverTimeResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
// Seeded data inserted today -> at least 1 point.
|
||||
if len(resp.Points) == 0 {
|
||||
t.Error("expected at least one data point")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIHourlyPattern(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/hourly-pattern", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp apiHourlyPatternResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
if len(resp.Hours) == 0 {
|
||||
t.Error("expected at least one hourly data point")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPICountryStats(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/country-stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp apiCountryStatsResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
if len(resp.Countries) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(resp.Countries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFragmentDashboardContent(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if strings.Contains(body, "<!DOCTYPE html>") {
|
||||
t.Error("dashboard content fragment should not contain full HTML document")
|
||||
}
|
||||
if !strings.Contains(body, "Top Usernames") {
|
||||
t.Error("dashboard content fragment should contain 'Top Usernames'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFragmentDashboardContentWithFilter(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
}
|
||||
for range 3 {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content?country=CN", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
// When filtered by CN, should show root but not admin.
|
||||
if !strings.Contains(body, "root") {
|
||||
t.Error("response should contain 'root' when filtered by CN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticAssets(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
|
||||
@@ -245,6 +557,9 @@ func TestStaticAssets(t *testing.T) {
|
||||
}{
|
||||
{"/static/pico.min.css", "text/css"},
|
||||
{"/static/htmx.min.js", "text/javascript"},
|
||||
{"/static/chart.min.js", "text/javascript"},
|
||||
{"/static/dashboard.js", "text/javascript"},
|
||||
{"/static/world.svg", "image/svg+xml"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -24,6 +24,26 @@ password = "admin"
|
||||
# password = "fridge"
|
||||
# shell = "fridge"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "teller"
|
||||
# password = "banking"
|
||||
# shell = "banking"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "admin"
|
||||
# password = "cisco"
|
||||
# shell = "cisco"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "irobot"
|
||||
# password = "roomba"
|
||||
# shell = "roomba"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "player"
|
||||
# password = "tetris"
|
||||
# shell = "tetris"
|
||||
|
||||
[storage]
|
||||
db_path = "oubliette.db"
|
||||
retention_days = 90
|
||||
@@ -32,12 +52,45 @@ retention_interval = "1h"
|
||||
# [web]
|
||||
# enabled = true
|
||||
# listen_addr = ":8080"
|
||||
# metrics_enabled = true
|
||||
# metrics_token = "" # bearer token for /metrics; empty = no auth
|
||||
|
||||
[shell]
|
||||
hostname = "ubuntu-server"
|
||||
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
|
||||
# fake_user = "" # override username in prompt; empty = use authenticated user
|
||||
|
||||
# Map usernames to specific shells (regardless of how auth succeeded).
|
||||
# Credential-specific shell overrides take priority over username routes.
|
||||
# [shell.username_routes]
|
||||
# postgres = "psql"
|
||||
# admin = "bash"
|
||||
|
||||
# Per-shell configuration (optional).
|
||||
# [shell.banking]
|
||||
# bank_name = "SECUREBANK"
|
||||
# terminal_id = "SB-0001" # random if not set
|
||||
# region = "NORTHEAST"
|
||||
|
||||
# [shell.adventure]
|
||||
# dungeon_name = "THE OUBLIETTE"
|
||||
|
||||
# [shell.cisco]
|
||||
# hostname = "Router"
|
||||
# model = "C2960"
|
||||
# ios_version = "15.0(2)SE11"
|
||||
# enable_password = "" # empty = accept after 1 failed attempt
|
||||
|
||||
# [shell.psql]
|
||||
# db_name = "postgres"
|
||||
# pg_version = "15.4"
|
||||
|
||||
# [shell.roomba]
|
||||
# No configuration options currently.
|
||||
|
||||
# [shell.tetris]
|
||||
# difficulty = "normal" # "easy" (slower start), "normal" (standard), "hard" (start at level 5)
|
||||
|
||||
# [detection]
|
||||
# enabled = true
|
||||
# threshold = 0.6 # 0.0–1.0, sessions above this trigger notifications
|
||||
|
||||
18
scripts/fetch-geoip.sh
Executable file
18
scripts/fetch-geoip.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
# Downloads the DB-IP Lite country MMDB database for development.
|
||||
# The Nix build fetches this automatically; this script is for local dev only.
|
||||
set -euo pipefail
|
||||
|
||||
URL="https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz"
|
||||
DEST="internal/geoip/dbip-country-lite.mmdb"
|
||||
|
||||
cd "$(git rev-parse --show-toplevel)"
|
||||
|
||||
if [ -f "$DEST" ]; then
|
||||
echo "GeoIP database already exists at $DEST"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Downloading DB-IP Lite country database..."
|
||||
curl -fSL "$URL" | gunzip > "$DEST"
|
||||
echo "Saved to $DEST"
|
||||
Reference in New Issue
Block a user