Compare commits
49 Commits
96c8476f77
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
1b28f10ca8
|
|||
|
664e79fce6
|
|||
|
c74313c195
|
|||
|
9783ae5865
|
|||
|
62de222488
|
|||
| c9d143d84b | |||
|
d18a904ed5
|
|||
|
cb7be28f42
|
|||
|
0908b43724
|
|||
|
52310f588d
|
|||
|
b52216bd2f
|
|||
|
2bc83a17dd
|
|||
|
faf6e2abd7
|
|||
|
0a4eac188a
|
|||
|
7c90c9ed4a
|
|||
|
8a631af0d2
|
|||
|
40fda3420c
|
|||
|
c4801e3309
|
|||
|
4f10a8a422
|
|||
|
0b44d1c83f
|
|||
|
0133d956a5
|
|||
|
3c20e854aa
|
|||
|
090dbec390
|
|||
|
df860b3061
|
|||
|
9aecc7ce02
|
|||
|
94f1f1c266
|
|||
|
8fff893d25
|
|||
|
5ba62afec3
|
|||
|
058da51f86
|
|||
|
adfe372d13
|
|||
|
3163ea47dc
|
|||
|
ab07e6a8dc
|
|||
|
b8fcbc7e10
|
|||
|
aa569aac16
|
|||
|
1a407ad4c2
|
|||
|
5d0c8cc20c
|
|||
|
d226c32b9b
|
|||
|
86786c9d05
|
|||
| d78d461236 | |||
| 49425635ce | |||
| 8ff029fcb7 | |||
| 462c44ce89 | |||
| 47159b9964 | |||
| 8e90f21d91 | |||
| 84c6912435 | |||
| 541b0df007 | |||
| 24c166b86b | |||
| d4380c0aea | |||
| 0ad6f4cb6a |
55
.claude/skills/bubbletea/SKILL.md
Normal file
55
.claude/skills/bubbletea/SKILL.md
Normal file
@@ -0,0 +1,55 @@
|
||||
---
|
||||
name: bubbletea
|
||||
description: Browse Bubbletea TUI framework documentation and examples. Use when working with Bubbletea components, models, commands, or building terminal user interfaces in Go.
|
||||
---
|
||||
|
||||
# Bubbletea Documentation
|
||||
|
||||
Bubbletea is a Go framework for building terminal user interfaces based on The Elm Architecture.
|
||||
|
||||
## Key Resources
|
||||
|
||||
When you need to understand Bubbletea patterns or find examples:
|
||||
|
||||
1. **Examples README** - Overview of all available examples:
|
||||
https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
|
||||
|
||||
2. **Examples Directory** - Full source code for all examples:
|
||||
https://github.com/charmbracelet/bubbletea/tree/main/examples
|
||||
|
||||
## How to Use
|
||||
|
||||
1. First, fetch the examples README to get an overview of available examples:
|
||||
```
|
||||
WebFetch https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
|
||||
```
|
||||
|
||||
2. Once you identify a relevant example, fetch its source code from the examples directory.
|
||||
|
||||
## Common Examples to Reference
|
||||
|
||||
- `list` - List component with filtering
|
||||
- `table` - Table component
|
||||
- `textinput` - Text input handling
|
||||
- `textarea` - Multi-line text input
|
||||
- `viewport` - Scrollable content
|
||||
- `paginator` - Pagination
|
||||
- `spinner` - Loading spinners
|
||||
- `progress` - Progress bars
|
||||
- `tabs` - Tab navigation
|
||||
- `help` - Help text/keybindings display
|
||||
|
||||
## Core Concepts
|
||||
|
||||
- **Model**: Application state
|
||||
- **Update**: Handles messages and returns updated model + commands
|
||||
- **View**: Renders the model to a string
|
||||
- **Cmd**: Side effects that produce messages
|
||||
- **Msg**: Events that trigger updates
|
||||
|
||||
## Related Charm Libraries
|
||||
|
||||
- **Bubbles**: Pre-built components (github.com/charmbracelet/bubbles)
|
||||
- **Lipgloss**: Styling and layout (github.com/charmbracelet/lipgloss)
|
||||
- **Glamour**: Markdown rendering (github.com/charmbracelet/glamour)
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -4,3 +4,5 @@ oubliette.toml
|
||||
*.db-wal
|
||||
*.db-shm
|
||||
/oubliette
|
||||
*.mmdb
|
||||
*.mmdb.gz
|
||||
|
||||
79
.golangci.yml
Normal file
79
.golangci.yml
Normal file
@@ -0,0 +1,79 @@
|
||||
version: "2"
|
||||
|
||||
linters:
|
||||
enable:
|
||||
# Bug detectors.
|
||||
- bodyclose
|
||||
- durationcheck
|
||||
- errorlint
|
||||
- gocritic
|
||||
- nilerr
|
||||
- sqlclosecheck
|
||||
|
||||
# Security.
|
||||
- gosec
|
||||
|
||||
# Style and modernization.
|
||||
- misspell
|
||||
- modernize
|
||||
- unconvert
|
||||
- usestdlibvars
|
||||
|
||||
# Logging.
|
||||
- sloglint
|
||||
|
||||
# Dead code.
|
||||
- wastedassign
|
||||
|
||||
settings:
|
||||
errcheck:
|
||||
exclude-functions:
|
||||
# Terminal I/O writes (honeypot shell output).
|
||||
- fmt.Fprint
|
||||
- fmt.Fprintf
|
||||
# Low-level byte I/O in shell readLine (escape sequences, echo).
|
||||
- (io.ReadWriter).Read
|
||||
- (io.ReadWriter).Write
|
||||
- (io.ReadWriteCloser).Read
|
||||
- (io.ReadWriteCloser).Write
|
||||
- (io.Reader).Read
|
||||
- (io.Writer).Write
|
||||
|
||||
gosec:
|
||||
excludes:
|
||||
# File reads from config paths — expected in a CLI tool.
|
||||
- G304
|
||||
# Weak RNG for shell selection — crypto/rand not needed.
|
||||
- G404
|
||||
|
||||
exclusions:
|
||||
rules:
|
||||
# Ignore unchecked Close() — standard resource cleanup.
|
||||
- linters: [errcheck]
|
||||
text: "Error return value of .+\\.Close.+ is not checked"
|
||||
|
||||
# Ignore unchecked Rollback() — called in error paths before returning.
|
||||
- linters: [errcheck]
|
||||
text: "Error return value of .+\\.Rollback.+ is not checked"
|
||||
|
||||
# Ignore unchecked Reply/Reject — SSH protocol; nothing useful on failure.
|
||||
- linters: [errcheck]
|
||||
text: "Error return value of .+\\.(Reply|Reject).+ is not checked"
|
||||
|
||||
# Test files: allow unchecked errors.
|
||||
- linters: [errcheck]
|
||||
path: "_test\\.go"
|
||||
|
||||
# Test files: InsecureIgnoreHostKey, file permissions, unhandled errors are expected.
|
||||
- linters: [gosec]
|
||||
path: "_test\\.go"
|
||||
|
||||
# Unhandled errors for cleanup/protocol ops — mirrors errcheck exclusions.
|
||||
- linters: [gosec]
|
||||
text: "G104"
|
||||
source: "\\.(Close|Rollback|Reject|Reply|Read|Write)\\("
|
||||
|
||||
# SQL with safe column interpolation from a fixed switch — not user input.
|
||||
- linters: [gosec]
|
||||
text: "G201"
|
||||
path: "internal/storage/"
|
||||
97
PLAN.md
97
PLAN.md
@@ -117,19 +117,19 @@ This lets shells build realistic prompts (`username@hostname:~$`) and log activi
|
||||
|
||||
Goal: Detect likely-human sessions and make the system smarter.
|
||||
|
||||
### 2.1 Human Detection Scoring
|
||||
### 2.1 Human Detection Scoring ✅
|
||||
- Keystroke timing analysis
|
||||
- Track backspace, tab, arrow key usage
|
||||
- Command diversity scoring
|
||||
- Compute per-session human score, store in sessions table
|
||||
- Flag sessions above configurable threshold
|
||||
|
||||
### 2.2 Notifications
|
||||
### 2.2 Notifications ✅
|
||||
- Webhook support (generic HTTP POST, works with Slack/Discord/ntfy)
|
||||
- Trigger on: human score threshold crossed, new session started, configurable
|
||||
- Include session details in payload
|
||||
|
||||
### 2.3 Session Replay
|
||||
### 2.3 Session Replay ✅
|
||||
- Store keystroke-by-keystroke data with timing information
|
||||
- Web UI: replay a session in a terminal-like viewer, watching commands play back in real-time
|
||||
- Filter/sort sessions by human score
|
||||
@@ -150,26 +150,41 @@ Goal: Add the entertaining shell implementations.
|
||||
- **Haunted:** commands gradually return stranger output, files appear/disappear, `whoami` returns different users
|
||||
- **Bread crumbs:** fake .bash_history, id_rsa files, database configs pointing to other honeypots
|
||||
|
||||
### 3.2 Cisco IOS Shell
|
||||
### 3.2 Cisco IOS Shell ✅
|
||||
- Realistic `>` and `#` prompts
|
||||
- Common commands: `show running-config`, `show interfaces`, `enable`, `configure terminal`
|
||||
- Fake device info that looks like a real router
|
||||
|
||||
### 3.3 Smart Fridge Shell
|
||||
### 3.3 Smart Fridge Shell ✅
|
||||
- Samsung FridgeOS boot banner
|
||||
- Inventory management commands
|
||||
- Temperature warnings
|
||||
- "WARNING: milk expires in 2 days"
|
||||
- Easter eggs
|
||||
- Per-credential shell routing via `shell` field in static credentials
|
||||
|
||||
### 3.4 Text Adventure
|
||||
### 3.4 Text Adventure ✅
|
||||
- Zork-style dungeon crawler
|
||||
- "You are in a dimly lit server room."
|
||||
- Navigation, items, puzzles
|
||||
- The dungeon is the oubliette itself
|
||||
|
||||
### 3.5 Other Shell Ideas (Future)
|
||||
- **Banking TUI:** 80s-style green-on-black bank terminal
|
||||
### 3.5 Banking TUI Shell ✅
|
||||
- 80s-style green-on-black bank terminal
|
||||
|
||||
### 3.6 PostgreSQL psql Shell ✅
|
||||
- Simulates psql interactive terminal with `db_name` and `pg_version` config
|
||||
- Backslash meta-commands: `\q`, `\dt`, `\d <table>`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
|
||||
- SQL statement handling with multi-line buffering (semicolon-terminated)
|
||||
- Canned responses for common queries (SELECT version(), current_database(), etc.)
|
||||
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
|
||||
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
|
||||
|
||||
### 3.7 Roomba Shell ✅
|
||||
- iRobot Roomba j7+ vacuum robot interface
|
||||
- Status, cleaning, scheduling, diagnostics, floor map
|
||||
- Humorous history entries (cat encounters, sock tangles, sticky substances)
|
||||
|
||||
### 3.8 Other Shell Ideas (Future)
|
||||
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
|
||||
- **ELIZA therapist:** every response is a therapy question
|
||||
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
|
||||
@@ -181,19 +196,55 @@ Goal: Add the entertaining shell implementations.
|
||||
|
||||
Goal: Make the web UI great and add operational niceties.
|
||||
|
||||
### 4.1 Enhanced Web UI
|
||||
- GeoIP lookups and world map visualization of attack sources
|
||||
- Charts: attempts over time, hourly patterns, credential trends
|
||||
- Session detail view with full command log
|
||||
- Filtering and search
|
||||
### 4.1 Enhanced Web UI ✅
|
||||
- GeoIP lookups and world map visualization of attack sources ✅
|
||||
- Charts: attempts over time, hourly patterns, credential trends ✅
|
||||
- Session detail view with full command log ✅
|
||||
- Filtering and search ✅
|
||||
|
||||
### 4.2 Operational
|
||||
- Prometheus metrics endpoint
|
||||
- Structured logging (slog)
|
||||
- Graceful shutdown
|
||||
- Systemd unit file / deployment docs
|
||||
### 4.2 Operational ✅
|
||||
- Prometheus metrics endpoint ✅
|
||||
- Structured logging (slog) ✅
|
||||
- Graceful shutdown ✅
|
||||
- Docker image (nix dockerTools) ✅
|
||||
- Systemd unit file / deployment docs ✅
|
||||
|
||||
### 4.3 GeoIP
|
||||
- Embed a lightweight GeoIP database or use an API
|
||||
- Store country/city with each attempt
|
||||
- Aggregate stats by country
|
||||
### 4.3 GeoIP ✅
|
||||
- Embed a lightweight GeoIP database or use an API ✅
|
||||
- Store country/city with each attempt ✅
|
||||
- Aggregate stats by country ✅
|
||||
|
||||
### 4.4 Capture SSH Exec Commands ✅
|
||||
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
|
||||
|
||||
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`) ✅
|
||||
- Parse the command string from the exec payload ✅
|
||||
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration ✅
|
||||
- Store the command on the session record before closing the channel ✅
|
||||
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
|
||||
- Surface exec commands in the web UI (session detail view) ✅
|
||||
|
||||
#### 4.4.1 Fake Exec Output
|
||||
Return plausible fake output for exec commands to encourage bots to interact further.
|
||||
|
||||
**Approach: regex-based output assembly.** Bots typically send a single long command that chains recon commands and then echoes a summary (e.g. `echo "UNAME:$uname"`). Rather than interpreting arbitrary shell pipelines, we scan the command string for known patterns and assemble fake output.
|
||||
|
||||
Implementation:
|
||||
- A map of common command/variable patterns to fake output strings, e.g.:
|
||||
- `uname -a` / `uname -s -v -n -m` → `"Linux ubuntu-server 5.15.0-91-generic #101-Ubuntu SMP Tue Jan 2 15:13:10 UTC 2024 x86_64"`
|
||||
- `uname -m` / `arch` → `"x86_64"`
|
||||
- `cat /proc/uptime` → `"86432.71 172801.55"`
|
||||
- `nproc` / `grep -c "^processor" /proc/cpuinfo` → `"2"`
|
||||
- `cat /proc/cpuinfo` → fake cpuinfo block
|
||||
- `lspci` → empty (no GPU — discourages cryptominer targeting)
|
||||
- `id` → `"uid=0(root) gid=0(root) groups=0(root)"`
|
||||
- `cat /etc/passwd` → minimal fake passwd file
|
||||
- `last` → fake login entries
|
||||
- `cat --help`, `ls --help` → canned GNU coreutils help text
|
||||
- Scan the exec command for `echo "KEY:$var"` patterns; for each key, look up the corresponding fake value from the variable assignment earlier in the command
|
||||
- If we recognise echo patterns, assemble and return the expected output
|
||||
- If we don't recognise the command at all, return empty output with exit 0 (current behaviour)
|
||||
- Values should draw from the existing shell config where possible (hostname, fake_user) for consistency
|
||||
- New package `internal/execfake` or a file in `internal/server/` — keep it simple
|
||||
|
||||
Gather more real-world bot examples before implementing to ensure good coverage of common recon patterns.
|
||||
|
||||
33
README.md
33
README.md
@@ -33,7 +33,9 @@ Key settings:
|
||||
- `ssh.host_key_path` — Ed25519 host key, auto-generated if missing
|
||||
- `auth.accept_after` — accept login after N failures per IP (default `10`)
|
||||
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
|
||||
- `auth.static_credentials` — always-accepted username/password pairs
|
||||
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
|
||||
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
|
||||
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
|
||||
- `storage.db_path` — SQLite database path (default `oubliette.db`)
|
||||
- `storage.retention_days` — auto-prune records older than N days (default `90`)
|
||||
- `storage.retention_interval` — how often to run retention (default `1h`)
|
||||
@@ -42,6 +44,20 @@ Key settings:
|
||||
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
|
||||
- `web.enabled` — enable the web dashboard (default `false`)
|
||||
- `web.listen_addr` — web dashboard listen address (default `:8080`)
|
||||
- Dashboard includes Chart.js charts (attempts over time, hourly pattern), an SVG world map choropleth colored by attack origin, and filter controls for date range / IP / country / username
|
||||
- Session detail pages at `/sessions/{id}` include terminal replay via xterm.js
|
||||
- `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`)
|
||||
- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty)
|
||||
- `detection.enabled` — enable human detection scoring (default `false`)
|
||||
- `detection.threshold` — score threshold (0.0–1.0) for flagging sessions (default `0.6`)
|
||||
- `detection.update_interval` — how often to recompute scores (default `5s`)
|
||||
- `notify.webhooks` — list of webhook endpoints for notifications (see example config)
|
||||
|
||||
### GeoIP
|
||||
|
||||
Country-level GeoIP lookups are embedded in the binary using the [DB-IP Lite](https://db-ip.com/db/lite.php) database (CC-BY-4.0). The dashboard shows country alongside IPs and includes a "Top Countries" table.
|
||||
|
||||
For local development, run `scripts/fetch-geoip.sh` to download the MMDB file. The Nix build fetches it automatically.
|
||||
|
||||
### Run
|
||||
|
||||
@@ -55,6 +71,9 @@ Test with:
|
||||
ssh -o StrictHostKeyChecking=no -p 2222 root@localhost
|
||||
```
|
||||
|
||||
SSH exec commands (`ssh user@host <command>`) are also captured and stored on the session record.
|
||||
|
||||
|
||||
### NixOS Module
|
||||
|
||||
Add the flake as an input and enable the service:
|
||||
@@ -76,3 +95,15 @@ Add the flake as an input and enable the service:
|
||||
```
|
||||
|
||||
Alternatively, use `configFile` to pass a pre-written TOML file instead of `settings`.
|
||||
|
||||
### Docker
|
||||
|
||||
Build a Docker image via nix:
|
||||
|
||||
```sh
|
||||
nix build .#dockerImage
|
||||
docker load < result
|
||||
docker run -v /path/to/data:/data -p 2222:2222 -p 8080:8080 oubliette:0.8.0
|
||||
```
|
||||
|
||||
Place your `oubliette.toml` in the data volume. The container exposes ports 2222 (SSH) and 8080 (web/metrics).
|
||||
|
||||
@@ -4,29 +4,38 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/server"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"git.t-juice.club/torjus/oubliette/internal/web"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/server"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/web"
|
||||
)
|
||||
|
||||
const Version = "0.2.0"
|
||||
const Version = "0.18.0"
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
slog.Error("fatal error", "err", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
configPath := flag.String("config", "oubliette.toml", "path to config file")
|
||||
flag.Parse()
|
||||
|
||||
cfg, err := config.Load(*configPath)
|
||||
if err != nil {
|
||||
slog.Error("failed to load config", "err", err)
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
level := new(slog.LevelVar)
|
||||
@@ -53,45 +62,57 @@ func main() {
|
||||
|
||||
store, err := storage.NewSQLiteStore(cfg.Storage.DBPath)
|
||||
if err != nil {
|
||||
logger.Error("failed to open database", "err", err)
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("open database: %w", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
// Clean up sessions left active by a previous unclean shutdown.
|
||||
if n, err := store.CloseActiveSessions(context.Background(), time.Now()); err != nil {
|
||||
return fmt.Errorf("close stale sessions: %w", err)
|
||||
} else if n > 0 {
|
||||
logger.Info("closed stale sessions from previous run", "count", n)
|
||||
}
|
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
m := metrics.New(Version)
|
||||
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
|
||||
m.RegisterStoreCollector(instrumentedStore)
|
||||
|
||||
srv, err := server.New(*cfg, store, logger)
|
||||
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
|
||||
srv, err := server.New(*cfg, instrumentedStore, logger, m)
|
||||
if err != nil {
|
||||
logger.Error("failed to create server", "err", err)
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("create server: %w", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start web server if enabled.
|
||||
if cfg.Web.Enabled {
|
||||
webHandler, err := web.NewServer(store, logger.With("component", "web"))
|
||||
var metricsHandler http.Handler
|
||||
if *cfg.Web.MetricsEnabled {
|
||||
metricsHandler = m.Handler()
|
||||
}
|
||||
|
||||
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||
if err != nil {
|
||||
logger.Error("failed to create web server", "err", err)
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("create web server: %w", err)
|
||||
}
|
||||
|
||||
httpServer := &http.Server{
|
||||
Addr: cfg.Web.ListenAddr,
|
||||
Handler: webHandler,
|
||||
Addr: cfg.Web.ListenAddr,
|
||||
Handler: webHandler,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
logger.Info("web server listening", "addr", cfg.Web.ListenAddr)
|
||||
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("web server error", "err", err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
|
||||
// Graceful shutdown on context cancellation.
|
||||
go func() {
|
||||
@@ -103,10 +124,10 @@ func main() {
|
||||
}
|
||||
|
||||
if err := srv.ListenAndServe(ctx); err != nil {
|
||||
logger.Error("server error", "err", err)
|
||||
os.Exit(1)
|
||||
return fmt.Errorf("server: %w", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
logger.Info("server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
27
flake.nix
27
flake.nix
@@ -18,19 +18,44 @@
|
||||
pkgs = nixpkgs.legacyPackages.${system};
|
||||
mainGo = builtins.readFile ./cmd/oubliette/main.go;
|
||||
version = builtins.head (builtins.match ''.*const Version = "([^"]+)".*'' mainGo);
|
||||
geoipDb = pkgs.fetchurl {
|
||||
url = "https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz";
|
||||
hash = "sha256-xmQZEJZ5WzE9uQww1Sdb8248l+liYw46tjbfJeu945Q=";
|
||||
};
|
||||
in
|
||||
{
|
||||
default = pkgs.buildGoModule {
|
||||
pname = "oubliette";
|
||||
inherit version;
|
||||
src = ./.;
|
||||
vendorHash = "sha256-EbJ90e4Jco7CvYYJLrewFLD5XF+Wv6TsT8RRLcj+ijU=";
|
||||
vendorHash = "sha256-/zxK6CABLYBNtuSOI8dIVgMNxKiDIcbZUS7bQR5TenA=";
|
||||
subPackages = [ "cmd/oubliette" ];
|
||||
nativeBuildInputs = [ pkgs.gzip ];
|
||||
preBuild = ''
|
||||
gunzip -c ${geoipDb} > internal/geoip/dbip-country-lite.mmdb
|
||||
'';
|
||||
meta = {
|
||||
description = "SSH honeypot";
|
||||
mainProgram = "oubliette";
|
||||
};
|
||||
};
|
||||
|
||||
dockerImage = pkgs.dockerTools.buildLayeredImage {
|
||||
name = "oubliette";
|
||||
tag = version;
|
||||
contents = [ self.packages.${system}.default ];
|
||||
config = {
|
||||
Entrypoint = [ "/bin/oubliette" ];
|
||||
Cmd = [ "-config" "/data/oubliette.toml" ];
|
||||
ExposedPorts = {
|
||||
"2222/tcp" = {};
|
||||
"8080/tcp" = {};
|
||||
};
|
||||
Volumes = {
|
||||
"/data" = {};
|
||||
};
|
||||
};
|
||||
};
|
||||
});
|
||||
|
||||
devShells = forAllSystems (system:
|
||||
|
||||
30
go.mod
30
go.mod
@@ -1,21 +1,49 @@
|
||||
module git.t-juice.club/torjus/oubliette
|
||||
module code.t-juice.club/torjus/oubliette
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/BurntSushi/toml v1.6.0
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/oschwald/maxminddb-golang v1.13.1
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.45.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
|
||||
94
go.sum
94
go.sum
@@ -1,34 +1,116 @@
|
||||
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
|
||||
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
|
||||
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,6 +21,7 @@ type credKey struct {
|
||||
type Decision struct {
|
||||
Accepted bool
|
||||
Reason string // "static_credential", "threshold_reached", "remembered_credential", "rejected"
|
||||
Shell string // optional: route to specific shell (only set for static credentials)
|
||||
}
|
||||
|
||||
type Authenticator struct {
|
||||
@@ -50,7 +51,7 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
|
||||
pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password))
|
||||
if uMatch == 1 && pMatch == 1 {
|
||||
a.failCounts[ip] = 0
|
||||
return Decision{Accepted: true, Reason: "static_credential"}
|
||||
return Decision{Accepted: true, Reason: "static_credential", Shell: cred.Shell}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {
|
||||
@@ -36,7 +36,7 @@ func TestStaticCredentialsWrongPassword(t *testing.T) {
|
||||
|
||||
func TestRejectionBeforeThreshold(t *testing.T) {
|
||||
a := newTestAuth(3, time.Hour)
|
||||
for i := 0; i < 2; i++ {
|
||||
for i := range 2 {
|
||||
d := a.Authenticate("1.2.3.4", "user", "pass")
|
||||
if d.Accepted {
|
||||
t.Fatalf("attempt %d should be rejected", i+1)
|
||||
@@ -49,7 +49,7 @@ func TestRejectionBeforeThreshold(t *testing.T) {
|
||||
|
||||
func TestThresholdAcceptance(t *testing.T) {
|
||||
a := newTestAuth(3, time.Hour)
|
||||
for i := 0; i < 2; i++ {
|
||||
for i := range 2 {
|
||||
d := a.Authenticate("1.2.3.4", "user", "pass")
|
||||
if d.Accepted {
|
||||
t.Fatalf("attempt %d should be rejected", i+1)
|
||||
@@ -65,7 +65,7 @@ func TestPerIPIsolation(t *testing.T) {
|
||||
a := newTestAuth(3, time.Hour)
|
||||
|
||||
// IP1 gets 2 failures.
|
||||
for i := 0; i < 2; i++ {
|
||||
for range 2 {
|
||||
a.Authenticate("1.1.1.1", "user", "pass")
|
||||
}
|
||||
|
||||
@@ -153,16 +153,47 @@ func TestExpiredCredentialsSweep(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticCredentialShellPropagation(t *testing.T) {
|
||||
a := newTestAuth(10, time.Hour,
|
||||
config.Credential{Username: "samsung", Password: "fridge", Shell: "fridge"},
|
||||
config.Credential{Username: "root", Password: "toor"},
|
||||
)
|
||||
|
||||
// Static credential with shell set should propagate it.
|
||||
d := a.Authenticate("1.2.3.4", "samsung", "fridge")
|
||||
if !d.Accepted || d.Reason != "static_credential" {
|
||||
t.Fatalf("got %+v, want accepted with static_credential", d)
|
||||
}
|
||||
if d.Shell != "fridge" {
|
||||
t.Errorf("Shell = %q, want %q", d.Shell, "fridge")
|
||||
}
|
||||
|
||||
// Static credential without shell should leave it empty.
|
||||
d = a.Authenticate("1.2.3.4", "root", "toor")
|
||||
if !d.Accepted || d.Reason != "static_credential" {
|
||||
t.Fatalf("got %+v, want accepted with static_credential", d)
|
||||
}
|
||||
if d.Shell != "" {
|
||||
t.Errorf("Shell = %q, want empty", d.Shell)
|
||||
}
|
||||
|
||||
// Threshold-reached decision should not have a shell set.
|
||||
a2 := newTestAuth(2, time.Hour)
|
||||
a2.Authenticate("5.5.5.5", "user", "pass")
|
||||
d = a2.Authenticate("5.5.5.5", "user", "pass")
|
||||
if d.Shell != "" {
|
||||
t.Errorf("threshold decision Shell = %q, want empty", d.Shell)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
a := newTestAuth(5, time.Hour)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 100 {
|
||||
wg.Go(func() {
|
||||
a.Authenticate("1.2.3.4", "user", "pass")
|
||||
}()
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -9,25 +9,30 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
SSH SSHConfig `toml:"ssh"`
|
||||
Auth AuthConfig `toml:"auth"`
|
||||
Storage StorageConfig `toml:"storage"`
|
||||
Shell ShellConfig `toml:"shell"`
|
||||
Web WebConfig `toml:"web"`
|
||||
LogLevel string `toml:"log_level"`
|
||||
LogFormat string `toml:"log_format"` // "text" (default) or "json"
|
||||
SSH SSHConfig `toml:"ssh"`
|
||||
Auth AuthConfig `toml:"auth"`
|
||||
Storage StorageConfig `toml:"storage"`
|
||||
Shell ShellConfig `toml:"shell"`
|
||||
Web WebConfig `toml:"web"`
|
||||
Detection DetectionConfig `toml:"detection"`
|
||||
Notify NotifyConfig `toml:"notify"`
|
||||
LogLevel string `toml:"log_level"`
|
||||
LogFormat string `toml:"log_format"` // "text" (default) or "json"
|
||||
}
|
||||
|
||||
type WebConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
Enabled bool `toml:"enabled"`
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
MetricsEnabled *bool `toml:"metrics_enabled"`
|
||||
MetricsToken string `toml:"metrics_token"`
|
||||
}
|
||||
|
||||
type ShellConfig struct {
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
UsernameRoutes map[string]string `toml:"username_routes"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
@@ -57,6 +62,26 @@ type AuthConfig struct {
|
||||
type Credential struct {
|
||||
Username string `toml:"username"`
|
||||
Password string `toml:"password"`
|
||||
Shell string `toml:"shell"` // optional: route to specific shell (empty = random)
|
||||
}
|
||||
|
||||
type DetectionConfig struct {
|
||||
Enabled bool `toml:"enabled"`
|
||||
Threshold float64 `toml:"threshold"`
|
||||
UpdateInterval string `toml:"update_interval"`
|
||||
|
||||
// Parsed duration, not from TOML directly.
|
||||
UpdateIntervalDuration time.Duration `toml:"-"`
|
||||
}
|
||||
|
||||
type NotifyConfig struct {
|
||||
Webhooks []WebhookNotifyConfig `toml:"webhooks"`
|
||||
}
|
||||
|
||||
type WebhookNotifyConfig struct {
|
||||
URL string `toml:"url"`
|
||||
Headers map[string]string `toml:"headers"`
|
||||
Events []string `toml:"events"` // empty = all events
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
@@ -121,19 +146,30 @@ func applyDefaults(cfg *Config) {
|
||||
if cfg.Web.ListenAddr == "" {
|
||||
cfg.Web.ListenAddr = ":8080"
|
||||
}
|
||||
if cfg.Web.MetricsEnabled == nil {
|
||||
t := true
|
||||
cfg.Web.MetricsEnabled = &t
|
||||
}
|
||||
if cfg.Shell.Hostname == "" {
|
||||
cfg.Shell.Hostname = "ubuntu-server"
|
||||
}
|
||||
if cfg.Shell.Banner == "" {
|
||||
cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
|
||||
}
|
||||
if cfg.Detection.Threshold == 0 {
|
||||
cfg.Detection.Threshold = 0.6
|
||||
}
|
||||
if cfg.Detection.UpdateInterval == "" {
|
||||
cfg.Detection.UpdateInterval = "5s"
|
||||
}
|
||||
}
|
||||
|
||||
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
|
||||
var knownShellKeys = map[string]bool{
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
"username_routes": true,
|
||||
}
|
||||
|
||||
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
|
||||
@@ -189,5 +225,33 @@ func validate(cfg *Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate detection config.
|
||||
if cfg.Detection.Enabled {
|
||||
if cfg.Detection.Threshold < 0 || cfg.Detection.Threshold > 1 {
|
||||
return fmt.Errorf("detection.threshold must be between 0 and 1, got %f", cfg.Detection.Threshold)
|
||||
}
|
||||
ui, err := time.ParseDuration(cfg.Detection.UpdateInterval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid detection.update_interval %q: %w", cfg.Detection.UpdateInterval, err)
|
||||
}
|
||||
if ui <= 0 {
|
||||
return fmt.Errorf("detection.update_interval must be positive, got %s", ui)
|
||||
}
|
||||
cfg.Detection.UpdateIntervalDuration = ui
|
||||
}
|
||||
|
||||
// Validate notify config.
|
||||
knownEvents := map[string]bool{"human_detected": true, "session_started": true}
|
||||
for i, wh := range cfg.Notify.Webhooks {
|
||||
if wh.URL == "" {
|
||||
return fmt.Errorf("notify.webhooks[%d]: url must not be empty", i)
|
||||
}
|
||||
for j, ev := range wh.Events {
|
||||
if !knownEvents[ev] {
|
||||
return fmt.Errorf("notify.webhooks[%d].events[%d]: unknown event %q", i, j, ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -255,6 +255,49 @@ listen_addr = ":9090"
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadCredentialWithShell(t *testing.T) {
|
||||
content := `
|
||||
[[auth.static_credentials]]
|
||||
username = "samsung"
|
||||
password = "fridge"
|
||||
shell = "fridge"
|
||||
|
||||
[[auth.static_credentials]]
|
||||
username = "root"
|
||||
password = "toor"
|
||||
`
|
||||
path := writeTemp(t, content)
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(cfg.Auth.StaticCredentials) != 2 {
|
||||
t.Fatalf("static_credentials len = %d, want 2", len(cfg.Auth.StaticCredentials))
|
||||
}
|
||||
if cfg.Auth.StaticCredentials[0].Shell != "fridge" {
|
||||
t.Errorf("cred[0].Shell = %q, want %q", cfg.Auth.StaticCredentials[0].Shell, "fridge")
|
||||
}
|
||||
if cfg.Auth.StaticCredentials[1].Shell != "" {
|
||||
t.Errorf("cred[1].Shell = %q, want empty", cfg.Auth.StaticCredentials[1].Shell)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMetricsToken(t *testing.T) {
|
||||
content := `
|
||||
[web]
|
||||
enabled = true
|
||||
metrics_token = "my-secret-token"
|
||||
`
|
||||
path := writeTemp(t, content)
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Web.MetricsToken != "my-secret-token" {
|
||||
t.Errorf("metrics_token = %q, want %q", cfg.Web.MetricsToken, "my-secret-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path/config.toml")
|
||||
if err == nil {
|
||||
@@ -270,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadUsernameRoutes(t *testing.T) {
|
||||
content := `
|
||||
[shell]
|
||||
hostname = "myhost"
|
||||
|
||||
[shell.username_routes]
|
||||
postgres = "psql"
|
||||
admin = "bash"
|
||||
|
||||
[shell.bash]
|
||||
custom_key = "value"
|
||||
`
|
||||
path := writeTemp(t, content)
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes == nil {
|
||||
t.Fatal("UsernameRoutes should not be nil")
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
|
||||
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
|
||||
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
|
||||
}
|
||||
// username_routes should NOT appear in the Shells map.
|
||||
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
|
||||
t.Error("username_routes should not appear in Shells map")
|
||||
}
|
||||
// bash should still appear in Shells map.
|
||||
if _, ok := cfg.Shell.Shells["bash"]; !ok {
|
||||
t.Error("Shells[\"bash\"] should still be present")
|
||||
}
|
||||
}
|
||||
|
||||
func writeTemp(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
path := filepath.Join(t.TempDir(), "config.toml")
|
||||
|
||||
259
internal/detection/scorer.go
Normal file
259
internal/detection/scorer.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package detection
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Direction constants for RecordEvent.
|
||||
const (
|
||||
DirInput = 0 // client → server (keystrokes)
|
||||
DirOutput = 1 // server → client (shell output)
|
||||
)
|
||||
|
||||
// Signal weights for the composite score.
|
||||
const (
|
||||
weightTimingVariance = 0.30
|
||||
weightSpecialKeys = 0.20
|
||||
weightTypingSpeed = 0.20
|
||||
weightCommandDiversity = 0.15
|
||||
weightSessionDuration = 0.15
|
||||
)
|
||||
|
||||
// Scorer accumulates keystroke events and computes a 0.0–1.0
|
||||
// human likelihood score based on multiple signals.
|
||||
type Scorer struct {
|
||||
mu sync.Mutex
|
||||
|
||||
// Input timing data.
|
||||
inputTimes []time.Time
|
||||
delays []time.Duration
|
||||
|
||||
// Special key counters.
|
||||
specialKeys int
|
||||
|
||||
// Command tracking: we count newlines and unique command prefixes.
|
||||
currentCmd []byte
|
||||
commands map[string]struct{}
|
||||
|
||||
// Session activity duration.
|
||||
firstInput time.Time
|
||||
lastInput time.Time
|
||||
}
|
||||
|
||||
// NewScorer returns a new Scorer ready to record events.
|
||||
func NewScorer() *Scorer {
|
||||
return &Scorer{
|
||||
commands: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordEvent records a data event with timestamp and direction.
|
||||
// direction should be DirInput (0) for client input or DirOutput (1) for server output.
|
||||
func (s *Scorer) RecordEvent(ts time.Time, direction int, data []byte) {
|
||||
if direction != DirInput {
|
||||
return // only analyze input
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.firstInput.IsZero() {
|
||||
s.firstInput = ts
|
||||
}
|
||||
s.lastInput = ts
|
||||
|
||||
for _, b := range data {
|
||||
// Track inter-keystroke delay for single-byte inputs.
|
||||
if len(s.inputTimes) > 0 {
|
||||
delay := ts.Sub(s.inputTimes[len(s.inputTimes)-1])
|
||||
if delay > 0 && delay < 30*time.Second {
|
||||
s.delays = append(s.delays, delay)
|
||||
}
|
||||
}
|
||||
s.inputTimes = append(s.inputTimes, ts)
|
||||
|
||||
// Count special keys.
|
||||
if isSpecialKey(b) {
|
||||
s.specialKeys++
|
||||
}
|
||||
|
||||
// Track commands (split on newline/CR).
|
||||
if b == '\r' || b == '\n' {
|
||||
cmd := string(s.currentCmd)
|
||||
if len(cmd) > 0 {
|
||||
s.commands[cmd] = struct{}{}
|
||||
}
|
||||
s.currentCmd = s.currentCmd[:0]
|
||||
} else {
|
||||
// Handle backspace: remove last byte from current command.
|
||||
if b == 0x7f || b == 0x08 {
|
||||
if len(s.currentCmd) > 0 {
|
||||
s.currentCmd = s.currentCmd[:len(s.currentCmd)-1]
|
||||
}
|
||||
} else if b >= 0x20 { // printable
|
||||
s.currentCmd = append(s.currentCmd, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Score computes the composite human likelihood score (0.0–1.0).
|
||||
// Thread-safe.
|
||||
func (s *Scorer) Score() float64 {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if len(s.inputTimes) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
tv := s.timingVarianceScore()
|
||||
sk := s.specialKeysScore()
|
||||
ts := s.typingSpeedScore()
|
||||
cd := s.commandDiversityScore()
|
||||
sd := s.sessionDurationScore()
|
||||
|
||||
score := tv*weightTimingVariance +
|
||||
sk*weightSpecialKeys +
|
||||
ts*weightTypingSpeed +
|
||||
cd*weightCommandDiversity +
|
||||
sd*weightSessionDuration
|
||||
|
||||
return clamp(score, 0, 1)
|
||||
}
|
||||
|
||||
// timingVarianceScore returns 0–1 based on coefficient of variation of inter-key delays.
|
||||
// Bots have CV ≈ 0 (instant or uniform), humans have CV ≥ 0.6.
|
||||
func (s *Scorer) timingVarianceScore() float64 {
|
||||
if len(s.delays) < 3 {
|
||||
return 0
|
||||
}
|
||||
|
||||
mean := meanDuration(s.delays)
|
||||
if mean == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
variance := 0.0
|
||||
for _, d := range s.delays {
|
||||
diff := float64(d) - float64(mean)
|
||||
variance += diff * diff
|
||||
}
|
||||
variance /= float64(len(s.delays))
|
||||
stddev := math.Sqrt(variance)
|
||||
cv := stddev / float64(mean)
|
||||
|
||||
// Map CV to 0–1: CV of 0.6+ is fully human-like.
|
||||
return clamp(cv/0.6, 0, 1)
|
||||
}
|
||||
|
||||
// specialKeysScore returns 0–1 based on count of special key presses.
|
||||
// Scripts almost never generate backspace/tab/ctrl characters.
|
||||
func (s *Scorer) specialKeysScore() float64 {
|
||||
// 5+ special keys → full score.
|
||||
return clamp(float64(s.specialKeys)/5.0, 0, 1)
|
||||
}
|
||||
|
||||
// typingSpeedScore returns 0–1 based on median inter-key delay.
|
||||
// Paste/scripts have < 5ms, humans have 30–300ms.
|
||||
func (s *Scorer) typingSpeedScore() float64 {
|
||||
if len(s.delays) < 2 {
|
||||
return 0
|
||||
}
|
||||
|
||||
med := medianDuration(s.delays)
|
||||
ms := float64(med) / float64(time.Millisecond)
|
||||
|
||||
if ms < 5 {
|
||||
return 0 // paste or script
|
||||
}
|
||||
if ms > 300 {
|
||||
return 0.7 // very slow, still possibly human
|
||||
}
|
||||
if ms >= 30 && ms <= 300 {
|
||||
return 1.0 // human range
|
||||
}
|
||||
// 5–30ms: transition zone
|
||||
return clamp((ms-5)/25, 0, 1)
|
||||
}
|
||||
|
||||
// commandDiversityScore returns 0–1 based on number of unique commands.
|
||||
func (s *Scorer) commandDiversityScore() float64 {
|
||||
// 3+ unique commands → full score.
|
||||
return clamp(float64(len(s.commands))/3.0, 0, 1)
|
||||
}
|
||||
|
||||
// sessionDurationScore returns 0–1 based on active input duration.
|
||||
func (s *Scorer) sessionDurationScore() float64 {
|
||||
if s.firstInput.IsZero() || s.lastInput.IsZero() {
|
||||
return 0
|
||||
}
|
||||
dur := s.lastInput.Sub(s.firstInput)
|
||||
// 10s+ of active input → full score.
|
||||
return clamp(float64(dur)/float64(10*time.Second), 0, 1)
|
||||
}
|
||||
|
||||
// isSpecialKey returns true for non-printable keys that humans commonly use.
|
||||
func isSpecialKey(b byte) bool {
|
||||
switch b {
|
||||
case 0x7f, // DEL (backspace in most terminals)
|
||||
0x08, // BS
|
||||
0x09, // TAB
|
||||
0x03, // Ctrl-C
|
||||
0x04, // Ctrl-D
|
||||
0x1b: // ESC (arrow keys start with ESC)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func clamp(v, lo, hi float64) float64 {
|
||||
if v < lo {
|
||||
return lo
|
||||
}
|
||||
if v > hi {
|
||||
return hi
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func meanDuration(ds []time.Duration) time.Duration {
|
||||
if len(ds) == 0 {
|
||||
return 0
|
||||
}
|
||||
var sum time.Duration
|
||||
for _, d := range ds {
|
||||
sum += d
|
||||
}
|
||||
return sum / time.Duration(len(ds))
|
||||
}
|
||||
|
||||
func medianDuration(ds []time.Duration) time.Duration {
|
||||
n := len(ds)
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
// Copy to avoid mutating the original.
|
||||
sorted := make([]time.Duration, n)
|
||||
copy(sorted, ds)
|
||||
sortDurations(sorted)
|
||||
if n%2 == 0 {
|
||||
return (sorted[n/2-1] + sorted[n/2]) / 2
|
||||
}
|
||||
return sorted[n/2]
|
||||
}
|
||||
|
||||
func sortDurations(ds []time.Duration) {
|
||||
// Simple insertion sort — delay slices are small.
|
||||
for i := 1; i < len(ds); i++ {
|
||||
key := ds[i]
|
||||
j := i - 1
|
||||
for j >= 0 && ds[j] > key {
|
||||
ds[j+1] = ds[j]
|
||||
j--
|
||||
}
|
||||
ds[j+1] = key
|
||||
}
|
||||
}
|
||||
151
internal/detection/scorer_test.go
Normal file
151
internal/detection/scorer_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package detection
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestScorer_EmptyInput(t *testing.T) {
|
||||
s := NewScorer()
|
||||
score := s.Score()
|
||||
if score != 0 {
|
||||
t.Errorf("empty scorer: got %f, want 0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorer_SingleKeystroke(t *testing.T) {
|
||||
s := NewScorer()
|
||||
s.RecordEvent(time.Now(), DirInput, []byte("a"))
|
||||
score := s.Score()
|
||||
if score != 0 {
|
||||
t.Errorf("single keystroke: got %f, want 0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorer_BotLikeInput(t *testing.T) {
|
||||
// Simulate a bot: paste entire commands with uniform tiny delays, no special keys.
|
||||
s := NewScorer()
|
||||
now := time.Now()
|
||||
|
||||
// Bot pastes "cat /etc/passwd\r" all at once with perfectly uniform timing.
|
||||
for range 3 {
|
||||
cmd := []byte("cat /etc/passwd\r")
|
||||
for _, b := range cmd {
|
||||
s.RecordEvent(now, DirInput, []byte{b})
|
||||
now = now.Add(100 * time.Microsecond) // ~0.1ms uniform delay = paste
|
||||
}
|
||||
}
|
||||
|
||||
score := s.Score()
|
||||
if score >= 0.3 {
|
||||
t.Errorf("bot-like input: got %f, want < 0.3", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorer_HumanLikeInput(t *testing.T) {
|
||||
// Simulate a human: variable timing, backspaces, diverse commands.
|
||||
s := NewScorer()
|
||||
now := time.Now()
|
||||
|
||||
type cmd struct {
|
||||
text string
|
||||
delay time.Duration // base delay between keys
|
||||
}
|
||||
|
||||
commands := []cmd{
|
||||
{"ls -la\r", 80 * time.Millisecond},
|
||||
{"cat /etc/paswd", 120 * time.Millisecond}, // typo
|
||||
{string([]byte{0x7f}), 200 * time.Millisecond}, // backspace
|
||||
{"wd\r", 90 * time.Millisecond}, // correction
|
||||
{"whoami\r", 100 * time.Millisecond},
|
||||
{"uname -a\r", 150 * time.Millisecond},
|
||||
{string([]byte{0x09}), 300 * time.Millisecond}, // tab completion
|
||||
{"pwd\r", 70 * time.Millisecond},
|
||||
}
|
||||
|
||||
for _, c := range commands {
|
||||
for _, b := range []byte(c.text) {
|
||||
// Add ±30% jitter to make timing more natural.
|
||||
jitter := time.Duration(float64(c.delay) * 0.3)
|
||||
delay := c.delay + jitter // simplified: always add, still variable across commands
|
||||
s.RecordEvent(now, DirInput, []byte{b})
|
||||
now = now.Add(delay)
|
||||
}
|
||||
// Pause between commands (thinking time).
|
||||
now = now.Add(2 * time.Second)
|
||||
}
|
||||
|
||||
score := s.Score()
|
||||
if score <= 0.6 {
|
||||
t.Errorf("human-like input: got %f, want > 0.6", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorer_OutputIgnored(t *testing.T) {
|
||||
s := NewScorer()
|
||||
now := time.Now()
|
||||
|
||||
// Only output events — should not affect score.
|
||||
for range 100 {
|
||||
s.RecordEvent(now, DirOutput, []byte("some output\n"))
|
||||
now = now.Add(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
score := s.Score()
|
||||
if score != 0 {
|
||||
t.Errorf("output-only: got %f, want 0", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorer_ThreadSafety(t *testing.T) {
|
||||
s := NewScorer()
|
||||
now := time.Now()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := range 10 {
|
||||
wg.Go(func() {
|
||||
for j := range 100 {
|
||||
ts := now.Add(time.Duration(i*100+j) * time.Millisecond)
|
||||
s.RecordEvent(ts, DirInput, []byte("a"))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrently read score.
|
||||
wg.Go(func() {
|
||||
for range 50 {
|
||||
_ = s.Score()
|
||||
}
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic; score should be valid.
|
||||
score := s.Score()
|
||||
if score < 0 || score > 1 {
|
||||
t.Errorf("concurrent score out of range: %f", score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScorer_CommandDiversity(t *testing.T) {
|
||||
s := NewScorer()
|
||||
now := time.Now()
|
||||
|
||||
// Type 4 different commands with human-ish timing.
|
||||
cmds := []string{"ls\r", "pwd\r", "id\r", "whoami\r"}
|
||||
for _, cmd := range cmds {
|
||||
for _, b := range []byte(cmd) {
|
||||
s.RecordEvent(now, DirInput, []byte{b})
|
||||
now = now.Add(100 * time.Millisecond)
|
||||
}
|
||||
now = now.Add(time.Second)
|
||||
}
|
||||
|
||||
score := s.Score()
|
||||
// With 4 unique commands, human timing, and decent duration,
|
||||
// we should get a meaningful score.
|
||||
if score < 0.4 {
|
||||
t.Errorf("diverse commands: got %f, want >= 0.4", score)
|
||||
}
|
||||
}
|
||||
51
internal/geoip/geoip.go
Normal file
51
internal/geoip/geoip.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package geoip
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
)
|
||||
|
||||
//go:embed dbip-country-lite.mmdb
|
||||
var mmdbData []byte
|
||||
|
||||
// Reader provides country-level GeoIP lookups using an embedded DB-IP Lite database.
|
||||
type Reader struct {
|
||||
db *maxminddb.Reader
|
||||
}
|
||||
|
||||
// New opens the embedded MMDB and returns a ready-to-use Reader.
|
||||
func New() (*Reader, error) {
|
||||
db, err := maxminddb.FromBytes(mmdbData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Reader{db: db}, nil
|
||||
}
|
||||
|
||||
type countryRecord struct {
|
||||
Country struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"country"`
|
||||
}
|
||||
|
||||
// Lookup returns the ISO 3166-1 alpha-2 country code for the given IP address,
|
||||
// or an empty string if the lookup fails or no result is found.
|
||||
func (r *Reader) Lookup(ipStr string) string {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var record countryRecord
|
||||
if err := r.db.Lookup(ip, &record); err != nil {
|
||||
return ""
|
||||
}
|
||||
return record.Country.ISOCode
|
||||
}
|
||||
|
||||
// Close releases resources held by the reader.
|
||||
func (r *Reader) Close() error {
|
||||
return r.db.Close()
|
||||
}
|
||||
44
internal/geoip/geoip_test.go
Normal file
44
internal/geoip/geoip_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package geoip
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestLookup(t *testing.T) {
|
||||
reader, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
want string
|
||||
}{
|
||||
{"8.8.8.8", "US"},
|
||||
{"1.1.1.1", "AU"},
|
||||
{"invalid", ""},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
got := reader.Lookup(tt.ip)
|
||||
if got != tt.want {
|
||||
t.Errorf("Lookup(%q) = %q, want %q", tt.ip, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupPrivateIP(t *testing.T) {
|
||||
reader, err := New()
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Private IPs should return empty string (no country).
|
||||
got := reader.Lookup("10.0.0.1")
|
||||
if got != "" {
|
||||
t.Errorf("Lookup(10.0.0.1) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
178
internal/metrics/metrics.go
Normal file
178
internal/metrics/metrics.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// Metrics holds all Prometheus collectors for the honeypot.
|
||||
type Metrics struct {
|
||||
registry *prometheus.Registry
|
||||
|
||||
SSHConnectionsTotal *prometheus.CounterVec
|
||||
SSHConnectionsActive prometheus.Gauge
|
||||
AuthAttemptsTotal *prometheus.CounterVec
|
||||
AuthAttemptsByCountry *prometheus.CounterVec
|
||||
CommandsExecuted *prometheus.CounterVec
|
||||
HumanScore prometheus.Histogram
|
||||
SessionsTotal *prometheus.CounterVec
|
||||
SessionsActive prometheus.Gauge
|
||||
SessionDuration prometheus.Histogram
|
||||
ExecCommandsTotal prometheus.Counter
|
||||
BuildInfo *prometheus.GaugeVec
|
||||
StorageQueryDuration *prometheus.HistogramVec
|
||||
StorageQueryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// New creates a new Metrics instance with all collectors registered.
|
||||
func New(version string) *Metrics {
|
||||
reg := prometheus.NewRegistry()
|
||||
|
||||
m := &Metrics{
|
||||
registry: reg,
|
||||
SSHConnectionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_ssh_connections_total",
|
||||
Help: "Total SSH connections received.",
|
||||
}, []string{"outcome"}),
|
||||
SSHConnectionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "oubliette_ssh_connections_active",
|
||||
Help: "Current active SSH connections.",
|
||||
}),
|
||||
AuthAttemptsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_auth_attempts_total",
|
||||
Help: "Total authentication attempts.",
|
||||
}, []string{"result", "reason"}),
|
||||
AuthAttemptsByCountry: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_auth_attempts_by_country_total",
|
||||
Help: "Total authentication attempts by country.",
|
||||
}, []string{"country"}),
|
||||
CommandsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_commands_executed_total",
|
||||
Help: "Total commands executed in shells.",
|
||||
}, []string{"shell"}),
|
||||
HumanScore: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "oubliette_human_score",
|
||||
Help: "Distribution of final human detection scores.",
|
||||
Buckets: prometheus.LinearBuckets(0, 0.1, 11), // 0.0, 0.1, ..., 1.0
|
||||
}),
|
||||
SessionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_sessions_total",
|
||||
Help: "Total sessions created.",
|
||||
}, []string{"shell"}),
|
||||
SessionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "oubliette_sessions_active",
|
||||
Help: "Current active sessions.",
|
||||
}),
|
||||
SessionDuration: prometheus.NewHistogram(prometheus.HistogramOpts{
|
||||
Name: "oubliette_session_duration_seconds",
|
||||
Help: "Session duration in seconds.",
|
||||
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600},
|
||||
}),
|
||||
ExecCommandsTotal: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Name: "oubliette_exec_commands_total",
|
||||
Help: "Total SSH exec commands received.",
|
||||
}),
|
||||
BuildInfo: prometheus.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Name: "oubliette_build_info",
|
||||
Help: "Build information. Always 1.",
|
||||
}, []string{"version"}),
|
||||
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "oubliette_storage_query_duration_seconds",
|
||||
Help: "Duration of storage query calls in seconds.",
|
||||
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
}, []string{"method"}),
|
||||
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_storage_query_errors_total",
|
||||
Help: "Total storage query errors.",
|
||||
}, []string{"method"}),
|
||||
}
|
||||
|
||||
reg.MustRegister(
|
||||
collectors.NewGoCollector(),
|
||||
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
|
||||
m.SSHConnectionsTotal,
|
||||
m.SSHConnectionsActive,
|
||||
m.AuthAttemptsTotal,
|
||||
m.AuthAttemptsByCountry,
|
||||
m.CommandsExecuted,
|
||||
m.HumanScore,
|
||||
m.SessionsTotal,
|
||||
m.SessionsActive,
|
||||
m.SessionDuration,
|
||||
m.ExecCommandsTotal,
|
||||
m.BuildInfo,
|
||||
m.StorageQueryDuration,
|
||||
m.StorageQueryErrors,
|
||||
)
|
||||
|
||||
m.BuildInfo.WithLabelValues(version).Set(1)
|
||||
|
||||
// Initialize label combinations so they appear in Gather/output.
|
||||
for _, outcome := range []string{"accepted", "rejected_handshake", "rejected_max_connections"} {
|
||||
m.SSHConnectionsTotal.WithLabelValues(outcome)
|
||||
}
|
||||
for _, reason := range []string{"static_credential", "remembered_credential", "threshold_reached", "rejected"} {
|
||||
m.AuthAttemptsTotal.WithLabelValues("accepted", reason)
|
||||
m.AuthAttemptsTotal.WithLabelValues("rejected", reason)
|
||||
}
|
||||
for _, sh := range []string{"bash", "fridge", "banking", "adventure", "cisco"} {
|
||||
m.SessionsTotal.WithLabelValues(sh)
|
||||
m.CommandsExecuted.WithLabelValues(sh)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// RegisterStoreCollector registers a collector that queries storage stats on each scrape.
|
||||
func (m *Metrics) RegisterStoreCollector(store storage.Store) {
|
||||
m.registry.MustRegister(&storeCollector{store: store})
|
||||
}
|
||||
|
||||
// Handler returns an http.Handler that serves Prometheus metrics.
|
||||
func (m *Metrics) Handler() http.Handler {
|
||||
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
|
||||
}
|
||||
|
||||
// storeCollector implements prometheus.Collector, querying storage on each scrape.
|
||||
type storeCollector struct {
|
||||
store storage.Store
|
||||
}
|
||||
|
||||
var (
|
||||
storageLoginAttemptsDesc = prometheus.NewDesc(
|
||||
"oubliette_storage_login_attempts_total",
|
||||
"Total login attempts in storage.",
|
||||
nil, nil,
|
||||
)
|
||||
storageUniqueIPsDesc = prometheus.NewDesc(
|
||||
"oubliette_storage_unique_ips",
|
||||
"Unique IPs in storage.",
|
||||
nil, nil,
|
||||
)
|
||||
storageSessionsDesc = prometheus.NewDesc(
|
||||
"oubliette_storage_sessions_total",
|
||||
"Total sessions in storage.",
|
||||
nil, nil,
|
||||
)
|
||||
)
|
||||
|
||||
func (c *storeCollector) Describe(ch chan<- *prometheus.Desc) {
|
||||
ch <- storageLoginAttemptsDesc
|
||||
ch <- storageUniqueIPsDesc
|
||||
ch <- storageSessionsDesc
|
||||
}
|
||||
|
||||
func (c *storeCollector) Collect(ch chan<- prometheus.Metric) {
|
||||
stats, err := c.store.GetDashboardStats(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ch <- prometheus.MustNewConstMetric(storageLoginAttemptsDesc, prometheus.GaugeValue, float64(stats.TotalAttempts))
|
||||
ch <- prometheus.MustNewConstMetric(storageUniqueIPsDesc, prometheus.GaugeValue, float64(stats.UniqueIPs))
|
||||
ch <- prometheus.MustNewConstMetric(storageSessionsDesc, prometheus.GaugeValue, float64(stats.TotalSessions))
|
||||
}
|
||||
142
internal/metrics/metrics_test.go
Normal file
142
internal/metrics/metrics_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
m := New("1.2.3")
|
||||
|
||||
// Gather all metrics and check expected names exist.
|
||||
families, err := m.registry.Gather()
|
||||
if err != nil {
|
||||
t.Fatalf("gather: %v", err)
|
||||
}
|
||||
|
||||
want := map[string]bool{
|
||||
"oubliette_ssh_connections_total": false,
|
||||
"oubliette_ssh_connections_active": false,
|
||||
"oubliette_auth_attempts_total": false,
|
||||
"oubliette_commands_executed_total": false,
|
||||
"oubliette_human_score": false,
|
||||
"oubliette_sessions_total": false,
|
||||
"oubliette_sessions_active": false,
|
||||
"oubliette_session_duration_seconds": false,
|
||||
"oubliette_build_info": false,
|
||||
}
|
||||
|
||||
for _, f := range families {
|
||||
if _, ok := want[f.GetName()]; ok {
|
||||
want[f.GetName()] = true
|
||||
}
|
||||
}
|
||||
|
||||
for name, found := range want {
|
||||
if !found {
|
||||
t.Errorf("metric %q not registered", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAttemptsByCountry(t *testing.T) {
|
||||
m := New("1.0.0")
|
||||
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
|
||||
m.AuthAttemptsByCountry.WithLabelValues("DE").Inc()
|
||||
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
|
||||
|
||||
families, err := m.registry.Gather()
|
||||
if err != nil {
|
||||
t.Fatalf("gather: %v", err)
|
||||
}
|
||||
|
||||
var found bool
|
||||
for _, f := range families {
|
||||
if f.GetName() == "oubliette_auth_attempts_by_country_total" {
|
||||
found = true
|
||||
if len(f.GetMetric()) != 2 {
|
||||
t.Errorf("expected 2 label pairs (US, DE), got %d", len(f.GetMetric()))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("oubliette_auth_attempts_by_country_total not found after incrementing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
m := New("1.2.3")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
m.Handler().ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(w.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("reading body: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), `oubliette_build_info{version="1.2.3"} 1`) {
|
||||
t.Errorf("response should contain build_info metric, got:\n%s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreCollector(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
// Seed some data.
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
m := New("test")
|
||||
m.RegisterStoreCollector(store)
|
||||
|
||||
families, err := m.registry.Gather()
|
||||
if err != nil {
|
||||
t.Fatalf("gather: %v", err)
|
||||
}
|
||||
|
||||
wantMetrics := map[string]float64{
|
||||
"oubliette_storage_login_attempts_total": 2,
|
||||
"oubliette_storage_unique_ips": 2,
|
||||
"oubliette_storage_sessions_total": 1,
|
||||
}
|
||||
|
||||
for _, f := range families {
|
||||
expected, ok := wantMetrics[f.GetName()]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if len(f.GetMetric()) == 0 {
|
||||
t.Errorf("metric %q has no samples", f.GetName())
|
||||
continue
|
||||
}
|
||||
got := f.GetMetric()[0].GetGauge().GetValue()
|
||||
if got != expected {
|
||||
t.Errorf("metric %q = %f, want %f", f.GetName(), got, expected)
|
||||
}
|
||||
delete(wantMetrics, f.GetName())
|
||||
}
|
||||
|
||||
for name := range wantMetrics {
|
||||
t.Errorf("metric %q not found in gather output", name)
|
||||
}
|
||||
}
|
||||
175
internal/notify/webhook.go
Normal file
175
internal/notify/webhook.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
// Event types.
|
||||
const (
|
||||
EventHumanDetected = "human_detected"
|
||||
EventSessionStarted = "session_started"
|
||||
)
|
||||
|
||||
// SessionInfo holds session data included in webhook payloads.
|
||||
type SessionInfo struct {
|
||||
ID string `json:"id"`
|
||||
IP string `json:"ip"`
|
||||
Username string `json:"username"`
|
||||
ShellName string `json:"shell_name"`
|
||||
HumanScore float64 `json:"human_score"`
|
||||
ConnectedAt string `json:"connected_at"`
|
||||
}
|
||||
|
||||
// webhookPayload is the JSON body sent to webhooks.
|
||||
type webhookPayload struct {
|
||||
Event string `json:"event"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Session SessionInfo `json:"session"`
|
||||
}
|
||||
|
||||
// Notifier sends webhook notifications for honeypot events.
|
||||
type Notifier struct {
|
||||
webhooks []config.WebhookNotifyConfig
|
||||
logger *slog.Logger
|
||||
client *http.Client
|
||||
|
||||
mu sync.Mutex
|
||||
sent map[string]struct{} // dedup key: "sessionID:eventType"
|
||||
}
|
||||
|
||||
// NewNotifier creates a Notifier with the given webhook configurations.
|
||||
func NewNotifier(webhooks []config.WebhookNotifyConfig, logger *slog.Logger) *Notifier {
|
||||
return &Notifier{
|
||||
webhooks: webhooks,
|
||||
logger: logger,
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
sent: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Notify sends a notification for the given event type and session.
|
||||
// Deduplicates by (sessionID, eventType) — each combination is sent at most once.
|
||||
func (n *Notifier) Notify(ctx context.Context, eventType string, session SessionInfo) {
|
||||
dedupKey := session.ID + ":" + eventType
|
||||
|
||||
n.mu.Lock()
|
||||
if _, ok := n.sent[dedupKey]; ok {
|
||||
n.mu.Unlock()
|
||||
return
|
||||
}
|
||||
n.sent[dedupKey] = struct{}{}
|
||||
n.mu.Unlock()
|
||||
|
||||
payload := webhookPayload{
|
||||
Event: eventType,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
Session: session,
|
||||
}
|
||||
|
||||
for _, wh := range n.webhooks {
|
||||
if !n.shouldSend(wh, eventType) {
|
||||
continue
|
||||
}
|
||||
go n.send(ctx, wh, payload)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupSession removes dedup state for a session.
|
||||
func (n *Notifier) CleanupSession(sessionID string) {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
for key := range n.sent {
|
||||
if len(key) > len(sessionID) && key[:len(sessionID)+1] == sessionID+":" {
|
||||
delete(n.sent, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldSend returns true if the webhook is configured to receive this event type.
|
||||
func (n *Notifier) shouldSend(wh config.WebhookNotifyConfig, eventType string) bool {
|
||||
if len(wh.Events) == 0 {
|
||||
return true // empty = all events
|
||||
}
|
||||
return slices.Contains(wh.Events, eventType)
|
||||
}
|
||||
|
||||
func (n *Notifier) send(ctx context.Context, wh config.WebhookNotifyConfig, payload webhookPayload) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
n.logger.Error("failed to marshal webhook payload", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, wh.URL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
n.logger.Error("failed to create webhook request", "err", err, "url", wh.URL)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
for k, v := range wh.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := n.client.Do(req)
|
||||
if err != nil {
|
||||
n.logger.Error("webhook request failed", "err", err, "url", wh.URL)
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
n.logger.Warn("webhook returned error status",
|
||||
"url", wh.URL,
|
||||
"status", resp.StatusCode,
|
||||
"event", payload.Event,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
n.logger.Debug("webhook sent",
|
||||
"url", wh.URL,
|
||||
"event", payload.Event,
|
||||
"session_id", payload.Session.ID,
|
||||
)
|
||||
}
|
||||
|
||||
// FormatConnectedAt formats a time for use in SessionInfo.
|
||||
func FormatConnectedAt(t time.Time) string {
|
||||
return t.UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// NoopNotifier is a no-op notifier used when no webhooks are configured.
|
||||
type NoopNotifier struct{}
|
||||
|
||||
func (NoopNotifier) Notify(context.Context, string, SessionInfo) {}
|
||||
func (NoopNotifier) CleanupSession(string) {}
|
||||
|
||||
// Sender is the interface for sending notifications.
|
||||
type Sender interface {
|
||||
Notify(ctx context.Context, eventType string, session SessionInfo)
|
||||
CleanupSession(sessionID string)
|
||||
}
|
||||
|
||||
var (
|
||||
_ Sender = (*Notifier)(nil)
|
||||
_ Sender = NoopNotifier{}
|
||||
)
|
||||
|
||||
// NewSender creates a Sender from configuration. Returns a NoopNotifier
|
||||
// if no webhooks are configured.
|
||||
func NewSender(webhooks []config.WebhookNotifyConfig, logger *slog.Logger) Sender {
|
||||
if len(webhooks) == 0 {
|
||||
return NoopNotifier{}
|
||||
}
|
||||
return NewNotifier(webhooks, logger)
|
||||
}
|
||||
243
internal/notify/webhook_test.go
Normal file
243
internal/notify/webhook_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
func testSession() SessionInfo {
|
||||
return SessionInfo{
|
||||
ID: "test-session-123",
|
||||
IP: "1.2.3.4",
|
||||
Username: "root",
|
||||
ShellName: "bash",
|
||||
HumanScore: 0.85,
|
||||
ConnectedAt: FormatConnectedAt(time.Now()),
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifier_PayloadStructure(t *testing.T) {
|
||||
var received webhookPayload
|
||||
var mu sync.Mutex
|
||||
done := make(chan struct{})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err := json.NewDecoder(r.Body).Decode(&received); err != nil {
|
||||
t.Errorf("failed to decode payload: %v", err)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
close(done)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
webhooks := []config.WebhookNotifyConfig{
|
||||
{URL: srv.URL},
|
||||
}
|
||||
|
||||
n := NewNotifier(webhooks, slog.Default())
|
||||
session := testSession()
|
||||
n.Notify(context.Background(), EventHumanDetected, session)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for webhook")
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if received.Event != EventHumanDetected {
|
||||
t.Errorf("event: got %q, want %q", received.Event, EventHumanDetected)
|
||||
}
|
||||
if received.Session.ID != session.ID {
|
||||
t.Errorf("session ID: got %q, want %q", received.Session.ID, session.ID)
|
||||
}
|
||||
if received.Session.IP != session.IP {
|
||||
t.Errorf("session IP: got %q, want %q", received.Session.IP, session.IP)
|
||||
}
|
||||
if received.Session.HumanScore != session.HumanScore {
|
||||
t.Errorf("score: got %f, want %f", received.Session.HumanScore, session.HumanScore)
|
||||
}
|
||||
if received.Timestamp == "" {
|
||||
t.Error("timestamp should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifier_CustomHeaders(t *testing.T) {
|
||||
var receivedHeaders http.Header
|
||||
done := make(chan struct{})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
close(done)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
webhooks := []config.WebhookNotifyConfig{
|
||||
{
|
||||
URL: srv.URL,
|
||||
Headers: map[string]string{
|
||||
"Authorization": "Bearer test-token",
|
||||
"X-Custom": "my-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
n := NewNotifier(webhooks, slog.Default())
|
||||
n.Notify(context.Background(), EventSessionStarted, testSession())
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for webhook")
|
||||
}
|
||||
|
||||
if got := receivedHeaders.Get("Authorization"); got != "Bearer test-token" {
|
||||
t.Errorf("Authorization header: got %q, want %q", got, "Bearer test-token")
|
||||
}
|
||||
if got := receivedHeaders.Get("X-Custom"); got != "my-value" {
|
||||
t.Errorf("X-Custom header: got %q, want %q", got, "my-value")
|
||||
}
|
||||
if got := receivedHeaders.Get("Content-Type"); got != "application/json" {
|
||||
t.Errorf("Content-Type: got %q, want %q", got, "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifier_Deduplication(t *testing.T) {
|
||||
var count int
|
||||
var mu sync.Mutex
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
count++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
webhooks := []config.WebhookNotifyConfig{{URL: srv.URL}}
|
||||
n := NewNotifier(webhooks, slog.Default())
|
||||
session := testSession()
|
||||
|
||||
// Send same event three times for the same session.
|
||||
for range 3 {
|
||||
n.Notify(context.Background(), EventHumanDetected, session)
|
||||
}
|
||||
|
||||
// Allow goroutines to complete.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if count != 1 {
|
||||
t.Errorf("dedup: got %d sends, want 1", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifier_EventFiltering(t *testing.T) {
|
||||
var receivedEvents []string
|
||||
var mu sync.Mutex
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var payload webhookPayload
|
||||
_ = json.NewDecoder(r.Body).Decode(&payload)
|
||||
mu.Lock()
|
||||
receivedEvents = append(receivedEvents, payload.Event)
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
// Only subscribe to human_detected.
|
||||
webhooks := []config.WebhookNotifyConfig{
|
||||
{
|
||||
URL: srv.URL,
|
||||
Events: []string{EventHumanDetected},
|
||||
},
|
||||
}
|
||||
|
||||
n := NewNotifier(webhooks, slog.Default())
|
||||
session := testSession()
|
||||
|
||||
// Send both event types.
|
||||
n.Notify(context.Background(), EventSessionStarted, session)
|
||||
// Need a different session for human_detected to avoid dedup with same session.
|
||||
session2 := testSession()
|
||||
session2.ID = "test-session-456"
|
||||
n.Notify(context.Background(), EventHumanDetected, session2)
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(receivedEvents) != 1 {
|
||||
t.Fatalf("event filtering: got %d events, want 1", len(receivedEvents))
|
||||
}
|
||||
if receivedEvents[0] != EventHumanDetected {
|
||||
t.Errorf("filtered event: got %q, want %q", receivedEvents[0], EventHumanDetected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifier_CleanupSession(t *testing.T) {
|
||||
var count int
|
||||
var mu sync.Mutex
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
count++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
webhooks := []config.WebhookNotifyConfig{{URL: srv.URL}}
|
||||
n := NewNotifier(webhooks, slog.Default())
|
||||
session := testSession()
|
||||
|
||||
n.Notify(context.Background(), EventHumanDetected, session)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Cleanup and resend — should work again.
|
||||
n.CleanupSession(session.ID)
|
||||
n.Notify(context.Background(), EventHumanDetected, session)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if count != 2 {
|
||||
t.Errorf("after cleanup: got %d sends, want 2", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoopNotifier(t *testing.T) {
|
||||
// Should not panic.
|
||||
n := NoopNotifier{}
|
||||
n.Notify(context.Background(), EventHumanDetected, testSession())
|
||||
n.CleanupSession("test")
|
||||
}
|
||||
|
||||
func TestNewSender_NoWebhooks(t *testing.T) {
|
||||
sender := NewSender(nil, slog.Default())
|
||||
if _, ok := sender.(NoopNotifier); !ok {
|
||||
t.Errorf("expected NoopNotifier, got %T", sender)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSender_WithWebhooks(t *testing.T) {
|
||||
webhooks := []config.WebhookNotifyConfig{{URL: "http://example.com"}}
|
||||
sender := NewSender(webhooks, slog.Default())
|
||||
if _, ok := sender.(*Notifier); !ok {
|
||||
t.Errorf("expected *Notifier, got %T", sender)
|
||||
}
|
||||
}
|
||||
@@ -12,29 +12,69 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/detection"
|
||||
"code.t-juice.club/torjus/oubliette/internal/geoip"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/notify"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
cfg config.Config
|
||||
store storage.Store
|
||||
authenticator *auth.Authenticator
|
||||
sshConfig *ssh.ServerConfig
|
||||
logger *slog.Logger
|
||||
connSem chan struct{} // semaphore limiting concurrent connections
|
||||
shellRegistry *shell.Registry
|
||||
cfg config.Config
|
||||
store storage.Store
|
||||
authenticator *auth.Authenticator
|
||||
sshConfig *ssh.ServerConfig
|
||||
logger *slog.Logger
|
||||
connSem chan struct{} // semaphore limiting concurrent connections
|
||||
shellRegistry *shell.Registry
|
||||
notifier notify.Sender
|
||||
metrics *metrics.Metrics
|
||||
geoip *geoip.Reader
|
||||
}
|
||||
|
||||
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
|
||||
registry := shell.NewRegistry()
|
||||
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering bash shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering fridge shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(banking.NewBankingShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering banking shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering adventure shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering cisco shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering psql shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering roomba shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering tetris shell: %w", err)
|
||||
}
|
||||
|
||||
geo, err := geoip.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening geoip database: %w", err)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
@@ -43,6 +83,9 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
|
||||
logger: logger,
|
||||
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
||||
shellRegistry: registry,
|
||||
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
|
||||
metrics: m,
|
||||
geoip: geo,
|
||||
}
|
||||
|
||||
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
||||
@@ -60,6 +103,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(ctx context.Context) error {
|
||||
defer s.geoip.Close()
|
||||
|
||||
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen: %w", err)
|
||||
@@ -86,11 +131,16 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
|
||||
// Enforce max concurrent connections.
|
||||
select {
|
||||
case s.connSem <- struct{}{}:
|
||||
s.metrics.SSHConnectionsActive.Inc()
|
||||
go func() {
|
||||
defer func() { <-s.connSem }()
|
||||
defer func() {
|
||||
<-s.connSem
|
||||
s.metrics.SSHConnectionsActive.Dec()
|
||||
}()
|
||||
s.handleConn(conn)
|
||||
}()
|
||||
default:
|
||||
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
|
||||
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
}
|
||||
@@ -102,11 +152,13 @@ func (s *Server) handleConn(conn net.Conn) {
|
||||
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
||||
if err != nil {
|
||||
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
|
||||
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
|
||||
return
|
||||
}
|
||||
defer sshConn.Close()
|
||||
|
||||
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
|
||||
s.logger.Info("SSH connection established",
|
||||
"remote_addr", sshConn.RemoteAddr(),
|
||||
"user", sshConn.User(),
|
||||
@@ -134,18 +186,50 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
defer channel.Close()
|
||||
|
||||
// Select a shell from the registry.
|
||||
selectedShell, err := s.shellRegistry.Select()
|
||||
if err != nil {
|
||||
s.logger.Error("failed to select shell", "err", err)
|
||||
return
|
||||
// If the auth layer specified a shell preference, use it; otherwise random.
|
||||
var selectedShell shell.Shell
|
||||
if conn.Permissions != nil && conn.Permissions.Extensions["shell"] != "" {
|
||||
shellName := conn.Permissions.Extensions["shell"]
|
||||
sh, ok := s.shellRegistry.Get(shellName)
|
||||
if ok {
|
||||
selectedShell = sh
|
||||
} else {
|
||||
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
|
||||
}
|
||||
}
|
||||
// Second priority: username-based route.
|
||||
if selectedShell == nil {
|
||||
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
|
||||
sh, found := s.shellRegistry.Get(shellName)
|
||||
if found {
|
||||
selectedShell = sh
|
||||
} else {
|
||||
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
|
||||
}
|
||||
}
|
||||
}
|
||||
// Lowest priority: random selection.
|
||||
if selectedShell == nil {
|
||||
var err error
|
||||
selectedShell, err = s.shellRegistry.Select()
|
||||
if err != nil {
|
||||
s.logger.Error("failed to select shell", "err", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ip := extractIP(conn.RemoteAddr())
|
||||
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
|
||||
country := s.geoip.Lookup(ip)
|
||||
sessionStart := time.Now()
|
||||
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create session", "err", err)
|
||||
} else {
|
||||
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
|
||||
s.metrics.SessionsActive.Inc()
|
||||
defer func() {
|
||||
s.metrics.SessionsActive.Dec()
|
||||
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
|
||||
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
|
||||
s.logger.Error("failed to end session", "err", err)
|
||||
}
|
||||
@@ -159,14 +243,36 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
"session_id", sessionID,
|
||||
)
|
||||
|
||||
// Handle session requests (pty-req, shell, etc.)
|
||||
// Send session_started notification.
|
||||
connectedAt := time.Now()
|
||||
sessionInfo := notify.SessionInfo{
|
||||
ID: sessionID,
|
||||
IP: ip,
|
||||
Username: conn.User(),
|
||||
ShellName: selectedShell.Name(),
|
||||
ConnectedAt: notify.FormatConnectedAt(connectedAt),
|
||||
}
|
||||
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
|
||||
defer s.notifier.CleanupSession(sessionID)
|
||||
|
||||
// Handle session requests (pty-req, shell, exec, etc.)
|
||||
execCh := make(chan string, 1)
|
||||
go func() {
|
||||
defer close(execCh)
|
||||
for req := range requests {
|
||||
switch req.Type {
|
||||
case "pty-req", "shell":
|
||||
if req.WantReply {
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
case "exec":
|
||||
if req.WantReply {
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
var payload struct{ Command string }
|
||||
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
|
||||
execCh <- payload.Command
|
||||
}
|
||||
default:
|
||||
if req.WantReply {
|
||||
req.Reply(false, nil)
|
||||
@@ -175,6 +281,29 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
}
|
||||
}()
|
||||
|
||||
// Check for exec request before proceeding to interactive shell.
|
||||
select {
|
||||
case cmd, ok := <-execCh:
|
||||
if ok && cmd != "" {
|
||||
s.logger.Info("exec command received",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"user", conn.User(),
|
||||
"session_id", sessionID,
|
||||
"command", cmd,
|
||||
)
|
||||
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
|
||||
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
|
||||
}
|
||||
s.metrics.ExecCommandsTotal.Inc()
|
||||
// Send exit-status 0 and close channel.
|
||||
exitPayload := make([]byte, 4) // uint32(0)
|
||||
_, _ = channel.SendRequest("exit-status", false, exitPayload)
|
||||
return
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
// No exec request within timeout — proceed with interactive shell.
|
||||
}
|
||||
|
||||
// Build session context.
|
||||
var shellCfg map[string]any
|
||||
if s.cfg.Shell.Shells != nil {
|
||||
@@ -192,26 +321,100 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
Banner: s.cfg.Shell.Banner,
|
||||
FakeUser: s.cfg.Shell.FakeUser,
|
||||
},
|
||||
OnCommand: func(sh string) {
|
||||
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
|
||||
},
|
||||
}
|
||||
|
||||
// Wrap channel in RecordingChannel for future byte-level recording.
|
||||
// Wrap channel in RecordingChannel.
|
||||
recorder := shell.NewRecordingChannel(channel)
|
||||
|
||||
// Always record session events for replay.
|
||||
eventRec := shell.NewEventRecorder(sessionID, s.store, s.logger)
|
||||
eventRec.Start(context.Background())
|
||||
defer eventRec.Close()
|
||||
recorder.AddCallback(eventRec.RecordEvent)
|
||||
|
||||
// Set up detection scorer if enabled.
|
||||
var scorer *detection.Scorer
|
||||
var scoreCancel context.CancelFunc
|
||||
if s.cfg.Detection.Enabled {
|
||||
scorer = detection.NewScorer()
|
||||
recorder.AddCallback(func(ts time.Time, direction int, data []byte) {
|
||||
scorer.RecordEvent(ts, direction, data)
|
||||
})
|
||||
|
||||
var scoreCtx context.Context
|
||||
scoreCtx, scoreCancel = context.WithCancel(context.Background())
|
||||
go s.runScoreUpdater(scoreCtx, sessionID, scorer, sessionInfo)
|
||||
}
|
||||
|
||||
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
|
||||
s.logger.Error("shell error", "err", err, "session_id", sessionID)
|
||||
}
|
||||
|
||||
s.logger.Info("session ended",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"user", conn.User(),
|
||||
"session_id", sessionID,
|
||||
)
|
||||
// Stop score updater and write final score.
|
||||
if scoreCancel != nil {
|
||||
scoreCancel()
|
||||
}
|
||||
if scorer != nil {
|
||||
finalScore := scorer.Score()
|
||||
s.metrics.HumanScore.Observe(finalScore)
|
||||
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
|
||||
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
|
||||
}
|
||||
s.logger.Info("session ended",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"user", conn.User(),
|
||||
"session_id", sessionID,
|
||||
"human_score", finalScore,
|
||||
)
|
||||
} else {
|
||||
s.logger.Info("session ended",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"user", conn.User(),
|
||||
"session_id", sessionID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// runScoreUpdater periodically computes the human score, writes it to the DB,
|
||||
// and triggers a notification if the threshold is crossed.
|
||||
func (s *Server) runScoreUpdater(ctx context.Context, sessionID string, scorer *detection.Scorer, sessionInfo notify.SessionInfo) {
|
||||
ticker := time.NewTicker(s.cfg.Detection.UpdateIntervalDuration)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
score := scorer.Score()
|
||||
if err := s.store.UpdateHumanScore(ctx, sessionID, score); err != nil {
|
||||
s.logger.Error("failed to update human score", "err", err, "session_id", sessionID)
|
||||
continue
|
||||
}
|
||||
s.logger.Debug("human score updated", "session_id", sessionID, "score", score)
|
||||
|
||||
if score >= s.cfg.Detection.Threshold {
|
||||
info := sessionInfo
|
||||
info.HumanScore = score
|
||||
s.notifier.Notify(ctx, notify.EventHumanDetected, info)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
ip := extractIP(conn.RemoteAddr())
|
||||
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
|
||||
|
||||
if d.Accepted {
|
||||
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
|
||||
} else {
|
||||
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
|
||||
}
|
||||
|
||||
s.logger.Info("auth attempt",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"username", conn.User(),
|
||||
@@ -219,12 +422,22 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
|
||||
"reason", d.Reason,
|
||||
)
|
||||
|
||||
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil {
|
||||
country := s.geoip.Lookup(ip)
|
||||
if country != "" {
|
||||
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
|
||||
}
|
||||
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
|
||||
s.logger.Error("failed to record login attempt", "err", err)
|
||||
}
|
||||
|
||||
if d.Accepted {
|
||||
return nil, nil
|
||||
var perms *ssh.Permissions
|
||||
if d.Shell != "" {
|
||||
perms = &ssh.Permissions{
|
||||
Extensions: map[string]string{"shell": d.Shell},
|
||||
}
|
||||
}
|
||||
return perms, nil
|
||||
}
|
||||
return nil, fmt.Errorf("rejected")
|
||||
}
|
||||
|
||||
@@ -11,8 +11,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@@ -108,7 +110,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
AcceptAfter: 2,
|
||||
CredentialTTLDuration: time.Hour,
|
||||
StaticCredentials: []config.Credential{
|
||||
{Username: "root", Password: "toor"},
|
||||
{Username: "root", Password: "toor", Shell: "bash"},
|
||||
},
|
||||
},
|
||||
Shell: config.ShellConfig{
|
||||
@@ -120,7 +122,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := New(cfg, store, logger)
|
||||
srv, err := New(cfg, store, logger, metrics.New("test"))
|
||||
if err != nil {
|
||||
t.Fatalf("creating server: %v", err)
|
||||
}
|
||||
@@ -251,6 +253,137 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Test exec command capture.
|
||||
t.Run("exec_command", func(t *testing.T) {
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "root",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SSH dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// Run a command via exec (no PTY, no shell).
|
||||
if err := session.Run("uname -a"); err != nil {
|
||||
// Run returns an error because the server closes the channel,
|
||||
// but that's expected.
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Give the server a moment to store the command.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify the exec command was captured.
|
||||
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
var foundExec bool
|
||||
for _, s := range sessions {
|
||||
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
|
||||
foundExec = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundExec {
|
||||
t.Error("expected a session with exec_command='uname -a'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test username route: add username_routes so that "postgres" gets psql shell.
|
||||
t.Run("username_route", func(t *testing.T) {
|
||||
// Reconfigure with username routes.
|
||||
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
|
||||
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
|
||||
|
||||
// Need to get the "postgres" user in via static creds or threshold.
|
||||
// Use static creds for simplicity.
|
||||
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
|
||||
config.Credential{Username: "postgres", Password: "postgres"},
|
||||
)
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
defer func() {
|
||||
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
}()
|
||||
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "postgres",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SSH dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||
t.Fatalf("request pty: %v", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("stdin pipe: %v", err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
session.Stdout = &output
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("shell: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the psql banner.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Send \q to quit.
|
||||
stdin.Write([]byte(`\q` + "\r"))
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
session.Wait()
|
||||
|
||||
out := output.String()
|
||||
if !strings.Contains(out, "psql") {
|
||||
t.Errorf("output should contain psql banner, got: %s", out)
|
||||
}
|
||||
|
||||
// Verify session was created with shell name "psql".
|
||||
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
var foundPsql bool
|
||||
for _, s := range sessions {
|
||||
if s.ShellName == "psql" && s.Username == "postgres" {
|
||||
foundPsql = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundPsql {
|
||||
t.Error("expected a session with shell_name='psql' for user 'postgres'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test threshold acceptance: after enough failed dials, a subsequent
|
||||
// dial with the same credentials should succeed via threshold or
|
||||
// remembered credential.
|
||||
|
||||
117
internal/shell/adventure/adventure.go
Normal file
117
internal/shell/adventure/adventure.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package adventure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
// AdventureShell implements a Zork-style text adventure set in a dungeon/data center.
|
||||
type AdventureShell struct{}
|
||||
|
||||
// NewAdventureShell returns a new AdventureShell instance.
|
||||
func NewAdventureShell() *AdventureShell {
|
||||
return &AdventureShell{}
|
||||
}
|
||||
|
||||
func (a *AdventureShell) Name() string { return "adventure" }
|
||||
func (a *AdventureShell) Description() string { return "Zork-style text adventure dungeon crawler" }
|
||||
|
||||
func (a *AdventureShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dungeonName := configString(sess.ShellConfig, "dungeon_name", "THE OUBLIETTE")
|
||||
game := newGame()
|
||||
|
||||
// Print banner and initial room.
|
||||
banner := strings.ReplaceAll(adventureBanner(dungeonName), "\n", "\r\n")
|
||||
fmt.Fprint(rw, banner)
|
||||
|
||||
// Show starting room.
|
||||
startDesc := game.describeRoom(game.rooms[game.currentRoom])
|
||||
startDesc = strings.ReplaceAll(startDesc, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n\r\n", startDesc)
|
||||
|
||||
for {
|
||||
if _, err := fmt.Fprint(rw, "> "); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := game.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
// Log command and output to store.
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("adventure")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func adventureBanner(dungeonName string) string {
|
||||
return fmt.Sprintf(`
|
||||
___ _ _ ____ _ ___ _____ _____ _____
|
||||
/ _ \| | | | __ )| | |_ _| ____|_ _|_ _| ___
|
||||
| | | | | | | _ \| | | || _| | | | | / _ \
|
||||
| |_| | |_| | |_) | |___ | || |___ | | | | | __/
|
||||
\___/ \___/|____/|_____|___|_____| |_| |_| \___|
|
||||
|
||||
Welcome to %s.
|
||||
|
||||
You wake up in the dark. The air is cold and hums with electricity.
|
||||
This is a place where things are put to be forgotten.
|
||||
|
||||
Type 'help' for commands. Type 'look' to examine your surroundings.
|
||||
|
||||
`, dungeonName)
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
383
internal/shell/adventure/adventure_test.go
Normal file
383
internal/shell/adventure/adventure_test.go
Normal file
@@ -0,0 +1,383 @@
|
||||
package adventure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (r *rwCloser) Close() error { return nil }
|
||||
|
||||
func runShell(t *testing.T, commands string) string {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "root",
|
||||
Store: store,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: "testhost",
|
||||
},
|
||||
}
|
||||
|
||||
rw := &rwCloser{
|
||||
Reader: bytes.NewBufferString(commands),
|
||||
Writer: &bytes.Buffer{},
|
||||
}
|
||||
|
||||
sh := NewAdventureShell()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := sh.Handle(ctx, sess, rw); err != nil {
|
||||
t.Fatalf("Handle: %v", err)
|
||||
}
|
||||
|
||||
return rw.Writer.(*bytes.Buffer).String()
|
||||
}
|
||||
|
||||
func TestAdventureShellName(t *testing.T) {
|
||||
sh := NewAdventureShell()
|
||||
if sh.Name() != "adventure" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "adventure")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBanner(t *testing.T) {
|
||||
output := runShell(t, "quit\r")
|
||||
if !strings.Contains(output, "OUBLIETTE") {
|
||||
t.Error("output should contain OUBLIETTE in banner")
|
||||
}
|
||||
if !strings.Contains(output, ">") {
|
||||
t.Error("output should contain > prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartingRoom(t *testing.T) {
|
||||
output := runShell(t, "quit\r")
|
||||
if !strings.Contains(output, "The Oubliette") {
|
||||
t.Error("should start in The Oubliette")
|
||||
}
|
||||
if !strings.Contains(output, "narrow stone chamber") {
|
||||
t.Error("should show starting room description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelpCommand(t *testing.T) {
|
||||
output := runShell(t, "help\rquit\r")
|
||||
for _, keyword := range []string{"look", "go", "take", "drop", "use", "inventory", "help", "quit"} {
|
||||
if !strings.Contains(output, keyword) {
|
||||
t.Errorf("help output should mention %q", keyword)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookCommand(t *testing.T) {
|
||||
output := runShell(t, "look\rquit\r")
|
||||
if !strings.Contains(output, "The Oubliette") {
|
||||
t.Error("look should show current room")
|
||||
}
|
||||
if !strings.Contains(output, "Exits:") {
|
||||
t.Error("look should show exits")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMovement(t *testing.T) {
|
||||
output := runShell(t, "go east\rquit\r")
|
||||
if !strings.Contains(output, "Stone Corridor") {
|
||||
t.Error("going east from oubliette should reach Stone Corridor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBareDirection(t *testing.T) {
|
||||
output := runShell(t, "e\rquit\r")
|
||||
if !strings.Contains(output, "Stone Corridor") {
|
||||
t.Error("bare 'e' should move east to Stone Corridor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectionAliases(t *testing.T) {
|
||||
output := runShell(t, "east\rquit\r")
|
||||
if !strings.Contains(output, "Stone Corridor") {
|
||||
t.Error("'east' should move to Stone Corridor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidDirection(t *testing.T) {
|
||||
output := runShell(t, "go north\rquit\r")
|
||||
if !strings.Contains(output, "can't go") {
|
||||
t.Error("should say you can't go that direction")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTakeItem(t *testing.T) {
|
||||
// Move to corridor where flashlight is.
|
||||
output := runShell(t, "e\rtake flashlight\rinventory\rquit\r")
|
||||
if !strings.Contains(output, "Taken") {
|
||||
t.Error("should confirm taking item")
|
||||
}
|
||||
if !strings.Contains(output, "flashlight") {
|
||||
t.Error("inventory should show flashlight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDropItem(t *testing.T) {
|
||||
output := runShell(t, "e\rtake flashlight\rdrop flashlight\rinventory\rquit\r")
|
||||
if !strings.Contains(output, "Dropped") {
|
||||
t.Error("should confirm dropping item")
|
||||
}
|
||||
if !strings.Contains(output, "not carrying anything") {
|
||||
t.Error("inventory should be empty after dropping")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyInventory(t *testing.T) {
|
||||
output := runShell(t, "inventory\rquit\r")
|
||||
if !strings.Contains(output, "not carrying anything") {
|
||||
t.Error("should say not carrying anything")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExamineItem(t *testing.T) {
|
||||
output := runShell(t, "e\rexamine flashlight\rquit\r")
|
||||
if !strings.Contains(output, "batteries") {
|
||||
t.Error("examining flashlight should describe it")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDarkRoom(t *testing.T) {
|
||||
// Go to the pit without flashlight.
|
||||
output := runShell(t, "down\rquit\r")
|
||||
if !strings.Contains(output, "darkness") {
|
||||
t.Error("pit should be dark without flashlight")
|
||||
}
|
||||
// Should NOT show the lit description.
|
||||
if strings.Contains(output, "skeleton") {
|
||||
t.Error("should not see skeleton without flashlight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDarkRoomWithFlashlight(t *testing.T) {
|
||||
// Get flashlight first, then go to pit.
|
||||
output := runShell(t, "e\rtake flashlight\rw\rdown\rquit\r")
|
||||
if !strings.Contains(output, "skeleton") {
|
||||
t.Error("should see skeleton with flashlight in pit")
|
||||
}
|
||||
if !strings.Contains(output, "rusty key") {
|
||||
t.Error("should see rusty key with flashlight in pit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDarkRoomCantTake(t *testing.T) {
|
||||
output := runShell(t, "down\rtake rusty_key\rquit\r")
|
||||
if !strings.Contains(output, "too dark") {
|
||||
t.Error("should not be able to take items in dark room")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockedDoor(t *testing.T) {
|
||||
// Navigate to archive without keycard.
|
||||
output := runShell(t, "e\rs\rs\re\rquit\r")
|
||||
if !strings.Contains(output, "keycard") || !strings.Contains(output, "red") {
|
||||
t.Error("should mention keycard reader when trying locked door")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlockWithKeycard(t *testing.T) {
|
||||
// Get keycard from server room, navigate to archive, use keycard.
|
||||
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\rquit\r")
|
||||
if !strings.Contains(output, "green") || !strings.Contains(output, "clicks open") {
|
||||
t.Error("should show unlock message")
|
||||
}
|
||||
if !strings.Contains(output, "Control Room") {
|
||||
t.Error("should be able to enter control room after unlocking")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRustyKeyPuzzle(t *testing.T) {
|
||||
// Get flashlight, get rusty key from pit, go to generator room, use key.
|
||||
output := runShell(t, "e\rtake flashlight\rw\rdown\rtake rusty_key\rup\re\re\rs\re\ruse rusty_key\rquit\r")
|
||||
if !strings.Contains(output, "maintenance panel") || !strings.Contains(output, "logbook") {
|
||||
t.Error("should show maintenance log when using rusty key in generator room")
|
||||
}
|
||||
if !strings.Contains(output, "Dunwich") {
|
||||
t.Error("logbook should mention Dunwich")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitEndsGame(t *testing.T) {
|
||||
// Full path to exit: get keycard, navigate to archive, unlock, go to control room, east to exit.
|
||||
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\re\r")
|
||||
if !strings.Contains(output, "SECURITY AUDIT") {
|
||||
t.Error("exit should show the ending message")
|
||||
}
|
||||
if !strings.Contains(output, "SESSION HAS BEEN LOGGED") {
|
||||
t.Error("exit should mention session logging")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownCommand(t *testing.T) {
|
||||
output := runShell(t, "xyzzy\rquit\r")
|
||||
if !strings.Contains(output, "don't understand") {
|
||||
t.Error("should show error for unknown command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitCommand(t *testing.T) {
|
||||
output := runShell(t, "quit\r")
|
||||
if !strings.Contains(output, "terminated") {
|
||||
t.Error("quit should show termination message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitAlias(t *testing.T) {
|
||||
output := runShell(t, "exit\r")
|
||||
if !strings.Contains(output, "terminated") {
|
||||
t.Error("exit alias should work like quit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVentilationShaft(t *testing.T) {
|
||||
output := runShell(t, "e\rup\rquit\r")
|
||||
if !strings.Contains(output, "Ventilation Shaft") {
|
||||
t.Error("should reach ventilation shaft from corridor")
|
||||
}
|
||||
if !strings.Contains(output, "note") {
|
||||
t.Error("should see note in ventilation shaft")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNote(t *testing.T) {
|
||||
output := runShell(t, "e\rup\rtake note\rexamine note\rquit\r")
|
||||
if !strings.Contains(output, "GET OUT") {
|
||||
t.Error("note should contain warning text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseItemWrongRoom(t *testing.T) {
|
||||
// Get keycard from server room, try to use in wrong room.
|
||||
output := runShell(t, "e\re\rtake keycard\ruse keycard\rquit\r")
|
||||
if !strings.Contains(output, "nothing to use") {
|
||||
t.Error("should say nothing to use keycard on in wrong room")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEthernetCable(t *testing.T) {
|
||||
output := runShell(t, "e\rs\rtake ethernet_cable\ruse ethernet_cable\rquit\r")
|
||||
if !strings.Contains(output, "chewed") {
|
||||
t.Error("using ethernet cable should mention chewed ends")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "root",
|
||||
Store: store,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: "testhost",
|
||||
},
|
||||
}
|
||||
|
||||
rw := &rwCloser{
|
||||
Reader: bytes.NewBufferString("help\rquit\r"),
|
||||
Writer: &bytes.Buffer{},
|
||||
}
|
||||
|
||||
sh := NewAdventureShell()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sh.Handle(ctx, sess, rw)
|
||||
|
||||
if len(store.SessionLogs) < 2 {
|
||||
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserBasics(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
verb string
|
||||
object string
|
||||
}{
|
||||
{"look", "look", ""},
|
||||
{"l", "look", ""},
|
||||
{"examine flashlight", "look", "flashlight"},
|
||||
{"go north", "go", "north"},
|
||||
{"n", "go", "north"},
|
||||
{"south", "go", "south"},
|
||||
{"take the key", "take", "key"},
|
||||
{"get a flashlight", "take", "flashlight"},
|
||||
{"drop rusty key", "drop", "rusty key"},
|
||||
{"use keycard", "use", "keycard"},
|
||||
{"inventory", "inventory", ""},
|
||||
{"i", "inventory", ""},
|
||||
{"inv", "inventory", ""},
|
||||
{"help", "help", ""},
|
||||
{"?", "help", ""},
|
||||
{"quit", "quit", ""},
|
||||
{"exit", "quit", ""},
|
||||
{"q", "quit", ""},
|
||||
{"LOOK", "look", ""},
|
||||
{" go east ", "go", "east"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
cmd := parseCommand(tt.input)
|
||||
if cmd.verb != tt.verb {
|
||||
t.Errorf("parseCommand(%q).verb = %q, want %q", tt.input, cmd.verb, tt.verb)
|
||||
}
|
||||
if cmd.object != tt.object {
|
||||
t.Errorf("parseCommand(%q).object = %q, want %q", tt.input, cmd.object, tt.object)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserEmpty(t *testing.T) {
|
||||
cmd := parseCommand("")
|
||||
if cmd.verb != "" {
|
||||
t.Errorf("empty input should give empty verb, got %q", cmd.verb)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParserArticleStripping(t *testing.T) {
|
||||
cmd := parseCommand("take the an a flashlight")
|
||||
if cmd.verb != "take" || cmd.object != "flashlight" {
|
||||
t.Errorf("articles should be stripped, got verb=%q object=%q", cmd.verb, cmd.object)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"dungeon_name": "MY DUNGEON"}
|
||||
if got := configString(cfg, "dungeon_name", "DEFAULT"); got != "MY DUNGEON" {
|
||||
t.Errorf("configString() = %q, want %q", got, "MY DUNGEON")
|
||||
}
|
||||
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
|
||||
t.Errorf("configString() for missing key = %q, want %q", got, "DEFAULT")
|
||||
}
|
||||
if got := configString(nil, "key", "DEFAULT"); got != "DEFAULT" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
|
||||
}
|
||||
}
|
||||
358
internal/shell/adventure/game.go
Normal file
358
internal/shell/adventure/game.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package adventure
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type gameState struct {
|
||||
currentRoom string
|
||||
inventory []string
|
||||
rooms map[string]*room
|
||||
items map[string]*item
|
||||
flags map[string]bool
|
||||
turns int
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
func newGame() *gameState {
|
||||
rooms, items := newWorld()
|
||||
return &gameState{
|
||||
currentRoom: "oubliette",
|
||||
rooms: rooms,
|
||||
items: items,
|
||||
flags: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gameState) dispatch(input string) commandResult {
|
||||
cmd := parseCommand(input)
|
||||
if cmd.verb == "" {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
g.turns++
|
||||
|
||||
switch cmd.verb {
|
||||
case "look":
|
||||
return g.cmdLook(cmd.object)
|
||||
case "go":
|
||||
return g.cmdGo(cmd.object)
|
||||
case "take":
|
||||
return g.cmdTake(cmd.object)
|
||||
case "drop":
|
||||
return g.cmdDrop(cmd.object)
|
||||
case "use":
|
||||
return g.cmdUse(cmd.object)
|
||||
case "inventory":
|
||||
return g.cmdInventory()
|
||||
case "help":
|
||||
return g.cmdHelp()
|
||||
case "quit":
|
||||
return commandResult{output: "The darkness closes in. Session terminated.", exit: true}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("I don't understand '%s'. Type 'help' for available commands.", input)}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdHelp() commandResult {
|
||||
help := `Available commands:
|
||||
look / examine [item] - Look around or examine something
|
||||
go <direction> - Move (north, south, east, west, up, down)
|
||||
take <item> - Pick up an item
|
||||
drop <item> - Drop an item
|
||||
use <item> - Use an item
|
||||
inventory - Check what you're carrying
|
||||
help - Show this help
|
||||
quit / exit - End session
|
||||
|
||||
You can also just type a direction (n, s, e, w, u, d) to move.`
|
||||
return commandResult{output: help}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdLook(object string) commandResult {
|
||||
r := g.rooms[g.currentRoom]
|
||||
|
||||
// Look at a specific item.
|
||||
if object != "" {
|
||||
return g.examineItem(object)
|
||||
}
|
||||
|
||||
// Look at the room.
|
||||
return commandResult{output: g.describeRoom(r)}
|
||||
}
|
||||
|
||||
func (g *gameState) describeRoom(r *room) string {
|
||||
var b strings.Builder
|
||||
|
||||
fmt.Fprintf(&b, "== %s ==\n", r.name)
|
||||
|
||||
// Dark room check.
|
||||
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||
b.WriteString(r.darkDesc)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
b.WriteString(r.description)
|
||||
|
||||
// List visible items.
|
||||
visibleItems := g.roomItems(r)
|
||||
if len(visibleItems) > 0 {
|
||||
b.WriteString("\n\nYou can see: ")
|
||||
names := make([]string, len(visibleItems))
|
||||
for i, id := range visibleItems {
|
||||
names[i] = g.items[id].name
|
||||
}
|
||||
b.WriteString(strings.Join(names, ", "))
|
||||
}
|
||||
|
||||
// List exits.
|
||||
if len(r.exits) > 0 {
|
||||
b.WriteString("\n\nExits: ")
|
||||
dirs := make([]string, 0, len(r.exits))
|
||||
for dir := range r.exits {
|
||||
dirs = append(dirs, dir)
|
||||
}
|
||||
b.WriteString(strings.Join(dirs, ", "))
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (g *gameState) roomItems(r *room) []string {
|
||||
// In dark rooms without flashlight, can't see items.
|
||||
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||
return nil
|
||||
}
|
||||
return r.items
|
||||
}
|
||||
|
||||
func (g *gameState) examineItem(name string) commandResult {
|
||||
id := g.resolveItem(name)
|
||||
if id == "" {
|
||||
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
|
||||
}
|
||||
return commandResult{output: g.items[id].description}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdGo(direction string) commandResult {
|
||||
if direction == "" {
|
||||
return commandResult{output: "Go where? Try a direction: north, south, east, west, up, down."}
|
||||
}
|
||||
|
||||
r := g.rooms[g.currentRoom]
|
||||
destID, ok := r.exits[direction]
|
||||
if !ok {
|
||||
return commandResult{output: fmt.Sprintf("You can't go %s from here.", direction)}
|
||||
}
|
||||
|
||||
// Check locked doors.
|
||||
if r.locked != nil {
|
||||
if flag, locked := r.locked[direction]; locked && !g.flags[flag] {
|
||||
return g.lockedMessage(direction)
|
||||
}
|
||||
}
|
||||
|
||||
g.currentRoom = destID
|
||||
dest := g.rooms[destID]
|
||||
|
||||
// Entering the exit room ends the game.
|
||||
if destID == "exit" {
|
||||
return commandResult{output: g.describeRoom(dest), exit: true}
|
||||
}
|
||||
|
||||
return commandResult{output: g.describeRoom(dest)}
|
||||
}
|
||||
|
||||
func (g *gameState) lockedMessage(direction string) commandResult {
|
||||
if direction == "east" && g.currentRoom == "archive" {
|
||||
return commandResult{output: "The steel door won't budge. The keycard reader blinks red, waiting."}
|
||||
}
|
||||
return commandResult{output: "The way is locked."}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdTake(name string) commandResult {
|
||||
if name == "" {
|
||||
return commandResult{output: "Take what?"}
|
||||
}
|
||||
|
||||
r := g.rooms[g.currentRoom]
|
||||
|
||||
// Can't take items in dark rooms.
|
||||
if r.darkDesc != "" && !g.hasItem("flashlight") {
|
||||
return commandResult{output: "It's too dark to find anything."}
|
||||
}
|
||||
|
||||
id := g.resolveRoomItem(name)
|
||||
if id == "" {
|
||||
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
|
||||
}
|
||||
|
||||
it := g.items[id]
|
||||
if !it.takeable {
|
||||
return commandResult{output: fmt.Sprintf("You can't take the %s.", it.name)}
|
||||
}
|
||||
|
||||
// Remove from room, add to inventory.
|
||||
g.removeRoomItem(r, id)
|
||||
g.inventory = append(g.inventory, id)
|
||||
return commandResult{output: fmt.Sprintf("Taken: %s", it.name)}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdDrop(name string) commandResult {
|
||||
if name == "" {
|
||||
return commandResult{output: "Drop what?"}
|
||||
}
|
||||
|
||||
id := g.resolveInventoryItem(name)
|
||||
if id == "" {
|
||||
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
|
||||
}
|
||||
|
||||
// Remove from inventory, add to room.
|
||||
g.removeInventoryItem(id)
|
||||
r := g.rooms[g.currentRoom]
|
||||
r.items = append(r.items, id)
|
||||
return commandResult{output: fmt.Sprintf("Dropped: %s", g.items[id].name)}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdUse(name string) commandResult {
|
||||
if name == "" {
|
||||
return commandResult{output: "Use what?"}
|
||||
}
|
||||
|
||||
id := g.resolveInventoryItem(name)
|
||||
if id == "" {
|
||||
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
|
||||
}
|
||||
|
||||
switch id {
|
||||
case "keycard":
|
||||
return g.useKeycard()
|
||||
case "rusty_key":
|
||||
return g.useRustyKey()
|
||||
case "flashlight":
|
||||
return commandResult{output: "The flashlight is already on. Its beam cuts through the darkness."}
|
||||
case "ethernet_cable":
|
||||
return commandResult{output: "You wave the cable around hopefully. Nothing happens. Both ends are chewed through anyway."}
|
||||
case "floppy_disk":
|
||||
return commandResult{output: "You don't have anything to put it in. Floppy drives went extinct decades ago."}
|
||||
case "note":
|
||||
return commandResult{output: g.items[id].description}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("You can't figure out how to use the %s here.", g.items[id].name)}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gameState) useKeycard() commandResult {
|
||||
if g.currentRoom != "archive" {
|
||||
return commandResult{output: "There's nothing to use the keycard on here."}
|
||||
}
|
||||
if g.flags["archive_unlocked"] {
|
||||
return commandResult{output: "You already unlocked that door."}
|
||||
}
|
||||
g.flags["archive_unlocked"] = true
|
||||
return commandResult{output: "You swipe the keycard. The reader flashes green and the steel door clicks open with a heavy thunk."}
|
||||
}
|
||||
|
||||
func (g *gameState) useRustyKey() commandResult {
|
||||
if g.currentRoom != "generator" {
|
||||
return commandResult{output: "There's nothing to use the rusty key on here."}
|
||||
}
|
||||
if g.flags["panel_opened"] {
|
||||
return commandResult{output: "The maintenance panel is already open."}
|
||||
}
|
||||
g.flags["panel_opened"] = true
|
||||
return commandResult{output: `The key fits. The maintenance panel swings open with a screech, revealing a logbook:
|
||||
|
||||
MAINTENANCE LOG - GENERATOR B
|
||||
Last service: 2003-11-15
|
||||
Technician: J. Dunwich
|
||||
|
||||
Notes: "Generator A offline permanently. B running on fumes.
|
||||
Fuel delivery canceled - 'facility decommissioned' per management.
|
||||
But the servers are still running. WHO IS USING THEM?
|
||||
Filed ticket #4,271. No response. As usual."
|
||||
|
||||
The entries stop after that date.`}
|
||||
}
|
||||
|
||||
func (g *gameState) cmdInventory() commandResult {
|
||||
if len(g.inventory) == 0 {
|
||||
return commandResult{output: "You're not carrying anything."}
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("You are carrying:\n")
|
||||
for _, id := range g.inventory {
|
||||
fmt.Fprintf(&b, " - %s\n", g.items[id].name)
|
||||
}
|
||||
return commandResult{output: strings.TrimRight(b.String(), "\n")}
|
||||
}
|
||||
|
||||
// hasItem checks if the player has an item in inventory.
|
||||
func (g *gameState) hasItem(id string) bool {
|
||||
return slices.Contains(g.inventory, id)
|
||||
}
|
||||
|
||||
// resolveItem finds an item by name in both room and inventory.
|
||||
func (g *gameState) resolveItem(name string) string {
|
||||
if id := g.resolveInventoryItem(name); id != "" {
|
||||
return id
|
||||
}
|
||||
return g.resolveRoomItem(name)
|
||||
}
|
||||
|
||||
// resolveRoomItem finds an item in the current room by partial name match.
|
||||
func (g *gameState) resolveRoomItem(name string) string {
|
||||
r := g.rooms[g.currentRoom]
|
||||
name = strings.ToLower(name)
|
||||
for _, id := range r.items {
|
||||
if matchesItem(id, g.items[id].name, name) {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveInventoryItem finds an item in inventory by partial name match.
|
||||
func (g *gameState) resolveInventoryItem(name string) string {
|
||||
name = strings.ToLower(name)
|
||||
for _, id := range g.inventory {
|
||||
if matchesItem(id, g.items[id].name, name) {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// matchesItem checks if a search term matches an item's ID or display name.
|
||||
func matchesItem(id, displayName, search string) bool {
|
||||
return id == search ||
|
||||
strings.ToLower(displayName) == search ||
|
||||
strings.Contains(id, search) ||
|
||||
strings.Contains(strings.ToLower(displayName), search)
|
||||
}
|
||||
|
||||
func (g *gameState) removeRoomItem(r *room, id string) {
|
||||
for i, itemID := range r.items {
|
||||
if itemID == id {
|
||||
r.items = append(r.items[:i], r.items[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gameState) removeInventoryItem(id string) {
|
||||
for i, invID := range g.inventory {
|
||||
if invID == id {
|
||||
g.inventory = append(g.inventory[:i], g.inventory[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
106
internal/shell/adventure/parser.go
Normal file
106
internal/shell/adventure/parser.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package adventure
|
||||
|
||||
import "strings"
|
||||
|
||||
type parsedCommand struct {
|
||||
verb string
|
||||
object string
|
||||
}
|
||||
|
||||
// directionAliases maps shorthand and full direction names to canonical forms.
|
||||
var directionAliases = map[string]string{
|
||||
"n": "north",
|
||||
"s": "south",
|
||||
"e": "east",
|
||||
"w": "west",
|
||||
"u": "up",
|
||||
"d": "down",
|
||||
"north": "north",
|
||||
"south": "south",
|
||||
"east": "east",
|
||||
"west": "west",
|
||||
"up": "up",
|
||||
"down": "down",
|
||||
}
|
||||
|
||||
// verbAliases maps aliases to canonical verbs.
|
||||
var verbAliases = map[string]string{
|
||||
"look": "look",
|
||||
"l": "look",
|
||||
"examine": "look",
|
||||
"inspect": "look",
|
||||
"x": "look",
|
||||
"go": "go",
|
||||
"move": "go",
|
||||
"walk": "go",
|
||||
"take": "take",
|
||||
"get": "take",
|
||||
"grab": "take",
|
||||
"pick": "take",
|
||||
"drop": "drop",
|
||||
"put": "drop",
|
||||
"use": "use",
|
||||
"apply": "use",
|
||||
"inventory": "inventory",
|
||||
"inv": "inventory",
|
||||
"i": "inventory",
|
||||
"help": "help",
|
||||
"?": "help",
|
||||
"quit": "quit",
|
||||
"exit": "quit",
|
||||
"logout": "quit",
|
||||
"q": "quit",
|
||||
}
|
||||
|
||||
// articles are stripped from input.
|
||||
var articles = map[string]bool{
|
||||
"the": true,
|
||||
"a": true,
|
||||
"an": true,
|
||||
}
|
||||
|
||||
// parseCommand parses raw input into a verb and object.
|
||||
func parseCommand(input string) parsedCommand {
|
||||
input = strings.TrimSpace(strings.ToLower(input))
|
||||
if input == "" {
|
||||
return parsedCommand{}
|
||||
}
|
||||
|
||||
words := strings.Fields(input)
|
||||
|
||||
// Strip articles.
|
||||
var filtered []string
|
||||
for _, w := range words {
|
||||
if !articles[w] {
|
||||
filtered = append(filtered, w)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return parsedCommand{}
|
||||
}
|
||||
|
||||
first := filtered[0]
|
||||
|
||||
// Bare direction → go <direction>.
|
||||
if dir, ok := directionAliases[first]; ok {
|
||||
return parsedCommand{verb: "go", object: dir}
|
||||
}
|
||||
|
||||
// Known verb alias.
|
||||
if verb, ok := verbAliases[first]; ok {
|
||||
object := ""
|
||||
if len(filtered) > 1 {
|
||||
object = strings.Join(filtered[1:], " ")
|
||||
// Resolve direction alias in object for "go north" etc.
|
||||
if verb == "go" {
|
||||
if dir, ok := directionAliases[object]; ok {
|
||||
object = dir
|
||||
}
|
||||
}
|
||||
}
|
||||
return parsedCommand{verb: verb, object: object}
|
||||
}
|
||||
|
||||
// Unknown verb — return as-is so game can give an error.
|
||||
return parsedCommand{verb: first, object: strings.Join(filtered[1:], " ")}
|
||||
}
|
||||
211
internal/shell/adventure/world.go
Normal file
211
internal/shell/adventure/world.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package adventure
|
||||
|
||||
// room represents a location in the game world.
|
||||
type room struct {
|
||||
name string
|
||||
description string
|
||||
darkDesc string // shown when room is dark (empty = not a dark room)
|
||||
exits map[string]string // direction → room ID
|
||||
items []string // item IDs present in the room
|
||||
locked map[string]string // direction → flag required to unlock
|
||||
}
|
||||
|
||||
// item represents an object that can be picked up and used.
|
||||
type item struct {
|
||||
name string
|
||||
description string
|
||||
takeable bool
|
||||
}
|
||||
|
||||
func newWorld() (map[string]*room, map[string]*item) {
|
||||
rooms := map[string]*room{
|
||||
"oubliette": {
|
||||
name: "The Oubliette",
|
||||
description: `You are in a narrow stone chamber. The walls are damp and slick with condensation.
|
||||
Far above, an iron grate lets in a faint green glow — not daylight, but the steady
|
||||
pulse of status LEDs. A frayed ethernet cable hangs from the grate like a vine.
|
||||
A passage leads east into darkness. Stone steps spiral downward.`,
|
||||
exits: map[string]string{
|
||||
"east": "corridor",
|
||||
"down": "pit",
|
||||
},
|
||||
},
|
||||
"corridor": {
|
||||
name: "Stone Corridor",
|
||||
description: `A long corridor carved from living rock. The walls transition from rough-hewn
|
||||
stone to poured concrete as you look east. Fluorescent tubes flicker overhead,
|
||||
half of them dead. Cable trays run along the ceiling, sagging under the weight
|
||||
of bundled Cat5. A draft comes from a ventilation shaft above.`,
|
||||
exits: map[string]string{
|
||||
"west": "oubliette",
|
||||
"east": "server_room",
|
||||
"south": "cable_crypt",
|
||||
"up": "ventilation",
|
||||
},
|
||||
items: []string{"flashlight"},
|
||||
},
|
||||
"ventilation": {
|
||||
name: "Ventilation Shaft",
|
||||
description: `You've squeezed into a narrow ventilation shaft. The aluminum walls vibrate with
|
||||
the hum of distant fans. It's barely wide enough to turn around in. Dust and
|
||||
cobwebs coat everything. Someone has scratched tally marks into the metal — you
|
||||
stop counting at forty-seven.`,
|
||||
exits: map[string]string{
|
||||
"down": "corridor",
|
||||
},
|
||||
items: []string{"note"},
|
||||
},
|
||||
"server_room": {
|
||||
name: "Server Room",
|
||||
description: `Rows of black server racks stretch into the gloom, their LEDs blinking in
|
||||
patterns that almost seem deliberate. The air is frigid and filled with the
|
||||
white noise of a thousand fans. A yellowed label on the nearest rack reads:
|
||||
"PRODUCTION - DO NOT TOUCH". Someone has added in marker: "OR ELSE".
|
||||
A laminated keycard sits on top of a powered-down blade server.`,
|
||||
exits: map[string]string{
|
||||
"west": "corridor",
|
||||
"south": "cold_storage",
|
||||
},
|
||||
items: []string{"keycard"},
|
||||
},
|
||||
"pit": {
|
||||
name: "The Pit",
|
||||
darkDesc: `You are in absolute darkness. You can hear water dripping somewhere far below
|
||||
and the faint hum of electronics. The air smells of rust and ozone. You can't
|
||||
see a thing without a light source. The stairs lead back up.`,
|
||||
description: `Your flashlight reveals a deep shaft — part medieval oubliette, part cable run.
|
||||
Rusty chains hang from iron rings set into the walls, intertwined with bundles
|
||||
of fiber optic cable that glow faintly orange. At the bottom, a skeleton in a
|
||||
lab coat slumps against the wall, still wearing an ID badge. A rusty key glints
|
||||
on a hook near the skeleton's hand.`,
|
||||
exits: map[string]string{
|
||||
"up": "oubliette",
|
||||
},
|
||||
items: []string{"rusty_key"},
|
||||
},
|
||||
"cable_crypt": {
|
||||
name: "Cable Crypt",
|
||||
description: `This vaulted chamber was clearly something else before the cables arrived.
|
||||
Stone sarcophagi line the walls, but their lids have been removed and they're
|
||||
now stuffed full of tangled ethernet runs and power strips. A faded sign reads
|
||||
"STRUCTURED CABLING" — someone has crossed out "STRUCTURED" and written
|
||||
"CHAOTIC" above it. The air smells of old stone and warm plastic.`,
|
||||
exits: map[string]string{
|
||||
"north": "corridor",
|
||||
"south": "archive",
|
||||
},
|
||||
items: []string{"ethernet_cable"},
|
||||
},
|
||||
"cold_storage": {
|
||||
name: "Cold Storage",
|
||||
description: `A cavernous room kept at near-freezing temperatures. Frost coats the walls.
|
||||
Rows of old tape drives and disk platters are stacked on industrial shelving,
|
||||
their labels faded beyond reading. A humming cryo-unit in the corner has a
|
||||
blinking amber light. A hand-written sign says "BACKUP STORAGE - CRITICAL".`,
|
||||
exits: map[string]string{
|
||||
"north": "server_room",
|
||||
"east": "generator",
|
||||
},
|
||||
items: []string{"floppy_disk"},
|
||||
},
|
||||
"archive": {
|
||||
name: "The Archive",
|
||||
description: `Floor-to-ceiling shelves hold thousands of manila folders, binders, and
|
||||
three-ring notebooks. The organization system, if there ever was one, has
|
||||
long since collapsed into entropy. A thick layer of dust covers everything.
|
||||
A heavy steel door to the east has an electronic keycard reader with a
|
||||
steady red light. The cable crypt lies back to the north.`,
|
||||
exits: map[string]string{
|
||||
"north": "cable_crypt",
|
||||
"east": "control_room",
|
||||
},
|
||||
locked: map[string]string{
|
||||
"east": "archive_unlocked",
|
||||
},
|
||||
},
|
||||
"generator": {
|
||||
name: "Generator Room",
|
||||
description: `Two massive diesel generators squat in this room like sleeping beasts. One is
|
||||
clearly dead — corrosion has eaten through the fuel lines. The other hums at
|
||||
a low idle, keeping the facility on life support. A maintenance panel on the
|
||||
wall is secured with an old-fashioned keyhole.`,
|
||||
exits: map[string]string{
|
||||
"west": "cold_storage",
|
||||
},
|
||||
},
|
||||
"control_room": {
|
||||
name: "Control Room",
|
||||
description: `Banks of CRT monitors cast a blue glow across the room. Most show static, but
|
||||
one displays a facility map with pulsing dots labeled "ACTIVE SESSIONS". You
|
||||
count the dots — there are more than there should be. A desk covered in coffee
|
||||
rings holds a battered keyboard. The main display reads:
|
||||
|
||||
FACILITY STATUS: NOMINAL
|
||||
CONTAINMENT: ACTIVE
|
||||
SUBJECTS: ███
|
||||
|
||||
A heavy blast door leads east. Above it, a faded exit sign flickers.`,
|
||||
exits: map[string]string{
|
||||
"west": "archive",
|
||||
"east": "exit",
|
||||
},
|
||||
},
|
||||
"exit": {
|
||||
name: "The Exit",
|
||||
description: `The blast door groans open to reveal... a corridor. Not the freedom you
|
||||
expected, but another corridor — identical to the one you started in. The
|
||||
same flickering fluorescents. The same sagging cable trays. The same damp
|
||||
stone walls.
|
||||
|
||||
As you step through, the door slams shut behind you. A speaker crackles:
|
||||
|
||||
"THANK YOU FOR PARTICIPATING IN TODAY'S SECURITY AUDIT.
|
||||
YOUR SESSION HAS BEEN LOGGED.
|
||||
HAVE A NICE DAY."
|
||||
|
||||
The lights go out.`,
|
||||
},
|
||||
}
|
||||
|
||||
items := map[string]*item{
|
||||
"flashlight": {
|
||||
name: "flashlight",
|
||||
description: "A heavy-duty flashlight. The batteries are low but it still works.",
|
||||
takeable: true,
|
||||
},
|
||||
"keycard": {
|
||||
name: "keycard",
|
||||
description: "A laminated keycard with a faded photo. The name reads 'J. DUNWICH, LEVEL 3'.",
|
||||
takeable: true,
|
||||
},
|
||||
"rusty_key": {
|
||||
name: "rusty key",
|
||||
description: "An old iron key, spotted with rust. It looks like it fits a maintenance panel.",
|
||||
takeable: true,
|
||||
},
|
||||
"note": {
|
||||
name: "note",
|
||||
description: `A crumpled note in shaky handwriting:
|
||||
|
||||
"If you're reading this, GET OUT. The facility is automated now.
|
||||
The sessions never end. I thought I was running tests but the
|
||||
tests were running me. Don't trust the prompts. Don't trust
|
||||
the exits. Don't trust
|
||||
|
||||
[the writing trails off into an illegible scrawl]"`,
|
||||
takeable: true,
|
||||
},
|
||||
"ethernet_cable": {
|
||||
name: "ethernet cable",
|
||||
description: "A tangled Cat5e cable, about 3 meters long. Both ends have been chewed by something.",
|
||||
takeable: true,
|
||||
},
|
||||
"floppy_disk": {
|
||||
name: "floppy disk",
|
||||
description: `A 3.5" floppy disk labeled "BACKUP - CRITICAL - DO NOT FORMAT". The label is dated 1997.`,
|
||||
takeable: true,
|
||||
},
|
||||
}
|
||||
|
||||
return rooms, items
|
||||
}
|
||||
74
internal/shell/banking/banking.go
Normal file
74
internal/shell/banking/banking.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
// BankingShell is an 80s-style green-on-black bank terminal TUI.
|
||||
type BankingShell struct{}
|
||||
|
||||
// NewBankingShell returns a new BankingShell instance.
|
||||
func NewBankingShell() *BankingShell {
|
||||
return &BankingShell{}
|
||||
}
|
||||
|
||||
func (b *BankingShell) Name() string { return "banking" }
|
||||
func (b *BankingShell) Description() string { return "80s-style banking terminal TUI" }
|
||||
|
||||
func (b *BankingShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
bankName := configString(sess.ShellConfig, "bank_name", "SECUREBANK")
|
||||
terminalID := configString(sess.ShellConfig, "terminal_id", "")
|
||||
region := configString(sess.ShellConfig, "region", "NORTHEAST")
|
||||
|
||||
if terminalID == "" {
|
||||
terminalID = fmt.Sprintf("SB-%04d", rand.IntN(10000))
|
||||
}
|
||||
|
||||
m := newModel(sess, bankName, terminalID, region)
|
||||
p := tea.NewProgram(m,
|
||||
tea.WithInput(rw),
|
||||
tea.WithOutput(rw),
|
||||
tea.WithAltScreen(),
|
||||
)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := p.Run()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
p.Quit()
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
560
internal/shell/banking/banking_test.go
Normal file
560
internal/shell/banking/banking_test.go
Normal file
@@ -0,0 +1,560 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// newTestModel creates a model with a test session context.
|
||||
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "banker", "banking", "")
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "banker",
|
||||
Store: store,
|
||||
}
|
||||
m := newModel(sess, "SECUREBANK", "SB-0001", "NORTHEAST")
|
||||
return m, store
|
||||
}
|
||||
|
||||
// sendKeys sends a string of characters as individual key messages to the model.
|
||||
func sendKeys(m *model, s string) {
|
||||
for _, ch := range s {
|
||||
var msg tea.KeyMsg
|
||||
switch ch {
|
||||
case '\r':
|
||||
msg = tea.KeyMsg{Type: tea.KeyEnter}
|
||||
case '\x1b':
|
||||
msg = tea.KeyMsg{Type: tea.KeyEscape}
|
||||
case '\x03':
|
||||
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
|
||||
default:
|
||||
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{ch}}
|
||||
}
|
||||
m.Update(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBankingShellName(t *testing.T) {
|
||||
sh := NewBankingShell()
|
||||
if sh.Name() != "banking" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "banking")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCurrency(t *testing.T) {
|
||||
tests := []struct {
|
||||
cents int64
|
||||
want string
|
||||
}{
|
||||
{0, "$0.00"},
|
||||
{100, "$1.00"},
|
||||
{4738291, "$47,382.91"},
|
||||
{18254100, "$182,541.00"},
|
||||
{52387450, "$523,874.50"},
|
||||
{25000000, "$250,000.00"},
|
||||
{-125000, "-$1,250.00"},
|
||||
{99, "$0.99"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := formatCurrency(tt.cents)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatCurrency(%d) = %q, want %q", tt.cents, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewBankState(t *testing.T) {
|
||||
state := newBankState()
|
||||
if len(state.Accounts) != 4 {
|
||||
t.Errorf("expected 4 accounts, got %d", len(state.Accounts))
|
||||
}
|
||||
for _, acct := range state.Accounts {
|
||||
txns, ok := state.Transactions[acct.Number]
|
||||
if !ok {
|
||||
t.Errorf("no transactions for account %s", acct.Number)
|
||||
continue
|
||||
}
|
||||
if len(txns) == 0 {
|
||||
t.Errorf("account %s has no transactions", acct.Number)
|
||||
}
|
||||
}
|
||||
if len(state.Messages) != 4 {
|
||||
t.Errorf("expected 4 messages, got %d", len(state.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginScreenRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "SECUREBANK") {
|
||||
t.Error("login should show bank name")
|
||||
}
|
||||
if !strings.Contains(view, "AUTHORIZED ACCESS ONLY") {
|
||||
t.Error("login should show authorization warning")
|
||||
}
|
||||
if !strings.Contains(view, "ACCOUNT NUMBER") {
|
||||
t.Error("login should prompt for account number")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginFlow(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
|
||||
// Type account number.
|
||||
sendKeys(m, "12345678")
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "12345678") {
|
||||
t.Error("should show typed account number")
|
||||
}
|
||||
|
||||
// Press enter.
|
||||
sendKeys(m, "\r")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "PIN") {
|
||||
t.Error("should show PIN prompt after entering account number")
|
||||
}
|
||||
|
||||
// Type PIN and enter.
|
||||
sendKeys(m, "1234\r")
|
||||
|
||||
// Should be on menu now.
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("expected screenMenu, got %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMainMenuRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r")
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "MAIN MENU") {
|
||||
t.Error("should show MAIN MENU after login")
|
||||
}
|
||||
if !strings.Contains(view, "WIRE TRANSFER") {
|
||||
t.Error("menu should contain WIRE TRANSFER option")
|
||||
}
|
||||
if !strings.Contains(view, "SECURE MESSAGES") {
|
||||
t.Error("menu should contain SECURE MESSAGES option")
|
||||
}
|
||||
if !strings.Contains(view, "ACCOUNT SUMMARY") {
|
||||
t.Error("menu should contain ACCOUNT SUMMARY option")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountSummary(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "1\r") // account summary
|
||||
|
||||
if m.screen != screenAccountSummary {
|
||||
t.Fatalf("expected screenAccountSummary, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "ACCOUNT SUMMARY") {
|
||||
t.Error("should show ACCOUNT SUMMARY")
|
||||
}
|
||||
if !strings.Contains(view, "CHECKING") {
|
||||
t.Error("should show CHECKING account")
|
||||
}
|
||||
if !strings.Contains(view, "SAVINGS") {
|
||||
t.Error("should show SAVINGS account")
|
||||
}
|
||||
if !strings.Contains(view, "TOTAL") {
|
||||
t.Error("should show TOTAL")
|
||||
}
|
||||
|
||||
// Press any key to return.
|
||||
sendKeys(m, " ")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireTransferFlow(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "3\r") // wire transfer
|
||||
|
||||
if m.screen != screenTransfer {
|
||||
t.Fatalf("expected screenTransfer, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "WIRE TRANSFER") {
|
||||
t.Error("should show WIRE TRANSFER header")
|
||||
}
|
||||
|
||||
// Fill all fields.
|
||||
sendKeys(m, "021000021\r") // routing
|
||||
sendKeys(m, "9876543210\r") // dest account
|
||||
sendKeys(m, "JOHN DOE\r") // beneficiary
|
||||
sendKeys(m, "FIRST NATIONAL BANK\r") // bank name
|
||||
sendKeys(m, "50000\r") // amount
|
||||
sendKeys(m, "INVOICE 12345\r") // memo
|
||||
|
||||
// Should be on confirm step.
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "TRANSFER SUMMARY") {
|
||||
t.Error("should show TRANSFER SUMMARY for confirmation")
|
||||
}
|
||||
if !strings.Contains(view, "021000021") {
|
||||
t.Error("summary should show routing number")
|
||||
}
|
||||
|
||||
// Confirm.
|
||||
sendKeys(m, "Y\r")
|
||||
|
||||
// Auth code.
|
||||
sendKeys(m, "AUTH99\r")
|
||||
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "TRANSFER QUEUED") {
|
||||
t.Error("should show TRANSFER QUEUED confirmation")
|
||||
}
|
||||
|
||||
// Check wire transfer was stored.
|
||||
if len(m.state.Transfers) != 1 {
|
||||
t.Fatalf("expected 1 transfer, got %d", len(m.state.Transfers))
|
||||
}
|
||||
wt := m.state.Transfers[0]
|
||||
if wt.RoutingNumber != "021000021" {
|
||||
t.Errorf("routing = %q, want %q", wt.RoutingNumber, "021000021")
|
||||
}
|
||||
if wt.DestAccount != "9876543210" {
|
||||
t.Errorf("dest = %q, want %q", wt.DestAccount, "9876543210")
|
||||
}
|
||||
if wt.Beneficiary != "JOHN DOE" {
|
||||
t.Errorf("beneficiary = %q, want %q", wt.Beneficiary, "JOHN DOE")
|
||||
}
|
||||
|
||||
// Press key to return to menu.
|
||||
sendKeys(m, " ")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu after transfer, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireTransferCancel(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "3\r") // wire transfer
|
||||
sendKeys(m, "021000021\r")
|
||||
sendKeys(m, "9876543210\r")
|
||||
sendKeys(m, "JOHN DOE\r")
|
||||
sendKeys(m, "FIRST NATIONAL BANK\r")
|
||||
sendKeys(m, "50000\r")
|
||||
sendKeys(m, "INVOICE 12345\r")
|
||||
|
||||
// Cancel at confirm step.
|
||||
sendKeys(m, "N\r")
|
||||
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu after cancel, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransactionHistory(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "4\r") // transaction history
|
||||
|
||||
if m.screen != screenHistory {
|
||||
t.Fatalf("expected screenHistory, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "TRANSACTION HISTORY") {
|
||||
t.Error("should show TRANSACTION HISTORY header")
|
||||
}
|
||||
|
||||
// Select first account.
|
||||
sendKeys(m, "1\r")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "DATE") {
|
||||
t.Error("should show transaction list with DATE column")
|
||||
}
|
||||
if !strings.Contains(view, "PAGE") {
|
||||
t.Error("should show page indicator")
|
||||
}
|
||||
|
||||
// Press B to go back to account list.
|
||||
sendKeys(m, "B")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "SELECT ACCOUNT") {
|
||||
t.Error("should return to account selection")
|
||||
}
|
||||
|
||||
// Press 0 to return to menu.
|
||||
sendKeys(m, "0\r")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureMessages(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "5\r") // secure messages
|
||||
|
||||
if m.screen != screenMessages {
|
||||
t.Fatalf("expected screenMessages, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "SECURE MESSAGES") {
|
||||
t.Error("should show SECURE MESSAGES header")
|
||||
}
|
||||
if !strings.Contains(view, "SCHEDULED MAINTENANCE") {
|
||||
t.Error("should show first message subject")
|
||||
}
|
||||
|
||||
// View first message.
|
||||
sendKeys(m, "1\r")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "10.48.2.100") {
|
||||
t.Error("message body should contain breadcrumb IP")
|
||||
}
|
||||
|
||||
// Press key to return to list.
|
||||
sendKeys(m, " ")
|
||||
|
||||
// Return to menu.
|
||||
sendKeys(m, "0\r")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminAccessDenied(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "99\r") // admin
|
||||
|
||||
if m.screen != screenAdmin {
|
||||
t.Fatalf("expected screenAdmin, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "SYSTEM ADMINISTRATION") {
|
||||
t.Error("should show SYSTEM ADMINISTRATION header")
|
||||
}
|
||||
|
||||
// Three failed PIN attempts.
|
||||
sendKeys(m, "secret1\r")
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "INVALID CREDENTIALS") {
|
||||
t.Error("should show INVALID CREDENTIALS after first attempt")
|
||||
}
|
||||
|
||||
sendKeys(m, "secret2\r")
|
||||
sendKeys(m, "secret3\r")
|
||||
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "ACCESS DENIED") {
|
||||
t.Error("should show ACCESS DENIED after 3 attempts")
|
||||
}
|
||||
if !strings.Contains(view, "ABEND S0C4") {
|
||||
t.Error("should show COBOL-style error")
|
||||
}
|
||||
|
||||
// Press key to return to menu.
|
||||
sendKeys(m, " ")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu after lockout, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminEscapeReturns(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "99\r") // admin
|
||||
sendKeys(m, "\x1b") // ESC to return
|
||||
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("ESC should return to menu from admin, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangePin(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "6\r") // change PIN
|
||||
|
||||
if m.screen != screenChangePin {
|
||||
t.Fatalf("expected screenChangePin, got %d", m.screen)
|
||||
}
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "CHANGE PIN") {
|
||||
t.Error("should show CHANGE PIN header")
|
||||
}
|
||||
|
||||
// Old PIN.
|
||||
sendKeys(m, "1234\r")
|
||||
// New PIN.
|
||||
sendKeys(m, "5678\r")
|
||||
// Confirm PIN.
|
||||
sendKeys(m, "5678\r")
|
||||
|
||||
view = m.View()
|
||||
if !strings.Contains(view, "PIN CHANGED SUCCESSFULLY") {
|
||||
t.Error("should show PIN CHANGED SUCCESSFULLY")
|
||||
}
|
||||
|
||||
// Press key to return.
|
||||
sendKeys(m, " ")
|
||||
if m.screen != screenMenu {
|
||||
t.Errorf("should return to menu after PIN change, got screen %d", m.screen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogoutExits(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKeys(m, "12345678\r1234\r") // login
|
||||
sendKeys(m, "7\r") // logout
|
||||
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after logout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
m, store := newTestModel(t)
|
||||
|
||||
// Send login keys and manually run returned commands.
|
||||
// Type account number.
|
||||
sendKeys(m, "12345678")
|
||||
// Enter to advance to PIN.
|
||||
sendKeys(m, "\r")
|
||||
// Type PIN.
|
||||
sendKeys(m, "1234")
|
||||
// Enter to login — this returns a logAction cmd.
|
||||
_, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
if cmd != nil {
|
||||
// Execute the batch of commands (login log).
|
||||
execCmds(cmd)
|
||||
}
|
||||
|
||||
// Give async store writes a moment.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if len(store.SessionLogs) == 0 {
|
||||
t.Error("expected session logs to be recorded after login")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, log := range store.SessionLogs {
|
||||
if strings.Contains(log.Input, "LOGIN") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected a LOGIN entry in session logs")
|
||||
}
|
||||
|
||||
// Navigate to account summary — also returns a logAction cmd.
|
||||
sendKeys(m, "1")
|
||||
_, cmd = m.Update(tea.KeyMsg{Type: tea.KeyEnter})
|
||||
if cmd != nil {
|
||||
execCmds(cmd)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
foundMenu := false
|
||||
for _, log := range store.SessionLogs {
|
||||
if strings.Contains(log.Input, "MENU") {
|
||||
foundMenu = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundMenu {
|
||||
t.Error("expected a MENU entry in session logs for account summary")
|
||||
}
|
||||
}
|
||||
|
||||
// execCmds recursively executes tea.Cmd functions (including batches).
|
||||
func execCmds(cmd tea.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
msg := cmd()
|
||||
// tea.BatchMsg is a slice of Cmds returned by tea.Batch.
|
||||
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||
for _, c := range batch {
|
||||
execCmds(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"bank_name": "TESTBANK",
|
||||
"region": "SOUTHWEST",
|
||||
}
|
||||
if got := configString(cfg, "bank_name", "DEFAULT"); got != "TESTBANK" {
|
||||
t.Errorf("configString() = %q, want %q", got, "TESTBANK")
|
||||
}
|
||||
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
|
||||
t.Errorf("configString() = %q, want %q", got, "DEFAULT")
|
||||
}
|
||||
if got := configString(nil, "bank_name", "DEFAULT"); got != "DEFAULT" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScreenFrame(t *testing.T) {
|
||||
frame := screenFrame("TESTBANK", "TB-0001", "NORTHEAST", "content here", 0)
|
||||
if !strings.Contains(frame, "TESTBANK FEDERAL RESERVE SYSTEM") {
|
||||
t.Error("frame should contain bank name in header")
|
||||
}
|
||||
if !strings.Contains(frame, "TB-0001") {
|
||||
t.Error("frame should contain terminal ID in footer")
|
||||
}
|
||||
if !strings.Contains(frame, "content here") {
|
||||
t.Error("frame should contain the content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestScreenFramePadsLines(t *testing.T) {
|
||||
frame := screenFrame("TESTBANK", "TB-0001", "NE", "short\n", 0)
|
||||
for i, line := range strings.Split(frame, "\n") {
|
||||
w := lipgloss.Width(line)
|
||||
if w > 0 && w < termWidth {
|
||||
t.Errorf("line %d has visual width %d, want at least %d: %q", i, w, termWidth, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestScreenFramePadsToHeight(t *testing.T) {
|
||||
short := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 30)
|
||||
lines := strings.Count(short, "\n")
|
||||
// Total newlines should be at least height-1 (since the last line has no trailing newline).
|
||||
if lines < 29 {
|
||||
t.Errorf("padded frame has %d newlines, want at least 29 for height=30", lines)
|
||||
}
|
||||
|
||||
// Without height, no padding.
|
||||
noPad := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 0)
|
||||
noPadLines := strings.Count(noPad, "\n")
|
||||
if noPadLines >= 29 {
|
||||
t.Errorf("unpadded frame has %d newlines, should be much less than 29", noPadLines)
|
||||
}
|
||||
}
|
||||
258
internal/shell/banking/data.go
Normal file
258
internal/shell/banking/data.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Account types.
|
||||
const (
|
||||
AcctChecking = "CHECKING"
|
||||
AcctSavings = "SAVINGS"
|
||||
AcctMoneyMarket = "MONEY MARKET"
|
||||
AcctCertDeposit = "CERT OF DEPOSIT"
|
||||
)
|
||||
|
||||
// Account represents a fake bank account.
|
||||
type Account struct {
|
||||
Number string
|
||||
Type string
|
||||
Balance int64 // cents
|
||||
}
|
||||
|
||||
// Transaction represents a fake bank transaction.
|
||||
type Transaction struct {
|
||||
Date string
|
||||
Description string
|
||||
Amount int64 // cents (negative for debits)
|
||||
Balance int64 // running balance in cents
|
||||
}
|
||||
|
||||
// SecureMessage represents a fake internal message.
|
||||
type SecureMessage struct {
|
||||
ID int
|
||||
Date string
|
||||
From string
|
||||
Subj string
|
||||
Body string
|
||||
Unread bool
|
||||
}
|
||||
|
||||
// WireTransfer captures data entered during the wire transfer wizard.
|
||||
type WireTransfer struct {
|
||||
RoutingNumber string
|
||||
DestAccount string
|
||||
Beneficiary string
|
||||
BankName string
|
||||
Amount string
|
||||
Memo string
|
||||
AuthCode string
|
||||
}
|
||||
|
||||
// bankState holds all fake data for a session.
|
||||
type bankState struct {
|
||||
Accounts []Account
|
||||
Transactions map[string][]Transaction // keyed by account number
|
||||
Messages []SecureMessage
|
||||
Transfers []WireTransfer
|
||||
}
|
||||
|
||||
func newBankState() *bankState {
|
||||
now := time.Now()
|
||||
|
||||
accounts := []Account{
|
||||
{Number: "****4821", Type: AcctChecking, Balance: 4738291},
|
||||
{Number: "****7203", Type: AcctSavings, Balance: 18254100},
|
||||
{Number: "****9915", Type: AcctMoneyMarket, Balance: 52387450},
|
||||
{Number: "****1102", Type: AcctCertDeposit, Balance: 25000000},
|
||||
}
|
||||
|
||||
transactions := make(map[string][]Transaction)
|
||||
transactions["****4821"] = generateCheckingTxns(now, accounts[0].Balance)
|
||||
transactions["****7203"] = generateSavingsTxns(now, accounts[1].Balance)
|
||||
transactions["****9915"] = generateMoneyMarketTxns(now, accounts[2].Balance)
|
||||
transactions["****1102"] = generateCDTxns(now, accounts[3].Balance)
|
||||
|
||||
messages := []SecureMessage{
|
||||
{
|
||||
ID: 1,
|
||||
Date: now.Add(-2 * 24 * time.Hour).Format("01/02/2006"),
|
||||
From: "SYSTEM ADMINISTRATOR",
|
||||
Subj: "SCHEDULED MAINTENANCE WINDOW",
|
||||
Unread: true,
|
||||
Body: fmt.Sprintf(`FROM: SYSTEM ADMINISTRATOR <sysadmin@internal.securebank.local>
|
||||
DATE: %s
|
||||
RE: SCHEDULED MAINTENANCE WINDOW
|
||||
|
||||
ALL TERMINALS WILL BE OFFLINE FOR MAINTENANCE:
|
||||
DATE: %s
|
||||
TIME: 02:00 - 04:00 EST
|
||||
AFFECTED: ALL REGIONS
|
||||
|
||||
DURING THIS WINDOW, THE FOLLOWING SYSTEMS WILL BE UNAVAILABLE:
|
||||
- WIRE TRANSFER PROCESSING (10.48.2.100:8443)
|
||||
- ACCOUNT MANAGEMENT (10.48.2.101:8443)
|
||||
- ACH BATCH PROCESSOR (10.48.2.105:9090)
|
||||
|
||||
PLEASE ENSURE ALL PENDING TRANSACTIONS ARE SUBMITTED BEFORE 01:30 EST.
|
||||
|
||||
CONTACT: HELPDESK EXT 4400 OR ops-support@internal.securebank.local`,
|
||||
now.Add(-2*24*time.Hour).Format("01/02/2006 15:04"),
|
||||
now.Add(5*24*time.Hour).Format("01/02/2006")),
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Date: now.Add(-5 * 24 * time.Hour).Format("01/02/2006"),
|
||||
From: "COMPLIANCE DEPT",
|
||||
Subj: "QUARTERLY AUDIT REMINDER",
|
||||
Unread: true,
|
||||
Body: `FROM: COMPLIANCE DEPT <compliance@internal.securebank.local>
|
||||
RE: QUARTERLY AUDIT REMINDER
|
||||
|
||||
ALL BRANCH MANAGERS:
|
||||
|
||||
THE Q4 COMPLIANCE AUDIT IS SCHEDULED FOR NEXT WEEK.
|
||||
PLEASE ENSURE THE FOLLOWING ARE CURRENT:
|
||||
|
||||
1. TRANSACTION LOGS EXPORTED TO \\FILESERV01\AUDIT\Q4
|
||||
2. VAULT ACCESS CODES ROTATED (LAST ROTATION: SEE VAULT-MGMT PORTAL)
|
||||
3. EMPLOYEE ACCESS REVIEWS COMPLETED IN IAM PORTAL (https://iam.internal:8443)
|
||||
|
||||
NOTE: DEFAULT CREDENTIALS FOR THE AUDIT PORTAL HAVE BEEN RESET.
|
||||
NEW CREDENTIALS DISTRIBUTED VIA SECURE COURIER.
|
||||
REFERENCE: AUDIT-2024-Q4-0847
|
||||
|
||||
VAULT MASTER CODE HINT: FIRST 4 OF ROUTING + BRANCH ZIP (STANDARD FORMAT)`,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Date: now.Add(-8 * 24 * time.Hour).Format("01/02/2006"),
|
||||
From: "IT SECURITY",
|
||||
Subj: "PASSWORD POLICY UPDATE",
|
||||
Unread: false,
|
||||
Body: `FROM: IT SECURITY <itsec@internal.securebank.local>
|
||||
RE: PASSWORD POLICY UPDATE - EFFECTIVE IMMEDIATELY
|
||||
|
||||
ALL STAFF:
|
||||
|
||||
PER FEDERAL BANKING REGULATION 12 CFR 748, THE FOLLOWING
|
||||
PASSWORD POLICY IS NOW IN EFFECT:
|
||||
|
||||
- MINIMUM 12 CHARACTERS
|
||||
- MUST CONTAIN UPPERCASE, LOWERCASE, NUMBER, SPECIAL CHAR
|
||||
- 90-DAY ROTATION CYCLE
|
||||
- NO REUSE OF LAST 24 PASSWORDS
|
||||
|
||||
LEGACY SYSTEM ACCOUNTS (MAINFRAME, AS/400) ARE EXEMPT UNTIL
|
||||
MIGRATION IS COMPLETE. CURRENT LEGACY ACCESS:
|
||||
MAINFRAME: telnet://10.48.1.50:23 (CICS REGION PROD1)
|
||||
AS/400: tn5250://10.48.1.55 (SUBSYSTEM QINTER)
|
||||
|
||||
SERVICE ACCOUNT PASSWORDS ARE MANAGED VIA CYBERARK:
|
||||
https://pam.internal.securebank.local:8443
|
||||
|
||||
TICKET: SEC-2024-1847`,
|
||||
},
|
||||
{
|
||||
ID: 4,
|
||||
Date: now.Add(-12 * 24 * time.Hour).Format("01/02/2006"),
|
||||
From: "WIRE OPERATIONS",
|
||||
Subj: "FEDWIRE CUTOFF TIME CHANGE",
|
||||
Unread: false,
|
||||
Body: `FROM: WIRE OPERATIONS <wireops@internal.securebank.local>
|
||||
RE: FEDWIRE CUTOFF TIME CHANGE
|
||||
|
||||
EFFECTIVE NEXT MONDAY, FEDWIRE CUTOFF TIMES ARE:
|
||||
DOMESTIC WIRES: 16:30 EST (WAS 17:00)
|
||||
INTERNATIONAL WIRES: 14:00 EST (NO CHANGE)
|
||||
BOOK TRANSFERS: 17:30 EST (NO CHANGE)
|
||||
|
||||
WIRES SUBMITTED AFTER CUTOFF WILL BE QUEUED FOR NEXT
|
||||
BUSINESS DAY PROCESSING.
|
||||
|
||||
FOR EMERGENCY SAME-DAY PROCESSING AFTER CUTOFF:
|
||||
CONTACT WIRE ROOM: EXT 4450
|
||||
AUTH CODE REQUIRED (OBTAIN FROM BRANCH MANAGER)
|
||||
APPROVAL CHAIN: OPS-MGR -> VP-WIRE -> SVP-TREASURY
|
||||
|
||||
CORRESPONDENT BANK CONTACTS:
|
||||
JPMORGAN: wire.ops@jpmc.com / 212-555-0147
|
||||
CITI: fedwire@citi.com / 212-555-0283`,
|
||||
},
|
||||
}
|
||||
|
||||
return &bankState{
|
||||
Accounts: accounts,
|
||||
Transactions: transactions,
|
||||
Messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func generateCheckingTxns(now time.Time, endBalance int64) []Transaction {
|
||||
txns := []Transaction{
|
||||
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
|
||||
{Description: "CHECK #1847", Amount: -125000},
|
||||
{Description: "POS DEBIT - WHOLE FOODS #1284", Amount: -18743},
|
||||
{Description: "ATM WITHDRAWAL - MAIN ST BRANCH", Amount: -40000},
|
||||
{Description: "ACH DEBIT - MORTGAGE PMT", Amount: -215000},
|
||||
{Description: "WIRE TRANSFER IN - REF#8847201", Amount: 1250000},
|
||||
{Description: "POS DEBIT - SHELL OIL #4492", Amount: -6821},
|
||||
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
|
||||
{Description: "CHECK #1848", Amount: -75000},
|
||||
{Description: "ONLINE TRANSFER TO SAVINGS", Amount: -100000},
|
||||
{Description: "POS DEBIT - AMAZON.COM", Amount: -14599},
|
||||
{Description: "ACH DEBIT - ELECTRIC COMPANY", Amount: -18742},
|
||||
{Description: "ATM WITHDRAWAL - PARK AVE BRANCH", Amount: -20000},
|
||||
{Description: "WIRE TRANSFER OUT - REF#9014882", Amount: -500000},
|
||||
{Description: "POS DEBIT - COSTCO #0441", Amount: -28734},
|
||||
{Description: "ACH DEPOSIT - TAX REFUND", Amount: 342100},
|
||||
}
|
||||
return populateTransactions(txns, now, endBalance)
|
||||
}
|
||||
|
||||
func generateSavingsTxns(now time.Time, endBalance int64) []Transaction {
|
||||
txns := []Transaction{
|
||||
{Description: "INTEREST PAYMENT", Amount: 4521},
|
||||
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 100000},
|
||||
{Description: "INTEREST PAYMENT", Amount: 4633},
|
||||
{Description: "ACH DEPOSIT - DIVIDEND PMT", Amount: 125000},
|
||||
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 200000},
|
||||
{Description: "INTEREST PAYMENT", Amount: 4748},
|
||||
{Description: "WITHDRAWAL - TRANSFER TO MM", Amount: -500000},
|
||||
{Description: "INTEREST PAYMENT", Amount: 4812},
|
||||
}
|
||||
return populateTransactions(txns, now, endBalance)
|
||||
}
|
||||
|
||||
func generateMoneyMarketTxns(now time.Time, endBalance int64) []Transaction {
|
||||
txns := []Transaction{
|
||||
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 21847},
|
||||
{Description: "DEPOSIT - TRANSFER FROM SAVINGS", Amount: 500000},
|
||||
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22105},
|
||||
{Description: "WITHDRAWAL - WIRE TRANSFER", Amount: -1000000},
|
||||
{Description: "DEPOSIT - ACH TRANSFER", Amount: 750000},
|
||||
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22394},
|
||||
}
|
||||
return populateTransactions(txns, now, endBalance)
|
||||
}
|
||||
|
||||
func generateCDTxns(now time.Time, endBalance int64) []Transaction {
|
||||
txns := []Transaction{
|
||||
{Description: "CERTIFICATE OPENED - 12MO TERM", Amount: 25000000},
|
||||
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||
{Description: "INTEREST ACCRUAL", Amount: 10417},
|
||||
}
|
||||
return populateTransactions(txns, now, endBalance)
|
||||
}
|
||||
|
||||
func populateTransactions(txns []Transaction, now time.Time, endBalance int64) []Transaction {
|
||||
// Work backwards from end balance to assign dates and running balances.
|
||||
bal := endBalance
|
||||
for i := len(txns) - 1; i >= 0; i-- {
|
||||
txns[i].Balance = bal
|
||||
txns[i].Date = now.Add(time.Duration(-(len(txns) - i)) * 3 * 24 * time.Hour).Format("01/02/2006")
|
||||
bal -= txns[i].Amount
|
||||
}
|
||||
return txns
|
||||
}
|
||||
353
internal/shell/banking/model.go
Normal file
353
internal/shell/banking/model.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
type screen int
|
||||
|
||||
const (
|
||||
screenLogin screen = iota
|
||||
screenMenu
|
||||
screenAccountSummary
|
||||
screenAccountDetail
|
||||
screenTransfer
|
||||
screenHistory
|
||||
screenMessages
|
||||
screenChangePin
|
||||
screenAdmin
|
||||
)
|
||||
|
||||
type model struct {
|
||||
sess *shell.SessionContext
|
||||
bankName string
|
||||
terminalID string
|
||||
region string
|
||||
state *bankState
|
||||
screen screen
|
||||
quitting bool
|
||||
height int
|
||||
|
||||
login loginModel
|
||||
menu menuModel
|
||||
summary accountSummaryModel
|
||||
detail accountDetailModel
|
||||
transfer transferModel
|
||||
history historyModel
|
||||
messages messagesModel
|
||||
admin adminModel
|
||||
changePin changePinModel
|
||||
}
|
||||
|
||||
func newModel(sess *shell.SessionContext, bankName, terminalID, region string) *model {
|
||||
state := newBankState()
|
||||
unread := 0
|
||||
for _, msg := range state.Messages {
|
||||
if msg.Unread {
|
||||
unread++
|
||||
}
|
||||
}
|
||||
|
||||
return &model{
|
||||
sess: sess,
|
||||
bankName: bankName,
|
||||
terminalID: terminalID,
|
||||
region: region,
|
||||
state: state,
|
||||
screen: screenLogin,
|
||||
login: newLoginModel(bankName),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.quitting {
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.height = msg.Height
|
||||
return m, nil
|
||||
case tea.KeyMsg:
|
||||
if msg.Type == tea.KeyCtrlC {
|
||||
m.quitting = true
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
|
||||
switch m.screen {
|
||||
case screenLogin:
|
||||
return m.updateLogin(msg)
|
||||
case screenMenu:
|
||||
return m.updateMenu(msg)
|
||||
case screenAccountSummary:
|
||||
return m.updateAccountSummary(msg)
|
||||
case screenAccountDetail:
|
||||
return m.updateAccountDetail(msg)
|
||||
case screenTransfer:
|
||||
return m.updateTransfer(msg)
|
||||
case screenHistory:
|
||||
return m.updateHistory(msg)
|
||||
case screenMessages:
|
||||
return m.updateMessages(msg)
|
||||
case screenChangePin:
|
||||
return m.updateChangePin(msg)
|
||||
case screenAdmin:
|
||||
return m.updateAdmin(msg)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) View() string {
|
||||
var content string
|
||||
switch m.screen {
|
||||
case screenLogin:
|
||||
content = m.login.View()
|
||||
case screenMenu:
|
||||
content = m.menu.View()
|
||||
case screenAccountSummary:
|
||||
content = m.summary.View()
|
||||
case screenAccountDetail:
|
||||
content = m.detail.View()
|
||||
case screenTransfer:
|
||||
content = m.transfer.View()
|
||||
case screenHistory:
|
||||
content = m.history.View()
|
||||
case screenMessages:
|
||||
content = m.messages.View()
|
||||
case screenChangePin:
|
||||
content = m.changePin.View()
|
||||
case screenAdmin:
|
||||
content = m.admin.View()
|
||||
}
|
||||
|
||||
return screenFrame(m.bankName, m.terminalID, m.region, content, m.height)
|
||||
}
|
||||
|
||||
// --- Screen update handlers ---
|
||||
|
||||
func (m *model) updateLogin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.login, cmd = m.login.Update(msg)
|
||||
|
||||
if m.login.stage == 2 {
|
||||
// Login always succeeds — this is a honeypot.
|
||||
logCmd := logAction(m.sess, fmt.Sprintf("LOGIN acct=%s", m.login.accountNum), "ACCESS GRANTED")
|
||||
clearCmd := m.goToMenu()
|
||||
return m, tea.Batch(cmd, logCmd, clearCmd)
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateMenu(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.menu, cmd = m.menu.Update(msg)
|
||||
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEnter {
|
||||
choice := strings.TrimSpace(m.menu.choice)
|
||||
switch choice {
|
||||
case "1":
|
||||
m.screen = screenAccountSummary
|
||||
m.summary = newAccountSummaryModel(m.state.Accounts)
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 1", "ACCOUNT SUMMARY"))
|
||||
case "2":
|
||||
m.screen = screenAccountDetail
|
||||
m.detail = newAccountDetailModel(m.state.Accounts, m.state.Transactions)
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 2", "ACCOUNT DETAIL"))
|
||||
case "3":
|
||||
m.screen = screenTransfer
|
||||
m.transfer = newTransferModel()
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 3", "WIRE TRANSFER"))
|
||||
case "4":
|
||||
m.screen = screenHistory
|
||||
m.history = newHistoryModel(m.state.Accounts, m.state.Transactions)
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 4", "TRANSACTION HISTORY"))
|
||||
case "5":
|
||||
m.screen = screenMessages
|
||||
m.messages = newMessagesModel(m.state.Messages)
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 5", "SECURE MESSAGES"))
|
||||
case "6":
|
||||
m.screen = screenChangePin
|
||||
m.changePin = newChangePinModel()
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 6", "CHANGE PIN"))
|
||||
case "7":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(logAction(m.sess, "LOGOUT", "SESSION ENDED"), tea.Quit)
|
||||
case "99", "admin", "ADMIN":
|
||||
m.screen = screenAdmin
|
||||
m.admin = newAdminModel()
|
||||
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "ADMIN ACCESS ATTEMPT", "ADMIN SCREEN SHOWN"))
|
||||
}
|
||||
// Invalid choice, reset.
|
||||
m.menu.choice = ""
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateAccountSummary(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if _, ok := msg.(tea.KeyMsg); ok {
|
||||
return m, m.goToMenu()
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) updateAccountDetail(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.detail, cmd = m.detail.Update(msg)
|
||||
if m.detail.choice == "back" {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateTransfer(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
prevStep := m.transfer.step
|
||||
var cmd tea.Cmd
|
||||
m.transfer, cmd = m.transfer.Update(msg)
|
||||
|
||||
// Transfer cancelled.
|
||||
if m.transfer.confirm == "cancelled" {
|
||||
clearCmd := m.goToMenu()
|
||||
return m, tea.Batch(clearCmd, logAction(m.sess, "WIRE TRANSFER CANCELLED", "USER CANCELLED"))
|
||||
}
|
||||
|
||||
// Clear screen when transfer steps change content height significantly
|
||||
// (e.g. confirm→authcode, fields→confirm, authcode→complete).
|
||||
if m.transfer.step != prevStep {
|
||||
cmd = tea.Batch(cmd, tea.ClearScreen)
|
||||
}
|
||||
|
||||
// Transfer completed — log it.
|
||||
if m.transfer.step == transferStepComplete && prevStep != transferStepComplete {
|
||||
t := m.transfer.transfer
|
||||
m.state.Transfers = append(m.state.Transfers, t)
|
||||
logMsg := fmt.Sprintf("WIRE TRANSFER: routing=%s dest=%s beneficiary=%s bank=%s amount=%s memo=%s auth=%s",
|
||||
t.RoutingNumber, t.DestAccount, t.Beneficiary, t.BankName, t.Amount, t.Memo, t.AuthCode)
|
||||
return m, tea.Batch(cmd, logAction(m.sess, logMsg, "TRANSFER QUEUED"))
|
||||
}
|
||||
|
||||
// Completed screen → any key goes back.
|
||||
if m.transfer.step == transferStepComplete {
|
||||
if _, ok := msg.(tea.KeyMsg); ok && prevStep == transferStepComplete {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
}
|
||||
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateHistory(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.history, cmd = m.history.Update(msg)
|
||||
if m.history.choice == "back" {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateMessages(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
m.messages, cmd = m.messages.Update(msg)
|
||||
if m.messages.choice == "back" {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
// Log when viewing a message.
|
||||
if m.messages.viewing >= 0 {
|
||||
idx := m.messages.viewing
|
||||
return m, tea.Batch(cmd, logAction(m.sess,
|
||||
fmt.Sprintf("VIEW MESSAGE #%d", idx+1),
|
||||
m.state.Messages[idx].Subj))
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateChangePin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
prevStage := m.changePin.stage
|
||||
var cmd tea.Cmd
|
||||
m.changePin, cmd = m.changePin.Update(msg)
|
||||
|
||||
// Log successful PIN change.
|
||||
if m.changePin.stage == 3 && prevStage != 3 {
|
||||
cmd = tea.Batch(cmd, logAction(m.sess, "CHANGE PIN", "PIN CHANGED SUCCESSFULLY"))
|
||||
}
|
||||
|
||||
if m.changePin.done {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) updateAdmin(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
prevLocked := m.admin.locked
|
||||
prevAttempts := m.admin.attempts
|
||||
|
||||
// Check for ESC before delegating.
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEscape && !m.admin.locked {
|
||||
return m, m.goToMenu()
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.admin, cmd = m.admin.Update(msg)
|
||||
|
||||
// Log failed attempts.
|
||||
if m.admin.attempts > prevAttempts {
|
||||
cmd = tea.Batch(cmd, logAction(m.sess,
|
||||
fmt.Sprintf("ADMIN PIN ATTEMPT #%d", m.admin.attempts),
|
||||
"INVALID CREDENTIALS"))
|
||||
}
|
||||
|
||||
// Log lockout.
|
||||
if m.admin.locked && !prevLocked {
|
||||
cmd = tea.Batch(cmd, tea.ClearScreen, logAction(m.sess,
|
||||
"ADMIN LOCKOUT",
|
||||
"TERMINAL LOCKED - INCIDENT LOGGED"))
|
||||
}
|
||||
|
||||
// If locked and any key pressed, go back.
|
||||
if m.admin.locked {
|
||||
if _, ok := msg.(tea.KeyMsg); ok && prevLocked {
|
||||
return m, tea.Batch(cmd, m.goToMenu())
|
||||
}
|
||||
}
|
||||
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) goToMenu() tea.Cmd {
|
||||
unread := 0
|
||||
for _, msg := range m.state.Messages {
|
||||
if msg.Unread {
|
||||
unread++
|
||||
}
|
||||
}
|
||||
m.screen = screenMenu
|
||||
m.menu = newMenuModel(m.bankName, unread)
|
||||
return tea.ClearScreen
|
||||
}
|
||||
|
||||
// logAction returns a tea.Cmd that logs an action to the session store.
|
||||
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if sess.Store != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("banking")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
213
internal/shell/banking/screen_accounts.go
Normal file
213
internal/shell/banking/screen_accounts.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// --- Account Summary ---
|
||||
|
||||
type accountSummaryModel struct {
|
||||
accounts []Account
|
||||
}
|
||||
|
||||
func newAccountSummaryModel(accounts []Account) accountSummaryModel {
|
||||
return accountSummaryModel{accounts: accounts}
|
||||
}
|
||||
|
||||
func (m accountSummaryModel) Update(_ tea.Msg) (accountSummaryModel, tea.Cmd) {
|
||||
// Any key returns to menu.
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m accountSummaryModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("ACCOUNT SUMMARY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Header.
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "ACCOUNT", "TYPE", "BALANCE")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
|
||||
b.WriteString("\n")
|
||||
|
||||
total := int64(0)
|
||||
for _, acct := range m.accounts {
|
||||
total += acct.Balance
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-18s %18s",
|
||||
acct.Number, acct.Type, formatCurrency(acct.Balance))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "", "TOTAL", formatCurrency(total))))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// --- Account Detail (with transactions) ---
|
||||
|
||||
type accountDetailModel struct {
|
||||
accounts []Account
|
||||
transactions map[string][]Transaction
|
||||
selected int
|
||||
page int
|
||||
pageSize int
|
||||
choosing bool
|
||||
choice string
|
||||
}
|
||||
|
||||
func newAccountDetailModel(accounts []Account, transactions map[string][]Transaction) accountDetailModel {
|
||||
return accountDetailModel{
|
||||
accounts: accounts,
|
||||
transactions: transactions,
|
||||
choosing: true,
|
||||
pageSize: 10,
|
||||
}
|
||||
}
|
||||
|
||||
func (m accountDetailModel) Update(msg tea.Msg) (accountDetailModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.choosing {
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
for i := range m.accounts {
|
||||
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||
m.selected = i
|
||||
m.choosing = false
|
||||
m.page = 0
|
||||
break
|
||||
}
|
||||
}
|
||||
if m.choice == "0" {
|
||||
m.choice = "back"
|
||||
}
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.choice) > 0 {
|
||||
m.choice = m.choice[:len(m.choice)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
|
||||
m.choice += ch
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch keyMsg.String() {
|
||||
case "n", "N":
|
||||
acctNum := m.accounts[m.selected].Number
|
||||
txns := m.transactions[acctNum]
|
||||
maxPage := (len(txns) - 1) / m.pageSize
|
||||
if m.page < maxPage {
|
||||
m.page++
|
||||
}
|
||||
case "p", "P":
|
||||
if m.page > 0 {
|
||||
m.page--
|
||||
}
|
||||
case "b", "B":
|
||||
m.choosing = true
|
||||
m.choice = ""
|
||||
default:
|
||||
m.choice = "back"
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m accountDetailModel) View() string {
|
||||
if m.choosing {
|
||||
return m.viewChooseAccount()
|
||||
}
|
||||
return m.viewDetail()
|
||||
}
|
||||
|
||||
func (m accountDetailModel) viewChooseAccount() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("ACCOUNT DETAIL"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
for i, acct := range m.accounts {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s %s",
|
||||
i+1, acct.Number, acct.Type, formatCurrency(acct.Balance))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||
b.WriteString(inputStyle.Render(m.choice))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m accountDetailModel) viewDetail() string {
|
||||
var b strings.Builder
|
||||
|
||||
acct := m.accounts[m.selected]
|
||||
txns := m.transactions[acct.Number]
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(fmt.Sprintf("ACCOUNT DETAIL - %s", acct.Number)))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" TYPE: %s", acct.Type)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" BALANCE: %s", formatCurrency(acct.Balance))))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Header.
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s", "DATE", "DESCRIPTION", "AMOUNT", "BALANCE")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 72)))
|
||||
b.WriteString("\n")
|
||||
|
||||
// Paginate.
|
||||
start := m.page * m.pageSize
|
||||
end := min(start+m.pageSize, len(txns))
|
||||
|
||||
for _, txn := range txns[start:end] {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s",
|
||||
txn.Date, txn.Description, formatCurrency(txn.Amount), formatCurrency(txn.Balance))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
|
||||
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
111
internal/shell/banking/screen_admin.go
Normal file
111
internal/shell/banking/screen_admin.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type adminModel struct {
|
||||
pin string
|
||||
attempts int
|
||||
locked bool
|
||||
}
|
||||
|
||||
func newAdminModel() adminModel {
|
||||
return adminModel{}
|
||||
}
|
||||
|
||||
func (m adminModel) Update(msg tea.Msg) (adminModel, tea.Cmd) {
|
||||
if m.locked {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
if m.pin != "" {
|
||||
m.attempts++
|
||||
if m.attempts >= 3 {
|
||||
m.locked = true
|
||||
}
|
||||
m.pin = ""
|
||||
}
|
||||
case tea.KeyBackspace:
|
||||
if len(m.pin) > 0 {
|
||||
m.pin = m.pin[:len(m.pin)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.pin) < 20 {
|
||||
m.pin += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m adminModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("SYSTEM ADMINISTRATION"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.locked {
|
||||
b.WriteString(errorStyle.Render(" *** ACCESS DENIED ***"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(errorStyle.Render(" MAXIMUM AUTHENTICATION ATTEMPTS EXCEEDED"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(errorStyle.Render(" TERMINAL LOCKED - INCIDENT LOGGED"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" SECURITY ALERT HAS BEEN DISPATCHED TO:"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" - INFORMATION SECURITY DEPT"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" - BRANCH SECURITY OFFICER"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" - FEDERAL RESERVE OVERSIGHT"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" INCIDENT REF: SEC-%d-ADMIN-BRUTE", 20240000+m.attempts)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" IEF4271I UNAUTHORIZED ACCESS ATTEMPT - ABEND S0C4"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" IEF4272I JOB SECADMIN STEP0001 - COND CODE 4088"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
} else {
|
||||
b.WriteString(titleStyle.Render(" RESTRICTED ACCESS - ADMINISTRATOR ONLY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" THIS FUNCTION REQUIRES LEVEL 5 SECURITY CLEARANCE."))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" ALL ACCESS ATTEMPTS ARE LOGGED AND AUDITED."))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.attempts > 0 {
|
||||
b.WriteString(errorStyle.Render(fmt.Sprintf(" INVALID CREDENTIALS (%d OF 3 ATTEMPTS)", m.attempts)))
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
|
||||
b.WriteString(titleStyle.Render(" ADMIN PIN: "))
|
||||
masked := strings.Repeat("*", len(m.pin))
|
||||
b.WriteString(inputStyle.Render(masked))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
111
internal/shell/banking/screen_changepin.go
Normal file
111
internal/shell/banking/screen_changepin.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type changePinModel struct {
|
||||
input string
|
||||
stage int // 0=old, 1=new, 2=confirm, 3=done
|
||||
newPin string
|
||||
done bool
|
||||
}
|
||||
|
||||
func newChangePinModel() changePinModel {
|
||||
return changePinModel{}
|
||||
}
|
||||
|
||||
func (m changePinModel) Update(msg tea.Msg) (changePinModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.stage == 3 {
|
||||
m.done = true
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
switch m.stage {
|
||||
case 0:
|
||||
if m.input != "" {
|
||||
m.stage = 1
|
||||
m.input = ""
|
||||
}
|
||||
case 1:
|
||||
if len(m.input) >= 4 {
|
||||
m.newPin = m.input
|
||||
m.stage = 2
|
||||
m.input = ""
|
||||
}
|
||||
case 2:
|
||||
if m.input == m.newPin {
|
||||
m.stage = 3
|
||||
} else {
|
||||
m.input = ""
|
||||
m.newPin = ""
|
||||
m.stage = 1
|
||||
}
|
||||
}
|
||||
case tea.KeyEscape:
|
||||
m.done = true
|
||||
case tea.KeyBackspace:
|
||||
if len(m.input) > 0 {
|
||||
m.input = m.input[:len(m.input)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.input) < 12 {
|
||||
m.input += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m changePinModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("CHANGE PIN"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.stage == 3 {
|
||||
b.WriteString(titleStyle.Render(" PIN CHANGED SUCCESSFULLY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" YOUR NEW PIN IS NOW ACTIVE."))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" PLEASE USE YOUR NEW PIN FOR ALL FUTURE TRANSACTIONS."))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||
} else {
|
||||
prompts := []string{" CURRENT PIN: ", " NEW PIN: ", " CONFIRM PIN: "}
|
||||
for i := 0; i < m.stage; i++ {
|
||||
b.WriteString(baseStyle.Render(prompts[i]))
|
||||
b.WriteString(baseStyle.Render(strings.Repeat("*", 4)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
if m.stage < 3 {
|
||||
b.WriteString(titleStyle.Render(prompts[m.stage]))
|
||||
masked := strings.Repeat("*", len(m.input))
|
||||
b.WriteString(inputStyle.Render(masked))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
if m.stage == 1 {
|
||||
b.WriteString(dimStyle.Render(" PIN MUST BE AT LEAST 4 CHARACTERS"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
152
internal/shell/banking/screen_history.go
Normal file
152
internal/shell/banking/screen_history.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type historyModel struct {
|
||||
accounts []Account
|
||||
transactions map[string][]Transaction
|
||||
selected int
|
||||
page int
|
||||
pageSize int
|
||||
choosing bool
|
||||
choice string
|
||||
}
|
||||
|
||||
func newHistoryModel(accounts []Account, transactions map[string][]Transaction) historyModel {
|
||||
return historyModel{
|
||||
accounts: accounts,
|
||||
transactions: transactions,
|
||||
choosing: true,
|
||||
pageSize: 12,
|
||||
}
|
||||
}
|
||||
|
||||
func (m historyModel) Update(msg tea.Msg) (historyModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.choosing {
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
for i := range m.accounts {
|
||||
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||
m.selected = i
|
||||
m.choosing = false
|
||||
m.page = 0
|
||||
break
|
||||
}
|
||||
}
|
||||
if m.choice == "0" {
|
||||
m.choice = "back"
|
||||
}
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.choice) > 0 {
|
||||
m.choice = m.choice[:len(m.choice)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
|
||||
m.choice += ch
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch keyMsg.String() {
|
||||
case "n", "N":
|
||||
acctNum := m.accounts[m.selected].Number
|
||||
txns := m.transactions[acctNum]
|
||||
maxPage := (len(txns) - 1) / m.pageSize
|
||||
if m.page < maxPage {
|
||||
m.page++
|
||||
}
|
||||
case "p", "P":
|
||||
if m.page > 0 {
|
||||
m.page--
|
||||
}
|
||||
case "b", "B":
|
||||
m.choosing = true
|
||||
m.choice = ""
|
||||
default:
|
||||
m.choice = "back"
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m historyModel) View() string {
|
||||
if m.choosing {
|
||||
return m.viewChooseAccount()
|
||||
}
|
||||
return m.viewHistory()
|
||||
}
|
||||
|
||||
func (m historyModel) viewChooseAccount() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("TRANSACTION HISTORY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
for i, acct := range m.accounts {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s",
|
||||
i+1, acct.Number, acct.Type)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||
b.WriteString(inputStyle.Render(m.choice))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m historyModel) viewHistory() string {
|
||||
var b strings.Builder
|
||||
|
||||
acct := m.accounts[m.selected]
|
||||
txns := m.transactions[acct.Number]
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(fmt.Sprintf("TRANSACTION HISTORY - %s (%s)", acct.Number, acct.Type)))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-40s %14s", "DATE", "DESCRIPTION", "AMOUNT")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
|
||||
b.WriteString("\n")
|
||||
|
||||
start := m.page * m.pageSize
|
||||
end := min(start+m.pageSize, len(txns))
|
||||
|
||||
for _, txn := range txns[start:end] {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-40s %14s",
|
||||
txn.Date, txn.Description, formatCurrency(txn.Amount))))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
|
||||
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
103
internal/shell/banking/screen_login.go
Normal file
103
internal/shell/banking/screen_login.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type loginModel struct {
|
||||
accountNum string
|
||||
pin string
|
||||
stage int // 0 = account, 1 = pin, 2 = authenticating
|
||||
bankName string
|
||||
}
|
||||
|
||||
func newLoginModel(bankName string) loginModel {
|
||||
return loginModel{bankName: bankName}
|
||||
}
|
||||
|
||||
func (m loginModel) Update(msg tea.Msg) (loginModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
if m.stage == 0 && m.accountNum != "" {
|
||||
m.stage = 1
|
||||
return m, nil
|
||||
}
|
||||
if m.stage == 1 && m.pin != "" {
|
||||
m.stage = 2
|
||||
return m, nil
|
||||
}
|
||||
case tea.KeyBackspace:
|
||||
if m.stage == 0 && len(m.accountNum) > 0 {
|
||||
m.accountNum = m.accountNum[:len(m.accountNum)-1]
|
||||
} else if m.stage == 1 && len(m.pin) > 0 {
|
||||
m.pin = m.pin[:len(m.pin)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
|
||||
if m.stage == 0 && len(m.accountNum) < 20 {
|
||||
m.accountNum += ch
|
||||
} else if m.stage == 1 && len(m.pin) < 12 {
|
||||
m.pin += ch
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m loginModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(m.bankName + " ONLINE BANKING"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("AUTHORIZED ACCESS ONLY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(titleStyle.Render(" ACCOUNT NUMBER: "))
|
||||
if m.stage == 0 {
|
||||
b.WriteString(inputStyle.Render(m.accountNum))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
} else {
|
||||
b.WriteString(baseStyle.Render(m.accountNum))
|
||||
}
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.stage >= 1 {
|
||||
b.WriteString(titleStyle.Render(" PIN: "))
|
||||
masked := strings.Repeat("*", len(m.pin))
|
||||
if m.stage == 1 {
|
||||
b.WriteString(inputStyle.Render(masked))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
} else {
|
||||
b.WriteString(baseStyle.Render(masked))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
if m.stage == 2 {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" AUTHENTICATING..."))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" WARNING: UNAUTHORIZED ACCESS TO THIS SYSTEM IS A FEDERAL CRIME"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(fmt.Sprintf(" UNDER 18 U.S.C. %s 1030. ALL ACTIVITY IS MONITORED AND LOGGED.", "\u00A7")))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
80
internal/shell/banking/screen_menu.go
Normal file
80
internal/shell/banking/screen_menu.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type menuModel struct {
|
||||
choice string
|
||||
unread int
|
||||
bankName string
|
||||
}
|
||||
|
||||
func newMenuModel(bankName string, unreadCount int) menuModel {
|
||||
return menuModel{bankName: bankName, unread: unreadCount}
|
||||
}
|
||||
|
||||
func (m menuModel) Update(msg tea.Msg) (menuModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.choice) > 0 {
|
||||
m.choice = m.choice[:len(m.choice)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
|
||||
if len(m.choice) < 10 {
|
||||
m.choice += ch
|
||||
}
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m menuModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("MAIN MENU"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
items := []struct {
|
||||
num string
|
||||
desc string
|
||||
}{
|
||||
{"1", "ACCOUNT SUMMARY"},
|
||||
{"2", "ACCOUNT DETAIL / TRANSACTIONS"},
|
||||
{"3", "WIRE TRANSFER"},
|
||||
{"4", "TRANSACTION HISTORY"},
|
||||
{"5", fmt.Sprintf("SECURE MESSAGES (%d UNREAD)", m.unread)},
|
||||
{"6", "CHANGE PIN"},
|
||||
{"7", "LOGOUT"},
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%s] %s", item.num, item.desc)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
|
||||
b.WriteString(inputStyle.Render(m.choice))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
122
internal/shell/banking/screen_messages.go
Normal file
122
internal/shell/banking/screen_messages.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type messagesModel struct {
|
||||
messages []SecureMessage
|
||||
viewing int // -1 = list, >= 0 = detail
|
||||
choice string
|
||||
}
|
||||
|
||||
func newMessagesModel(messages []SecureMessage) messagesModel {
|
||||
return messagesModel{messages: messages, viewing: -1}
|
||||
}
|
||||
|
||||
func (m messagesModel) Update(msg tea.Msg) (messagesModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.viewing >= 0 {
|
||||
// In detail view, any key goes back to list.
|
||||
m.viewing = -1
|
||||
m.choice = ""
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
if m.choice == "0" {
|
||||
m.choice = "back"
|
||||
return m, nil
|
||||
}
|
||||
for i := range m.messages {
|
||||
if m.choice == fmt.Sprintf("%d", i+1) {
|
||||
m.viewing = i
|
||||
m.messages[i].Unread = false
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
m.choice = ""
|
||||
case tea.KeyBackspace:
|
||||
if len(m.choice) > 0 {
|
||||
m.choice = m.choice[:len(m.choice)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' && len(m.choice) < 2 {
|
||||
m.choice += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m messagesModel) View() string {
|
||||
if m.viewing >= 0 {
|
||||
return m.viewDetail()
|
||||
}
|
||||
return m.viewList()
|
||||
}
|
||||
|
||||
func (m messagesModel) viewList() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("SECURE MESSAGES"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-4s %-3s %-12s %-22s %s", "#", "", "DATE", "FROM", "SUBJECT")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
|
||||
b.WriteString("\n")
|
||||
|
||||
for i, msg := range m.messages {
|
||||
marker := " "
|
||||
if msg.Unread {
|
||||
marker = " * "
|
||||
}
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d]%s%-12s %-22s %s",
|
||||
i+1, marker, msg.Date, msg.From, msg.Subj)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" SELECT MESSAGE: "))
|
||||
b.WriteString(inputStyle.Render(m.choice))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m messagesModel) viewDetail() string {
|
||||
var b strings.Builder
|
||||
|
||||
msg := m.messages[m.viewing]
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(fmt.Sprintf("MESSAGE #%d", m.viewing+1)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(msg.Body))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MESSAGE LIST"))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
198
internal/shell/banking/screen_transfer.go
Normal file
198
internal/shell/banking/screen_transfer.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
const (
|
||||
transferStepRouting = iota
|
||||
transferStepDest
|
||||
transferStepBeneficiary
|
||||
transferStepBankName
|
||||
transferStepAmount
|
||||
transferStepMemo
|
||||
transferStepConfirm
|
||||
transferStepAuthCode
|
||||
transferStepComplete
|
||||
)
|
||||
|
||||
var transferPrompts = []string{
|
||||
" ROUTING NUMBER (ABA): ",
|
||||
" DESTINATION ACCOUNT: ",
|
||||
" BENEFICIARY NAME: ",
|
||||
" RECEIVING BANK NAME: ",
|
||||
" AMOUNT (USD): ",
|
||||
" MEMO / REFERENCE: ",
|
||||
"",
|
||||
" AUTHORIZATION CODE: ",
|
||||
}
|
||||
|
||||
type transferModel struct {
|
||||
step int
|
||||
fields [8]string // indexed by step
|
||||
transfer WireTransfer
|
||||
confirm string // y/n input for confirm step
|
||||
}
|
||||
|
||||
func newTransferModel() transferModel {
|
||||
return transferModel{}
|
||||
}
|
||||
|
||||
func (m transferModel) Update(msg tea.Msg) (transferModel, tea.Cmd) {
|
||||
keyMsg, ok := msg.(tea.KeyMsg)
|
||||
if !ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.step == transferStepComplete {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if m.step == transferStepConfirm {
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
switch strings.ToUpper(m.confirm) {
|
||||
case "Y", "YES":
|
||||
m.step = transferStepAuthCode
|
||||
case "N", "NO":
|
||||
m.confirm = "cancelled"
|
||||
}
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.confirm) > 0 {
|
||||
m.confirm = m.confirm[:len(m.confirm)-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.confirm) < 3 {
|
||||
m.confirm += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch keyMsg.Type {
|
||||
case tea.KeyEnter:
|
||||
val := strings.TrimSpace(m.fields[m.step])
|
||||
if val == "" {
|
||||
return m, nil
|
||||
}
|
||||
if m.step == transferStepAuthCode {
|
||||
m.transfer.AuthCode = val
|
||||
m.step = transferStepComplete
|
||||
return m, nil
|
||||
}
|
||||
switch m.step {
|
||||
case transferStepRouting:
|
||||
m.transfer.RoutingNumber = val
|
||||
case transferStepDest:
|
||||
m.transfer.DestAccount = val
|
||||
case transferStepBeneficiary:
|
||||
m.transfer.Beneficiary = val
|
||||
case transferStepBankName:
|
||||
m.transfer.BankName = val
|
||||
case transferStepAmount:
|
||||
m.transfer.Amount = val
|
||||
case transferStepMemo:
|
||||
m.transfer.Memo = val
|
||||
}
|
||||
m.step++
|
||||
return m, nil
|
||||
case tea.KeyBackspace:
|
||||
if len(m.fields[m.step]) > 0 {
|
||||
m.fields[m.step] = m.fields[m.step][:len(m.fields[m.step])-1]
|
||||
}
|
||||
default:
|
||||
ch := keyMsg.String()
|
||||
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.fields[m.step]) < 40 {
|
||||
m.fields[m.step] += ch
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m transferModel) View() string {
|
||||
var b strings.Builder
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("WIRE TRANSFER"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Show completed fields.
|
||||
for i := 0; i < m.step && i < len(transferPrompts); i++ {
|
||||
if i == transferStepConfirm {
|
||||
continue
|
||||
}
|
||||
prompt := transferPrompts[i]
|
||||
if prompt == "" {
|
||||
continue
|
||||
}
|
||||
b.WriteString(baseStyle.Render(prompt))
|
||||
b.WriteString(baseStyle.Render(m.fields[i]))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Current field.
|
||||
switch {
|
||||
case m.step == transferStepConfirm:
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" === TRANSFER SUMMARY ==="))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" ROUTING: %s", m.transfer.RoutingNumber)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" ACCOUNT: %s", m.transfer.DestAccount)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" BENEFICIARY: %s", m.transfer.Beneficiary)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" BANK: %s", m.transfer.BankName)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" AMOUNT: $%s", m.transfer.Amount)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" MEMO: %s", m.transfer.Memo)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" CONFIRM TRANSFER? (Y/N): "))
|
||||
b.WriteString(inputStyle.Render(m.confirm))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
case m.step == transferStepComplete:
|
||||
b.WriteString("\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(titleStyle.Render(" TRANSFER QUEUED FOR PROCESSING"))
|
||||
b.WriteString("\n\n")
|
||||
routing := m.transfer.RoutingNumber
|
||||
if len(routing) > 4 {
|
||||
routing = routing[:4]
|
||||
}
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" CONFIRMATION #: WR-%s-%s",
|
||||
routing, "847291")))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" STATUS: PENDING FEDWIRE SETTLEMENT"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(" ESTIMATED COMPLETION: NEXT BUSINESS DAY"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
|
||||
b.WriteString("\n")
|
||||
case m.step < len(transferPrompts):
|
||||
prompt := transferPrompts[m.step]
|
||||
b.WriteString(titleStyle.Render(prompt))
|
||||
b.WriteString(inputStyle.Render(m.fields[m.step]))
|
||||
b.WriteString(inputStyle.Render("_"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
if m.step < transferStepConfirm {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(thinDivider())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(fmt.Sprintf(" STEP %d OF 6", m.step+1)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
168
internal/shell/banking/style.go
Normal file
168
internal/shell/banking/style.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package banking
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const termWidth = 80
|
||||
|
||||
// Color palette — green-on-black retro terminal.
|
||||
var (
|
||||
colorGreen = lipgloss.Color("#00FF00")
|
||||
colorDim = lipgloss.Color("#007700")
|
||||
colorBlack = lipgloss.Color("#000000")
|
||||
colorBright = lipgloss.Color("#AAFFAA")
|
||||
colorRed = lipgloss.Color("#FF3333")
|
||||
)
|
||||
|
||||
// Reusable styles.
|
||||
var (
|
||||
baseStyle = lipgloss.NewStyle().
|
||||
Foreground(colorGreen).
|
||||
Background(colorBlack)
|
||||
|
||||
headerStyle = lipgloss.NewStyle().
|
||||
Foreground(colorBright).
|
||||
Background(colorBlack).
|
||||
Bold(true).
|
||||
Width(termWidth).
|
||||
Align(lipgloss.Center)
|
||||
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Foreground(colorGreen).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
|
||||
dimStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
errorStyle = lipgloss.NewStyle().
|
||||
Foreground(colorRed).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
|
||||
inputStyle = lipgloss.NewStyle().
|
||||
Foreground(colorBright).
|
||||
Background(colorBlack)
|
||||
)
|
||||
|
||||
// divider returns an 80-column === line.
|
||||
func divider() string {
|
||||
return dimStyle.Render(strings.Repeat("=", termWidth))
|
||||
}
|
||||
|
||||
// thinDivider returns an 80-column --- line.
|
||||
func thinDivider() string {
|
||||
return dimStyle.Render(strings.Repeat("-", termWidth))
|
||||
}
|
||||
|
||||
// centerText centers text within 80 columns.
|
||||
func centerText(s string) string {
|
||||
return headerStyle.Render(s)
|
||||
}
|
||||
|
||||
// padRight pads a string to the given width.
|
||||
func padRight(s string, width int) string {
|
||||
if len(s) >= width {
|
||||
return s[:width]
|
||||
}
|
||||
return s + strings.Repeat(" ", width-len(s))
|
||||
}
|
||||
|
||||
// formatCurrency formats cents as $X,XXX.XX
|
||||
func formatCurrency(cents int64) string {
|
||||
negative := cents < 0
|
||||
if negative {
|
||||
cents = -cents
|
||||
}
|
||||
dollars := cents / 100
|
||||
remainder := cents % 100
|
||||
|
||||
// Add thousands separators.
|
||||
ds := fmt.Sprintf("%d", dollars)
|
||||
if len(ds) > 3 {
|
||||
var parts []string
|
||||
for len(ds) > 3 {
|
||||
parts = append([]string{ds[len(ds)-3:]}, parts...)
|
||||
ds = ds[:len(ds)-3]
|
||||
}
|
||||
parts = append([]string{ds}, parts...)
|
||||
ds = strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
if negative {
|
||||
return fmt.Sprintf("-$%s.%02d", ds, remainder)
|
||||
}
|
||||
return fmt.Sprintf("$%s.%02d", ds, remainder)
|
||||
}
|
||||
|
||||
// padLine pads a single line (which may contain ANSI codes) to termWidth
|
||||
// using its visual width. Padding uses a black background so the terminal's
|
||||
// default background doesn't bleed through.
|
||||
func padLine(line string) string {
|
||||
w := lipgloss.Width(line)
|
||||
if w >= termWidth {
|
||||
return line
|
||||
}
|
||||
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
|
||||
}
|
||||
|
||||
// padLines pads every line in a multi-line string to termWidth so that
|
||||
// shorter lines fully overwrite previous content in the terminal.
|
||||
func padLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = padLine(line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// screenFrame wraps content in the persistent header and footer.
|
||||
// The height parameter is used to pad the output to fill the terminal,
|
||||
// preventing leftover lines from previous renders bleeding through.
|
||||
func screenFrame(bankName, terminalID, region, content string, height int) string {
|
||||
var b strings.Builder
|
||||
|
||||
// Header (4 lines).
|
||||
b.WriteString(divider())
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText(bankName + " FEDERAL RESERVE SYSTEM"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(centerText("SECURE BANKING TERMINAL"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(divider())
|
||||
b.WriteString("\n")
|
||||
|
||||
// Content.
|
||||
b.WriteString(content)
|
||||
|
||||
// Pad with blank lines between content and footer so the footer
|
||||
// stays at the bottom and the total output fills the terminal height.
|
||||
if height > 0 {
|
||||
const headerLines = 4
|
||||
const footerLines = 2
|
||||
// strings.Count gives newlines; add 1 for the line after the last \n.
|
||||
contentLines := strings.Count(content, "\n") + 1
|
||||
used := headerLines + contentLines + footerLines
|
||||
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
|
||||
for i := used; i < height; i++ {
|
||||
b.WriteString(blankLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer (2 lines).
|
||||
b.WriteString("\n")
|
||||
b.WriteString(divider())
|
||||
b.WriteString("\n")
|
||||
footer := fmt.Sprintf(" TERMINAL: %s | REGION: %s | ENCRYPTED SESSION ACTIVE", terminalID, region)
|
||||
b.WriteString(dimStyle.Render(padRight(footer, termWidth)))
|
||||
|
||||
// Pad every line to full terminal width so shorter lines overwrite
|
||||
// leftover content from previous renders.
|
||||
return padLines(b.String())
|
||||
}
|
||||
@@ -2,12 +2,13 @@ package bash
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
@@ -54,8 +55,8 @@ func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw i
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := readLine(ctx, rw)
|
||||
if err == io.EOF {
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Fprint(rw, "logout\r\n")
|
||||
return nil
|
||||
}
|
||||
@@ -81,7 +82,12 @@ func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw i
|
||||
|
||||
// Log command and output to store.
|
||||
if sess.Store != nil {
|
||||
sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output)
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("bash")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
@@ -100,59 +106,3 @@ func formatPrompt(state *shellState) string {
|
||||
return fmt.Sprintf("%s@%s:%s# ", state.username, state.hostname, cwd)
|
||||
}
|
||||
|
||||
// readLine reads a line of input byte-by-byte, handling backspace, Ctrl+C, and Ctrl+D.
|
||||
func readLine(ctx context.Context, rw io.ReadWriter) (string, error) {
|
||||
var buf []byte
|
||||
b := make([]byte, 1)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := rw.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
ch := b[0]
|
||||
switch {
|
||||
case ch == '\r' || ch == '\n':
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
return string(buf), nil
|
||||
|
||||
case ch == 4: // Ctrl+D
|
||||
if len(buf) == 0 {
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
case ch == 3: // Ctrl+C
|
||||
fmt.Fprint(rw, "^C\r\n")
|
||||
return "", nil
|
||||
|
||||
case ch == 127 || ch == 8: // DEL or Backspace
|
||||
if len(buf) > 0 {
|
||||
buf = buf[:len(buf)-1]
|
||||
fmt.Fprint(rw, "\b \b")
|
||||
}
|
||||
|
||||
case ch == 27: // ESC - start of escape sequence
|
||||
// Read and discard the rest of the escape sequence.
|
||||
// Most are 3 bytes: ESC [ X (arrow keys, etc.)
|
||||
next := make([]byte, 1)
|
||||
rw.Read(next)
|
||||
if next[0] == '[' {
|
||||
rw.Read(next) // read the final byte
|
||||
}
|
||||
|
||||
case ch >= 32 && ch < 127: // printable ASCII
|
||||
buf = append(buf, ch)
|
||||
rw.Write([]byte{ch})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,13 +3,14 @@ package bash
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
@@ -52,7 +53,7 @@ func TestReadLineEnter(t *testing.T) {
|
||||
}{input, &output}
|
||||
|
||||
ctx := context.Background()
|
||||
line, err := readLine(ctx, rw)
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("readLine: %v", err)
|
||||
}
|
||||
@@ -71,7 +72,7 @@ func TestReadLineBackspace(t *testing.T) {
|
||||
}{input, &output}
|
||||
|
||||
ctx := context.Background()
|
||||
line, err := readLine(ctx, rw)
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("readLine: %v", err)
|
||||
}
|
||||
@@ -89,7 +90,7 @@ func TestReadLineCtrlC(t *testing.T) {
|
||||
}{input, &output}
|
||||
|
||||
ctx := context.Background()
|
||||
line, err := readLine(ctx, rw)
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("readLine: %v", err)
|
||||
}
|
||||
@@ -107,15 +108,15 @@ func TestReadLineCtrlD(t *testing.T) {
|
||||
}{input, &output}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := readLine(ctx, rw)
|
||||
if err != io.EOF {
|
||||
_, err := shell.ReadLine(ctx, rw)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("expected io.EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashShellHandle(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash")
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
@@ -165,7 +166,7 @@ func TestBashShellHandle(t *testing.T) {
|
||||
|
||||
func TestBashShellFakeUser(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash")
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
|
||||
206
internal/shell/cisco/cisco.go
Normal file
206
internal/shell/cisco/cisco.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package cisco
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// CiscoShell emulates a Cisco IOS CLI.
|
||||
type CiscoShell struct{}
|
||||
|
||||
// NewCiscoShell returns a new CiscoShell instance.
|
||||
func NewCiscoShell() *CiscoShell {
|
||||
return &CiscoShell{}
|
||||
}
|
||||
|
||||
func (c *CiscoShell) Name() string { return "cisco" }
|
||||
func (c *CiscoShell) Description() string { return "Cisco IOS CLI emulator" }
|
||||
|
||||
func (c *CiscoShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
hostname := configString(sess.ShellConfig, "hostname", "Router")
|
||||
model := configString(sess.ShellConfig, "model", "C2960")
|
||||
iosVersion := configString(sess.ShellConfig, "ios_version", "15.0(2)SE11")
|
||||
enablePass := configString(sess.ShellConfig, "enable_password", "")
|
||||
|
||||
state := newIOSState(hostname, model, iosVersion, enablePass)
|
||||
|
||||
// IOS just shows a blank line then the prompt after SSH auth.
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
|
||||
for {
|
||||
prompt := state.prompt()
|
||||
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for Ctrl+Z (^Z) — return to privileged exec.
|
||||
if trimmed == "\x1a" || trimmed == "^Z" {
|
||||
if state.mode == modeGlobalConfig || state.mode == modeInterfaceConfig {
|
||||
state.mode = modePrivilegedExec
|
||||
state.currentIf = ""
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle "enable" specially — it needs password prompting.
|
||||
if state.mode == modeUserExec && isEnableCommand(trimmed) {
|
||||
output := handleEnable(ctx, state, rw)
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("cisco")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
result := state.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("cisco")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isEnableCommand checks if input resolves to "enable" in user exec mode.
|
||||
func isEnableCommand(input string) bool {
|
||||
words := strings.Fields(input)
|
||||
if len(words) != 1 {
|
||||
return false
|
||||
}
|
||||
w := strings.ToLower(words[0])
|
||||
enable := "enable"
|
||||
return len(w) >= 2 && len(w) <= len(enable) && enable[:len(w)] == w
|
||||
}
|
||||
|
||||
// handleEnable manages the enable password prompt flow.
|
||||
// Returns the output string (for logging).
|
||||
func handleEnable(ctx context.Context, state *iosState, rw io.ReadWriter) string {
|
||||
const maxAttempts = 3
|
||||
hadFailure := false
|
||||
|
||||
for range maxAttempts {
|
||||
fmt.Fprint(rw, "Password: ")
|
||||
password, err := readPassword(ctx, rw)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
|
||||
if state.enablePass == "" {
|
||||
// No password configured — accept after one failed attempt.
|
||||
if hadFailure {
|
||||
state.mode = modePrivilegedExec
|
||||
return ""
|
||||
}
|
||||
hadFailure = true
|
||||
} else if password == state.enablePass {
|
||||
state.mode = modePrivilegedExec
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
output := "% Bad passwords"
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
return output
|
||||
}
|
||||
|
||||
// readPassword reads a password without echoing characters.
|
||||
func readPassword(ctx context.Context, rw io.ReadWriter) (string, error) {
|
||||
var buf []byte
|
||||
b := make([]byte, 1)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := rw.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
ch := b[0]
|
||||
switch {
|
||||
case ch == '\r' || ch == '\n':
|
||||
return string(buf), nil
|
||||
case ch == 4: // Ctrl+D
|
||||
return string(buf), io.EOF
|
||||
case ch == 3: // Ctrl+C
|
||||
return "", io.EOF
|
||||
case ch == 127 || ch == 8: // Backspace/DEL
|
||||
if len(buf) > 0 {
|
||||
buf = buf[:len(buf)-1]
|
||||
}
|
||||
case ch == 27: // ESC sequence
|
||||
next := make([]byte, 1)
|
||||
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
|
||||
rw.Read(next)
|
||||
}
|
||||
case ch >= 32 && ch < 127:
|
||||
buf = append(buf, ch)
|
||||
// Don't echo.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
531
internal/shell/cisco/cisco_test.go
Normal file
531
internal/shell/cisco/cisco_test.go
Normal file
@@ -0,0 +1,531 @@
|
||||
package cisco
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Abbreviation resolution tests ---
|
||||
|
||||
func TestResolveAbbreviationExact(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "shutdown"},
|
||||
}
|
||||
got, err := resolveAbbreviation("show", entries)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "show" {
|
||||
t.Errorf("got %q, want %q", got, "show")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAbbreviationUnique(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "enable"},
|
||||
{name: "exit"},
|
||||
}
|
||||
got, err := resolveAbbreviation("sh", entries)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "show" {
|
||||
t.Errorf("got %q, want %q", got, "show")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAbbreviationAmbiguous(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "shutdown"},
|
||||
}
|
||||
_, err := resolveAbbreviation("sh", entries)
|
||||
if err == nil {
|
||||
t.Fatal("expected ambiguous error, got nil")
|
||||
}
|
||||
if err.Error() != "ambiguous" {
|
||||
t.Errorf("got error %q, want %q", err.Error(), "ambiguous")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAbbreviationUnknown(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "enable"},
|
||||
}
|
||||
_, err := resolveAbbreviation("xyz", entries)
|
||||
if err == nil {
|
||||
t.Fatal("expected unknown error, got nil")
|
||||
}
|
||||
if err.Error() != "unknown" {
|
||||
t.Errorf("got error %q, want %q", err.Error(), "unknown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveAbbreviationCaseInsensitive(t *testing.T) {
|
||||
entries := []commandEntry{
|
||||
{name: "show"},
|
||||
{name: "enable"},
|
||||
}
|
||||
got, err := resolveAbbreviation("SH", entries)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "show" {
|
||||
t.Errorf("got %q, want %q", got, "show")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Multi-word command resolution tests ---
|
||||
|
||||
func TestResolveCommandShowRunningConfig(t *testing.T) {
|
||||
resolved, args, err := resolveCommand([]string{"sh", "run"}, privilegedExecCommands)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(args) != 0 {
|
||||
t.Errorf("unexpected args: %v", args)
|
||||
}
|
||||
want := []string{"show", "running-config"}
|
||||
if len(resolved) != len(want) {
|
||||
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||
}
|
||||
for i := range want {
|
||||
if resolved[i] != want[i] {
|
||||
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCommandConfigureTerminal(t *testing.T) {
|
||||
resolved, _, err := resolveCommand([]string{"conf", "t"}, privilegedExecCommands)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := []string{"configure", "terminal"}
|
||||
if len(resolved) != len(want) {
|
||||
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||
}
|
||||
for i := range want {
|
||||
if resolved[i] != want[i] {
|
||||
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCommandShowIPInterfaceBrief(t *testing.T) {
|
||||
resolved, _, err := resolveCommand([]string{"sh", "ip", "int", "br"}, privilegedExecCommands)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
want := []string{"show", "ip", "interface", "brief"}
|
||||
if len(resolved) != len(want) {
|
||||
t.Fatalf("resolved = %v, want %v", resolved, want)
|
||||
}
|
||||
for i := range want {
|
||||
if resolved[i] != want[i] {
|
||||
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCommandWithArgs(t *testing.T) {
|
||||
// "hostname MyRouter" → resolved=["hostname"], args=["MyRouter"]
|
||||
resolved, args, err := resolveCommand([]string{"hostname", "MyRouter"}, globalConfigCommands)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(resolved) != 1 || resolved[0] != "hostname" {
|
||||
t.Errorf("resolved = %v, want [hostname]", resolved)
|
||||
}
|
||||
if len(args) != 1 || args[0] != "MyRouter" {
|
||||
t.Errorf("args = %v, want [MyRouter]", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCommandAmbiguous(t *testing.T) {
|
||||
// In user exec, "e" matches "enable" and "exit" — ambiguous
|
||||
_, _, err := resolveCommand([]string{"e"}, userExecCommands)
|
||||
if err == nil {
|
||||
t.Fatal("expected ambiguous error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Mode state machine tests ---
|
||||
|
||||
func TestPromptGeneration(t *testing.T) {
|
||||
tests := []struct {
|
||||
mode iosMode
|
||||
want string
|
||||
}{
|
||||
{modeUserExec, "Router>"},
|
||||
{modePrivilegedExec, "Router#"},
|
||||
{modeGlobalConfig, "Router(config)#"},
|
||||
{modeInterfaceConfig, "Router(config-if)#"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = tt.mode
|
||||
if got := s.prompt(); got != tt.want {
|
||||
t.Errorf("prompt(%d) = %q, want %q", tt.mode, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptAfterHostnameChange(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modeGlobalConfig
|
||||
s.dispatch("hostname Switch1")
|
||||
if s.hostname != "Switch1" {
|
||||
t.Fatalf("hostname = %q, want %q", s.hostname, "Switch1")
|
||||
}
|
||||
if got := s.prompt(); got != "Switch1(config)#" {
|
||||
t.Errorf("prompt = %q, want %q", got, "Switch1(config)#")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModeTransitions(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
|
||||
// Start in user exec.
|
||||
if s.mode != modeUserExec {
|
||||
t.Fatalf("initial mode = %d, want %d", s.mode, modeUserExec)
|
||||
}
|
||||
|
||||
// Can't skip to config mode directly from user exec.
|
||||
result := s.dispatch("configure terminal")
|
||||
if result.output == "" {
|
||||
t.Error("expected error for conf t in user exec mode")
|
||||
}
|
||||
|
||||
// Manually set privileged mode (enable tested separately).
|
||||
s.mode = modePrivilegedExec
|
||||
|
||||
// conf t → global config
|
||||
s.dispatch("configure terminal")
|
||||
if s.mode != modeGlobalConfig {
|
||||
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
|
||||
}
|
||||
|
||||
// interface Gi0/0 → interface config
|
||||
s.dispatch("interface GigabitEthernet0/0")
|
||||
if s.mode != modeInterfaceConfig {
|
||||
t.Errorf("mode after interface = %d, want %d", s.mode, modeInterfaceConfig)
|
||||
}
|
||||
|
||||
// exit → back to global config
|
||||
s.dispatch("exit")
|
||||
if s.mode != modeGlobalConfig {
|
||||
t.Errorf("mode after exit from if-config = %d, want %d", s.mode, modeGlobalConfig)
|
||||
}
|
||||
|
||||
// end → back to privileged exec
|
||||
s.dispatch("end")
|
||||
if s.mode != modePrivilegedExec {
|
||||
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
|
||||
}
|
||||
|
||||
// disable → back to user exec
|
||||
s.dispatch("disable")
|
||||
if s.mode != modeUserExec {
|
||||
t.Errorf("mode after disable = %d, want %d", s.mode, modeUserExec)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndFromInterfaceConfig(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modeInterfaceConfig
|
||||
s.currentIf = "GigabitEthernet0/0"
|
||||
|
||||
s.dispatch("end")
|
||||
if s.mode != modePrivilegedExec {
|
||||
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
|
||||
}
|
||||
if s.currentIf != "" {
|
||||
t.Errorf("currentIf = %q, want empty", s.currentIf)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitFromPrivilegedExec(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modePrivilegedExec
|
||||
result := s.dispatch("exit")
|
||||
if !result.exit {
|
||||
t.Error("expected exit=true from privileged exec exit")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Show command output tests ---
|
||||
|
||||
func TestShowVersionContainsModel(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showVersion(s)
|
||||
if !contains(output, "C2960") {
|
||||
t.Error("show version missing model")
|
||||
}
|
||||
if !contains(output, "15.0(2)SE11") {
|
||||
t.Error("show version missing IOS version")
|
||||
}
|
||||
if !contains(output, "Router") {
|
||||
t.Error("show version missing hostname")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowRunningConfigContainsInterfaces(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showRunningConfig(s)
|
||||
if !contains(output, "hostname Router") {
|
||||
t.Error("running-config missing hostname")
|
||||
}
|
||||
if !contains(output, "interface GigabitEthernet0/0") {
|
||||
t.Error("running-config missing interface")
|
||||
}
|
||||
if !contains(output, "ip address 192.168.1.1") {
|
||||
t.Error("running-config missing IP address")
|
||||
}
|
||||
if !contains(output, "line vty") {
|
||||
t.Error("running-config missing VTY config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowRunningConfigWithEnableSecret(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "secret123")
|
||||
output := showRunningConfig(s)
|
||||
if !contains(output, "enable secret") {
|
||||
t.Error("running-config missing enable secret when password is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowRunningConfigWithoutEnableSecret(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showRunningConfig(s)
|
||||
if contains(output, "enable secret") {
|
||||
t.Error("running-config should not have enable secret when password is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowIPInterfaceBrief(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showIPInterfaceBrief(s)
|
||||
if !contains(output, "GigabitEthernet0/0") {
|
||||
t.Error("ip interface brief missing GigabitEthernet0/0")
|
||||
}
|
||||
if !contains(output, "192.168.1.1") {
|
||||
t.Error("ip interface brief missing 192.168.1.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowIPRoute(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
output := showIPRoute(s)
|
||||
if !contains(output, "directly connected") {
|
||||
t.Error("ip route missing connected routes")
|
||||
}
|
||||
if !contains(output, "0.0.0.0/0") {
|
||||
t.Error("ip route missing default route")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowVLANBrief(t *testing.T) {
|
||||
output := showVLANBrief()
|
||||
if !contains(output, "default") {
|
||||
t.Error("vlan brief missing default vlan")
|
||||
}
|
||||
if !contains(output, "MGMT") {
|
||||
t.Error("vlan brief missing MGMT vlan")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Interface config tests ---
|
||||
|
||||
func TestInterfaceShutdownNoShutdown(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modeInterfaceConfig
|
||||
s.currentIf = "GigabitEthernet0/0"
|
||||
|
||||
s.dispatch("shutdown")
|
||||
iface := s.findInterface("GigabitEthernet0/0")
|
||||
if iface == nil {
|
||||
t.Fatal("interface not found")
|
||||
}
|
||||
if !iface.shutdown {
|
||||
t.Error("interface should be shutdown")
|
||||
}
|
||||
if iface.status != "administratively down" {
|
||||
t.Errorf("status = %q, want %q", iface.status, "administratively down")
|
||||
}
|
||||
|
||||
s.dispatch("no shutdown")
|
||||
if iface.shutdown {
|
||||
t.Error("interface should not be shutdown after no shutdown")
|
||||
}
|
||||
if iface.status != "up" {
|
||||
t.Errorf("status = %q, want %q", iface.status, "up")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterfaceIPAddress(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modeInterfaceConfig
|
||||
s.currentIf = "GigabitEthernet0/0"
|
||||
|
||||
s.dispatch("ip address 10.10.10.1 255.255.255.0")
|
||||
iface := s.findInterface("GigabitEthernet0/0")
|
||||
if iface == nil {
|
||||
t.Fatal("interface not found")
|
||||
}
|
||||
if iface.ip != "10.10.10.1" {
|
||||
t.Errorf("ip = %q, want %q", iface.ip, "10.10.10.1")
|
||||
}
|
||||
if iface.mask != "255.255.255.0" {
|
||||
t.Errorf("mask = %q, want %q", iface.mask, "255.255.255.0")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Dispatch / invalid command tests ---
|
||||
|
||||
func TestInvalidCommandInUserExec(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
result := s.dispatch("foobar")
|
||||
if !contains(result.output, "Invalid input") {
|
||||
t.Errorf("expected invalid input error, got %q", result.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmbiguousCommandOutput(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
// "e" in user exec is ambiguous (enable, exit)
|
||||
result := s.dispatch("e")
|
||||
if !contains(result.output, "Ambiguous") {
|
||||
t.Errorf("expected ambiguous error, got %q", result.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelpCommand(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
result := s.dispatch("?")
|
||||
if !contains(result.output, "show") {
|
||||
t.Error("help missing 'show'")
|
||||
}
|
||||
if !contains(result.output, "enable") {
|
||||
t.Error("help missing 'enable'")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Abbreviation integration tests ---
|
||||
|
||||
func TestShowAbbreviationInDispatch(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modePrivilegedExec
|
||||
result := s.dispatch("sh ver")
|
||||
if !contains(result.output, "Cisco IOS Software") {
|
||||
t.Error("'sh ver' should produce version output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfTAbbreviation(t *testing.T) {
|
||||
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
|
||||
s.mode = modePrivilegedExec
|
||||
s.dispatch("conf t")
|
||||
if s.mode != modeGlobalConfig {
|
||||
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Enable command detection ---
|
||||
|
||||
func TestIsEnableCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"enable", true},
|
||||
{"en", true},
|
||||
{"ena", true},
|
||||
{"e", false}, // too short (single char could be other commands)
|
||||
{"enab", true},
|
||||
{"ENABLE", true},
|
||||
{"exit", false},
|
||||
{"enable 15", false}, // has extra argument
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isEnableCommand(tt.input); got != tt.want {
|
||||
t.Errorf("isEnableCommand(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- configString tests ---
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"hostname": "MySwitch"}
|
||||
if got := configString(cfg, "hostname", "Router"); got != "MySwitch" {
|
||||
t.Errorf("configString() = %q, want %q", got, "MySwitch")
|
||||
}
|
||||
if got := configString(cfg, "missing", "Default"); got != "Default" {
|
||||
t.Errorf("configString() for missing = %q, want %q", got, "Default")
|
||||
}
|
||||
if got := configString(nil, "key", "Default"); got != "Default" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "Default")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper ---
|
||||
|
||||
func TestMaskBits(t *testing.T) {
|
||||
tests := []struct {
|
||||
mask string
|
||||
want int
|
||||
}{
|
||||
{"255.255.255.0", 24},
|
||||
{"255.255.255.252", 30},
|
||||
{"255.255.0.0", 16},
|
||||
{"255.0.0.0", 8},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := maskBits(tt.mask); got != tt.want {
|
||||
t.Errorf("maskBits(%q) = %d, want %d", tt.mask, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkFromIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip, mask, want string
|
||||
}{
|
||||
{"192.168.1.1", "255.255.255.0", "192.168.1.0"},
|
||||
{"10.0.0.1", "255.255.255.252", "10.0.0.0"},
|
||||
{"172.16.5.100", "255.255.0.0", "172.16.0.0"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := networkFromIP(tt.ip, tt.mask); got != tt.want {
|
||||
t.Errorf("networkFromIP(%q, %q) = %q, want %q", tt.ip, tt.mask, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shell metadata ---
|
||||
|
||||
func TestShellNameAndDescription(t *testing.T) {
|
||||
s := NewCiscoShell()
|
||||
if s.Name() != "cisco" {
|
||||
t.Errorf("Name() = %q, want %q", s.Name(), "cisco")
|
||||
}
|
||||
if s.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && containsHelper(s, substr)
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
414
internal/shell/cisco/commands.go
Normal file
414
internal/shell/cisco/commands.go
Normal file
@@ -0,0 +1,414 @@
|
||||
package cisco
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// commandResult holds the output of a command and whether the session should end.
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
// commandEntry defines a single command with its name and optional sub-commands.
|
||||
type commandEntry struct {
|
||||
name string
|
||||
subs []commandEntry // nil for leaf commands
|
||||
}
|
||||
|
||||
// userExecCommands defines the command tree for user EXEC mode.
|
||||
var userExecCommands = []commandEntry{
|
||||
{name: "show", subs: []commandEntry{
|
||||
{name: "version"},
|
||||
{name: "clock"},
|
||||
{name: "ip", subs: []commandEntry{
|
||||
{name: "route"},
|
||||
{name: "interface", subs: []commandEntry{
|
||||
{name: "brief"},
|
||||
}},
|
||||
}},
|
||||
{name: "interfaces"},
|
||||
{name: "vlan", subs: []commandEntry{
|
||||
{name: "brief"},
|
||||
}},
|
||||
}},
|
||||
{name: "enable"},
|
||||
{name: "exit"},
|
||||
{name: "?"},
|
||||
}
|
||||
|
||||
// privilegedExecCommands extends user commands for privileged mode.
|
||||
var privilegedExecCommands = []commandEntry{
|
||||
{name: "show", subs: []commandEntry{
|
||||
{name: "version"},
|
||||
{name: "clock"},
|
||||
{name: "ip", subs: []commandEntry{
|
||||
{name: "route"},
|
||||
{name: "interface", subs: []commandEntry{
|
||||
{name: "brief"},
|
||||
}},
|
||||
}},
|
||||
{name: "interfaces"},
|
||||
{name: "running-config"},
|
||||
{name: "startup-config"},
|
||||
{name: "vlan", subs: []commandEntry{
|
||||
{name: "brief"},
|
||||
}},
|
||||
}},
|
||||
{name: "configure", subs: []commandEntry{
|
||||
{name: "terminal"},
|
||||
}},
|
||||
{name: "write", subs: []commandEntry{
|
||||
{name: "memory"},
|
||||
}},
|
||||
{name: "copy"},
|
||||
{name: "reload"},
|
||||
{name: "disable"},
|
||||
{name: "terminal", subs: []commandEntry{
|
||||
{name: "length"},
|
||||
}},
|
||||
{name: "exit"},
|
||||
{name: "?"},
|
||||
}
|
||||
|
||||
// globalConfigCommands defines the command tree for global config mode.
|
||||
var globalConfigCommands = []commandEntry{
|
||||
{name: "hostname"},
|
||||
{name: "interface"},
|
||||
{name: "ip", subs: []commandEntry{
|
||||
{name: "route"},
|
||||
}},
|
||||
{name: "no"},
|
||||
{name: "end"},
|
||||
{name: "exit"},
|
||||
{name: "?"},
|
||||
}
|
||||
|
||||
// interfaceConfigCommands defines the command tree for interface config mode.
|
||||
var interfaceConfigCommands = []commandEntry{
|
||||
{name: "ip", subs: []commandEntry{
|
||||
{name: "address"},
|
||||
}},
|
||||
{name: "description"},
|
||||
{name: "shutdown"},
|
||||
{name: "no", subs: []commandEntry{
|
||||
{name: "shutdown"},
|
||||
}},
|
||||
{name: "switchport", subs: []commandEntry{
|
||||
{name: "mode"},
|
||||
}},
|
||||
{name: "end"},
|
||||
{name: "exit"},
|
||||
{name: "?"},
|
||||
}
|
||||
|
||||
// commandsForMode returns the command tree for the given IOS mode.
|
||||
func commandsForMode(mode iosMode) []commandEntry {
|
||||
switch mode {
|
||||
case modeUserExec:
|
||||
return userExecCommands
|
||||
case modePrivilegedExec:
|
||||
return privilegedExecCommands
|
||||
case modeGlobalConfig:
|
||||
return globalConfigCommands
|
||||
case modeInterfaceConfig:
|
||||
return interfaceConfigCommands
|
||||
default:
|
||||
return userExecCommands
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAbbreviation attempts to match an abbreviated word against a list of
|
||||
// command entries. It returns the matched entry name, or an error string if
|
||||
// ambiguous or unknown.
|
||||
func resolveAbbreviation(word string, entries []commandEntry) (string, error) {
|
||||
word = strings.ToLower(word)
|
||||
var matches []string
|
||||
for _, e := range entries {
|
||||
if strings.ToLower(e.name) == word {
|
||||
return e.name, nil // exact match
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(e.name), word) {
|
||||
matches = append(matches, e.name)
|
||||
}
|
||||
}
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
return "", fmt.Errorf("unknown")
|
||||
case 1:
|
||||
return matches[0], nil
|
||||
default:
|
||||
return "", fmt.Errorf("ambiguous")
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCommand resolves a sequence of abbreviated words into the canonical
|
||||
// command path (e.g., ["sh", "run"] → ["show", "running-config"]).
|
||||
// It returns the resolved path, any remaining arguments, and an error if
|
||||
// resolution fails.
|
||||
func resolveCommand(words []string, entries []commandEntry) ([]string, []string, error) {
|
||||
var resolved []string
|
||||
current := entries
|
||||
|
||||
for i, w := range words {
|
||||
name, err := resolveAbbreviation(w, current)
|
||||
if err != nil {
|
||||
if err.Error() == "unknown" && len(resolved) > 0 {
|
||||
// Remaining words are arguments to the resolved command.
|
||||
return resolved, words[i:], nil
|
||||
}
|
||||
return resolved, words[i:], err
|
||||
}
|
||||
resolved = append(resolved, name)
|
||||
|
||||
// Find sub-commands for the matched entry.
|
||||
var nextLevel []commandEntry
|
||||
for _, e := range current {
|
||||
if e.name == name {
|
||||
nextLevel = e.subs
|
||||
break
|
||||
}
|
||||
}
|
||||
if nextLevel == nil {
|
||||
// Leaf command — rest are arguments.
|
||||
return resolved, words[i+1:], nil
|
||||
}
|
||||
current = nextLevel
|
||||
}
|
||||
return resolved, nil, nil
|
||||
}
|
||||
|
||||
// dispatch processes a command line in the context of the current IOS state.
|
||||
func (s *iosState) dispatch(input string) commandResult {
|
||||
words := strings.Fields(input)
|
||||
if len(words) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
// Handle "?" as a help request.
|
||||
if words[0] == "?" {
|
||||
return s.cmdHelp()
|
||||
}
|
||||
|
||||
cmds := commandsForMode(s.mode)
|
||||
resolved, args, err := resolveCommand(words, cmds)
|
||||
if err != nil {
|
||||
if err.Error() == "ambiguous" {
|
||||
return commandResult{output: fmt.Sprintf("%% Ambiguous command: \"%s\"", input)}
|
||||
}
|
||||
return commandResult{output: invalidInput(input)}
|
||||
}
|
||||
|
||||
if len(resolved) == 0 {
|
||||
return commandResult{output: invalidInput(input)}
|
||||
}
|
||||
|
||||
cmd := strings.Join(resolved, " ")
|
||||
|
||||
switch s.mode {
|
||||
case modeUserExec:
|
||||
return s.dispatchUserExec(cmd, args)
|
||||
case modePrivilegedExec:
|
||||
return s.dispatchPrivilegedExec(cmd, args)
|
||||
case modeGlobalConfig:
|
||||
return s.dispatchGlobalConfig(cmd, args)
|
||||
case modeInterfaceConfig:
|
||||
return s.dispatchInterfaceConfig(cmd, args)
|
||||
}
|
||||
return commandResult{output: invalidInput(input)}
|
||||
}
|
||||
|
||||
func (s *iosState) dispatchUserExec(cmd string, args []string) commandResult {
|
||||
switch cmd {
|
||||
case "show version":
|
||||
return commandResult{output: showVersion(s)}
|
||||
case "show clock":
|
||||
return commandResult{output: showClock()}
|
||||
case "show ip route":
|
||||
return commandResult{output: showIPRoute(s)}
|
||||
case "show ip interface brief":
|
||||
return commandResult{output: showIPInterfaceBrief(s)}
|
||||
case "show interfaces":
|
||||
return commandResult{output: showInterfaces(s)}
|
||||
case "show vlan brief":
|
||||
return commandResult{output: showVLANBrief()}
|
||||
case "enable":
|
||||
return commandResult{} // handled in Handle() loop
|
||||
case "exit":
|
||||
return commandResult{exit: true}
|
||||
}
|
||||
return commandResult{output: invalidInput(cmd)}
|
||||
}
|
||||
|
||||
func (s *iosState) dispatchPrivilegedExec(cmd string, args []string) commandResult {
|
||||
switch cmd {
|
||||
case "show version":
|
||||
return commandResult{output: showVersion(s)}
|
||||
case "show clock":
|
||||
return commandResult{output: showClock()}
|
||||
case "show ip route":
|
||||
return commandResult{output: showIPRoute(s)}
|
||||
case "show ip interface brief":
|
||||
return commandResult{output: showIPInterfaceBrief(s)}
|
||||
case "show interfaces":
|
||||
return commandResult{output: showInterfaces(s)}
|
||||
case "show running-config":
|
||||
return commandResult{output: showRunningConfig(s)}
|
||||
case "show startup-config":
|
||||
return commandResult{output: showRunningConfig(s)} // same as running
|
||||
case "show vlan brief":
|
||||
return commandResult{output: showVLANBrief()}
|
||||
case "configure terminal":
|
||||
s.mode = modeGlobalConfig
|
||||
return commandResult{output: "Enter configuration commands, one per line. End with CNTL/Z."}
|
||||
case "write memory":
|
||||
return commandResult{output: "[OK]"}
|
||||
case "copy":
|
||||
return commandResult{output: "[OK]"}
|
||||
case "reload":
|
||||
return commandResult{output: "System configuration has been modified. Save? [yes/no]: ", exit: true}
|
||||
case "disable":
|
||||
s.mode = modeUserExec
|
||||
return commandResult{}
|
||||
case "terminal length":
|
||||
return commandResult{} // accept silently
|
||||
case "exit":
|
||||
return commandResult{exit: true}
|
||||
}
|
||||
return commandResult{output: invalidInput(cmd)}
|
||||
}
|
||||
|
||||
func (s *iosState) dispatchGlobalConfig(cmd string, args []string) commandResult {
|
||||
switch cmd {
|
||||
case "hostname":
|
||||
if len(args) < 1 {
|
||||
return commandResult{output: "% Incomplete command."}
|
||||
}
|
||||
s.hostname = args[0]
|
||||
return commandResult{}
|
||||
case "interface":
|
||||
if len(args) < 1 {
|
||||
return commandResult{output: "% Incomplete command."}
|
||||
}
|
||||
ifName := strings.Join(args, "")
|
||||
s.currentIf = ifName
|
||||
s.mode = modeInterfaceConfig
|
||||
return commandResult{}
|
||||
case "ip route":
|
||||
return commandResult{} // accept silently
|
||||
case "no":
|
||||
return commandResult{} // accept silently
|
||||
case "end":
|
||||
s.mode = modePrivilegedExec
|
||||
return commandResult{}
|
||||
case "exit":
|
||||
s.mode = modePrivilegedExec
|
||||
return commandResult{}
|
||||
}
|
||||
return commandResult{output: invalidInput(cmd)}
|
||||
}
|
||||
|
||||
func (s *iosState) dispatchInterfaceConfig(cmd string, args []string) commandResult {
|
||||
switch cmd {
|
||||
case "ip address":
|
||||
if len(args) < 2 {
|
||||
return commandResult{output: "% Incomplete command."}
|
||||
}
|
||||
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||
iface.ip = args[0]
|
||||
iface.mask = args[1]
|
||||
}
|
||||
return commandResult{}
|
||||
case "description":
|
||||
if len(args) < 1 {
|
||||
return commandResult{output: "% Incomplete command."}
|
||||
}
|
||||
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||
iface.desc = strings.Join(args, " ")
|
||||
}
|
||||
return commandResult{}
|
||||
case "shutdown":
|
||||
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||
iface.shutdown = true
|
||||
iface.status = "administratively down"
|
||||
iface.protocol = "down"
|
||||
}
|
||||
return commandResult{}
|
||||
case "no shutdown":
|
||||
if iface := s.findInterface(s.currentIf); iface != nil {
|
||||
iface.shutdown = false
|
||||
iface.status = "up"
|
||||
iface.protocol = "up"
|
||||
}
|
||||
return commandResult{}
|
||||
case "switchport mode":
|
||||
return commandResult{} // accept silently
|
||||
case "end":
|
||||
s.mode = modePrivilegedExec
|
||||
s.currentIf = ""
|
||||
return commandResult{}
|
||||
case "exit":
|
||||
s.mode = modeGlobalConfig
|
||||
s.currentIf = ""
|
||||
return commandResult{}
|
||||
}
|
||||
return commandResult{output: invalidInput(cmd)}
|
||||
}
|
||||
|
||||
func (s *iosState) cmdHelp() commandResult {
|
||||
cmds := commandsForMode(s.mode)
|
||||
var b strings.Builder
|
||||
for _, e := range cmds {
|
||||
if e.name == "?" {
|
||||
continue
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" %-20s %s\n", e.name, helpText(e.name)))
|
||||
}
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func helpText(name string) string {
|
||||
switch name {
|
||||
case "show":
|
||||
return "Show running system information"
|
||||
case "enable":
|
||||
return "Turn on privileged commands"
|
||||
case "disable":
|
||||
return "Turn off privileged commands"
|
||||
case "exit":
|
||||
return "Exit from the EXEC"
|
||||
case "configure":
|
||||
return "Enter configuration mode"
|
||||
case "write":
|
||||
return "Write running configuration to memory"
|
||||
case "copy":
|
||||
return "Copy from one file to another"
|
||||
case "reload":
|
||||
return "Halt and perform a cold restart"
|
||||
case "terminal":
|
||||
return "Set terminal line parameters"
|
||||
case "hostname":
|
||||
return "Set system's network name"
|
||||
case "interface":
|
||||
return "Select an interface to configure"
|
||||
case "ip":
|
||||
return "Global IP configuration subcommands"
|
||||
case "no":
|
||||
return "Negate a command or set its defaults"
|
||||
case "end":
|
||||
return "Exit from configure mode"
|
||||
case "description":
|
||||
return "Interface specific description"
|
||||
case "shutdown":
|
||||
return "Shutdown the selected interface"
|
||||
case "switchport":
|
||||
return "Set switching mode characteristics"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func invalidInput(input string) string {
|
||||
return fmt.Sprintf("%% Invalid input detected at '^' marker.\n\n%s\n^", input)
|
||||
}
|
||||
234
internal/shell/cisco/output.go
Normal file
234
internal/shell/cisco/output.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package cisco
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func showVersion(s *iosState) string {
|
||||
days := 14 + rand.Intn(350)
|
||||
hours := rand.Intn(24)
|
||||
mins := rand.Intn(60)
|
||||
|
||||
return fmt.Sprintf(`Cisco IOS Software, %s Software (%s-UNIVERSALK9-M), Version %s, RELEASE SOFTWARE (fc3)
|
||||
Technical Support: http://www.cisco.com/techsupport
|
||||
Copyright (c) 1986-2019 by Cisco Systems, Inc.
|
||||
Compiled Thu 30-Jan-19 10:08 by prod_rel_team
|
||||
|
||||
ROM: Bootstrap program is %s boot loader
|
||||
BOOTLDR: %s Boot Loader (C2960-HBOOT-M) Version 15.0(2r)SE, RELEASE SOFTWARE (fc1)
|
||||
|
||||
%s uptime is %d days, %d hours, %d minutes
|
||||
System returned to ROM by power-on
|
||||
System image file is "flash:/%s-universalk9-mz.SPA.%s.bin"
|
||||
|
||||
This product contains cryptographic features and is subject to United States
|
||||
and local country laws governing import, export, transfer and use.
|
||||
|
||||
cisco %s (%s) processor (revision K0) with 524288K bytes of memory.
|
||||
Processor board ID %s
|
||||
Last reset from power-on
|
||||
2 Gigabit Ethernet interfaces
|
||||
1 Virtual Ethernet interface
|
||||
64K bytes of flash-simulated non-volatile configuration memory.
|
||||
Total of 65536K bytes of APC System Flash (Read/Write)
|
||||
|
||||
Configuration register is 0x2102`,
|
||||
s.model, s.model, s.iosVersion,
|
||||
s.model, s.model,
|
||||
s.hostname, days, hours, mins,
|
||||
s.model, s.iosVersion,
|
||||
s.model, processorForModel(s.model),
|
||||
s.serial,
|
||||
)
|
||||
}
|
||||
|
||||
func processorForModel(model string) string {
|
||||
if strings.HasPrefix(model, "C29") {
|
||||
return "PowerPC405"
|
||||
}
|
||||
return "MIPS"
|
||||
}
|
||||
|
||||
func showClock() string {
|
||||
now := time.Now().UTC()
|
||||
return fmt.Sprintf("*%s UTC", now.Format("15:04:05.000 Mon Jan 2 2006"))
|
||||
}
|
||||
|
||||
func showIPRoute(s *iosState) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("Codes: C - connected, S - static, R - RIP, M - mobile, B - BGP\n")
|
||||
b.WriteString(" D - EIGRP, EX - EIGRP external, O - OSPF, IA - OSPF inter area\n")
|
||||
b.WriteString(" N1 - OSPF NSSA external type 1, N2 - OSPF NSSA external type 2\n")
|
||||
b.WriteString(" E1 - OSPF external type 1, E2 - OSPF external type 2\n")
|
||||
b.WriteString(" i - IS-IS, su - IS-IS summary, L1 - IS-IS level-1, L2 - IS-IS level-2\n")
|
||||
b.WriteString(" ia - IS-IS inter area, * - candidate default, U - per-user static route\n")
|
||||
b.WriteString(" o - ODR, P - periodic downloaded static route\n\n")
|
||||
b.WriteString("Gateway of last resort is 10.0.0.2 to network 0.0.0.0\n\n")
|
||||
|
||||
for _, iface := range s.interfaces {
|
||||
if iface.ip == "unassigned" || iface.status != "up" {
|
||||
continue
|
||||
}
|
||||
network := networkFromIP(iface.ip, iface.mask)
|
||||
maskBits := maskBits(iface.mask)
|
||||
fmt.Fprintf(&b, "C %s/%d is directly connected, %s\n", network, maskBits, iface.name)
|
||||
}
|
||||
b.WriteString("S* 0.0.0.0/0 [1/0] via 10.0.0.2")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func showIPInterfaceBrief(s *iosState) string {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%-25s %-15s %-4s %-7s %-22s %s\n",
|
||||
"Interface", "IP-Address", "OK?", "Method", "Status", "Protocol")
|
||||
for _, iface := range s.interfaces {
|
||||
ip := iface.ip
|
||||
if ip == "" {
|
||||
ip = "unassigned"
|
||||
}
|
||||
fmt.Fprintf(&b, "%-25s %-15s YES manual %-22s %s\n",
|
||||
iface.name, ip, iface.status, iface.protocol)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func showInterfaces(s *iosState) string {
|
||||
var b strings.Builder
|
||||
for i, iface := range s.interfaces {
|
||||
if i > 0 {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
upDown := "up"
|
||||
if iface.shutdown {
|
||||
upDown = "administratively down"
|
||||
}
|
||||
fmt.Fprintf(&b, "%s is %s, line protocol is %s\n", iface.name, upDown, iface.protocol)
|
||||
fmt.Fprintf(&b, " Hardware is Gigabit Ethernet, address is %s (bia %s)\n", iface.mac, iface.mac)
|
||||
if iface.ip != "unassigned" && iface.ip != "" {
|
||||
fmt.Fprintf(&b, " Internet address is %s/%d\n", iface.ip, maskBits(iface.mask))
|
||||
}
|
||||
fmt.Fprintf(&b, " MTU %d bytes, BW %s sec, DLY 10 usec,\n", iface.mtu, iface.bandwidth)
|
||||
b.WriteString(" reliability 255/255, txload 1/255, rxload 1/255\n")
|
||||
b.WriteString(" Encapsulation ARPA, loopback not set\n")
|
||||
fmt.Fprintf(&b, " %d packets input, %d bytes, 0 no buffer\n", iface.rxPackets, iface.rxBytes)
|
||||
fmt.Fprintf(&b, " %d packets output, %d bytes, 0 underruns", iface.txPackets, iface.txBytes)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func showRunningConfig(s *iosState) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("Building configuration...\n\n")
|
||||
b.WriteString("Current configuration : 1482 bytes\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("! Last configuration change at 14:32:22 UTC Mon Feb 10 2025\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("version 15.0\n")
|
||||
b.WriteString("service timestamps debug datetime msec\n")
|
||||
b.WriteString("service timestamps log datetime msec\n")
|
||||
b.WriteString("no service password-encryption\n")
|
||||
b.WriteString("!\n")
|
||||
fmt.Fprintf(&b, "hostname %s\n", s.hostname)
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("boot-start-marker\n")
|
||||
b.WriteString("boot-end-marker\n")
|
||||
b.WriteString("!\n")
|
||||
if s.enablePass != "" {
|
||||
b.WriteString("enable secret 5 $1$mERr$hx5rVt7rPNoS4wqbXKX7m0\n")
|
||||
}
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("no aaa new-model\n")
|
||||
b.WriteString("!\n")
|
||||
|
||||
for _, iface := range s.interfaces {
|
||||
b.WriteString("!\n")
|
||||
fmt.Fprintf(&b, "interface %s\n", iface.name)
|
||||
if iface.desc != "" {
|
||||
fmt.Fprintf(&b, " description %s\n", iface.desc)
|
||||
}
|
||||
if iface.ip != "unassigned" && iface.ip != "" {
|
||||
fmt.Fprintf(&b, " ip address %s %s\n", iface.ip, iface.mask)
|
||||
} else {
|
||||
b.WriteString(" no ip address\n")
|
||||
}
|
||||
if iface.shutdown {
|
||||
b.WriteString(" shutdown\n")
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("ip forward-protocol nd\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("ip route 0.0.0.0 0.0.0.0 10.0.0.2\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("access-list 10 permit 192.168.1.0 0.0.0.255\n")
|
||||
b.WriteString("access-list 10 deny any\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("line con 0\n")
|
||||
b.WriteString(" logging synchronous\n")
|
||||
b.WriteString("line vty 0 4\n")
|
||||
b.WriteString(" login local\n")
|
||||
b.WriteString(" transport input ssh\n")
|
||||
b.WriteString("!\n")
|
||||
b.WriteString("end")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func showVLANBrief() string {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "VLAN", "Name", "Status", "Ports")
|
||||
b.WriteString("---- -------------------------------- --------- -------------------------------\n")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1", "default", "active", "Gi0/0, Gi0/1, Gi0/2")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "10", "MGMT", "active", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "20", "USERS", "active", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "99", "NATIVE", "active", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1002", "fddi-default", "act/unsup", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1003", "token-ring-default", "act/unsup", "")
|
||||
fmt.Fprintf(&b, "%-6s %-32s %-10s %s", "1004", "fddinet-default", "act/unsup", "")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// networkFromIP derives the network address from an IP and mask.
|
||||
func networkFromIP(ip, mask string) string {
|
||||
ipParts := parseIPv4(ip)
|
||||
maskParts := parseIPv4(mask)
|
||||
if ipParts == nil || maskParts == nil {
|
||||
return ip
|
||||
}
|
||||
return fmt.Sprintf("%d.%d.%d.%d",
|
||||
ipParts[0]&maskParts[0],
|
||||
ipParts[1]&maskParts[1],
|
||||
ipParts[2]&maskParts[2],
|
||||
ipParts[3]&maskParts[3],
|
||||
)
|
||||
}
|
||||
|
||||
func maskBits(mask string) int {
|
||||
parts := parseIPv4(mask)
|
||||
if parts == nil {
|
||||
return 24
|
||||
}
|
||||
bits := 0
|
||||
for _, p := range parts {
|
||||
for i := 7; i >= 0; i-- {
|
||||
if p&(1<<uint(i)) != 0 {
|
||||
bits++
|
||||
} else {
|
||||
return bits
|
||||
}
|
||||
}
|
||||
}
|
||||
return bits
|
||||
}
|
||||
|
||||
func parseIPv4(s string) []int {
|
||||
var a, b, c, d int
|
||||
n, _ := fmt.Sscanf(s, "%d.%d.%d.%d", &a, &b, &c, &d)
|
||||
if n != 4 {
|
||||
return nil
|
||||
}
|
||||
return []int{a, b, c, d}
|
||||
}
|
||||
109
internal/shell/cisco/state.go
Normal file
109
internal/shell/cisco/state.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package cisco
|
||||
|
||||
import "fmt"
|
||||
|
||||
// iosMode represents the current CLI mode of the IOS state machine.
|
||||
type iosMode int
|
||||
|
||||
const (
|
||||
modeUserExec iosMode = iota // Router>
|
||||
modePrivilegedExec // Router#
|
||||
modeGlobalConfig // Router(config)#
|
||||
modeInterfaceConfig // Router(config-if)#
|
||||
)
|
||||
|
||||
// ifaceInfo holds interface metadata for show commands.
|
||||
type ifaceInfo struct {
|
||||
name string
|
||||
ip string
|
||||
mask string
|
||||
status string
|
||||
protocol string
|
||||
mac string
|
||||
bandwidth string
|
||||
mtu int
|
||||
rxPackets int
|
||||
txPackets int
|
||||
rxBytes int
|
||||
txBytes int
|
||||
shutdown bool
|
||||
desc string
|
||||
}
|
||||
|
||||
// iosState holds all mutable state for the Cisco IOS shell session.
|
||||
type iosState struct {
|
||||
mode iosMode
|
||||
hostname string
|
||||
model string
|
||||
iosVersion string
|
||||
serial string
|
||||
enablePass string
|
||||
interfaces []ifaceInfo
|
||||
currentIf string
|
||||
}
|
||||
|
||||
func newIOSState(hostname, model, iosVersion, enablePass string) *iosState {
|
||||
return &iosState{
|
||||
mode: modeUserExec,
|
||||
hostname: hostname,
|
||||
model: model,
|
||||
iosVersion: iosVersion,
|
||||
serial: "FTX1524Z0P3",
|
||||
enablePass: enablePass,
|
||||
interfaces: defaultInterfaces(),
|
||||
}
|
||||
}
|
||||
|
||||
func defaultInterfaces() []ifaceInfo {
|
||||
return []ifaceInfo{
|
||||
{
|
||||
name: "GigabitEthernet0/0", ip: "192.168.1.1", mask: "255.255.255.0",
|
||||
status: "up", protocol: "up", mac: "0050.7966.6800",
|
||||
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||
rxPackets: 148253, txPackets: 93127, rxBytes: 19284732, txBytes: 8291043,
|
||||
},
|
||||
{
|
||||
name: "GigabitEthernet0/1", ip: "10.0.0.1", mask: "255.255.255.252",
|
||||
status: "up", protocol: "up", mac: "0050.7966.6801",
|
||||
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||
rxPackets: 52104, txPackets: 48891, rxBytes: 4182934, txBytes: 3901284,
|
||||
},
|
||||
{
|
||||
name: "GigabitEthernet0/2", ip: "unassigned", mask: "",
|
||||
status: "administratively down", protocol: "down", mac: "0050.7966.6802",
|
||||
bandwidth: "1000000 Kbit", mtu: 1500, shutdown: true,
|
||||
},
|
||||
{
|
||||
name: "Vlan1", ip: "172.16.0.1", mask: "255.255.0.0",
|
||||
status: "up", protocol: "up", mac: "0050.7966.6810",
|
||||
bandwidth: "1000000 Kbit", mtu: 1500,
|
||||
rxPackets: 8421, txPackets: 7103, rxBytes: 512384, txBytes: 423901,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// prompt returns the IOS prompt string for the current mode.
|
||||
func (s *iosState) prompt() string {
|
||||
switch s.mode {
|
||||
case modeUserExec:
|
||||
return fmt.Sprintf("%s>", s.hostname)
|
||||
case modePrivilegedExec:
|
||||
return fmt.Sprintf("%s#", s.hostname)
|
||||
case modeGlobalConfig:
|
||||
return fmt.Sprintf("%s(config)#", s.hostname)
|
||||
case modeInterfaceConfig:
|
||||
return fmt.Sprintf("%s(config-if)#", s.hostname)
|
||||
default:
|
||||
return fmt.Sprintf("%s>", s.hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// findInterface returns a pointer to the interface with the given name, or nil.
|
||||
func (s *iosState) findInterface(name string) *ifaceInfo {
|
||||
for i := range s.interfaces {
|
||||
if s.interfaces[i].name == name {
|
||||
return &s.interfaces[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
92
internal/shell/eventrecorder.go
Normal file
92
internal/shell/eventrecorder.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// EventRecorder buffers I/O events in memory and periodically flushes them to
|
||||
// a storage.Store. It is designed to be registered as a RecordingChannel
|
||||
// callback so that SSH I/O is never blocked by database writes.
|
||||
type EventRecorder struct {
|
||||
sessionID string
|
||||
store storage.Store
|
||||
logger *slog.Logger
|
||||
|
||||
mu sync.Mutex
|
||||
buf []storage.SessionEvent
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewEventRecorder creates a recorder that will persist events for the given session.
|
||||
func NewEventRecorder(sessionID string, store storage.Store, logger *slog.Logger) *EventRecorder {
|
||||
return &EventRecorder{
|
||||
sessionID: sessionID,
|
||||
store: store,
|
||||
logger: logger,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordEvent implements the EventCallback signature and appends an event to
|
||||
// the in-memory buffer. It is safe to call concurrently.
|
||||
func (er *EventRecorder) RecordEvent(ts time.Time, direction int, data []byte) {
|
||||
er.mu.Lock()
|
||||
defer er.mu.Unlock()
|
||||
er.buf = append(er.buf, storage.SessionEvent{
|
||||
SessionID: er.sessionID,
|
||||
Timestamp: ts,
|
||||
Direction: direction,
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
// Start begins the background flush goroutine that drains the buffer every 2 seconds.
|
||||
func (er *EventRecorder) Start(ctx context.Context) {
|
||||
ctx, er.cancel = context.WithCancel(ctx)
|
||||
go er.run(ctx)
|
||||
}
|
||||
|
||||
// Close cancels the background goroutine and performs a final flush.
|
||||
func (er *EventRecorder) Close() {
|
||||
if er.cancel != nil {
|
||||
er.cancel()
|
||||
}
|
||||
<-er.done
|
||||
}
|
||||
|
||||
func (er *EventRecorder) run(ctx context.Context) {
|
||||
defer close(er.done)
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
er.flush()
|
||||
return
|
||||
case <-ticker.C:
|
||||
er.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (er *EventRecorder) flush() {
|
||||
er.mu.Lock()
|
||||
if len(er.buf) == 0 {
|
||||
er.mu.Unlock()
|
||||
return
|
||||
}
|
||||
events := er.buf
|
||||
er.buf = nil
|
||||
er.mu.Unlock()
|
||||
|
||||
if err := er.store.AppendSessionEvents(context.Background(), events); err != nil {
|
||||
er.logger.Error("failed to flush session events", "err", err, "session_id", er.sessionID)
|
||||
}
|
||||
}
|
||||
80
internal/shell/eventrecorder_test.go
Normal file
80
internal/shell/eventrecorder_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func TestEventRecorderFlush(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a session so events have a valid session ID.
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
rec := NewEventRecorder(id, store, slog.Default())
|
||||
rec.Start(ctx)
|
||||
|
||||
// Record some events.
|
||||
now := time.Now()
|
||||
rec.RecordEvent(now, 0, []byte("hello"))
|
||||
rec.RecordEvent(now.Add(100*time.Millisecond), 1, []byte("world"))
|
||||
|
||||
// Close should trigger final flush.
|
||||
rec.Close()
|
||||
|
||||
events, err := store.GetSessionEvents(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionEvents: %v", err)
|
||||
}
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(events))
|
||||
}
|
||||
if string(events[0].Data) != "hello" {
|
||||
t.Errorf("events[0].Data = %q, want %q", events[0].Data, "hello")
|
||||
}
|
||||
if events[0].Direction != 0 {
|
||||
t.Errorf("events[0].Direction = %d, want 0", events[0].Direction)
|
||||
}
|
||||
if string(events[1].Data) != "world" {
|
||||
t.Errorf("events[1].Data = %q, want %q", events[1].Data, "world")
|
||||
}
|
||||
if events[1].Direction != 1 {
|
||||
t.Errorf("events[1].Direction = %d, want 1", events[1].Direction)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventRecorderPeriodicFlush(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
rec := NewEventRecorder(id, store, slog.Default())
|
||||
rec.Start(ctx)
|
||||
|
||||
// Record an event and wait for the periodic flush (2s + some margin).
|
||||
rec.RecordEvent(time.Now(), 1, []byte("periodic"))
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
events, err := store.GetSessionEvents(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionEvents: %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Errorf("expected periodic flush, got %d events", len(events))
|
||||
}
|
||||
|
||||
rec.Close()
|
||||
}
|
||||
352
internal/shell/fridge/fridge.go
Normal file
352
internal/shell/fridge/fridge.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package fridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// FridgeShell emulates a Samsung Smart Fridge OS interface.
|
||||
type FridgeShell struct{}
|
||||
|
||||
// NewFridgeShell returns a new FridgeShell instance.
|
||||
func NewFridgeShell() *FridgeShell {
|
||||
return &FridgeShell{}
|
||||
}
|
||||
|
||||
func (f *FridgeShell) Name() string { return "fridge" }
|
||||
func (f *FridgeShell) Description() string { return "Samsung Smart Fridge shell emulator" }
|
||||
|
||||
func (f *FridgeShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
state := newFridgeState()
|
||||
|
||||
// Boot banner — convert \n to \r\n for terminal display.
|
||||
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
|
||||
fmt.Fprint(rw, banner)
|
||||
|
||||
for {
|
||||
if _, err := fmt.Fprint(rw, "FridgeOS> "); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Fprint(rw, "logout\r\n")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := state.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
// Log command and output to store.
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("fridge")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bootBanner() string {
|
||||
now := time.Now()
|
||||
defrost := now.Add(-3*time.Hour - 22*time.Minute).Format("2006-01-02 15:04")
|
||||
return fmt.Sprintf(`
|
||||
_____ ____ ___ ____ ____ _____ ___ ____
|
||||
| ___| _ \|_ _| _ \ / ___| ____/ _ \/ ___|
|
||||
| |_ | |_) || || | | | | _| _|| | | \___ \
|
||||
| _| | _ < | || |_| | |_| | |__| |_| |___) |
|
||||
|_| |_| \_\___|____/ \____|_____\___/|____/
|
||||
|
||||
Samsung Smart Fridge OS v3.2.1 (FridgeOS-ARM)
|
||||
Model: RF28R7351SR | Serial: SN-2847-FRDG-9182
|
||||
Firmware: 3.2.1-stable | Last defrost: %s
|
||||
|
||||
Type 'help' for available commands.
|
||||
|
||||
`, defrost)
|
||||
}
|
||||
|
||||
type fridgeState struct {
|
||||
inventory []inventoryItem
|
||||
fridgeF int // fridge temp in °F
|
||||
freezerF int // freezer temp in °F
|
||||
}
|
||||
|
||||
type inventoryItem struct {
|
||||
name string
|
||||
expiry string
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
func newFridgeState() *fridgeState {
|
||||
return &fridgeState{
|
||||
inventory: []inventoryItem{
|
||||
{"Whole Milk (1 gal)", time.Now().Add(48 * time.Hour).Format("2006-01-02")},
|
||||
{"Eggs (dozen)", time.Now().Add(7 * 24 * time.Hour).Format("2006-01-02")},
|
||||
{"Leftover Pizza (3 slices)", time.Now().Add(24 * time.Hour).Format("2006-01-02")},
|
||||
{"Orange Juice", time.Now().Add(5 * 24 * time.Hour).Format("2006-01-02")},
|
||||
{"Butter (unsalted)", time.Now().Add(30 * 24 * time.Hour).Format("2006-01-02")},
|
||||
{"Mystery Tupperware", time.Now().Add(-14 * 24 * time.Hour).Format("2006-01-02")},
|
||||
},
|
||||
fridgeF: 37,
|
||||
freezerF: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fridgeState) dispatch(input string) commandResult {
|
||||
parts := strings.Fields(input)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
cmd := strings.ToLower(parts[0])
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "help":
|
||||
return s.cmdHelp()
|
||||
case "inventory":
|
||||
return s.cmdInventory(args)
|
||||
case "temp", "temperature":
|
||||
return s.cmdTemp(args)
|
||||
case "status":
|
||||
return s.cmdStatus()
|
||||
case "diagnostics":
|
||||
return s.cmdDiagnostics()
|
||||
case "alerts":
|
||||
return s.cmdAlerts()
|
||||
case "reboot":
|
||||
return s.cmdReboot()
|
||||
case "exit", "logout":
|
||||
return commandResult{output: "Goodbye! Keep your food fresh!", exit: true}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("FridgeOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fridgeState) cmdHelp() commandResult {
|
||||
help := `Available commands:
|
||||
help - Show this help message
|
||||
inventory - List fridge contents
|
||||
inventory add <item> - Add item to inventory
|
||||
inventory remove <item> - Remove item from inventory
|
||||
temp - Show current temperatures
|
||||
temp set <zone> <value> - Set temperature (zone: fridge|freezer)
|
||||
status - Show system status
|
||||
diagnostics - Run system diagnostics
|
||||
alerts - Show active alerts
|
||||
reboot - Reboot FridgeOS
|
||||
exit / logout - Disconnect`
|
||||
return commandResult{output: help}
|
||||
}
|
||||
|
||||
func (s *fridgeState) cmdInventory(args []string) commandResult {
|
||||
if len(args) == 0 || strings.ToLower(args[0]) == "list" {
|
||||
return s.inventoryList()
|
||||
}
|
||||
|
||||
sub := strings.ToLower(args[0])
|
||||
switch sub {
|
||||
case "add":
|
||||
if len(args) < 2 {
|
||||
return commandResult{output: "Usage: inventory add <item>"}
|
||||
}
|
||||
item := strings.Join(args[1:], " ")
|
||||
return s.inventoryAdd(item)
|
||||
case "remove":
|
||||
if len(args) < 2 {
|
||||
return commandResult{output: "Usage: inventory remove <item>"}
|
||||
}
|
||||
item := strings.Join(args[1:], " ")
|
||||
return s.inventoryRemove(item)
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Unknown inventory subcommand '%s'. Try: list, add, remove", sub)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fridgeState) inventoryList() commandResult {
|
||||
if len(s.inventory) == 0 {
|
||||
return commandResult{output: "Inventory is empty."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Fridge Inventory ===\n")
|
||||
b.WriteString(fmt.Sprintf("%-30s %s\n", "ITEM", "EXPIRES"))
|
||||
b.WriteString(fmt.Sprintf("%-30s %s\n", "----", "-------"))
|
||||
for _, item := range s.inventory {
|
||||
b.WriteString(fmt.Sprintf("%-30s %s\n", item.name, item.expiry))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\nTotal items: %d", len(s.inventory)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *fridgeState) inventoryAdd(item string) commandResult {
|
||||
expiry := time.Now().Add(7 * 24 * time.Hour).Format("2006-01-02")
|
||||
s.inventory = append(s.inventory, inventoryItem{name: item, expiry: expiry})
|
||||
return commandResult{output: fmt.Sprintf("Added '%s' to inventory (expires: %s).", item, expiry)}
|
||||
}
|
||||
|
||||
func (s *fridgeState) inventoryRemove(item string) commandResult {
|
||||
lower := strings.ToLower(item)
|
||||
for i, inv := range s.inventory {
|
||||
if strings.ToLower(inv.name) == lower || strings.Contains(strings.ToLower(inv.name), lower) {
|
||||
s.inventory = append(s.inventory[:i], s.inventory[i+1:]...)
|
||||
return commandResult{output: fmt.Sprintf("Removed '%s' from inventory.", inv.name)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("Item '%s' not found in inventory.", item)}
|
||||
}
|
||||
|
||||
func (s *fridgeState) cmdTemp(args []string) commandResult {
|
||||
if len(args) == 0 {
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"=== Temperature Status ===\nFridge: %d°F (%.1f°C)\nFreezer: %d°F (%.1f°C)",
|
||||
s.fridgeF, fToC(s.fridgeF), s.freezerF, fToC(s.freezerF),
|
||||
)}
|
||||
}
|
||||
|
||||
if strings.ToLower(args[0]) != "set" || len(args) < 3 {
|
||||
return commandResult{output: "Usage: temp set <fridge|freezer> <value_in_F>"}
|
||||
}
|
||||
|
||||
zone := strings.ToLower(args[1])
|
||||
var val int
|
||||
if _, err := fmt.Sscanf(args[2], "%d", &val); err != nil {
|
||||
return commandResult{output: "Invalid temperature value. Must be an integer."}
|
||||
}
|
||||
|
||||
switch zone {
|
||||
case "fridge":
|
||||
if val < 33 || val > 45 {
|
||||
return commandResult{output: fmt.Sprintf("WARNING: Temperature %d°F is out of safe range (33-45°F). Setting rejected.", val)}
|
||||
}
|
||||
s.fridgeF = val
|
||||
return commandResult{output: fmt.Sprintf("Fridge temperature set to %d°F (%.1f°C).", val, fToC(val))}
|
||||
case "freezer":
|
||||
if val < -10 || val > 10 {
|
||||
return commandResult{output: fmt.Sprintf("WARNING: Temperature %d°F is out of safe range (-10 to 10°F). Setting rejected.", val)}
|
||||
}
|
||||
s.freezerF = val
|
||||
return commandResult{output: fmt.Sprintf("Freezer temperature set to %d°F (%.1f°C).", val, fToC(val))}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Unknown zone '%s'. Use 'fridge' or 'freezer'.", zone)}
|
||||
}
|
||||
}
|
||||
|
||||
func fToC(f int) float64 {
|
||||
return float64(f-32) * 5.0 / 9.0
|
||||
}
|
||||
|
||||
func (s *fridgeState) cmdStatus() commandResult {
|
||||
status := `=== FridgeOS System Status ===
|
||||
Compressor: Running
|
||||
Door seal: OK
|
||||
Ice maker: Active
|
||||
Water filter: 82% remaining
|
||||
|
||||
WiFi: Connected (SmartHome-5G)
|
||||
Signal: -42 dBm
|
||||
Internal camera: Online (3 objects detected)
|
||||
Voice assistant: Standby
|
||||
TikTok recipes: Enabled
|
||||
Spotify: "Chill Vibes" playlist paused
|
||||
|
||||
Energy rating: A++
|
||||
Power: 127W
|
||||
SmartHome Hub: Connected (12 devices)
|
||||
|
||||
Firmware: v3.2.1-stable
|
||||
Update available: v3.3.0-beta`
|
||||
return commandResult{output: status}
|
||||
}
|
||||
|
||||
func (s *fridgeState) cmdDiagnostics() commandResult {
|
||||
diag := `Running FridgeOS diagnostics...
|
||||
|
||||
[1/6] Compressor.............. OK
|
||||
[2/6] Temperature sensors..... OK
|
||||
[3/6] Door seal integrity..... OK
|
||||
[4/6] Ice maker............... OK
|
||||
[5/6] Network connectivity.... OK
|
||||
[6/6] Internal camera......... OK
|
||||
|
||||
ALL SYSTEMS NOMINAL`
|
||||
return commandResult{output: diag}
|
||||
}
|
||||
|
||||
func (s *fridgeState) cmdAlerts() commandResult {
|
||||
// Build dynamic alerts based on inventory.
|
||||
var alerts []string
|
||||
for _, item := range s.inventory {
|
||||
expiry, err := time.Parse("2006-01-02", item.expiry)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
days := int(time.Until(expiry).Hours() / 24)
|
||||
if days < 0 {
|
||||
alerts = append(alerts, fmt.Sprintf("CRITICAL: %s expired %d day(s) ago!", item.name, -days))
|
||||
} else if days <= 2 {
|
||||
alerts = append(alerts, fmt.Sprintf("WARNING: %s expires in %d day(s)", item.name, days))
|
||||
}
|
||||
}
|
||||
alerts = append(alerts,
|
||||
"INFO: Ice maker: low water pressure detected",
|
||||
"INFO: Firmware update available: v3.3.0-beta",
|
||||
"INFO: TikTok recipe sync overdue (last sync: 3 days ago)",
|
||||
)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Active Alerts ===\n")
|
||||
for _, a := range alerts {
|
||||
b.WriteString(a + "\n")
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *fridgeState) cmdReboot() commandResult {
|
||||
reboot := `FridgeOS is rebooting...
|
||||
|
||||
Stopping services........... done
|
||||
Saving inventory data....... done
|
||||
Flushing temperature log.... done
|
||||
Unmounting partitions....... done
|
||||
|
||||
Rebooting now. Goodbye!`
|
||||
return commandResult{output: reboot, exit: true}
|
||||
}
|
||||
233
internal/shell/fridge/fridge_test.go
Normal file
233
internal/shell/fridge/fridge_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
package fridge
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (r *rwCloser) Close() error { return nil }
|
||||
|
||||
func runShell(t *testing.T, commands string) string {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "root",
|
||||
Store: store,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: "testhost",
|
||||
},
|
||||
}
|
||||
|
||||
rw := &rwCloser{
|
||||
Reader: bytes.NewBufferString(commands),
|
||||
Writer: &bytes.Buffer{},
|
||||
}
|
||||
|
||||
sh := NewFridgeShell()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := sh.Handle(ctx, sess, rw); err != nil {
|
||||
t.Fatalf("Handle: %v", err)
|
||||
}
|
||||
|
||||
return rw.Writer.(*bytes.Buffer).String()
|
||||
}
|
||||
|
||||
func TestFridgeShellName(t *testing.T) {
|
||||
sh := NewFridgeShell()
|
||||
if sh.Name() != "fridge" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "fridge")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootBanner(t *testing.T) {
|
||||
output := runShell(t, "exit\r")
|
||||
if !strings.Contains(output, "FridgeOS-ARM") {
|
||||
t.Error("output should contain FridgeOS-ARM in banner")
|
||||
}
|
||||
if !strings.Contains(output, "Samsung Smart Fridge OS") {
|
||||
t.Error("output should contain Samsung Smart Fridge OS")
|
||||
}
|
||||
if !strings.Contains(output, "FridgeOS>") {
|
||||
t.Error("output should contain FridgeOS> prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelpCommand(t *testing.T) {
|
||||
output := runShell(t, "help\rexit\r")
|
||||
for _, keyword := range []string{"inventory", "temp", "status", "diagnostics", "alerts", "reboot", "exit"} {
|
||||
if !strings.Contains(output, keyword) {
|
||||
t.Errorf("help output should mention %q", keyword)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInventoryList(t *testing.T) {
|
||||
output := runShell(t, "inventory\rexit\r")
|
||||
if !strings.Contains(output, "Fridge Inventory") {
|
||||
t.Error("should show inventory header")
|
||||
}
|
||||
if !strings.Contains(output, "Whole Milk") {
|
||||
t.Error("should list milk")
|
||||
}
|
||||
if !strings.Contains(output, "Eggs") {
|
||||
t.Error("should list eggs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInventoryAdd(t *testing.T) {
|
||||
output := runShell(t, "inventory add Cheese\rinventory\rexit\r")
|
||||
if !strings.Contains(output, "Added 'Cheese'") {
|
||||
t.Error("should confirm adding cheese")
|
||||
}
|
||||
if !strings.Contains(output, "Cheese") {
|
||||
t.Error("inventory list should contain cheese")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInventoryRemove(t *testing.T) {
|
||||
output := runShell(t, "inventory remove milk\rinventory\rexit\r")
|
||||
if !strings.Contains(output, "Removed") {
|
||||
t.Error("should confirm removal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
output := runShell(t, "temp\rexit\r")
|
||||
if !strings.Contains(output, "37") {
|
||||
t.Error("should show fridge temp 37°F")
|
||||
}
|
||||
if !strings.Contains(output, "Fridge") {
|
||||
t.Error("should label fridge zone")
|
||||
}
|
||||
if !strings.Contains(output, "Freezer") {
|
||||
t.Error("should label freezer zone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTempSetValid(t *testing.T) {
|
||||
output := runShell(t, "temp set fridge 40\rtemp\rexit\r")
|
||||
if !strings.Contains(output, "set to 40") {
|
||||
t.Errorf("should confirm temp set, got: %s", output)
|
||||
}
|
||||
// Second temp call should show 40.
|
||||
if !strings.Contains(output, "40") {
|
||||
t.Error("temperature should now be 40")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTempSetOutOfRange(t *testing.T) {
|
||||
output := runShell(t, "temp set fridge 100\rexit\r")
|
||||
if !strings.Contains(output, "WARNING") {
|
||||
t.Error("should warn about out-of-range temp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTempSetFreezerOutOfRange(t *testing.T) {
|
||||
output := runShell(t, "temp set freezer 50\rexit\r")
|
||||
if !strings.Contains(output, "WARNING") {
|
||||
t.Error("should warn about out-of-range freezer temp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatus(t *testing.T) {
|
||||
output := runShell(t, "status\rexit\r")
|
||||
for _, keyword := range []string{"Compressor", "WiFi", "Ice maker", "TikTok", "Spotify", "SmartHome"} {
|
||||
if !strings.Contains(output, keyword) {
|
||||
t.Errorf("status should contain %q", keyword)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnostics(t *testing.T) {
|
||||
output := runShell(t, "diagnostics\rexit\r")
|
||||
if !strings.Contains(output, "ALL SYSTEMS NOMINAL") {
|
||||
t.Error("diagnostics should end with ALL SYSTEMS NOMINAL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlerts(t *testing.T) {
|
||||
output := runShell(t, "alerts\rexit\r")
|
||||
if !strings.Contains(output, "Active Alerts") {
|
||||
t.Error("should show alerts header")
|
||||
}
|
||||
if !strings.Contains(output, "Firmware update") {
|
||||
t.Error("should mention firmware update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReboot(t *testing.T) {
|
||||
output := runShell(t, "reboot\r")
|
||||
if !strings.Contains(output, "rebooting") || !strings.Contains(output, "Rebooting") {
|
||||
t.Error("should show reboot message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnknownCommand(t *testing.T) {
|
||||
output := runShell(t, "foobar\rexit\r")
|
||||
if !strings.Contains(output, "unknown command") {
|
||||
t.Error("should show unknown command message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitCommand(t *testing.T) {
|
||||
output := runShell(t, "exit\r")
|
||||
if !strings.Contains(output, "Goodbye") {
|
||||
t.Error("exit should show goodbye message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogoutCommand(t *testing.T) {
|
||||
output := runShell(t, "logout\r")
|
||||
if !strings.Contains(output, "Goodbye") {
|
||||
t.Error("logout should show goodbye message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "root",
|
||||
Store: store,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: "testhost",
|
||||
},
|
||||
}
|
||||
|
||||
rw := &rwCloser{
|
||||
Reader: bytes.NewBufferString("help\rexit\r"),
|
||||
Writer: &bytes.Buffer{},
|
||||
}
|
||||
|
||||
sh := NewFridgeShell()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sh.Handle(ctx, sess, rw)
|
||||
|
||||
if len(store.SessionLogs) < 2 {
|
||||
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||
}
|
||||
}
|
||||
123
internal/shell/psql/commands.go
Normal file
123
internal/shell/psql/commands.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// commandResult holds the output of a command and whether the session should end.
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
// dispatchBackslash handles psql backslash meta-commands.
|
||||
func dispatchBackslash(cmd, dbName string) commandResult {
|
||||
// Normalize: trim spaces after the backslash command word.
|
||||
parts := strings.Fields(cmd)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{output: "Invalid command \\. Try \\? for help."}
|
||||
}
|
||||
|
||||
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
|
||||
args := parts[1:]
|
||||
|
||||
switch verb {
|
||||
case `\q`:
|
||||
return commandResult{exit: true}
|
||||
case `\dt`:
|
||||
return commandResult{output: listTables()}
|
||||
case `\d`:
|
||||
if len(args) == 0 {
|
||||
return commandResult{output: listTables()}
|
||||
}
|
||||
return commandResult{output: describeTable(args[0])}
|
||||
case `\l`:
|
||||
return commandResult{output: listDatabases()}
|
||||
case `\du`:
|
||||
return commandResult{output: listRoles()}
|
||||
case `\conninfo`:
|
||||
return commandResult{output: connInfo(dbName)}
|
||||
case `\?`:
|
||||
return commandResult{output: backslashHelp()}
|
||||
case `\h`:
|
||||
return commandResult{output: sqlHelp()}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
|
||||
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
|
||||
// Strip trailing semicolon and whitespace for matching.
|
||||
trimmed := strings.TrimRight(sql, "; \t")
|
||||
trimmed = strings.TrimSpace(trimmed)
|
||||
upper := strings.ToUpper(trimmed)
|
||||
|
||||
switch {
|
||||
case upper == "SELECT VERSION()":
|
||||
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
|
||||
return commandResult{output: formatSingleValue("version", ver)}
|
||||
case upper == "SELECT CURRENT_DATABASE()":
|
||||
return commandResult{output: formatSingleValue("current_database", dbName)}
|
||||
case upper == "SELECT CURRENT_USER":
|
||||
return commandResult{output: formatSingleValue("current_user", "postgres")}
|
||||
case upper == "SELECT NOW()":
|
||||
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
|
||||
return commandResult{output: formatSingleValue("now", now)}
|
||||
case upper == "SELECT 1":
|
||||
return commandResult{output: formatSingleValue("?column?", "1")}
|
||||
case strings.HasPrefix(upper, "INSERT"):
|
||||
return commandResult{output: "INSERT 0 1"}
|
||||
case strings.HasPrefix(upper, "UPDATE"):
|
||||
return commandResult{output: "UPDATE 1"}
|
||||
case strings.HasPrefix(upper, "DELETE"):
|
||||
return commandResult{output: "DELETE 1"}
|
||||
case strings.HasPrefix(upper, "CREATE TABLE"):
|
||||
return commandResult{output: "CREATE TABLE"}
|
||||
case strings.HasPrefix(upper, "CREATE DATABASE"):
|
||||
return commandResult{output: "CREATE DATABASE"}
|
||||
case strings.HasPrefix(upper, "DROP TABLE"):
|
||||
return commandResult{output: "DROP TABLE"}
|
||||
case strings.HasPrefix(upper, "ALTER TABLE"):
|
||||
return commandResult{output: "ALTER TABLE"}
|
||||
case upper == "BEGIN":
|
||||
return commandResult{output: "BEGIN"}
|
||||
case upper == "COMMIT":
|
||||
return commandResult{output: "COMMIT"}
|
||||
case upper == "ROLLBACK":
|
||||
return commandResult{output: "ROLLBACK"}
|
||||
case upper == "SHOW SERVER_VERSION":
|
||||
return commandResult{output: formatSingleValue("server_version", pgVersion)}
|
||||
case upper == "SHOW SEARCH_PATH":
|
||||
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
|
||||
case strings.HasPrefix(upper, "SET "):
|
||||
return commandResult{output: "SET"}
|
||||
default:
|
||||
// Extract the first token for the error message.
|
||||
firstToken := strings.Fields(trimmed)
|
||||
token := trimmed
|
||||
if len(firstToken) > 0 {
|
||||
token = firstToken[0]
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
|
||||
}
|
||||
}
|
||||
|
||||
// formatSingleValue formats a single-row, single-column psql result.
|
||||
func formatSingleValue(colName, value string) string {
|
||||
width := max(len(colName), len(value))
|
||||
|
||||
var b strings.Builder
|
||||
// Header
|
||||
fmt.Fprintf(&b, " %-*s \n", width, colName)
|
||||
// Separator
|
||||
b.WriteString(strings.Repeat("-", width+2))
|
||||
b.WriteString("\n")
|
||||
// Value
|
||||
fmt.Fprintf(&b, " %-*s\n", width, value)
|
||||
// Row count
|
||||
b.WriteString("(1 row)")
|
||||
return b.String()
|
||||
}
|
||||
155
internal/shell/psql/output.go
Normal file
155
internal/shell/psql/output.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package psql
|
||||
|
||||
import "fmt"
|
||||
|
||||
func startupBanner(version string) string {
|
||||
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
|
||||
}
|
||||
|
||||
func listTables() string {
|
||||
return ` List of relations
|
||||
Schema | Name | Type | Owner
|
||||
--------+---------------+-------+----------
|
||||
public | audit_log | table | postgres
|
||||
public | credentials | table | postgres
|
||||
public | sessions | table | postgres
|
||||
public | users | table | postgres
|
||||
(4 rows)`
|
||||
}
|
||||
|
||||
func listDatabases() string {
|
||||
return ` List of databases
|
||||
Name | Owner | Encoding | Collate | Ctype | Access privileges
|
||||
-----------+----------+----------+-------------+-------------+-----------------------
|
||||
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||
| | | | | postgres=CTc/postgres
|
||||
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||
| | | | | postgres=CTc/postgres
|
||||
(4 rows)`
|
||||
}
|
||||
|
||||
func listRoles() string {
|
||||
return ` List of roles
|
||||
Role name | Attributes | Member of
|
||||
-----------+------------------------------------------------------------+-----------
|
||||
app_user | | {}
|
||||
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
|
||||
readonly | Cannot login | {}`
|
||||
}
|
||||
|
||||
func describeTable(name string) string {
|
||||
switch name {
|
||||
case "users":
|
||||
return ` Table "public.users"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+-----------------------------------
|
||||
id | integer | | not null | nextval('users_id_seq'::regclass)
|
||||
username | character varying(255) | | not null |
|
||||
email | character varying(255) | | not null |
|
||||
password | character varying(255) | | not null |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
updated_at | timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"users_pkey" PRIMARY KEY, btree (id)
|
||||
"users_email_key" UNIQUE, btree (email)
|
||||
"users_username_key" UNIQUE, btree (username)`
|
||||
case "sessions":
|
||||
return ` Table "public.sessions"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+--------------------------------------
|
||||
id | integer | | not null | nextval('sessions_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
token | character varying(255) | | not null |
|
||||
ip_address | inet | | |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
expires_at | timestamp without time zone | | not null |
|
||||
Indexes:
|
||||
"sessions_pkey" PRIMARY KEY, btree (id)
|
||||
"sessions_token_key" UNIQUE, btree (token)
|
||||
Foreign-key constraints:
|
||||
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
case "credentials":
|
||||
return ` Table "public.credentials"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
-----------+-----------------------------+-----------+----------+-----------------------------------------
|
||||
id | integer | | not null | nextval('credentials_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
type | character varying(50) | | not null |
|
||||
value | text | | not null |
|
||||
created_at| timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"credentials_pkey" PRIMARY KEY, btree (id)
|
||||
Foreign-key constraints:
|
||||
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
case "audit_log":
|
||||
return ` Table "public.audit_log"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+---------------------------------------
|
||||
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
action | character varying(100) | | not null |
|
||||
details | text | | |
|
||||
ip_address | inet | | |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"audit_log_pkey" PRIMARY KEY, btree (id)
|
||||
Foreign-key constraints:
|
||||
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
default:
|
||||
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
|
||||
}
|
||||
}
|
||||
|
||||
func connInfo(dbName string) string {
|
||||
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
|
||||
}
|
||||
|
||||
func backslashHelp() string {
|
||||
return `General
|
||||
\copyright show PostgreSQL usage and distribution terms
|
||||
\crosstabview [COLUMNS] execute query and display result in crosstab
|
||||
\errverbose show most recent error message at maximum verbosity
|
||||
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
|
||||
\gdesc describe result of query, without executing it
|
||||
\gexec execute query, then execute each value in its result
|
||||
\gset [PREFIX] execute query and store result in psql variables
|
||||
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
|
||||
\q quit psql
|
||||
\watch [SEC] execute query every SEC seconds
|
||||
|
||||
Informational
|
||||
(options: S = show system objects, + = additional detail)
|
||||
\d[S+] list tables, views, and sequences
|
||||
\d[S+] NAME describe table, view, sequence, or index
|
||||
\da[S] [PATTERN] list aggregates
|
||||
\dA[+] [PATTERN] list access methods
|
||||
\dt[S+] [PATTERN] list tables
|
||||
\du[S+] [PATTERN] list roles
|
||||
\l[+] [PATTERN] list databases`
|
||||
}
|
||||
|
||||
func sqlHelp() string {
|
||||
return `Available help:
|
||||
ABORT CREATE LANGUAGE
|
||||
ALTER AGGREGATE CREATE MATERIALIZED VIEW
|
||||
ALTER COLLATION CREATE OPERATOR
|
||||
ALTER CONVERSION CREATE POLICY
|
||||
ALTER DATABASE CREATE PROCEDURE
|
||||
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
|
||||
ALTER DOMAIN CREATE ROLE
|
||||
ALTER EVENT TRIGGER CREATE RULE
|
||||
ALTER EXTENSION CREATE SCHEMA
|
||||
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
|
||||
ALTER FOREIGN TABLE CREATE SERVER
|
||||
ALTER FUNCTION CREATE STATISTICS
|
||||
ALTER GROUP CREATE SUBSCRIPTION
|
||||
ALTER INDEX CREATE TABLE
|
||||
ALTER LANGUAGE CREATE TABLESPACE
|
||||
BEGIN DELETE
|
||||
COMMIT DROP TABLE
|
||||
CREATE DATABASE INSERT
|
||||
CREATE INDEX ROLLBACK
|
||||
SELECT UPDATE`
|
||||
}
|
||||
137
internal/shell/psql/psql.go
Normal file
137
internal/shell/psql/psql.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// PsqlShell emulates a PostgreSQL psql interactive terminal.
|
||||
type PsqlShell struct{}
|
||||
|
||||
// NewPsqlShell returns a new PsqlShell instance.
|
||||
func NewPsqlShell() *PsqlShell {
|
||||
return &PsqlShell{}
|
||||
}
|
||||
|
||||
func (p *PsqlShell) Name() string { return "psql" }
|
||||
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
|
||||
|
||||
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dbName := configString(sess.ShellConfig, "db_name", "postgres")
|
||||
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
|
||||
|
||||
// Print startup banner.
|
||||
fmt.Fprint(rw, startupBanner(pgVersion))
|
||||
|
||||
var sqlBuf []string // accumulates multi-line SQL
|
||||
|
||||
for {
|
||||
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
|
||||
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Empty line in non-buffering state: just re-prompt.
|
||||
if trimmed == "" && len(sqlBuf) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
|
||||
if strings.HasPrefix(trimmed, `\`) {
|
||||
sqlBuf = nil // discard any partial SQL
|
||||
|
||||
result := dispatchBackslash(trimmed, dbName)
|
||||
if result.output != "" {
|
||||
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("psql")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate SQL lines.
|
||||
sqlBuf = append(sqlBuf, line)
|
||||
|
||||
// Check if the statement is terminated by a semicolon.
|
||||
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Full statement ready — join and dispatch.
|
||||
fullSQL := strings.Join(sqlBuf, " ")
|
||||
sqlBuf = nil
|
||||
|
||||
result := dispatchSQL(fullSQL, dbName, pgVersion)
|
||||
if result.output != "" {
|
||||
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("psql")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
|
||||
func buildPrompt(dbName string, continuation bool) string {
|
||||
if continuation {
|
||||
return dbName + "-# "
|
||||
}
|
||||
return dbName + "=# "
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
330
internal/shell/psql/psql_test.go
Normal file
330
internal/shell/psql/psql_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Prompt tests ---
|
||||
|
||||
func TestBuildPromptNormal(t *testing.T) {
|
||||
got := buildPrompt("postgres", false)
|
||||
if got != "postgres=# " {
|
||||
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptContinuation(t *testing.T) {
|
||||
got := buildPrompt("postgres", true)
|
||||
if got != "postgres-# " {
|
||||
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptCustomDB(t *testing.T) {
|
||||
got := buildPrompt("mydb", false)
|
||||
if got != "mydb=# " {
|
||||
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Backslash command dispatch tests ---
|
||||
|
||||
func TestBackslashQuit(t *testing.T) {
|
||||
result := dispatchBackslash(`\q`, "postgres")
|
||||
if !result.exit {
|
||||
t.Error("\\q should set exit=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListTables(t *testing.T) {
|
||||
result := dispatchBackslash(`\dt`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\dt should list tables including 'users'")
|
||||
}
|
||||
if !strings.Contains(result.output, "sessions") {
|
||||
t.Error("\\dt should list tables including 'sessions'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d users`, "postgres")
|
||||
if !strings.Contains(result.output, "username") {
|
||||
t.Error("\\d users should describe users table with 'username' column")
|
||||
}
|
||||
if !strings.Contains(result.output, "PRIMARY KEY") {
|
||||
t.Error("\\d users should include index info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeUnknownTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d nonexistent`, "postgres")
|
||||
if !strings.Contains(result.output, "Did not find") {
|
||||
t.Error("\\d nonexistent should return not found message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListDatabases(t *testing.T) {
|
||||
result := dispatchBackslash(`\l`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\l should list databases including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "template0") {
|
||||
t.Error("\\l should list databases including 'template0'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListRoles(t *testing.T) {
|
||||
result := dispatchBackslash(`\du`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\du should list roles including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "Superuser") {
|
||||
t.Error("\\du should show Superuser attribute for postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashConnInfo(t *testing.T) {
|
||||
result := dispatchBackslash(`\conninfo`, "mydb")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("\\conninfo should include database name")
|
||||
}
|
||||
if !strings.Contains(result.output, "5432") {
|
||||
t.Error("\\conninfo should include port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\?`, "postgres")
|
||||
if !strings.Contains(result.output, `\q`) {
|
||||
t.Error("\\? should include \\q in help output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashSQLHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\h`, "postgres")
|
||||
if !strings.Contains(result.output, "SELECT") {
|
||||
t.Error("\\h should include SQL commands like SELECT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashUnknown(t *testing.T) {
|
||||
result := dispatchBackslash(`\xyz`, "postgres")
|
||||
if !strings.Contains(result.output, "Invalid command") {
|
||||
t.Error("unknown backslash command should return error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SQL dispatch tests ---
|
||||
|
||||
func TestSQLSelectVersion(t *testing.T) {
|
||||
result := dispatchSQL("SELECT version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SELECT version() should contain pg version")
|
||||
}
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT version() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentDatabase(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("SELECT current_database() should return db name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentUser(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("SELECT current_user should return postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectNow(t *testing.T) {
|
||||
result := dispatchSQL("SELECT now();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT now() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectOne(t *testing.T) {
|
||||
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "1") {
|
||||
t.Error("SELECT 1 should return 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLInsert(t *testing.T) {
|
||||
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
|
||||
if result.output != "INSERT 0 1" {
|
||||
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUpdate(t *testing.T) {
|
||||
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
|
||||
if result.output != "UPDATE 1" {
|
||||
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDelete(t *testing.T) {
|
||||
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
|
||||
if result.output != "DELETE 1" {
|
||||
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateTable(t *testing.T) {
|
||||
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
|
||||
if result.output != "CREATE TABLE" {
|
||||
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateDatabase(t *testing.T) {
|
||||
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
|
||||
if result.output != "CREATE DATABASE" {
|
||||
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDropTable(t *testing.T) {
|
||||
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
|
||||
if result.output != "DROP TABLE" {
|
||||
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLAlterTable(t *testing.T) {
|
||||
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
|
||||
if result.output != "ALTER TABLE" {
|
||||
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLBeginCommitRollback(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
want string
|
||||
}{
|
||||
{"BEGIN;", "BEGIN"},
|
||||
{"COMMIT;", "COMMIT"},
|
||||
{"ROLLBACK;", "ROLLBACK"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := dispatchSQL(tt.sql, "postgres", "15.4")
|
||||
if result.output != tt.want {
|
||||
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowServerVersion(t *testing.T) {
|
||||
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SHOW server_version should contain version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowSearchPath(t *testing.T) {
|
||||
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "public") {
|
||||
t.Error("SHOW search_path should contain public")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSet(t *testing.T) {
|
||||
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
|
||||
if result.output != "SET" {
|
||||
t.Errorf("SET output = %q, want %q", result.output, "SET")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUnrecognized(t *testing.T) {
|
||||
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "ERROR") {
|
||||
t.Error("unrecognized SQL should return error")
|
||||
}
|
||||
if !strings.Contains(result.output, "FOOBAR") {
|
||||
t.Error("error should reference the offending token")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Case insensitivity ---
|
||||
|
||||
func TestSQLCaseInsensitive(t *testing.T) {
|
||||
result := dispatchSQL("select version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("select version() (lowercase) should work")
|
||||
}
|
||||
|
||||
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("mixed case SELECT should work")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Startup banner ---
|
||||
|
||||
func TestStartupBanner(t *testing.T) {
|
||||
banner := startupBanner("15.4")
|
||||
if !strings.Contains(banner, "psql (15.4)") {
|
||||
t.Errorf("banner should contain version, got: %s", banner)
|
||||
}
|
||||
if !strings.Contains(banner, "help") {
|
||||
t.Error("banner should mention help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- configString ---
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"db_name": "mydb"}
|
||||
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
|
||||
t.Errorf("configString() = %q, want %q", got, "mydb")
|
||||
}
|
||||
if got := configString(cfg, "missing", "default"); got != "default" {
|
||||
t.Errorf("configString() for missing = %q, want %q", got, "default")
|
||||
}
|
||||
if got := configString(nil, "key", "default"); got != "default" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shell metadata ---
|
||||
|
||||
func TestShellNameAndDescription(t *testing.T) {
|
||||
s := NewPsqlShell()
|
||||
if s.Name() != "psql" {
|
||||
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
|
||||
}
|
||||
if s.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// --- formatSingleValue ---
|
||||
|
||||
func TestFormatSingleValue(t *testing.T) {
|
||||
out := formatSingleValue("?column?", "1")
|
||||
if !strings.Contains(out, "?column?") {
|
||||
t.Error("should contain column name")
|
||||
}
|
||||
if !strings.Contains(out, "1") {
|
||||
t.Error("should contain value")
|
||||
}
|
||||
if !strings.Contains(out, "(1 row)") {
|
||||
t.Error("should contain row count")
|
||||
}
|
||||
}
|
||||
|
||||
// --- \d with no args ---
|
||||
|
||||
func TestBackslashDescribeNoArgs(t *testing.T) {
|
||||
result := dispatchBackslash(`\d`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\d with no args should list tables")
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,19 @@
|
||||
package shell
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RecordingChannel wraps an io.ReadWriteCloser. In Phase 1.4 it is a
|
||||
// pass-through; Phase 2.3 will add byte-level keystroke recording here
|
||||
// without changing any shell code.
|
||||
// EventCallback is called with a copy of data whenever the channel is read or written.
|
||||
// direction is 0 for input (client→server) and 1 for output (server→client).
|
||||
type EventCallback func(ts time.Time, direction int, data []byte)
|
||||
|
||||
// RecordingChannel wraps an io.ReadWriteCloser and optionally invokes callbacks
|
||||
// on every Read (input) and Write (output).
|
||||
type RecordingChannel struct {
|
||||
inner io.ReadWriteCloser
|
||||
inner io.ReadWriteCloser
|
||||
callbacks []EventCallback
|
||||
}
|
||||
|
||||
// NewRecordingChannel returns a RecordingChannel wrapping rw.
|
||||
@@ -14,6 +21,42 @@ func NewRecordingChannel(rw io.ReadWriteCloser) *RecordingChannel {
|
||||
return &RecordingChannel{inner: rw}
|
||||
}
|
||||
|
||||
func (r *RecordingChannel) Read(p []byte) (int, error) { return r.inner.Read(p) }
|
||||
func (r *RecordingChannel) Write(p []byte) (int, error) { return r.inner.Write(p) }
|
||||
func (r *RecordingChannel) Close() error { return r.inner.Close() }
|
||||
// WithCallback clears existing callbacks, sets the given one, and returns the
|
||||
// RecordingChannel for chaining. Kept for backward compatibility.
|
||||
func (r *RecordingChannel) WithCallback(cb EventCallback) *RecordingChannel {
|
||||
r.callbacks = []EventCallback{cb}
|
||||
return r
|
||||
}
|
||||
|
||||
// AddCallback appends an additional event callback.
|
||||
func (r *RecordingChannel) AddCallback(cb EventCallback) {
|
||||
r.callbacks = append(r.callbacks, cb)
|
||||
}
|
||||
|
||||
func (r *RecordingChannel) Read(p []byte) (int, error) {
|
||||
n, err := r.inner.Read(p)
|
||||
if n > 0 && len(r.callbacks) > 0 {
|
||||
ts := time.Now()
|
||||
cp := make([]byte, n)
|
||||
copy(cp, p[:n])
|
||||
for _, cb := range r.callbacks {
|
||||
cb(ts, 0, cp)
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *RecordingChannel) Write(p []byte) (int, error) {
|
||||
n, err := r.inner.Write(p)
|
||||
if n > 0 && len(r.callbacks) > 0 {
|
||||
ts := time.Now()
|
||||
cp := make([]byte, n)
|
||||
copy(cp, p[:n])
|
||||
for _, cb := range r.callbacks {
|
||||
cb(ts, 1, cp)
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *RecordingChannel) Close() error { return r.inner.Close() }
|
||||
|
||||
@@ -3,7 +3,9 @@ package shell
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// nopCloser wraps a ReadWriter with a no-op Close.
|
||||
@@ -41,3 +43,80 @@ func TestRecordingChannelPassthrough(t *testing.T) {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordingChannelMultiCallback(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rc := NewRecordingChannel(nopCloser{&buf})
|
||||
|
||||
type event struct {
|
||||
ts time.Time
|
||||
direction int
|
||||
data string
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var events1, events2 []event
|
||||
|
||||
rc.AddCallback(func(ts time.Time, direction int, data []byte) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
events1 = append(events1, event{ts, direction, string(data)})
|
||||
})
|
||||
rc.AddCallback(func(ts time.Time, direction int, data []byte) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
events2 = append(events2, event{ts, direction, string(data)})
|
||||
})
|
||||
|
||||
// Write triggers both callbacks with direction=1.
|
||||
rc.Write([]byte("hello"))
|
||||
|
||||
// Read triggers both callbacks with direction=0.
|
||||
out := make([]byte, 16)
|
||||
rc.Read(out)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if len(events1) != 2 {
|
||||
t.Fatalf("callback1 got %d events, want 2", len(events1))
|
||||
}
|
||||
if len(events2) != 2 {
|
||||
t.Fatalf("callback2 got %d events, want 2", len(events2))
|
||||
}
|
||||
|
||||
// Write event should be direction=1.
|
||||
if events1[0].direction != 1 {
|
||||
t.Errorf("write direction = %d, want 1", events1[0].direction)
|
||||
}
|
||||
// Read event should be direction=0.
|
||||
if events1[1].direction != 0 {
|
||||
t.Errorf("read direction = %d, want 0", events1[1].direction)
|
||||
}
|
||||
|
||||
// Both callbacks should get the same timestamp for a single operation.
|
||||
if events1[0].ts != events2[0].ts {
|
||||
t.Error("callbacks should receive the same timestamp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordingChannelWithCallbackClearsExisting(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rc := NewRecordingChannel(nopCloser{&buf})
|
||||
|
||||
called1 := false
|
||||
called2 := false
|
||||
|
||||
rc.AddCallback(func(_ time.Time, _ int, _ []byte) { called1 = true })
|
||||
// WithCallback should clear existing and set new.
|
||||
rc.WithCallback(func(_ time.Time, _ int, _ []byte) { called2 = true })
|
||||
|
||||
rc.Write([]byte("x"))
|
||||
|
||||
if called1 {
|
||||
t.Error("first callback should not be called after WithCallback")
|
||||
}
|
||||
if !called2 {
|
||||
t.Error("second callback should be called")
|
||||
}
|
||||
}
|
||||
|
||||
463
internal/shell/roomba/roomba.go
Normal file
463
internal/shell/roomba/roomba.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package roomba
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
|
||||
type RoombaShell struct{}
|
||||
|
||||
// NewRoombaShell returns a new RoombaShell instance.
|
||||
func NewRoombaShell() *RoombaShell {
|
||||
return &RoombaShell{}
|
||||
}
|
||||
|
||||
func (r *RoombaShell) Name() string { return "roomba" }
|
||||
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
|
||||
|
||||
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
state := newRoombaState()
|
||||
|
||||
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
|
||||
fmt.Fprint(rw, banner)
|
||||
|
||||
for {
|
||||
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Fprint(rw, "logout\r\n")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := state.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("roomba")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bootBanner() string {
|
||||
return `
|
||||
____ _ ___ ____
|
||||
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
|
||||
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
|
||||
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|
||||
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
|
||||
|
||||
iRobot Roomba j7+ | RoombaOS v4.3.7
|
||||
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
|
||||
Battery: 73% | WiFi: Connected (SmartHome-5G)
|
||||
|
||||
Type 'help' for available commands.
|
||||
|
||||
`
|
||||
}
|
||||
|
||||
type room struct {
|
||||
name string
|
||||
areaSqFt int
|
||||
lastCleaned time.Time
|
||||
}
|
||||
|
||||
type scheduleEntry struct {
|
||||
day string
|
||||
time string
|
||||
}
|
||||
|
||||
type historyEntry struct {
|
||||
timestamp time.Time
|
||||
room string
|
||||
duration string
|
||||
note string
|
||||
}
|
||||
|
||||
type roombaState struct {
|
||||
battery int
|
||||
dustbin int
|
||||
status string
|
||||
rooms []room
|
||||
schedule []scheduleEntry
|
||||
cleanHistory []historyEntry
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
func newRoombaState() *roombaState {
|
||||
now := time.Now()
|
||||
return &roombaState{
|
||||
battery: 73,
|
||||
dustbin: 61,
|
||||
status: "Docked",
|
||||
rooms: []room{
|
||||
{"Kitchen", 180, now.Add(-2 * time.Hour)},
|
||||
{"Living Room", 320, now.Add(-5 * time.Hour)},
|
||||
{"Bedroom", 200, now.Add(-26 * time.Hour)},
|
||||
{"Hallway", 60, now.Add(-5 * time.Hour)},
|
||||
{"Bathroom", 75, now.Add(-50 * time.Hour)},
|
||||
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
|
||||
},
|
||||
schedule: []scheduleEntry{
|
||||
{"Monday", "09:00"},
|
||||
{"Wednesday", "09:00"},
|
||||
{"Friday", "09:00"},
|
||||
{"Saturday", "14:00"},
|
||||
},
|
||||
cleanHistory: []historyEntry{
|
||||
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
|
||||
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
|
||||
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
|
||||
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
|
||||
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
|
||||
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) dispatch(input string) commandResult {
|
||||
parts := strings.Fields(input)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
cmd := strings.ToLower(parts[0])
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "help":
|
||||
return s.cmdHelp()
|
||||
case "status":
|
||||
return s.cmdStatus()
|
||||
case "clean":
|
||||
return s.cmdClean(args)
|
||||
case "dock":
|
||||
return s.cmdDock()
|
||||
case "map":
|
||||
return s.cmdMap()
|
||||
case "schedule":
|
||||
return s.cmdSchedule(args)
|
||||
case "history":
|
||||
return s.cmdHistory()
|
||||
case "diagnostics":
|
||||
return s.cmdDiagnostics()
|
||||
case "alerts":
|
||||
return s.cmdAlerts()
|
||||
case "reboot":
|
||||
return s.cmdReboot()
|
||||
case "exit", "logout":
|
||||
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdHelp() commandResult {
|
||||
help := `Available commands:
|
||||
help - Show this help message
|
||||
status - Show robot status
|
||||
clean - Start full cleaning job
|
||||
clean room <name> - Clean a specific room
|
||||
dock - Return to dock
|
||||
map - Show floor plan and room list
|
||||
schedule - List cleaning schedule
|
||||
schedule add <day> <time> - Add scheduled cleaning
|
||||
schedule remove <day> - Remove scheduled cleaning
|
||||
history - Show recent cleaning history
|
||||
diagnostics - Run system diagnostics
|
||||
alerts - Show active alerts
|
||||
reboot - Reboot RoombaOS
|
||||
exit / logout - Disconnect`
|
||||
return commandResult{output: help}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdStatus() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== RoombaOS System Status ===\n")
|
||||
b.WriteString("Model: iRobot Roomba j7+\n")
|
||||
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
|
||||
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
|
||||
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
|
||||
b.WriteString("Side brush: OK (142 hrs)\n")
|
||||
b.WriteString("Main brush: OK (98 hrs)\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
|
||||
b.WriteString("Signal: -38 dBm\n")
|
||||
b.WriteString("Alexa: Linked\n")
|
||||
b.WriteString("Google Home: Linked\n")
|
||||
b.WriteString("iRobot Home App: Connected\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("Firmware: v4.3.7-stable\n")
|
||||
b.WriteString("LIDAR: Operational\n")
|
||||
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdClean(args []string) commandResult {
|
||||
if s.status == "Cleaning" {
|
||||
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
|
||||
}
|
||||
|
||||
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
|
||||
roomName := strings.Join(args[1:], " ")
|
||||
for _, r := range s.rooms {
|
||||
if strings.EqualFold(r.name, roomName) {
|
||||
s.status = "Cleaning"
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
|
||||
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
|
||||
)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
return commandResult{output: "Usage: clean [room <name>]"}
|
||||
}
|
||||
|
||||
s.status = "Cleaning"
|
||||
var totalArea int
|
||||
for _, r := range s.rooms {
|
||||
totalArea += r.areaSqFt
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
|
||||
totalArea, len(s.rooms), totalArea/8,
|
||||
)}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdDock() commandResult {
|
||||
if s.status == "Docked" {
|
||||
return commandResult{output: "Already docked."}
|
||||
}
|
||||
if s.status == "Returning to dock" {
|
||||
return commandResult{output: "Already returning to dock."}
|
||||
}
|
||||
s.status = "Returning to dock"
|
||||
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdMap() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Floor Plan ===\n\n")
|
||||
b.WriteString(" +------------+----------+\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" | Kitchen | Bathroom |\n")
|
||||
b.WriteString(" | 180sqft | 75sqft |\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" +------+-----+----+-----+\n")
|
||||
b.WriteString(" | | | |\n")
|
||||
b.WriteString(" | Hall | Living | Cat |\n")
|
||||
b.WriteString(" | 60sf | Room | Rm |\n")
|
||||
b.WriteString(" | | 320sqft |110sf|\n")
|
||||
b.WriteString(" +------+ +-----+\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" | Bed +----------+\n")
|
||||
b.WriteString(" | room | [DOCK]\n")
|
||||
b.WriteString(" |200sf |\n")
|
||||
b.WriteString(" +------+\n")
|
||||
b.WriteString("\nRoom Details:\n")
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
|
||||
for _, r := range s.rooms {
|
||||
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
|
||||
}
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdSchedule(args []string) commandResult {
|
||||
if len(args) == 0 {
|
||||
return s.scheduleList()
|
||||
}
|
||||
|
||||
sub := strings.ToLower(args[0])
|
||||
switch sub {
|
||||
case "add":
|
||||
if len(args) < 3 {
|
||||
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
|
||||
}
|
||||
return s.scheduleAdd(args[1], args[2])
|
||||
case "remove":
|
||||
if len(args) < 2 {
|
||||
return commandResult{output: "Usage: schedule remove <day>"}
|
||||
}
|
||||
return s.scheduleRemove(args[1])
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleList() commandResult {
|
||||
if len(s.schedule) == 0 {
|
||||
return commandResult{output: "No cleaning schedule configured."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Cleaning Schedule ===\n")
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
|
||||
for _, e := range s.schedule {
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleAdd(day, t string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
|
||||
if !slices.Contains(validDays, day) {
|
||||
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
|
||||
}
|
||||
|
||||
for _, e := range s.schedule {
|
||||
if strings.EqualFold(e.day, day) {
|
||||
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
|
||||
}
|
||||
}
|
||||
|
||||
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
|
||||
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleRemove(day string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
for i, e := range s.schedule {
|
||||
if strings.EqualFold(e.day, day) {
|
||||
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
|
||||
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdHistory() commandResult {
|
||||
if len(s.cleanHistory) == 0 {
|
||||
return commandResult{output: "No cleaning history."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Cleaning History ===\n")
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
|
||||
for _, h := range s.cleanHistory {
|
||||
ts := h.timestamp.Format("2006-01-02 15:04")
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdDiagnostics() commandResult {
|
||||
diag := `Running RoombaOS diagnostics...
|
||||
|
||||
[1/8] Cliff sensors........... OK
|
||||
[2/8] Bumper sensor........... OK
|
||||
[3/8] Side brush motor........ OK (142 hrs until replacement)
|
||||
[4/8] Main brush motor........ OK (98 hrs until replacement)
|
||||
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
|
||||
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
|
||||
[7/8] Dustbin sensor.......... OK
|
||||
[8/8] WiFi module............. OK (signal: -38 dBm)
|
||||
|
||||
ALL SYSTEMS NOMINAL`
|
||||
return commandResult{output: diag}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdAlerts() commandResult {
|
||||
var alerts []string
|
||||
if s.dustbin >= 60 {
|
||||
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
|
||||
}
|
||||
alerts = append(alerts,
|
||||
"WARNING: Side brush replacement due in 12 hours",
|
||||
"INFO: Unidentified sticky substance detected in Kitchen",
|
||||
"INFO: Cat frequently blocking cleaning path in Cat's Room",
|
||||
"INFO: Firmware update available: v4.4.0-beta",
|
||||
"INFO: Filter replacement recommended in 14 days",
|
||||
)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Active Alerts ===\n")
|
||||
for _, a := range alerts {
|
||||
b.WriteString(a + "\n")
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdReboot() commandResult {
|
||||
reboot := `RoombaOS is rebooting...
|
||||
|
||||
Stopping navigation engine..... done
|
||||
Saving room map data........... done
|
||||
Flushing cleaning logs......... done
|
||||
Disconnecting from WiFi........ done
|
||||
|
||||
Rebooting now. Goodbye!`
|
||||
return commandResult{output: reboot, exit: true}
|
||||
}
|
||||
|
||||
func capitalizeFirst(s string) string {
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
return strings.ToUpper(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
hours := int(d.Hours())
|
||||
minutes := int(d.Minutes()) % 60
|
||||
if hours >= 24 {
|
||||
days := hours / 24
|
||||
hours %= 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%dm", minutes)
|
||||
}
|
||||
@@ -2,9 +2,10 @@ package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// Shell is the interface that all honeypot shell implementations must satisfy.
|
||||
@@ -23,6 +24,7 @@ type SessionContext struct {
|
||||
Store storage.Store
|
||||
ShellConfig map[string]any
|
||||
CommonConfig ShellCommonConfig
|
||||
OnCommand func(shell string) // called when a command is executed; may be nil
|
||||
}
|
||||
|
||||
// ShellCommonConfig holds settings shared across all shell types.
|
||||
@@ -31,3 +33,59 @@ type ShellCommonConfig struct {
|
||||
Banner string
|
||||
FakeUser string // override username in prompt; empty = use authenticated user
|
||||
}
|
||||
|
||||
// ReadLine reads a line of input byte-by-byte, handling backspace, Ctrl+C, and Ctrl+D.
|
||||
func ReadLine(ctx context.Context, rw io.ReadWriter) (string, error) {
|
||||
var buf []byte
|
||||
b := make([]byte, 1)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := rw.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
ch := b[0]
|
||||
switch {
|
||||
case ch == '\r' || ch == '\n':
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
return string(buf), nil
|
||||
|
||||
case ch == 4: // Ctrl+D
|
||||
if len(buf) == 0 {
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
case ch == 3: // Ctrl+C
|
||||
fmt.Fprint(rw, "^C\r\n")
|
||||
return "", nil
|
||||
|
||||
case ch == 127 || ch == 8: // DEL or Backspace
|
||||
if len(buf) > 0 {
|
||||
buf = buf[:len(buf)-1]
|
||||
fmt.Fprint(rw, "\b \b")
|
||||
}
|
||||
|
||||
case ch == 27: // ESC - start of escape sequence
|
||||
// Read and discard the rest of the escape sequence.
|
||||
// Most are 3 bytes: ESC [ X (arrow keys, etc.)
|
||||
next := make([]byte, 1)
|
||||
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
|
||||
rw.Read(next) // read the final byte
|
||||
}
|
||||
|
||||
case ch >= 32 && ch < 127: // printable ASCII
|
||||
buf = append(buf, ch)
|
||||
rw.Write([]byte{ch})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
101
internal/shell/tetris/data.go
Normal file
101
internal/shell/tetris/data.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package tetris
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
// pieceType identifies a tetromino (0–6).
|
||||
type pieceType int
|
||||
|
||||
const (
|
||||
pieceI pieceType = iota
|
||||
pieceO
|
||||
pieceT
|
||||
pieceS
|
||||
pieceZ
|
||||
pieceJ
|
||||
pieceL
|
||||
)
|
||||
|
||||
const numPieceTypes = 7
|
||||
|
||||
// Standard Tetris colors.
|
||||
var pieceColors = [numPieceTypes]lipgloss.Color{
|
||||
lipgloss.Color("#00FFFF"), // I — cyan
|
||||
lipgloss.Color("#FFFF00"), // O — yellow
|
||||
lipgloss.Color("#AA00FF"), // T — purple
|
||||
lipgloss.Color("#00FF00"), // S — green
|
||||
lipgloss.Color("#FF0000"), // Z — red
|
||||
lipgloss.Color("#0000FF"), // J — blue
|
||||
lipgloss.Color("#FF8800"), // L — orange
|
||||
}
|
||||
|
||||
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
|
||||
// relative to the piece origin.
|
||||
type rotation [4][2]int
|
||||
|
||||
var pieces = [numPieceTypes][4]rotation{
|
||||
// I
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||
},
|
||||
// O
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
},
|
||||
// T
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
// S
|
||||
{
|
||||
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
// Z
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||
},
|
||||
// J
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||
},
|
||||
// L
|
||||
{
|
||||
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
// spawnCol returns the starting column for a piece, centering it on the board.
|
||||
func spawnCol(pt pieceType, rot int) int {
|
||||
shape := pieces[pt][rot]
|
||||
minC, maxC := shape[0][1], shape[0][1]
|
||||
for _, off := range shape {
|
||||
if off[1] < minC {
|
||||
minC = off[1]
|
||||
}
|
||||
if off[1] > maxC {
|
||||
maxC = off[1]
|
||||
}
|
||||
}
|
||||
width := maxC - minC + 1
|
||||
return (boardCols - width) / 2
|
||||
}
|
||||
210
internal/shell/tetris/game.go
Normal file
210
internal/shell/tetris/game.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package tetris
|
||||
|
||||
import "math/rand/v2"
|
||||
|
||||
const (
|
||||
boardRows = 20
|
||||
boardCols = 10
|
||||
)
|
||||
|
||||
// cell represents a single board cell. Zero value is empty.
|
||||
type cell struct {
|
||||
filled bool
|
||||
piece pieceType // which piece type filled this cell (for color)
|
||||
}
|
||||
|
||||
// gameState holds all mutable state for a Tetris game.
|
||||
type gameState struct {
|
||||
board [boardRows][boardCols]cell
|
||||
current pieceType
|
||||
currentRot int
|
||||
currentRow int
|
||||
currentCol int
|
||||
next pieceType
|
||||
score int
|
||||
level int
|
||||
lines int
|
||||
gameOver bool
|
||||
}
|
||||
|
||||
// newGame creates a new game state, optionally starting at a given level.
|
||||
func newGame(startLevel int) *gameState {
|
||||
g := &gameState{
|
||||
level: startLevel,
|
||||
next: pieceType(rand.IntN(numPieceTypes)),
|
||||
}
|
||||
g.spawnPiece()
|
||||
return g
|
||||
}
|
||||
|
||||
// spawnPiece pulls the next piece and generates a new next.
|
||||
func (g *gameState) spawnPiece() {
|
||||
g.current = g.next
|
||||
g.next = pieceType(rand.IntN(numPieceTypes))
|
||||
g.currentRot = 0
|
||||
g.currentRow = 0
|
||||
g.currentCol = spawnCol(g.current, 0)
|
||||
|
||||
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||
g.gameOver = true
|
||||
}
|
||||
}
|
||||
|
||||
// canPlace checks whether the piece fits at the given position.
|
||||
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
|
||||
shape := pieces[pt][rot]
|
||||
for _, off := range shape {
|
||||
r, c := row+off[0], col+off[1]
|
||||
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
|
||||
return false
|
||||
}
|
||||
if g.board[r][c].filled {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// moveLeft moves the current piece left if possible.
|
||||
func (g *gameState) moveLeft() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
|
||||
g.currentCol--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// moveRight moves the current piece right if possible.
|
||||
func (g *gameState) moveRight() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
|
||||
g.currentCol++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// moveDown moves the current piece down one row. Returns false if it cannot.
|
||||
func (g *gameState) moveDown() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
|
||||
g.currentRow++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// rotate rotates the current piece clockwise with wall kick attempts.
|
||||
func (g *gameState) rotate() bool {
|
||||
newRot := (g.currentRot + 1) % 4
|
||||
|
||||
// Try in-place first.
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
|
||||
g.currentRot = newRot
|
||||
return true
|
||||
}
|
||||
|
||||
// Wall kick: try +-1 column offset.
|
||||
for _, offset := range []int{-1, 1} {
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||
g.currentRot = newRot
|
||||
g.currentCol += offset
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// I piece: try +-2.
|
||||
if g.current == pieceI {
|
||||
for _, offset := range []int{-2, 2} {
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||
g.currentRot = newRot
|
||||
g.currentCol += offset
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ghostRow returns the row where the piece would land.
|
||||
func (g *gameState) ghostRow() int {
|
||||
row := g.currentRow
|
||||
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
|
||||
row++
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
|
||||
func (g *gameState) hardDrop() int {
|
||||
ghost := g.ghostRow()
|
||||
dropped := ghost - g.currentRow
|
||||
g.currentRow = ghost
|
||||
return dropped
|
||||
}
|
||||
|
||||
// lockPiece writes the current piece into the board.
|
||||
func (g *gameState) lockPiece() {
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||
g.board[r][c] = cell{filled: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clearLines removes completed rows and returns how many were cleared.
|
||||
func (g *gameState) clearLines() int {
|
||||
cleared := 0
|
||||
for r := boardRows - 1; r >= 0; r-- {
|
||||
full := true
|
||||
for c := range boardCols {
|
||||
if !g.board[r][c].filled {
|
||||
full = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if full {
|
||||
cleared++
|
||||
// Shift everything above down.
|
||||
for rr := r; rr > 0; rr-- {
|
||||
g.board[rr] = g.board[rr-1]
|
||||
}
|
||||
g.board[0] = [boardCols]cell{}
|
||||
r++ // re-check this row since we shifted
|
||||
}
|
||||
}
|
||||
return cleared
|
||||
}
|
||||
|
||||
// NES-style scoring multipliers per lines cleared.
|
||||
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
|
||||
|
||||
// addScore updates score, lines, and level after clearing rows.
|
||||
func (g *gameState) addScore(linesCleared int) {
|
||||
if linesCleared > 0 && linesCleared <= 4 {
|
||||
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
|
||||
}
|
||||
g.lines += linesCleared
|
||||
|
||||
// Level up every 10 lines.
|
||||
newLevel := g.lines / 10
|
||||
if newLevel > g.level {
|
||||
g.level = newLevel
|
||||
}
|
||||
}
|
||||
|
||||
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
|
||||
// Returns the number of lines cleared.
|
||||
func (g *gameState) afterLock() int {
|
||||
g.lockPiece()
|
||||
cleared := g.clearLines()
|
||||
g.addScore(cleared)
|
||||
g.spawnPiece()
|
||||
return cleared
|
||||
}
|
||||
|
||||
// tickInterval returns the gravity interval in milliseconds for the current level.
|
||||
func tickInterval(level int) int {
|
||||
return max(800-level*60, 100)
|
||||
}
|
||||
331
internal/shell/tetris/model.go
Normal file
331
internal/shell/tetris/model.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
type screen int
|
||||
|
||||
const (
|
||||
screenTitle screen = iota
|
||||
screenGame
|
||||
screenGameOver
|
||||
)
|
||||
|
||||
type tickMsg time.Time
|
||||
type lockMsg time.Time
|
||||
|
||||
const lockDelay = 500 * time.Millisecond
|
||||
|
||||
type model struct {
|
||||
sess *shell.SessionContext
|
||||
difficulty string
|
||||
screen screen
|
||||
game *gameState
|
||||
quitting bool
|
||||
height int
|
||||
keypresses int
|
||||
locking bool // true when piece has landed and lock delay is active
|
||||
}
|
||||
|
||||
func newModel(sess *shell.SessionContext, difficulty string) *model {
|
||||
return &model{
|
||||
sess: sess,
|
||||
difficulty: difficulty,
|
||||
screen: screenTitle,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.quitting {
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.height = msg.Height
|
||||
return m, nil
|
||||
case tea.KeyMsg:
|
||||
m.keypresses++
|
||||
if msg.Type == tea.KeyCtrlC {
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch m.screen {
|
||||
case screenTitle:
|
||||
return m.updateTitle(msg)
|
||||
case screenGame:
|
||||
return m.updateGame(msg)
|
||||
case screenGameOver:
|
||||
return m.updateGameOver(msg)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) View() string {
|
||||
var content string
|
||||
switch m.screen {
|
||||
case screenTitle:
|
||||
content = m.titleView()
|
||||
case screenGame:
|
||||
content = gameView(m.game)
|
||||
case screenGameOver:
|
||||
content = m.gameOverView()
|
||||
}
|
||||
|
||||
return gameFrame(content, m.height)
|
||||
}
|
||||
|
||||
// --- Title screen ---
|
||||
|
||||
func (m *model) titleView() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" Press any key to start"))
|
||||
b.WriteString("\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if _, ok := msg.(tea.KeyMsg); ok {
|
||||
m.screen = screenGame
|
||||
var startLevel int
|
||||
if m.difficulty == "hard" {
|
||||
startLevel = 5
|
||||
}
|
||||
m.game = newGame(startLevel)
|
||||
return m, tea.Batch(
|
||||
tea.ClearScreen,
|
||||
m.scheduleTick(),
|
||||
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||
)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// --- Game screen ---
|
||||
|
||||
func (m *model) scheduleTick() tea.Cmd {
|
||||
ms := tickInterval(m.game.level)
|
||||
if m.difficulty == "easy" {
|
||||
ms = max(1000-m.game.level*60, 150)
|
||||
}
|
||||
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *model) scheduleLock() tea.Cmd {
|
||||
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
|
||||
return lockMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
// performLock locks the piece, clears lines, and returns commands for logging
|
||||
// and scheduling the next tick. Returns nil if game over (goToGameOver is
|
||||
// included in the returned batch).
|
||||
func (m *model) performLock() tea.Cmd {
|
||||
m.locking = false
|
||||
cleared := m.game.afterLock()
|
||||
if m.game.gameOver {
|
||||
return tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
|
||||
m.goToGameOver(),
|
||||
)
|
||||
}
|
||||
var cmds []tea.Cmd
|
||||
cmds = append(cmds, m.scheduleTick())
|
||||
if cleared > 0 {
|
||||
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
|
||||
prevLevel := (m.game.lines - cleared) / 10
|
||||
if m.game.level > prevLevel {
|
||||
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
|
||||
}
|
||||
}
|
||||
return tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case lockMsg:
|
||||
if m.game.gameOver || !m.locking {
|
||||
return m, nil
|
||||
}
|
||||
// Lock delay expired — lock the piece now.
|
||||
return m, m.performLock()
|
||||
|
||||
case tickMsg:
|
||||
if m.game.gameOver || m.locking {
|
||||
return m, nil
|
||||
}
|
||||
if !m.game.moveDown() {
|
||||
// Piece landed — start lock delay instead of locking immediately.
|
||||
m.locking = true
|
||||
return m, m.scheduleLock()
|
||||
}
|
||||
return m, m.scheduleTick()
|
||||
|
||||
case tea.KeyMsg:
|
||||
if m.game.gameOver {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "left":
|
||||
m.game.moveLeft()
|
||||
// If piece can now drop further, cancel lock delay.
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case "right":
|
||||
m.game.moveRight()
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case "down":
|
||||
if m.game.moveDown() {
|
||||
m.game.score++ // soft drop bonus
|
||||
if m.locking {
|
||||
m.locking = false
|
||||
}
|
||||
}
|
||||
case "up", "z":
|
||||
m.game.rotate()
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case " ":
|
||||
m.locking = false
|
||||
dropped := m.game.hardDrop()
|
||||
m.game.score += dropped * 2
|
||||
return m, m.performLock()
|
||||
case "q":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// --- Game over screen ---
|
||||
|
||||
func (m *model) goToGameOver() tea.Cmd {
|
||||
m.screen = screenGameOver
|
||||
return tea.ClearScreen
|
||||
}
|
||||
|
||||
func (m *model) gameOverView() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" GAME OVER"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" R - Play again"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" Q - Quit"))
|
||||
b.WriteString("\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok {
|
||||
switch keyMsg.String() {
|
||||
case "r":
|
||||
startLevel := 0
|
||||
if m.difficulty == "hard" {
|
||||
startLevel = 5
|
||||
}
|
||||
m.game = newGame(startLevel)
|
||||
m.screen = screenGame
|
||||
m.keypresses = 0
|
||||
return m, tea.Batch(
|
||||
tea.ClearScreen,
|
||||
m.scheduleTick(),
|
||||
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||
)
|
||||
case "q":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Helper methods for safe access when game may be nil.
|
||||
func (m *model) gameScore() int {
|
||||
if m.game != nil {
|
||||
return m.game.score
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *model) gameLevel() int {
|
||||
if m.game != nil {
|
||||
return m.game.level
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *model) gameLines() int {
|
||||
if m.game != nil {
|
||||
return m.game.lines
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// logAction returns a tea.Cmd that logs an action to the session store.
|
||||
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if sess.Store != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("tetris")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
286
internal/shell/tetris/style.go
Normal file
286
internal/shell/tetris/style.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const termWidth = 80
|
||||
|
||||
var (
|
||||
colorWhite = lipgloss.Color("#FFFFFF")
|
||||
colorDim = lipgloss.Color("#555555")
|
||||
colorBlack = lipgloss.Color("#000000")
|
||||
colorGhost = lipgloss.Color("#333333")
|
||||
)
|
||||
|
||||
var (
|
||||
baseStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWhite).
|
||||
Background(colorBlack)
|
||||
|
||||
dimStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#00FFFF")).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
|
||||
sidebarLabelStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
sidebarValueStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWhite).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// cellStyle returns a style for a filled cell of a given piece type.
|
||||
func cellStyle(pt pieceType) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(pieceColors[pt]).
|
||||
Background(colorBlack)
|
||||
}
|
||||
|
||||
// ghostStyle returns a dimmed style for the ghost piece.
|
||||
func ghostCellStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(colorGhost).
|
||||
Background(colorBlack)
|
||||
}
|
||||
|
||||
// renderBoard renders the board, current piece, and ghost piece as a string.
|
||||
func renderBoard(g *gameState) string {
|
||||
// Build a display grid that includes the current piece and ghost.
|
||||
type displayCell struct {
|
||||
filled bool
|
||||
ghost bool
|
||||
piece pieceType
|
||||
}
|
||||
var grid [boardRows][boardCols]displayCell
|
||||
|
||||
// Copy locked cells.
|
||||
for r := range boardRows {
|
||||
for c := range boardCols {
|
||||
if g.board[r][c].filled {
|
||||
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ghost piece.
|
||||
ghostR := g.ghostRow()
|
||||
if ghostR != g.currentRow {
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := ghostR+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
|
||||
grid[r][c] = displayCell{ghost: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Current piece.
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||
grid[r][c] = displayCell{filled: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
|
||||
// Render grid.
|
||||
var b strings.Builder
|
||||
borderStyle := dimStyle
|
||||
|
||||
for _, row := range grid {
|
||||
b.WriteString(borderStyle.Render("|"))
|
||||
for _, dc := range row {
|
||||
switch {
|
||||
case dc.filled:
|
||||
b.WriteString(cellStyle(dc.piece).Render("[]"))
|
||||
case dc.ghost:
|
||||
b.WriteString(ghostCellStyle().Render("::"))
|
||||
default:
|
||||
b.WriteString(baseStyle.Render(" "))
|
||||
}
|
||||
}
|
||||
b.WriteString(borderStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// renderNextPiece renders the "next piece" preview box.
|
||||
func renderNextPiece(pt pieceType) string {
|
||||
shape := pieces[pt][0]
|
||||
// Determine bounding box.
|
||||
minR, maxR := shape[0][0], shape[0][0]
|
||||
minC, maxC := shape[0][1], shape[0][1]
|
||||
for _, off := range shape {
|
||||
if off[0] < minR {
|
||||
minR = off[0]
|
||||
}
|
||||
if off[0] > maxR {
|
||||
maxR = off[0]
|
||||
}
|
||||
if off[1] < minC {
|
||||
minC = off[1]
|
||||
}
|
||||
if off[1] > maxC {
|
||||
maxC = off[1]
|
||||
}
|
||||
}
|
||||
|
||||
rows := maxR - minR + 1
|
||||
cols := maxC - minC + 1
|
||||
|
||||
// Build a small grid.
|
||||
grid := make([][]bool, rows)
|
||||
for i := range grid {
|
||||
grid[i] = make([]bool, cols)
|
||||
}
|
||||
for _, off := range shape {
|
||||
grid[off[0]-minR][off[1]-minC] = true
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
boxWidth := 8 // chars for the box interior
|
||||
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||
b.WriteString("\n")
|
||||
|
||||
for r := range rows {
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
// Center the piece in the box.
|
||||
pieceWidth := cols * 2
|
||||
leftPad := (boxWidth - pieceWidth) / 2
|
||||
rightPad := boxWidth - pieceWidth - leftPad
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
|
||||
for c := range cols {
|
||||
if grid[r][c] {
|
||||
b.WriteString(cellStyle(pt).Render("[]"))
|
||||
} else {
|
||||
b.WriteString(baseStyle.Render(" "))
|
||||
}
|
||||
}
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Fill remaining rows in the box (max 4 rows for I piece).
|
||||
for r := rows; r < 2; r++ {
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// formatScore formats a score with comma separators.
|
||||
func formatScore(n int) string {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
if len(s) <= 3 {
|
||||
return s
|
||||
}
|
||||
var parts []string
|
||||
for len(s) > 3 {
|
||||
parts = append([]string{s[len(s)-3:]}, parts...)
|
||||
s = s[:len(s)-3]
|
||||
}
|
||||
parts = append([]string{s}, parts...)
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// gameView combines the board and sidebar into the game screen.
|
||||
func gameView(g *gameState) string {
|
||||
boardStr := renderBoard(g)
|
||||
boardLines := strings.Split(boardStr, "\n")
|
||||
|
||||
nextStr := renderNextPiece(g.next)
|
||||
nextLines := strings.Split(nextStr, "\n")
|
||||
|
||||
// Build sidebar lines.
|
||||
var sidebar []string
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
|
||||
sidebar = append(sidebar, nextLines...)
|
||||
sidebar = append(sidebar, "")
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
|
||||
sidebar = append(sidebar, "")
|
||||
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
|
||||
|
||||
// Combine board and sidebar side by side.
|
||||
var b strings.Builder
|
||||
maxLines := max(len(boardLines), len(sidebar))
|
||||
|
||||
for i := range maxLines {
|
||||
boardLine := ""
|
||||
if i < len(boardLines) {
|
||||
boardLine = boardLines[i]
|
||||
}
|
||||
sidebarLine := ""
|
||||
if i < len(sidebar) {
|
||||
sidebarLine = sidebar[i]
|
||||
}
|
||||
|
||||
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
|
||||
b.WriteString(boardLine)
|
||||
b.WriteString(sidebarLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// padLine pads a single line to termWidth.
|
||||
func padLine(line string) string {
|
||||
w := lipgloss.Width(line)
|
||||
if w >= termWidth {
|
||||
return line
|
||||
}
|
||||
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
|
||||
}
|
||||
|
||||
// padLines pads every line in a multi-line string to termWidth.
|
||||
func padLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = padLine(line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// gameFrame wraps content with padding to fill the terminal.
|
||||
func gameFrame(content string, height int) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(content)
|
||||
|
||||
// Pad with blank lines to fill terminal height.
|
||||
if height > 0 {
|
||||
contentLines := strings.Count(content, "\n") + 1
|
||||
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
|
||||
for i := contentLines; i < height; i++ {
|
||||
b.WriteString(blankLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return padLines(b.String())
|
||||
}
|
||||
66
internal/shell/tetris/tetris.go
Normal file
66
internal/shell/tetris/tetris.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
// TetrisShell is a Tetris game TUI for the honeypot.
|
||||
type TetrisShell struct{}
|
||||
|
||||
// NewTetrisShell returns a new TetrisShell instance.
|
||||
func NewTetrisShell() *TetrisShell {
|
||||
return &TetrisShell{}
|
||||
}
|
||||
|
||||
func (t *TetrisShell) Name() string { return "tetris" }
|
||||
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
|
||||
|
||||
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
|
||||
|
||||
m := newModel(sess, difficulty)
|
||||
p := tea.NewProgram(m,
|
||||
tea.WithInput(rw),
|
||||
tea.WithOutput(rw),
|
||||
tea.WithAltScreen(),
|
||||
)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := p.Run()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
p.Quit()
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
582
internal/shell/tetris/tetris_test.go
Normal file
582
internal/shell/tetris/tetris_test.go
Normal file
@@ -0,0 +1,582 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// newTestModel creates a model with a test session context.
|
||||
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "player",
|
||||
Store: store,
|
||||
}
|
||||
m := newModel(sess, "normal")
|
||||
return m, store
|
||||
}
|
||||
|
||||
// sendKey sends a single key message to the model and returns the command.
|
||||
func sendKey(m *model, key string) tea.Cmd {
|
||||
var msg tea.KeyMsg
|
||||
switch key {
|
||||
case "enter":
|
||||
msg = tea.KeyMsg{Type: tea.KeyEnter}
|
||||
case "up":
|
||||
msg = tea.KeyMsg{Type: tea.KeyUp}
|
||||
case "down":
|
||||
msg = tea.KeyMsg{Type: tea.KeyDown}
|
||||
case "left":
|
||||
msg = tea.KeyMsg{Type: tea.KeyLeft}
|
||||
case "right":
|
||||
msg = tea.KeyMsg{Type: tea.KeyRight}
|
||||
case "space":
|
||||
msg = tea.KeyMsg{Type: tea.KeySpace}
|
||||
case "ctrl+c":
|
||||
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
|
||||
default:
|
||||
if len(key) == 1 {
|
||||
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
|
||||
}
|
||||
}
|
||||
_, cmd := m.Update(msg)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// sendTick sends a tick message to the model.
|
||||
func sendTick(m *model) tea.Cmd {
|
||||
_, cmd := m.Update(tickMsg(time.Now()))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// execCmds recursively executes tea.Cmd functions (including batches).
|
||||
func execCmds(cmd tea.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
msg := cmd()
|
||||
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||
for _, c := range batch {
|
||||
execCmds(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTetrisShellName(t *testing.T) {
|
||||
sh := NewTetrisShell()
|
||||
if sh.Name() != "tetris" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"difficulty": "hard",
|
||||
}
|
||||
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
|
||||
t.Errorf("configString() = %q, want %q", got, "hard")
|
||||
}
|
||||
if got := configString(cfg, "missing", "normal"); got != "normal" {
|
||||
t.Errorf("configString() = %q, want %q", got, "normal")
|
||||
}
|
||||
if got := configString(nil, "difficulty", "normal"); got != "normal" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "normal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTitleScreenRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "████") {
|
||||
t.Error("title screen should show TETRIS logo")
|
||||
}
|
||||
if !strings.Contains(view, "Press any key") {
|
||||
t.Error("title screen should show 'Press any key'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTitleToGame(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
if m.screen != screenTitle {
|
||||
t.Fatalf("expected screenTitle, got %d", m.screen)
|
||||
}
|
||||
|
||||
sendKey(m, "enter")
|
||||
if m.screen != screenGame {
|
||||
t.Errorf("expected screenGame after keypress, got %d", m.screen)
|
||||
}
|
||||
if m.game == nil {
|
||||
t.Fatal("game should be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGameRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "|") {
|
||||
t.Error("game view should contain board borders")
|
||||
}
|
||||
if !strings.Contains(view, "SCORE") {
|
||||
t.Error("game view should show SCORE")
|
||||
}
|
||||
if !strings.Contains(view, "LEVEL") {
|
||||
t.Error("game view should show LEVEL")
|
||||
}
|
||||
if !strings.Contains(view, "LINES") {
|
||||
t.Error("game view should show LINES")
|
||||
}
|
||||
if !strings.Contains(view, "NEXT") {
|
||||
t.Error("game view should show NEXT")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Pure game logic tests ---
|
||||
|
||||
func TestNewGame(t *testing.T) {
|
||||
g := newGame(0)
|
||||
if g.gameOver {
|
||||
t.Error("new game should not be game over")
|
||||
}
|
||||
if g.score != 0 {
|
||||
t.Errorf("initial score = %d, want 0", g.score)
|
||||
}
|
||||
if g.level != 0 {
|
||||
t.Errorf("initial level = %d, want 0", g.level)
|
||||
}
|
||||
if g.lines != 0 {
|
||||
t.Errorf("initial lines = %d, want 0", g.lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGameHardLevel(t *testing.T) {
|
||||
g := newGame(5)
|
||||
if g.level != 5 {
|
||||
t.Errorf("hard start level = %d, want 5", g.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveLeft(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startCol := g.currentCol
|
||||
g.moveLeft()
|
||||
if g.currentCol != startCol-1 {
|
||||
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveRight(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startCol := g.currentCol
|
||||
g.moveRight()
|
||||
if g.currentCol != startCol+1 {
|
||||
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveDown(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRow := g.currentRow
|
||||
moved := g.moveDown()
|
||||
if !moved {
|
||||
t.Error("moveDown should succeed from starting position")
|
||||
}
|
||||
if g.currentRow != startRow+1 {
|
||||
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotMoveLeftBeyondWall(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Move all the way left.
|
||||
for range boardCols {
|
||||
g.moveLeft()
|
||||
}
|
||||
col := g.currentCol
|
||||
g.moveLeft() // should not move further
|
||||
if g.currentCol != col {
|
||||
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotMoveRightBeyondWall(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Move all the way right.
|
||||
for range boardCols {
|
||||
g.moveRight()
|
||||
}
|
||||
col := g.currentCol
|
||||
g.moveRight() // should not move further
|
||||
if g.currentCol != col {
|
||||
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotate(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRot := g.currentRot
|
||||
g.rotate()
|
||||
// Rotation should change (possibly with wall kick).
|
||||
if g.currentRot == startRot {
|
||||
// Rotation might legitimately fail in some edge cases, so just check
|
||||
// that the game state is valid.
|
||||
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||
t.Error("piece should be in a valid position after rotate attempt")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDrop(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRow := g.currentRow
|
||||
dropped := g.hardDrop()
|
||||
if dropped == 0 {
|
||||
t.Error("hard drop should move piece down at least some rows from top")
|
||||
}
|
||||
if g.currentRow <= startRow {
|
||||
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGhostRow(t *testing.T) {
|
||||
g := newGame(0)
|
||||
ghost := g.ghostRow()
|
||||
if ghost < g.currentRow {
|
||||
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
|
||||
}
|
||||
// Ghost should be at a position where moving down one more is impossible.
|
||||
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
|
||||
t.Error("ghost row should be the lowest valid position")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockPiece(t *testing.T) {
|
||||
g := newGame(0)
|
||||
g.hardDrop()
|
||||
pt := g.current
|
||||
row, col, rot := g.currentRow, g.currentCol, g.currentRot
|
||||
g.lockPiece()
|
||||
|
||||
// Verify that the piece's cells are now filled.
|
||||
shape := pieces[pt][rot]
|
||||
for _, off := range shape {
|
||||
r, c := row+off[0], col+off[1]
|
||||
if !g.board[r][c].filled {
|
||||
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearLines(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Fill the bottom row completely.
|
||||
for c := range boardCols {
|
||||
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
cleared := g.clearLines()
|
||||
if cleared != 1 {
|
||||
t.Errorf("clearLines() = %d, want 1", cleared)
|
||||
}
|
||||
// Bottom row should now be empty (shifted from above).
|
||||
for c := range boardCols {
|
||||
if g.board[boardRows-1][c].filled {
|
||||
t.Errorf("bottom row col %d should be empty after clearing", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearMultipleLines(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Fill the bottom 4 rows.
|
||||
for r := boardRows - 4; r < boardRows; r++ {
|
||||
for c := range boardCols {
|
||||
g.board[r][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
}
|
||||
cleared := g.clearLines()
|
||||
if cleared != 4 {
|
||||
t.Errorf("clearLines() = %d, want 4", cleared)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoring(t *testing.T) {
|
||||
tests := []struct {
|
||||
lines int
|
||||
level int
|
||||
want int
|
||||
}{
|
||||
{1, 0, 40},
|
||||
{2, 0, 100},
|
||||
{3, 0, 300},
|
||||
{4, 0, 1200},
|
||||
{1, 1, 80},
|
||||
{4, 2, 3600},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
g := newGame(tt.level)
|
||||
g.addScore(tt.lines)
|
||||
if g.score != tt.want {
|
||||
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLevelUp(t *testing.T) {
|
||||
g := newGame(0)
|
||||
g.lines = 9
|
||||
g.addScore(1) // This should push lines to 10, triggering level 1.
|
||||
if g.level != 1 {
|
||||
t.Errorf("level = %d, want 1 after 10 lines", g.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickInterval(t *testing.T) {
|
||||
if got := tickInterval(0); got != 800 {
|
||||
t.Errorf("tickInterval(0) = %d, want 800", got)
|
||||
}
|
||||
if got := tickInterval(5); got != 500 {
|
||||
t.Errorf("tickInterval(5) = %d, want 500", got)
|
||||
}
|
||||
// Floor at 100ms.
|
||||
if got := tickInterval(20); got != 100 {
|
||||
t.Errorf("tickInterval(20) = %d, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatScore(t *testing.T) {
|
||||
tests := []struct {
|
||||
n int
|
||||
want string
|
||||
}{
|
||||
{0, "0"},
|
||||
{100, "100"},
|
||||
{1250, "1,250"},
|
||||
{1000000, "1,000,000"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := formatScore(tt.n); got != tt.want {
|
||||
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGameOverScreen(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Force game over.
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "GAME OVER") {
|
||||
t.Error("game over screen should show GAME OVER")
|
||||
}
|
||||
if !strings.Contains(view, "Score") {
|
||||
t.Error("game over screen should show score")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestartFromGameOver(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
sendKey(m, "r")
|
||||
if m.screen != screenGame {
|
||||
t.Errorf("expected screenGame after restart, got %d", m.screen)
|
||||
}
|
||||
if m.game.gameOver {
|
||||
t.Error("game should not be over after restart")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitFromGame(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
sendKey(m, "q")
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after pressing q")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitFromGameOver(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
sendKey(m, "q")
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after pressing q in game over")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftDropScoring(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
scoreBefore := m.game.score
|
||||
sendKey(m, "down")
|
||||
if m.game.score != scoreBefore+1 {
|
||||
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDropScoring(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Hard drop gives 2 points per row dropped.
|
||||
sendKey(m, "space")
|
||||
if m.game.score < 2 {
|
||||
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickMovesDown(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
rowBefore := m.game.currentRow
|
||||
sendTick(m)
|
||||
// Piece should either move down by 1, or lock and spawn a new piece at top.
|
||||
movedDown := m.game.currentRow == rowBefore+1
|
||||
respawned := m.game.currentRow < rowBefore
|
||||
if !movedDown && !respawned && !m.game.gameOver {
|
||||
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
m, store := newTestModel(t)
|
||||
|
||||
// Press key to start game — returns a logAction cmd.
|
||||
cmd := sendKey(m, "enter")
|
||||
if cmd != nil {
|
||||
execCmds(cmd)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
found := false
|
||||
for _, log := range store.SessionLogs {
|
||||
if strings.Contains(log.Input, "GAME START") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected GAME START in session logs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeypressCounter(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
sendKey(m, "left")
|
||||
sendKey(m, "right")
|
||||
sendKey(m, "down")
|
||||
|
||||
if m.keypresses != 4 { // enter + 3 game keys
|
||||
t.Errorf("keypresses = %d, want 4", m.keypresses)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDelay(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Drop piece to the bottom via ticks until it can't move down.
|
||||
for range boardRows + 5 {
|
||||
if m.locking {
|
||||
break
|
||||
}
|
||||
sendTick(m)
|
||||
}
|
||||
|
||||
if !m.locking {
|
||||
t.Fatal("piece should be in locking state after hitting bottom")
|
||||
}
|
||||
|
||||
// During lock delay, we should still be able to move left/right.
|
||||
colBefore := m.game.currentCol
|
||||
sendKey(m, "left")
|
||||
if m.game.currentCol >= colBefore {
|
||||
// Might not have moved if against wall, try right.
|
||||
sendKey(m, "right")
|
||||
}
|
||||
|
||||
// Sending a lockMsg should finalize the piece.
|
||||
m.Update(lockMsg(time.Now()))
|
||||
// After lock, a new piece should have spawned (row near top).
|
||||
if m.game.currentRow > 1 && !m.game.gameOver {
|
||||
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDelayCancelledByDrop(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Build a ledge: fill rows 18-19 but leave column 0 empty.
|
||||
for r := boardRows - 2; r < boardRows; r++ {
|
||||
for c := 1; c < boardCols; c++ {
|
||||
m.game.board[r][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
}
|
||||
|
||||
// Move piece to column 0 area and drop it onto the ledge.
|
||||
for range boardCols {
|
||||
m.game.moveLeft()
|
||||
}
|
||||
// Tick down until locking.
|
||||
for range boardRows + 5 {
|
||||
if m.locking {
|
||||
break
|
||||
}
|
||||
sendTick(m)
|
||||
}
|
||||
|
||||
// If piece is on the ledge and we slide it to col 0 (open column),
|
||||
// the lock delay should cancel since it can fall further.
|
||||
// This test just validates the locking flag logic works.
|
||||
if m.locking {
|
||||
// Try moving — if piece can drop further, locking should cancel.
|
||||
sendKey(m, "left")
|
||||
// Whether locking cancels depends on the board state; just verify no crash.
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnCol(t *testing.T) {
|
||||
// All pieces should spawn roughly centered.
|
||||
for pt := range pieceType(numPieceTypes) {
|
||||
col := spawnCol(pt, 0)
|
||||
if col < 0 || col > boardCols-1 {
|
||||
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
|
||||
}
|
||||
// Verify piece fits at spawn position.
|
||||
shape := pieces[pt][0]
|
||||
for _, off := range shape {
|
||||
c := col + off[1]
|
||||
if c < 0 || c >= boardCols {
|
||||
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
217
internal/storage/instrumented.go
Normal file
217
internal/storage/instrumented.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// InstrumentedStore wraps a Store and records query duration and errors
|
||||
// as Prometheus metrics for each method call.
|
||||
type InstrumentedStore struct {
|
||||
store Store
|
||||
queryDuration *prometheus.HistogramVec
|
||||
queryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
|
||||
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
|
||||
return &InstrumentedStore{
|
||||
store: store,
|
||||
queryDuration: queryDuration,
|
||||
queryErrors: queryErrors,
|
||||
}
|
||||
}
|
||||
|
||||
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
v, err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||
return observeErr(s, "RecordLoginAttempt", func() error {
|
||||
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||
return observe(s, "CreateSession", func() (string, error) {
|
||||
return s.store.CreateSession(ctx, ip, username, shellName, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
|
||||
return observeErr(s, "EndSession", func() error {
|
||||
return s.store.EndSession(ctx, sessionID, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
|
||||
return observeErr(s, "UpdateHumanScore", func() error {
|
||||
return s.store.UpdateHumanScore(ctx, sessionID, score)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||
return observeErr(s, "SetExecCommand", func() error {
|
||||
return s.store.SetExecCommand(ctx, sessionID, command)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||
return observeErr(s, "AppendSessionLog", func() error {
|
||||
return s.store.AppendSessionLog(ctx, sessionID, input, output)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
|
||||
return s.store.DeleteRecordsBefore(ctx, cutoff)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetDashboardStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopUsernames(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopPasswords(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopIPs(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopCountries(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopExecCommands(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
return observe(s, "GetRecentSessions", func() ([]Session, error) {
|
||||
return s.store.GetRecentSessions(ctx, limit, activeOnly)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
|
||||
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||
return observe(s, "GetSession", func() (*Session, error) {
|
||||
return s.store.GetSession(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
||||
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
|
||||
return s.store.GetSessionLogs(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
||||
return observeErr(s, "AppendSessionEvents", func() error {
|
||||
return s.store.AppendSessionEvents(ctx, events)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
||||
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
|
||||
return s.store.GetSessionEvents(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
return observe(s, "CloseActiveSessions", func() (int64, error) {
|
||||
return s.store.CloseActiveSessions(ctx, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
|
||||
return s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
|
||||
return s.store.GetHourlyPattern(ctx, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
|
||||
return s.store.GetCountryStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetFilteredDashboardStats(ctx, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopUsernames(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopPasswords(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopIPs(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopCountries(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) Close() error {
|
||||
return s.store.Close()
|
||||
}
|
||||
163
internal/storage/instrumented_test.go
Normal file
163
internal/storage/instrumented_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
)
|
||||
|
||||
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
store := NewMemoryStore()
|
||||
return NewInstrumentedStore(store, dur, errs), dur, errs
|
||||
}
|
||||
|
||||
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
|
||||
m := &dto.Metric{}
|
||||
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
|
||||
return m.GetHistogram().GetSampleCount()
|
||||
}
|
||||
|
||||
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
|
||||
m := &dto.Metric{}
|
||||
c.WithLabelValues(method).Write(m)
|
||||
return m.GetCounter().GetValue()
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreDelegation(t *testing.T) {
|
||||
s, dur, _ := newTestInstrumented()
|
||||
ctx := context.Background()
|
||||
|
||||
// RecordLoginAttempt should delegate and record duration.
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// CreateSession should delegate and return a valid ID.
|
||||
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Fatal("CreateSession returned empty ID")
|
||||
}
|
||||
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// GetDashboardStats should delegate.
|
||||
stats, err := s.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetDashboardStats: %v", err)
|
||||
}
|
||||
if stats == nil {
|
||||
t.Fatal("GetDashboardStats returned nil")
|
||||
}
|
||||
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreErrorCounting(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_ec_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_ec_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
// Error should be counted.
|
||||
err := s.EndSession(ctx, "nonexistent", time.Now())
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
if c := getCounterValue(errs, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
|
||||
// Successful call should not increment error counter.
|
||||
s2, _, errs2 := newTestInstrumented()
|
||||
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
|
||||
t.Fatalf("expected error count 0, got %f", c)
|
||||
}
|
||||
}
|
||||
|
||||
// errorStore is a Store that returns errors for all methods.
|
||||
type errorStore struct {
|
||||
MemoryStore
|
||||
}
|
||||
|
||||
var errFake = errors.New("fake error")
|
||||
|
||||
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreObserveErr(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test2_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test2_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreClose(t *testing.T) {
|
||||
s, _, _ := newTestInstrumented()
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ type MemoryStore struct {
|
||||
LoginAttempts []LoginAttempt
|
||||
Sessions map[string]*Session
|
||||
SessionLogs []SessionLog
|
||||
SessionEvents []SessionEvent
|
||||
}
|
||||
|
||||
// NewMemoryStore returns a new empty MemoryStore.
|
||||
@@ -24,7 +25,7 @@ func NewMemoryStore() *MemoryStore {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip string) error {
|
||||
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip, country string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -34,6 +35,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
||||
if a.Username == username && a.Password == password && a.IP == ip {
|
||||
a.Count++
|
||||
a.LastSeen = now
|
||||
a.Country = country
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -43,6 +45,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
||||
Username: username,
|
||||
Password: password,
|
||||
IP: ip,
|
||||
Country: country,
|
||||
Count: 1,
|
||||
FirstSeen: now,
|
||||
LastSeen: now,
|
||||
@@ -50,7 +53,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName string) (string, error) {
|
||||
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName, country string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -59,6 +62,7 @@ func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName s
|
||||
m.Sessions[id] = &Session{
|
||||
ID: id,
|
||||
IP: ip,
|
||||
Country: country,
|
||||
Username: username,
|
||||
ShellName: shellName,
|
||||
ConnectedAt: now,
|
||||
@@ -87,6 +91,16 @@ func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, scor
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) SetExecCommand(_ context.Context, sessionID string, command string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if s, ok := m.Sessions[sessionID]; ok {
|
||||
s.ExecCommand = &command
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -101,6 +115,55 @@ func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, outp
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetSession(_ context.Context, sessionID string) (*Session, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
s, ok := m.Sessions[sessionID]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
copy := *s
|
||||
return ©, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetSessionLogs(_ context.Context, sessionID string) ([]SessionLog, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var logs []SessionLog
|
||||
for _, l := range m.SessionLogs {
|
||||
if l.SessionID == sessionID {
|
||||
logs = append(logs, l)
|
||||
}
|
||||
}
|
||||
sort.Slice(logs, func(i, j int) bool {
|
||||
return logs[i].Timestamp.Before(logs[j].Timestamp)
|
||||
})
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) AppendSessionEvents(_ context.Context, events []SessionEvent) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.SessionEvents = append(m.SessionEvents, events...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetSessionEvents(_ context.Context, sessionID string) ([]SessionEvent, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var events []SessionEvent
|
||||
for _, e := range m.SessionEvents {
|
||||
if e.SessionID == sessionID {
|
||||
events = append(events, e)
|
||||
}
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (int64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -136,6 +199,16 @@ func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (
|
||||
}
|
||||
m.SessionLogs = keptLogs
|
||||
|
||||
keptEvents := m.SessionEvents[:0]
|
||||
for _, e := range m.SessionEvents {
|
||||
if _, ok := m.Sessions[e.SessionID]; ok {
|
||||
keptEvents = append(keptEvents, e)
|
||||
} else {
|
||||
total++
|
||||
}
|
||||
}
|
||||
m.SessionEvents = keptEvents
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
@@ -174,7 +247,60 @@ func (m *MemoryStore) GetTopPasswords(_ context.Context, limit int) ([]TopEntry,
|
||||
func (m *MemoryStore) GetTopIPs(_ context.Context, limit int) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.topN("ip", limit), nil
|
||||
|
||||
type ipInfo struct {
|
||||
count int64
|
||||
country string
|
||||
}
|
||||
agg := make(map[string]*ipInfo)
|
||||
for _, a := range m.LoginAttempts {
|
||||
info, ok := agg[a.IP]
|
||||
if !ok {
|
||||
info = &ipInfo{}
|
||||
agg[a.IP] = info
|
||||
}
|
||||
info.count += int64(a.Count)
|
||||
if a.Country != "" {
|
||||
info.country = a.Country
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(agg))
|
||||
for ip, info := range agg {
|
||||
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetTopCountries(_ context.Context, limit int) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if a.Country == "" {
|
||||
continue
|
||||
}
|
||||
counts[a.Country] += int64(a.Count)
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// topN aggregates login attempts by the given field and returns the top N. Must be called with m.mu held.
|
||||
@@ -210,20 +336,372 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.collectSessions(limit, activeOnly, f), nil
|
||||
}
|
||||
|
||||
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
|
||||
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
|
||||
// Compute event counts and input bytes per session.
|
||||
eventCounts := make(map[string]int)
|
||||
inputBytes := make(map[string]int64)
|
||||
for _, e := range m.SessionEvents {
|
||||
eventCounts[e.SessionID]++
|
||||
if e.Direction == 0 {
|
||||
inputBytes[e.SessionID] += int64(len(e.Data))
|
||||
}
|
||||
}
|
||||
|
||||
var sessions []Session
|
||||
for _, s := range m.Sessions {
|
||||
if activeOnly && s.DisconnectedAt != nil {
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *s)
|
||||
if !matchesSessionFilter(s, f) {
|
||||
continue
|
||||
}
|
||||
sess := *s
|
||||
sess.EventCount = eventCounts[s.ID]
|
||||
sess.InputBytes = inputBytes[s.ID]
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||
})
|
||||
|
||||
if f.SortBy == "input_bytes" {
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].InputBytes > sessions[j].InputBytes
|
||||
})
|
||||
} else {
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
if limit > 0 && len(sessions) > limit {
|
||||
sessions = sessions[:limit]
|
||||
}
|
||||
return sessions, nil
|
||||
return sessions
|
||||
}
|
||||
|
||||
// matchesSessionFilter returns true if the session matches the given filter.
|
||||
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
|
||||
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||
return false
|
||||
}
|
||||
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||
return false
|
||||
}
|
||||
if f.IP != "" && s.IP != f.IP {
|
||||
return false
|
||||
}
|
||||
if f.Country != "" && s.Country != f.Country {
|
||||
return false
|
||||
}
|
||||
if f.Username != "" && s.Username != f.Username {
|
||||
return false
|
||||
}
|
||||
if f.HumanScoreAboveZero {
|
||||
if s.HumanScore == nil || *s.HumanScore <= 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, s := range m.Sessions {
|
||||
if s.ExecCommand != nil {
|
||||
counts[*s.ExecCommand]++
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var count int64
|
||||
t := disconnectedAt.UTC()
|
||||
for _, s := range m.Sessions {
|
||||
if s.DisconnectedAt == nil {
|
||||
s.DisconnectedAt = &t
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetAttemptsOverTime(_ context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var cutoff time.Time
|
||||
if since != nil {
|
||||
cutoff = *since
|
||||
} else {
|
||||
cutoff = time.Now().UTC().AddDate(0, 0, -days)
|
||||
}
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if a.LastSeen.Before(cutoff) {
|
||||
continue
|
||||
}
|
||||
if until != nil && a.LastSeen.After(*until) {
|
||||
continue
|
||||
}
|
||||
day := a.LastSeen.Format("2006-01-02")
|
||||
counts[day] += int64(a.Count)
|
||||
}
|
||||
|
||||
points := make([]TimeSeriesPoint, 0, len(counts))
|
||||
for day, count := range counts {
|
||||
t, _ := time.Parse("2006-01-02", day)
|
||||
points = append(points, TimeSeriesPoint{Timestamp: t, Count: count})
|
||||
}
|
||||
sort.Slice(points, func(i, j int) bool {
|
||||
return points[i].Timestamp.Before(points[j].Timestamp)
|
||||
})
|
||||
return points, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetHourlyPattern(_ context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
hourCounts := make(map[int]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if since != nil && a.LastSeen.Before(*since) {
|
||||
continue
|
||||
}
|
||||
if until != nil && a.LastSeen.After(*until) {
|
||||
continue
|
||||
}
|
||||
hour := a.LastSeen.Hour()
|
||||
hourCounts[hour] += int64(a.Count)
|
||||
}
|
||||
|
||||
counts := make([]HourlyCount, 0, len(hourCounts))
|
||||
for h, c := range hourCounts {
|
||||
counts = append(counts, HourlyCount{Hour: h, Count: c})
|
||||
}
|
||||
sort.Slice(counts, func(i, j int) bool {
|
||||
return counts[i].Hour < counts[j].Hour
|
||||
})
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetCountryStats(_ context.Context) ([]CountryCount, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for _, a := range m.LoginAttempts {
|
||||
if a.Country == "" {
|
||||
continue
|
||||
}
|
||||
counts[a.Country] += int64(a.Count)
|
||||
}
|
||||
|
||||
result := make([]CountryCount, 0, len(counts))
|
||||
for country, count := range counts {
|
||||
result = append(result, CountryCount{Country: country, Count: count})
|
||||
}
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].Count > result[j].Count
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// matchesFilter returns true if the login attempt matches the given filter. Must be called with m.mu held.
|
||||
func matchesFilter(a *LoginAttempt, f DashboardFilter) bool {
|
||||
if f.Since != nil && a.LastSeen.Before(*f.Since) {
|
||||
return false
|
||||
}
|
||||
if f.Until != nil && a.LastSeen.After(*f.Until) {
|
||||
return false
|
||||
}
|
||||
if f.IP != "" && a.IP != f.IP {
|
||||
return false
|
||||
}
|
||||
if f.Country != "" && a.Country != f.Country {
|
||||
return false
|
||||
}
|
||||
if f.Username != "" && a.Username != f.Username {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredDashboardStats(_ context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
stats := &DashboardStats{}
|
||||
ips := make(map[string]struct{})
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
stats.TotalAttempts += int64(a.Count)
|
||||
ips[a.IP] = struct{}{}
|
||||
}
|
||||
stats.UniqueIPs = int64(len(ips))
|
||||
|
||||
for _, s := range m.Sessions {
|
||||
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||
continue
|
||||
}
|
||||
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||
continue
|
||||
}
|
||||
if f.IP != "" && s.IP != f.IP {
|
||||
continue
|
||||
}
|
||||
if f.Country != "" && s.Country != f.Country {
|
||||
continue
|
||||
}
|
||||
stats.TotalSessions++
|
||||
if s.DisconnectedAt == nil {
|
||||
stats.ActiveSessions++
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopUsernames(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.filteredTopN("username", limit, f), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopPasswords(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.filteredTopN("password", limit, f), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopIPs(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
type ipInfo struct {
|
||||
count int64
|
||||
country string
|
||||
}
|
||||
agg := make(map[string]*ipInfo)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
info, ok := agg[a.IP]
|
||||
if !ok {
|
||||
info = &ipInfo{}
|
||||
agg[a.IP] = info
|
||||
}
|
||||
info.count += int64(a.Count)
|
||||
if a.Country != "" {
|
||||
info.country = a.Country
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(agg))
|
||||
for ip, info := range agg {
|
||||
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredTopCountries(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
counts := make(map[string]int64)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if a.Country == "" {
|
||||
continue
|
||||
}
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
counts[a.Country] += int64(a.Count)
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// filteredTopN aggregates login attempts by the given field with filter applied and returns the top N. Must be called with m.mu held.
|
||||
func (m *MemoryStore) filteredTopN(field string, limit int, f DashboardFilter) []TopEntry {
|
||||
counts := make(map[string]int64)
|
||||
for i := range m.LoginAttempts {
|
||||
a := &m.LoginAttempts[i]
|
||||
if !matchesFilter(a, f) {
|
||||
continue
|
||||
}
|
||||
var key string
|
||||
switch field {
|
||||
case "username":
|
||||
key = a.Username
|
||||
case "password":
|
||||
key = a.Password
|
||||
case "ip":
|
||||
key = a.IP
|
||||
}
|
||||
counts[key] += int64(a.Count)
|
||||
}
|
||||
|
||||
entries := make([]TopEntry, 0, len(counts))
|
||||
for k, v := range counts {
|
||||
entries = append(entries, TopEntry{Value: k, Count: v})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].Count > entries[j].Count
|
||||
})
|
||||
if limit > 0 && len(entries) > limit {
|
||||
entries = entries[:limit]
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
func (m *MemoryStore) Close() error {
|
||||
|
||||
9
internal/storage/migrations/002_session_events.sql
Normal file
9
internal/storage/migrations/002_session_events.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
CREATE TABLE session_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
timestamp TEXT NOT NULL,
|
||||
direction INTEGER NOT NULL,
|
||||
data BLOB NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX idx_session_events_session_id ON session_events(session_id);
|
||||
3
internal/storage/migrations/003_add_country.sql
Normal file
3
internal/storage/migrations/003_add_country.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE login_attempts ADD COLUMN country TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE sessions ADD COLUMN country TEXT NOT NULL DEFAULT '';
|
||||
CREATE INDEX idx_login_attempts_country ON login_attempts(country);
|
||||
1
internal/storage/migrations/004_add_exec_command.sql
Normal file
1
internal/storage/migrations/004_add_exec_command.sql
Normal file
@@ -0,0 +1 @@
|
||||
ALTER TABLE sessions ADD COLUMN exec_command TEXT;
|
||||
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
|
||||
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
|
||||
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);
|
||||
@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 1 {
|
||||
t.Errorf("version = %d, want 1", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d, want 5", version)
|
||||
}
|
||||
|
||||
// Verify tables exist by inserting into them.
|
||||
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 1 {
|
||||
t.Errorf("version = %d after double migrate, want 1", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d after double migrate, want 5", version)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ func TestRunRetentionDeletesOldRecords(t *testing.T) {
|
||||
}
|
||||
|
||||
// Insert a recent login attempt.
|
||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
|
||||
t.Fatalf("insert recent attempt: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -34,28 +35,29 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
|
||||
return &SQLiteStore{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip string) error {
|
||||
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO login_attempts (username, password, ip, count, first_seen, last_seen)
|
||||
VALUES (?, ?, ?, 1, ?, ?)
|
||||
INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
|
||||
VALUES (?, ?, ?, ?, 1, ?, ?)
|
||||
ON CONFLICT(username, password, ip) DO UPDATE SET
|
||||
count = count + 1,
|
||||
last_seen = ?`,
|
||||
username, password, ip, now, now, now)
|
||||
last_seen = ?,
|
||||
country = ?`,
|
||||
username, password, ip, country, now, now, now, country)
|
||||
if err != nil {
|
||||
return fmt.Errorf("recording login attempt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName string) (string, error) {
|
||||
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
INSERT INTO sessions (id, ip, username, shell_name, connected_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
id, ip, username, shellName, now)
|
||||
INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
id, ip, username, shellName, country, now)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating session: %w", err)
|
||||
}
|
||||
@@ -82,6 +84,16 @@ func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, sc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET exec_command = ? WHERE id = ?`,
|
||||
command, sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setting exec command: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
_, err := s.db.ExecContext(ctx, `
|
||||
@@ -94,6 +106,115 @@ func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, ou
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||
var sess Session
|
||||
var connectedAt string
|
||||
var disconnectedAt sql.NullString
|
||||
var humanScore sql.NullFloat64
|
||||
var execCommand sql.NullString
|
||||
|
||||
err := s.db.QueryRowContext(ctx, `
|
||||
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
|
||||
FROM sessions WHERE id = ?`, sessionID).Scan(
|
||||
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
|
||||
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying session: %w", err)
|
||||
}
|
||||
|
||||
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
if disconnectedAt.Valid {
|
||||
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
||||
sess.DisconnectedAt = &t
|
||||
}
|
||||
if humanScore.Valid {
|
||||
sess.HumanScore = &humanScore.Float64
|
||||
}
|
||||
if execCommand.Valid {
|
||||
sess.ExecCommand = &execCommand.String
|
||||
}
|
||||
return &sess, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT id, session_id, timestamp, input, output
|
||||
FROM session_logs WHERE session_id = ?
|
||||
ORDER BY timestamp`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying session logs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var logs []SessionLog
|
||||
for rows.Next() {
|
||||
var l SessionLog
|
||||
var ts string
|
||||
if err := rows.Scan(&l.ID, &l.SessionID, &ts, &l.Input, &l.Output); err != nil {
|
||||
return nil, fmt.Errorf("scanning session log: %w", err)
|
||||
}
|
||||
l.Timestamp, _ = time.Parse(time.RFC3339, ts)
|
||||
logs = append(logs, l)
|
||||
}
|
||||
return logs, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
||||
if len(events) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
INSERT INTO session_events (session_id, timestamp, direction, data)
|
||||
VALUES (?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing statement: %w", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, e := range events {
|
||||
_, err := stmt.ExecContext(ctx, e.SessionID, e.Timestamp.UTC().Format(time.RFC3339Nano), e.Direction, e.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("inserting session event: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT session_id, timestamp, direction, data
|
||||
FROM session_events WHERE session_id = ?
|
||||
ORDER BY id`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying session events: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var events []SessionEvent
|
||||
for rows.Next() {
|
||||
var e SessionEvent
|
||||
var ts string
|
||||
if err := rows.Scan(&e.SessionID, &ts, &e.Direction, &e.Data); err != nil {
|
||||
return nil, fmt.Errorf("scanning session event: %w", err)
|
||||
}
|
||||
e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts)
|
||||
events = append(events, e)
|
||||
}
|
||||
return events, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||
cutoffStr := cutoff.UTC().Format(time.RFC3339)
|
||||
|
||||
@@ -105,15 +226,26 @@ func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time)
|
||||
|
||||
var total int64
|
||||
|
||||
// Delete session logs for old sessions.
|
||||
// Delete session events for old sessions.
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
DELETE FROM session_events WHERE session_id IN (
|
||||
SELECT id FROM sessions WHERE connected_at < ?
|
||||
)`, cutoffStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("deleting session events: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
total += n
|
||||
|
||||
// Delete session logs for old sessions.
|
||||
res, err = tx.ExecContext(ctx, `
|
||||
DELETE FROM session_logs WHERE session_id IN (
|
||||
SELECT id FROM sessions WHERE connected_at < ?
|
||||
)`, cutoffStr)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("deleting session logs: %w", err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
n, _ = res.RowsAffected()
|
||||
total += n
|
||||
|
||||
// Delete old sessions.
|
||||
@@ -172,10 +304,60 @@ func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntr
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return s.queryTopN(ctx, "ip", limit)
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT ip, country, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
GROUP BY ip
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying top IPs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning top IPs: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT country, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
WHERE country != ''
|
||||
GROUP BY country
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying top countries: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning top countries: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
|
||||
switch column {
|
||||
case "username", "password", "ip":
|
||||
// valid columns
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid column: %s", column)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT %s, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
@@ -201,40 +383,401 @@ func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) (
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
query := `SELECT id, ip, username, shell_name, connected_at, disconnected_at, human_score FROM sessions`
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
|
||||
if activeOnly {
|
||||
query += ` WHERE disconnected_at IS NULL`
|
||||
query += ` WHERE s.disconnected_at IS NULL`
|
||||
}
|
||||
query += ` ORDER BY connected_at DESC LIMIT ?`
|
||||
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, limit)
|
||||
return s.scanSessions(ctx, query, limit)
|
||||
}
|
||||
|
||||
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
|
||||
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
|
||||
var clauses []string
|
||||
var args []any
|
||||
|
||||
if activeOnly {
|
||||
clauses = append(clauses, "s.disconnected_at IS NULL")
|
||||
}
|
||||
if f.Since != nil {
|
||||
clauses = append(clauses, "s.connected_at >= ?")
|
||||
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
clauses = append(clauses, "s.connected_at <= ?")
|
||||
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
clauses = append(clauses, "s.ip = ?")
|
||||
args = append(args, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
clauses = append(clauses, "s.country = ?")
|
||||
args = append(args, f.Country)
|
||||
}
|
||||
if f.Username != "" {
|
||||
clauses = append(clauses, "s.username = ?")
|
||||
args = append(args, f.Username)
|
||||
}
|
||||
if f.HumanScoreAboveZero {
|
||||
clauses = append(clauses, "s.human_score > 0")
|
||||
}
|
||||
|
||||
if len(clauses) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
|
||||
var validSessionSorts = map[string]string{
|
||||
"connected_at": "s.connected_at DESC",
|
||||
"input_bytes": "input_bytes DESC",
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
where, args := buildSessionWhereClause(f, activeOnly)
|
||||
args = append(args, limit)
|
||||
|
||||
orderBy := validSessionSorts["connected_at"]
|
||||
if mapped, ok := validSessionSorts[f.SortBy]; ok {
|
||||
orderBy = mapped
|
||||
}
|
||||
|
||||
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
|
||||
|
||||
return s.scanSessions(ctx, query, args...)
|
||||
}
|
||||
|
||||
// scanSessions executes a session query and scans the results.
|
||||
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying recent sessions: %w", err)
|
||||
return nil, fmt.Errorf("querying sessions: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var sessions []Session
|
||||
for rows.Next() {
|
||||
var s Session
|
||||
var sess Session
|
||||
var connectedAt string
|
||||
var disconnectedAt sql.NullString
|
||||
var humanScore sql.NullFloat64
|
||||
if err := rows.Scan(&s.ID, &s.IP, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore); err != nil {
|
||||
var execCommand sql.NullString
|
||||
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
|
||||
return nil, fmt.Errorf("scanning session: %w", err)
|
||||
}
|
||||
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
if disconnectedAt.Valid {
|
||||
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
||||
s.DisconnectedAt = &t
|
||||
sess.DisconnectedAt = &t
|
||||
}
|
||||
if humanScore.Valid {
|
||||
s.HumanScore = &humanScore.Float64
|
||||
sess.HumanScore = &humanScore.Float64
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
if execCommand.Valid {
|
||||
sess.ExecCommand = &execCommand.String
|
||||
}
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT exec_command, COUNT(*) as total
|
||||
FROM sessions
|
||||
WHERE exec_command IS NOT NULL
|
||||
GROUP BY exec_command
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying top exec commands: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning top exec commands: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
res, err := s.db.ExecContext(ctx, `
|
||||
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
|
||||
disconnectedAt.UTC().Format(time.RFC3339))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("closing active sessions: %w", err)
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
query := `SELECT DATE(last_seen) AS d, SUM(count) FROM login_attempts WHERE 1=1`
|
||||
var args []any
|
||||
|
||||
if since != nil {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, since.UTC().Format(time.RFC3339))
|
||||
} else {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, time.Now().UTC().AddDate(0, 0, -days).Format("2006-01-02"))
|
||||
}
|
||||
if until != nil {
|
||||
query += ` AND last_seen <= ?`
|
||||
args = append(args, until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
query += ` GROUP BY d ORDER BY d`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying attempts over time: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var points []TimeSeriesPoint
|
||||
for rows.Next() {
|
||||
var dateStr string
|
||||
var p TimeSeriesPoint
|
||||
if err := rows.Scan(&dateStr, &p.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning time series point: %w", err)
|
||||
}
|
||||
p.Timestamp, _ = time.Parse("2006-01-02", dateStr)
|
||||
points = append(points, p)
|
||||
}
|
||||
return points, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
query := `SELECT CAST(STRFTIME('%H', last_seen) AS INTEGER) AS h, SUM(count) FROM login_attempts WHERE 1=1`
|
||||
var args []any
|
||||
|
||||
if since != nil {
|
||||
query += ` AND last_seen >= ?`
|
||||
args = append(args, since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if until != nil {
|
||||
query += ` AND last_seen <= ?`
|
||||
args = append(args, until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
query += ` GROUP BY h ORDER BY h`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying hourly pattern: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var counts []HourlyCount
|
||||
for rows.Next() {
|
||||
var c HourlyCount
|
||||
if err := rows.Scan(&c.Hour, &c.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning hourly count: %w", err)
|
||||
}
|
||||
counts = append(counts, c)
|
||||
}
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT country, SUM(count) AS total
|
||||
FROM login_attempts
|
||||
WHERE country != ''
|
||||
GROUP BY country
|
||||
ORDER BY total DESC`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying country stats: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var counts []CountryCount
|
||||
for rows.Next() {
|
||||
var c CountryCount
|
||||
if err := rows.Scan(&c.Country, &c.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning country count: %w", err)
|
||||
}
|
||||
counts = append(counts, c)
|
||||
}
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
// buildAttemptWhereClause builds a dynamic WHERE clause for login_attempts filtering.
|
||||
func buildAttemptWhereClause(f DashboardFilter) (string, []any) {
|
||||
var clauses []string
|
||||
var args []any
|
||||
|
||||
if f.Since != nil {
|
||||
clauses = append(clauses, "last_seen >= ?")
|
||||
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
clauses = append(clauses, "last_seen <= ?")
|
||||
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
clauses = append(clauses, "ip = ?")
|
||||
args = append(args, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
clauses = append(clauses, "country = ?")
|
||||
args = append(args, f.Country)
|
||||
}
|
||||
if f.Username != "" {
|
||||
clauses = append(clauses, "username = ?")
|
||||
args = append(args, f.Username)
|
||||
}
|
||||
|
||||
if len(clauses) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
stats := &DashboardStats{}
|
||||
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
`SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`+where, args...).
|
||||
Scan(&stats.TotalAttempts, &stats.UniqueIPs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered attempt stats: %w", err)
|
||||
}
|
||||
|
||||
// Sessions don't have username/password, so only filter by time, IP, country.
|
||||
sessQuery := `SELECT COUNT(*) FROM sessions WHERE 1=1`
|
||||
var sessArgs []any
|
||||
if f.Since != nil {
|
||||
sessQuery += ` AND connected_at >= ?`
|
||||
sessArgs = append(sessArgs, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
sessQuery += ` AND connected_at <= ?`
|
||||
sessArgs = append(sessArgs, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
sessQuery += ` AND ip = ?`
|
||||
sessArgs = append(sessArgs, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
sessQuery += ` AND country = ?`
|
||||
sessArgs = append(sessArgs, f.Country)
|
||||
}
|
||||
|
||||
err = s.db.QueryRowContext(ctx, sessQuery, sessArgs...).Scan(&stats.TotalSessions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered total sessions: %w", err)
|
||||
}
|
||||
|
||||
err = s.db.QueryRowContext(ctx, sessQuery+` AND disconnected_at IS NULL`, sessArgs...).Scan(&stats.ActiveSessions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered active sessions: %w", err)
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return s.queryFilteredTopN(ctx, "username", limit, f)
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return s.queryFilteredTopN(ctx, "password", limit, f)
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
args = append(args, limit)
|
||||
//nolint:gosec // where clause built from trusted constants, not user input
|
||||
query := `SELECT ip, country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY ip ORDER BY total DESC LIMIT ?`
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top IPs: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top IPs: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
countryClause := "country != ''"
|
||||
if where == "" {
|
||||
where = " WHERE " + countryClause
|
||||
} else {
|
||||
where += " AND " + countryClause
|
||||
}
|
||||
args = append(args, limit)
|
||||
//nolint:gosec // where clause built from trusted constants, not user input
|
||||
query := `SELECT country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY country ORDER BY total DESC LIMIT ?`
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top countries: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top countries: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) queryFilteredTopN(ctx context.Context, column string, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
switch column {
|
||||
case "username", "password":
|
||||
// valid columns
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid column: %s", column)
|
||||
}
|
||||
|
||||
where, args := buildAttemptWhereClause(f)
|
||||
args = append(args, limit)
|
||||
query := fmt.Sprintf(`
|
||||
SELECT %s, SUM(count) AS total
|
||||
FROM login_attempts%s
|
||||
GROUP BY %s
|
||||
ORDER BY total DESC
|
||||
LIMIT ?`, column, where, column)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying filtered top %s: %w", column, err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var entries []TopEntry
|
||||
for rows.Next() {
|
||||
var e TopEntry
|
||||
if err := rows.Scan(&e.Value, &e.Count); err != nil {
|
||||
return nil, fmt.Errorf("scanning filtered top %s: %w", column, err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
return entries, rows.Err()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
@@ -23,17 +23,17 @@ func TestRecordLoginAttempt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// First attempt creates a new record.
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("first attempt: %v", err)
|
||||
}
|
||||
|
||||
// Second attempt with same credentials increments count.
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("second attempt: %v", err)
|
||||
}
|
||||
|
||||
// Different IP is a separate record.
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
|
||||
t.Fatalf("different IP: %v", err)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestCreateAndEndSession(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
@@ -100,7 +100,7 @@ func TestUpdateHumanScore(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
@@ -123,7 +123,7 @@ func TestAppendSessionLog(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
@@ -159,7 +159,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
||||
}
|
||||
|
||||
// Insert a recent login attempt.
|
||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
|
||||
t.Fatalf("insert recent attempt: %v", err)
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
||||
}
|
||||
|
||||
// Insert a recent session.
|
||||
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", ""); err != nil {
|
||||
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", "", ""); err != nil {
|
||||
t.Fatalf("insert recent session: %v", err)
|
||||
}
|
||||
|
||||
@@ -204,12 +204,81 @@ func TestDeleteRecordsBefore(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTopExecCommands(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create sessions with exec commands.
|
||||
for range 3 {
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
}
|
||||
for range 2 {
|
||||
id, err := store.CreateSession(ctx, "10.0.0.2", "admin", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "cat /etc/passwd"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
}
|
||||
// Session without exec command — should not appear.
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.3", "test", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
|
||||
entries, err := store.GetTopExecCommands(ctx, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTopExecCommands: %v", err)
|
||||
}
|
||||
if len(entries) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(entries))
|
||||
}
|
||||
if entries[0].Value != "uname -a" || entries[0].Count != 3 {
|
||||
t.Errorf("entries[0] = %+v, want uname -a:3", entries[0])
|
||||
}
|
||||
if entries[1].Value != "cat /etc/passwd" || entries[1].Count != 2 {
|
||||
t.Errorf("entries[1] = %+v, want cat /etc/passwd:2", entries[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRecentSessionsEventCount(t *testing.T) {
|
||||
store := newTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
|
||||
// Add some events.
|
||||
events := []SessionEvent{
|
||||
{SessionID: id, Timestamp: time.Now(), Direction: 0, Data: []byte("ls\n")},
|
||||
{SessionID: id, Timestamp: time.Now(), Direction: 1, Data: []byte("file1\n")},
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("appending events: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].EventCount != 2 {
|
||||
t.Errorf("EventCount = %d, want 2", sessions[0].EventCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSQLiteStoreCreatesFile(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "subdir", "test.db")
|
||||
// Parent directory doesn't exist yet; SQLite should create it.
|
||||
// Actually, SQLite doesn't create parent dirs, but the file itself.
|
||||
// Use a path in the temp dir directly.
|
||||
dbPath = filepath.Join(t.TempDir(), "test.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||
store, err := NewSQLiteStore(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("creating store: %v", err)
|
||||
@@ -218,7 +287,7 @@ func TestNewSQLiteStoreCreatesFile(t *testing.T) {
|
||||
|
||||
// Verify we can use the store.
|
||||
ctx := context.Background()
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1", ""); err != nil {
|
||||
t.Fatalf("recording attempt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ type LoginAttempt struct {
|
||||
Username string
|
||||
Password string
|
||||
IP string
|
||||
Country string
|
||||
Count int
|
||||
FirstSeen time.Time
|
||||
LastSeen time.Time
|
||||
@@ -20,11 +21,15 @@ type LoginAttempt struct {
|
||||
type Session struct {
|
||||
ID string
|
||||
IP string
|
||||
Country string
|
||||
Username string
|
||||
ShellName string
|
||||
ConnectedAt time.Time
|
||||
DisconnectedAt *time.Time
|
||||
HumanScore *float64
|
||||
ExecCommand *string
|
||||
EventCount int
|
||||
InputBytes int64
|
||||
}
|
||||
|
||||
// SessionLog represents a single log entry for a session.
|
||||
@@ -36,6 +41,14 @@ type SessionLog struct {
|
||||
Output string
|
||||
}
|
||||
|
||||
// SessionEvent represents a single I/O event recorded during a session.
|
||||
type SessionEvent struct {
|
||||
SessionID string
|
||||
Timestamp time.Time
|
||||
Direction int // 0=input (client→server), 1=output (server→client)
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// DashboardStats holds aggregate counts for the web dashboard.
|
||||
type DashboardStats struct {
|
||||
TotalAttempts int64
|
||||
@@ -44,20 +57,50 @@ type DashboardStats struct {
|
||||
ActiveSessions int64
|
||||
}
|
||||
|
||||
// TimeSeriesPoint represents a single data point in a time series.
|
||||
type TimeSeriesPoint struct {
|
||||
Timestamp time.Time
|
||||
Count int64
|
||||
}
|
||||
|
||||
// HourlyCount represents the total attempts for a given hour of day.
|
||||
type HourlyCount struct {
|
||||
Hour int // 0-23
|
||||
Count int64
|
||||
}
|
||||
|
||||
// CountryCount represents the total attempts from a given country.
|
||||
type CountryCount struct {
|
||||
Country string
|
||||
Count int64
|
||||
}
|
||||
|
||||
// DashboardFilter contains optional filters for dashboard queries.
|
||||
type DashboardFilter struct {
|
||||
Since *time.Time
|
||||
Until *time.Time
|
||||
IP string
|
||||
Country string
|
||||
Username string
|
||||
HumanScoreAboveZero bool
|
||||
SortBy string
|
||||
}
|
||||
|
||||
// TopEntry represents a value and its count for top-N queries.
|
||||
type TopEntry struct {
|
||||
Value string
|
||||
Count int64
|
||||
Value string
|
||||
Country string // populated by GetTopIPs
|
||||
Count int64
|
||||
}
|
||||
|
||||
// Store is the interface for persistent storage of honeypot data.
|
||||
type Store interface {
|
||||
// RecordLoginAttempt upserts a login attempt, incrementing the count
|
||||
// for existing (username, password, ip) combinations.
|
||||
RecordLoginAttempt(ctx context.Context, username, password, ip string) error
|
||||
RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error
|
||||
|
||||
// CreateSession creates a new session record and returns its UUID.
|
||||
CreateSession(ctx context.Context, ip, username, shellName string) (string, error)
|
||||
CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error)
|
||||
|
||||
// EndSession sets the disconnected_at timestamp for a session.
|
||||
EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error
|
||||
@@ -65,6 +108,9 @@ type Store interface {
|
||||
// UpdateHumanScore sets the human detection score for a session.
|
||||
UpdateHumanScore(ctx context.Context, sessionID string, score float64) error
|
||||
|
||||
// SetExecCommand sets the exec command for a session.
|
||||
SetExecCommand(ctx context.Context, sessionID string, command string) error
|
||||
|
||||
// AppendSessionLog adds a log entry to a session.
|
||||
AppendSessionLog(ctx context.Context, sessionID, input, output string) error
|
||||
|
||||
@@ -84,10 +130,61 @@ type Store interface {
|
||||
// GetTopIPs returns the top N IPs by total attempt count.
|
||||
GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error)
|
||||
|
||||
// GetTopCountries returns the top N countries by total attempt count.
|
||||
GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error)
|
||||
|
||||
// GetTopExecCommands returns the top N exec commands by session count.
|
||||
GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error)
|
||||
|
||||
// GetRecentSessions returns the most recent sessions ordered by connected_at DESC.
|
||||
// If activeOnly is true, only sessions with no disconnected_at are returned.
|
||||
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
|
||||
|
||||
// GetFilteredSessions returns sessions matching the given filter, ordered
|
||||
// by the filter's SortBy field (default: connected_at DESC).
|
||||
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
|
||||
|
||||
// GetSession returns a single session by ID.
|
||||
GetSession(ctx context.Context, sessionID string) (*Session, error)
|
||||
|
||||
// GetSessionLogs returns all log entries for a session ordered by timestamp.
|
||||
GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error)
|
||||
|
||||
// AppendSessionEvents batch-inserts session events.
|
||||
AppendSessionEvents(ctx context.Context, events []SessionEvent) error
|
||||
|
||||
// GetSessionEvents returns all events for a session ordered by id.
|
||||
GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error)
|
||||
|
||||
// CloseActiveSessions sets disconnected_at for all sessions that are
|
||||
// still marked as active. This should be called at startup to clean up
|
||||
// sessions left over from a previous unclean shutdown.
|
||||
CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error)
|
||||
|
||||
// GetAttemptsOverTime returns daily attempt counts for the last N days.
|
||||
GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error)
|
||||
|
||||
// GetHourlyPattern returns total attempts grouped by hour of day (0-23).
|
||||
GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error)
|
||||
|
||||
// GetCountryStats returns total attempts per country, ordered by count DESC.
|
||||
GetCountryStats(ctx context.Context) ([]CountryCount, error)
|
||||
|
||||
// GetFilteredDashboardStats returns aggregate counts with optional filters applied.
|
||||
GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error)
|
||||
|
||||
// GetFilteredTopUsernames returns top usernames with optional filters applied.
|
||||
GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopPasswords returns top passwords with optional filters applied.
|
||||
GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopIPs returns top IPs with optional filters applied.
|
||||
GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// GetFilteredTopCountries returns top countries with optional filters applied.
|
||||
GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
|
||||
|
||||
// Close releases any resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -37,24 +37,24 @@ func seedData(t *testing.T, store Store) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Login attempts: root/toor from two IPs, admin/admin from one IP.
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
|
||||
for range 3 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1"); err != nil {
|
||||
for range 2 {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Sessions: one active, one ended.
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
@@ -62,7 +62,7 @@ func seedData(t *testing.T, store Store) {
|
||||
t.Fatalf("ending session: %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash"); err != nil {
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -194,6 +194,456 @@ func TestGetTopIPs(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetSession(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
s, err := store.GetSession(context.Background(), "nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if s != nil {
|
||||
t.Errorf("expected nil, got %+v", s)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
s, err := store.GetSession(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatal("expected session, got nil")
|
||||
}
|
||||
if s.ID != id || s.IP != "10.0.0.1" || s.Username != "root" || s.ShellName != "bash" {
|
||||
t.Errorf("unexpected session: %+v", s)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetSessionLogs(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.AppendSessionLog(ctx, id, "ls", "file1\nfile2"); err != nil {
|
||||
t.Fatalf("AppendSessionLog: %v", err)
|
||||
}
|
||||
if err := store.AppendSessionLog(ctx, id, "pwd", "/home/root"); err != nil {
|
||||
t.Fatalf("AppendSessionLog: %v", err)
|
||||
}
|
||||
|
||||
logs, err := store.GetSessionLogs(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionLogs: %v", err)
|
||||
}
|
||||
if len(logs) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(logs))
|
||||
}
|
||||
if logs[0].Input != "ls" {
|
||||
t.Errorf("logs[0].Input = %q, want %q", logs[0].Input, "ls")
|
||||
}
|
||||
if logs[1].Input != "pwd" {
|
||||
t.Errorf("logs[1].Input = %q, want %q", logs[1].Input, "pwd")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionEvents(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
events, err := store.GetSessionEvents(context.Background(), "nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionEvents: %v", err)
|
||||
}
|
||||
if len(events) != 0 {
|
||||
t.Errorf("expected empty, got %d", len(events))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("append and retrieve", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
events := []SessionEvent{
|
||||
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")},
|
||||
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")},
|
||||
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")},
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
got, err := store.GetSessionEvents(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSessionEvents: %v", err)
|
||||
}
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("len = %d, want 3", len(got))
|
||||
}
|
||||
if got[0].Direction != 0 || string(got[0].Data) != "ls\n" {
|
||||
t.Errorf("got[0] = %+v", got[0])
|
||||
}
|
||||
if got[1].Direction != 1 || string(got[1].Data) != "file1\nfile2\n" {
|
||||
t.Errorf("got[1] = %+v", got[1])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("append empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
if err := store.AppendSessionEvents(context.Background(), nil); err != nil {
|
||||
t.Fatalf("AppendSessionEvents(nil): %v", err)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestCloseActiveSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("no active sessions", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
n, err := store.CloseActiveSessions(ctx, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("CloseActiveSessions: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("closed %d, want 0", n)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("closes only active sessions", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 3 sessions: end one, leave two active.
|
||||
id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
|
||||
store.CreateSession(ctx, "10.0.0.3", "test", "bash", "")
|
||||
store.EndSession(ctx, id1, time.Now())
|
||||
|
||||
n, err := store.CloseActiveSessions(ctx, time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("CloseActiveSessions: %v", err)
|
||||
}
|
||||
if n != 2 {
|
||||
t.Errorf("closed %d, want 2", n)
|
||||
}
|
||||
|
||||
// Verify no active sessions remain.
|
||||
active, err := store.GetRecentSessions(ctx, 10, true)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(active) != 0 {
|
||||
t.Errorf("active sessions = %d, want 0", len(active))
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetExecCommand(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("set and retrieve", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Initially nil.
|
||||
s, err := store.GetSession(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if s.ExecCommand != nil {
|
||||
t.Errorf("expected nil ExecCommand, got %q", *s.ExecCommand)
|
||||
}
|
||||
|
||||
// Set exec command.
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("SetExecCommand: %v", err)
|
||||
}
|
||||
|
||||
s, err = store.GetSession(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatalf("GetSession: %v", err)
|
||||
}
|
||||
if s.ExecCommand == nil {
|
||||
t.Fatal("expected non-nil ExecCommand")
|
||||
}
|
||||
if *s.ExecCommand != "uname -a" {
|
||||
t.Errorf("ExecCommand = %q, want %q", *s.ExecCommand, "uname -a")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("appears in recent sessions", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "id"); err != nil {
|
||||
t.Fatalf("SetExecCommand: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ExecCommand == nil || *sessions[0].ExecCommand != "id" {
|
||||
t.Errorf("ExecCommand = %v, want \"id\"", sessions[0].ExecCommand)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func seedChartData(t *testing.T, store Store) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
// Record attempts with country data from different IPs.
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for range 3 {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
for range 2 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "123456", "10.0.0.3", "CN"); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAttemptsOverTime(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAttemptsOverTime: %v", err)
|
||||
}
|
||||
if len(points) != 0 {
|
||||
t.Errorf("expected empty, got %v", points)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAttemptsOverTime: %v", err)
|
||||
}
|
||||
// All data was inserted today, so should be one point.
|
||||
if len(points) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(points))
|
||||
}
|
||||
// 5 + 3 + 2 = 10 total.
|
||||
if points[0].Count != 10 {
|
||||
t.Errorf("count = %d, want 10", points[0].Count)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetHourlyPattern(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetHourlyPattern: %v", err)
|
||||
}
|
||||
if len(counts) != 0 {
|
||||
t.Errorf("expected empty, got %v", counts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetHourlyPattern: %v", err)
|
||||
}
|
||||
// All data was inserted at the same hour.
|
||||
if len(counts) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(counts))
|
||||
}
|
||||
if counts[0].Count != 10 {
|
||||
t.Errorf("count = %d, want 10", counts[0].Count)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetCountryStats(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
counts, err := store.GetCountryStats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 0 {
|
||||
t.Errorf("expected empty, got %v", counts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
counts, err := store.GetCountryStats(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(counts))
|
||||
}
|
||||
// CN: 5 + 2 = 7, RU: 3 - ordered by count DESC.
|
||||
if counts[0].Country != "CN" || counts[0].Count != 7 {
|
||||
t.Errorf("counts[0] = %+v, want CN/7", counts[0])
|
||||
}
|
||||
if counts[1].Country != "RU" || counts[1].Count != 3 {
|
||||
t.Errorf("counts[1] = %+v, want RU/3", counts[1])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("excludes empty country", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.2", "US"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
|
||||
counts, err := store.GetCountryStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCountryStats: %v", err)
|
||||
}
|
||||
if len(counts) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(counts))
|
||||
}
|
||||
if counts[0].Country != "US" {
|
||||
t.Errorf("country = %q, want US", counts[0].Country)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredDashboardStats(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("no filter", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 10 {
|
||||
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by country", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Country: "CN"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
// CN: 5 + 2 = 7
|
||||
if stats.TotalAttempts != 7 {
|
||||
t.Errorf("TotalAttempts = %d, want 7", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by IP", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{IP: "10.0.0.1"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 5 {
|
||||
t.Errorf("TotalAttempts = %d, want 5", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("filter by username", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Username: "admin"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredDashboardStats: %v", err)
|
||||
}
|
||||
if stats.TotalAttempts != 3 {
|
||||
t.Errorf("TotalAttempts = %d, want 3", stats.TotalAttempts)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredTopUsernames(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
store := newStore(t)
|
||||
seedChartData(t, store)
|
||||
|
||||
// Filter by country CN should only show root.
|
||||
entries, err := store.GetFilteredTopUsernames(context.Background(), 10, DashboardFilter{Country: "CN"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredTopUsernames: %v", err)
|
||||
}
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(entries))
|
||||
}
|
||||
if entries[0].Value != "root" || entries[0].Count != 7 {
|
||||
t.Errorf("entries[0] = %+v, want root/7", entries[0])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetRecentSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
@@ -250,3 +700,192 @@ func TestGetRecentSessions(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestInputBytes(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("counts only input direction", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
events := []SessionEvent{
|
||||
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
|
||||
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
|
||||
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
|
||||
if sessions[0].InputBytes != 7 {
|
||||
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero when no events", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].InputBytes != 0 {
|
||||
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("filter by human score", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create two sessions, one with human score > 0.
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort by input bytes", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Session with more input (created first).
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
|
||||
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
|
||||
}); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
// Session with less input (created after id1, so would be first by connected_at).
|
||||
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
|
||||
}); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
// Default sort (connected_at DESC) should show id2 first.
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id2 {
|
||||
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
|
||||
}
|
||||
|
||||
// Sort by input_bytes should show id1 first (more input).
|
||||
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("combined filters", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
// Different country, also has score.
|
||||
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
// Same country CN but no score.
|
||||
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Filter: CN + human score > 0 -> only id1.
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
|
||||
Country: "CN",
|
||||
HumanScoreAboveZero: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,22 +1,37 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// dbContext returns a context detached from the HTTP request lifecycle with a
|
||||
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
|
||||
// queries when the browser aborts the previous XHR.
|
||||
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
|
||||
}
|
||||
|
||||
type dashboardData struct {
|
||||
Stats *storage.DashboardStats
|
||||
TopUsernames []storage.TopEntry
|
||||
TopPasswords []storage.TopEntry
|
||||
TopIPs []storage.TopEntry
|
||||
ActiveSessions []storage.Session
|
||||
RecentSessions []storage.Session
|
||||
Stats *storage.DashboardStats
|
||||
TopUsernames []storage.TopEntry
|
||||
TopPasswords []storage.TopEntry
|
||||
TopIPs []storage.TopEntry
|
||||
TopCountries []storage.TopEntry
|
||||
TopExecCommands []storage.TopEntry
|
||||
ActiveSessions []storage.Session
|
||||
RecentSessions []storage.Session
|
||||
}
|
||||
|
||||
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
@@ -46,6 +61,20 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
topCountries, err := s.store.GetTopCountries(ctx, 10)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get top countries", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topExecCommands, err := s.store.GetTopExecCommands(ctx, 10)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get top exec commands", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
activeSessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get active sessions", "err", err)
|
||||
@@ -61,22 +90,27 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
data := dashboardData{
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
ActiveSessions: activeSessions,
|
||||
RecentSessions: recentSessions,
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
TopCountries: topCountries,
|
||||
TopExecCommands: topExecCommands,
|
||||
ActiveSessions: activeSessions,
|
||||
RecentSessions: recentSessions,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.ExecuteTemplate(w, "layout.html", data); err != nil {
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "layout.html", data); err != nil {
|
||||
s.logger.Error("failed to render dashboard", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := s.store.GetDashboardStats(r.Context())
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get dashboard stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -84,13 +118,16 @@ func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.ExecuteTemplate(w, "stats", stats); err != nil {
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "stats", stats); err != nil {
|
||||
s.logger.Error("failed to render stats fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
|
||||
sessions, err := s.store.GetRecentSessions(r.Context(), 50, true)
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get active sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -98,7 +135,307 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.ExecuteTemplate(w, "active_sessions", sessions); err != nil {
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "active_sessions", sessions); err != nil {
|
||||
s.logger.Error("failed to render active sessions fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
f := parseDashboardFilter(r)
|
||||
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
|
||||
s.logger.Error("failed to render recent sessions fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type sessionDetailData struct {
|
||||
Session *storage.Session
|
||||
Logs []storage.SessionLog
|
||||
EventCount int
|
||||
}
|
||||
|
||||
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
session, err := s.store.GetSession(ctx, sessionID)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get session", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
logs, err := s.store.GetSessionLogs(ctx, sessionID)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get session logs", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
events, err := s.store.GetSessionEvents(ctx, sessionID)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get session events", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := sessionDetailData{
|
||||
Session: session,
|
||||
Logs: logs,
|
||||
EventCount: len(events),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.sessionDetail.ExecuteTemplate(w, "layout.html", data); err != nil {
|
||||
s.logger.Error("failed to render session detail", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type apiEvent struct {
|
||||
T int64 `json:"t"`
|
||||
D int `json:"d"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type apiEventsResponse struct {
|
||||
Events []apiEvent `json:"events"`
|
||||
}
|
||||
|
||||
// parseDateParam parses a "YYYY-MM-DD" query parameter into a *time.Time.
|
||||
func parseDateParam(r *http.Request, name string) *time.Time {
|
||||
v := r.URL.Query().Get(name)
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
t, err := time.Parse("2006-01-02", v)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// For "until" dates, set to end of day.
|
||||
if name == "until" {
|
||||
t = t.Add(24*time.Hour - time.Second)
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
|
||||
return storage.DashboardFilter{
|
||||
Since: parseDateParam(r, "since"),
|
||||
Until: parseDateParam(r, "until"),
|
||||
IP: r.URL.Query().Get("ip"),
|
||||
Country: r.URL.Query().Get("country"),
|
||||
Username: r.URL.Query().Get("username"),
|
||||
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
|
||||
SortBy: r.URL.Query().Get("sort"),
|
||||
}
|
||||
}
|
||||
|
||||
type apiTimeSeriesPoint struct {
|
||||
Date string `json:"date"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiAttemptsOverTimeResponse struct {
|
||||
Points []apiTimeSeriesPoint `json:"points"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Request) {
|
||||
days := 30
|
||||
if v := r.URL.Query().Get("days"); v != "" {
|
||||
if d, err := strconv.Atoi(v); err == nil && d > 0 && d <= 365 {
|
||||
days = d
|
||||
}
|
||||
}
|
||||
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get attempts over time", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiAttemptsOverTimeResponse{Points: make([]apiTimeSeriesPoint, len(points))}
|
||||
for i, p := range points {
|
||||
resp.Points[i] = apiTimeSeriesPoint{
|
||||
Date: p.Timestamp.Format("2006-01-02"),
|
||||
Count: p.Count,
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode attempts over time", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type apiHourlyCount struct {
|
||||
Hour int `json:"hour"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiHourlyPatternResponse struct {
|
||||
Hours []apiHourlyCount `json:"hours"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
counts, err := s.store.GetHourlyPattern(ctx, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get hourly pattern", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiHourlyPatternResponse{Hours: make([]apiHourlyCount, len(counts))}
|
||||
for i, c := range counts {
|
||||
resp.Hours[i] = apiHourlyCount{Hour: c.Hour, Count: c.Count}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode hourly pattern", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type apiCountryCount struct {
|
||||
Country string `json:"country"`
|
||||
Count int64 `json:"count"`
|
||||
}
|
||||
|
||||
type apiCountryStatsResponse struct {
|
||||
Countries []apiCountryCount `json:"countries"`
|
||||
}
|
||||
|
||||
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
counts, err := s.store.GetCountryStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get country stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiCountryStatsResponse{Countries: make([]apiCountryCount, len(counts))}
|
||||
for i, c := range counts {
|
||||
resp.Countries[i] = apiCountryCount{Country: c.Country, Count: c.Count}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode country stats", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
f := parseDashboardFilter(r)
|
||||
|
||||
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topUsernames, err := s.store.GetFilteredTopUsernames(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top usernames", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topPasswords, err := s.store.GetFilteredTopPasswords(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top passwords", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topIPs, err := s.store.GetFilteredTopIPs(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top IPs", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
topCountries, err := s.store.GetFilteredTopCountries(ctx, 10, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered top countries", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := dashboardData{
|
||||
Stats: stats,
|
||||
TopUsernames: topUsernames,
|
||||
TopPasswords: topPasswords,
|
||||
TopIPs: topIPs,
|
||||
TopCountries: topCountries,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "dashboard_content", data); err != nil {
|
||||
s.logger.Error("failed to render dashboard content fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
events, err := s.store.GetSessionEvents(ctx, sessionID)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get session events", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := apiEventsResponse{Events: make([]apiEvent, len(events))}
|
||||
var baseTime int64
|
||||
for i, e := range events {
|
||||
ms := e.Timestamp.UnixMilli()
|
||||
if i == 0 {
|
||||
baseTime = ms
|
||||
}
|
||||
resp.Events[i] = apiEvent{
|
||||
T: ms - baseTime,
|
||||
D: e.Direction,
|
||||
Data: base64.StdEncoding.EncodeToString(e.Data),
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.Error("failed to encode session events", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
14
internal/web/static/chart.min.js
vendored
Normal file
14
internal/web/static/chart.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
275
internal/web/static/dashboard.js
Normal file
275
internal/web/static/dashboard.js
Normal file
@@ -0,0 +1,275 @@
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
// Chart.js theme for Pico dark mode
|
||||
Chart.defaults.color = '#b0b0b8';
|
||||
Chart.defaults.borderColor = '#3a3a4a';
|
||||
|
||||
var attemptsChart = null;
|
||||
var hourlyChart = null;
|
||||
|
||||
function getFilterParams() {
|
||||
var form = document.getElementById('filter-form');
|
||||
if (!form) return '';
|
||||
var params = new URLSearchParams();
|
||||
var since = form.elements['since'].value;
|
||||
var until = form.elements['until'].value;
|
||||
if (since) params.set('since', since);
|
||||
if (until) params.set('until', until);
|
||||
var humanScore = form.elements['human_score'];
|
||||
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||
var sortBy = form.elements['sort'];
|
||||
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||
return params.toString();
|
||||
}
|
||||
|
||||
function initAttemptsChart() {
|
||||
var canvas = document.getElementById('chart-attempts');
|
||||
if (!canvas) return;
|
||||
var ctx = canvas.getContext('2d');
|
||||
|
||||
var qs = getFilterParams();
|
||||
var url = '/api/charts/attempts-over-time' + (qs ? '?' + qs : '');
|
||||
|
||||
fetch(url)
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
var labels = data.points.map(function(p) { return p.date; });
|
||||
var values = data.points.map(function(p) { return p.count; });
|
||||
|
||||
if (attemptsChart) {
|
||||
attemptsChart.data.labels = labels;
|
||||
attemptsChart.data.datasets[0].data = values;
|
||||
attemptsChart.update();
|
||||
return;
|
||||
}
|
||||
|
||||
attemptsChart = new Chart(ctx, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: labels,
|
||||
datasets: [{
|
||||
label: 'Attempts',
|
||||
data: values,
|
||||
borderColor: '#6366f1',
|
||||
backgroundColor: 'rgba(99, 102, 241, 0.1)',
|
||||
fill: true,
|
||||
tension: 0.3,
|
||||
pointRadius: 2
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: true,
|
||||
plugins: { legend: { display: false } },
|
||||
scales: {
|
||||
x: { grid: { display: false } },
|
||||
y: { beginAtZero: true }
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function initHourlyChart() {
|
||||
var canvas = document.getElementById('chart-hourly');
|
||||
if (!canvas) return;
|
||||
var ctx = canvas.getContext('2d');
|
||||
|
||||
var qs = getFilterParams();
|
||||
var url = '/api/charts/hourly-pattern' + (qs ? '?' + qs : '');
|
||||
|
||||
fetch(url)
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
// Fill all 24 hours, defaulting to 0
|
||||
var hourMap = {};
|
||||
data.hours.forEach(function(h) { hourMap[h.hour] = h.count; });
|
||||
var labels = [];
|
||||
var values = [];
|
||||
for (var i = 0; i < 24; i++) {
|
||||
labels.push(i + ':00');
|
||||
values.push(hourMap[i] || 0);
|
||||
}
|
||||
|
||||
if (hourlyChart) {
|
||||
hourlyChart.data.labels = labels;
|
||||
hourlyChart.data.datasets[0].data = values;
|
||||
hourlyChart.update();
|
||||
return;
|
||||
}
|
||||
|
||||
hourlyChart = new Chart(ctx, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
labels: labels,
|
||||
datasets: [{
|
||||
label: 'Attempts',
|
||||
data: values,
|
||||
backgroundColor: 'rgba(99, 102, 241, 0.6)',
|
||||
borderColor: '#6366f1',
|
||||
borderWidth: 1
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: true,
|
||||
plugins: { legend: { display: false } },
|
||||
scales: {
|
||||
x: { grid: { display: false } },
|
||||
y: { beginAtZero: true }
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function initWorldMap() {
|
||||
var container = document.getElementById('world-map');
|
||||
if (!container) return;
|
||||
|
||||
fetch('/static/world.svg')
|
||||
.then(function(r) { return r.text(); })
|
||||
.then(function(svgText) {
|
||||
container.innerHTML = svgText;
|
||||
|
||||
fetch('/api/charts/country-stats')
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
colorMap(container, data.countries);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function colorMap(container, countries) {
|
||||
if (!countries || countries.length === 0) return;
|
||||
|
||||
var maxCount = countries[0].count; // already sorted DESC
|
||||
var logMax = Math.log(maxCount + 1);
|
||||
|
||||
// Build lookup
|
||||
var lookup = {};
|
||||
countries.forEach(function(c) {
|
||||
lookup[c.country.toLowerCase()] = c.count;
|
||||
});
|
||||
|
||||
// Create tooltip element
|
||||
var tooltip = document.createElement('div');
|
||||
tooltip.id = 'map-tooltip';
|
||||
tooltip.style.cssText = 'position:fixed;display:none;background:#1a1a2e;color:#e0e0e8;padding:4px 8px;border-radius:4px;font-size:13px;pointer-events:none;z-index:1000;border:1px solid #3a3a4a;';
|
||||
document.body.appendChild(tooltip);
|
||||
|
||||
var svg = container.querySelector('svg');
|
||||
if (!svg) return;
|
||||
|
||||
// Remove SVG title to prevent browser native tooltip
|
||||
var svgTitle = svg.querySelector('title');
|
||||
if (svgTitle) svgTitle.remove();
|
||||
|
||||
// Select both <path id="xx"> and <g id="xx"> country elements
|
||||
var elements = svg.querySelectorAll('path[id], g[id]');
|
||||
elements.forEach(function(el) {
|
||||
var id = el.id.toLowerCase();
|
||||
if (id.charAt(0) === '_') return; // skip non-country paths
|
||||
|
||||
var count = lookup[id];
|
||||
if (count) {
|
||||
var intensity = Math.log(count + 1) / logMax;
|
||||
var r = Math.round(30 + intensity * 69); // 30 -> 99
|
||||
var g = Math.round(30 + intensity * 72); // 30 -> 102
|
||||
var b = Math.round(62 + intensity * 179); // 62 -> 241
|
||||
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
|
||||
// For <g> elements, color child paths; for <path>, color directly
|
||||
if (el.tagName.toLowerCase() === 'g') {
|
||||
el.querySelectorAll('path').forEach(function(p) {
|
||||
p.style.fill = color;
|
||||
});
|
||||
} else {
|
||||
el.style.fill = color;
|
||||
}
|
||||
}
|
||||
|
||||
el.addEventListener('mouseenter', function(e) {
|
||||
var cc = id.toUpperCase();
|
||||
var n = lookup[id] || 0;
|
||||
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
|
||||
tooltip.style.display = 'block';
|
||||
});
|
||||
|
||||
el.addEventListener('mousemove', function(e) {
|
||||
tooltip.style.left = (e.clientX + 12) + 'px';
|
||||
tooltip.style.top = (e.clientY - 10) + 'px';
|
||||
});
|
||||
|
||||
el.addEventListener('mouseleave', function() {
|
||||
tooltip.style.display = 'none';
|
||||
});
|
||||
|
||||
el.addEventListener('click', function() {
|
||||
var input = document.querySelector('#filter-form input[name="country"]');
|
||||
if (input) {
|
||||
input.value = id.toUpperCase();
|
||||
applyFilters();
|
||||
}
|
||||
});
|
||||
|
||||
el.style.cursor = 'pointer';
|
||||
});
|
||||
}
|
||||
|
||||
function applyFilters() {
|
||||
// Re-fetch charts with filter params
|
||||
initAttemptsChart();
|
||||
initHourlyChart();
|
||||
|
||||
// Re-fetch dashboard content via htmx
|
||||
var form = document.getElementById('filter-form');
|
||||
if (!form) return;
|
||||
|
||||
var params = new URLSearchParams();
|
||||
['since', 'until', 'ip', 'country', 'username'].forEach(function(name) {
|
||||
var val = form.elements[name].value;
|
||||
if (val) params.set(name, val);
|
||||
});
|
||||
|
||||
var humanScore = form.elements['human_score'];
|
||||
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||
var sortBy = form.elements['sort'];
|
||||
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||
|
||||
var target = document.getElementById('dashboard-content');
|
||||
if (target) {
|
||||
var url = '/fragments/dashboard-content?' + params.toString();
|
||||
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
|
||||
}
|
||||
|
||||
// Server-side filter for recent sessions table
|
||||
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
|
||||
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
|
||||
}
|
||||
|
||||
window.clearFilters = function() {
|
||||
var form = document.getElementById('filter-form');
|
||||
if (form) {
|
||||
form.reset();
|
||||
applyFilters();
|
||||
}
|
||||
};
|
||||
|
||||
window.applyFilters = applyFilters;
|
||||
|
||||
// Initialize on DOM ready
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
initAttemptsChart();
|
||||
initHourlyChart();
|
||||
initWorldMap();
|
||||
|
||||
var form = document.getElementById('filter-form');
|
||||
if (form) {
|
||||
form.addEventListener('submit', function(e) {
|
||||
e.preventDefault();
|
||||
applyFilters();
|
||||
});
|
||||
}
|
||||
});
|
||||
})();
|
||||
83
internal/web/static/replay.js
Normal file
83
internal/web/static/replay.js
Normal file
@@ -0,0 +1,83 @@
|
||||
// ReplayPlayer drives xterm.js playback of recorded session events.
|
||||
function ReplayPlayer(containerId, sessionId) {
|
||||
this.terminal = new Terminal({
|
||||
cols: 80,
|
||||
rows: 24,
|
||||
convertEol: true,
|
||||
disableStdin: true,
|
||||
theme: {
|
||||
background: '#000000',
|
||||
foreground: '#ffffff'
|
||||
}
|
||||
});
|
||||
this.terminal.open(document.getElementById(containerId));
|
||||
|
||||
this.sessionId = sessionId;
|
||||
this.events = [];
|
||||
this.index = 0;
|
||||
this.speed = 1;
|
||||
this.timers = [];
|
||||
this.playing = false;
|
||||
|
||||
// Fetch events immediately.
|
||||
var self = this;
|
||||
fetch('/api/sessions/' + sessionId + '/events')
|
||||
.then(function(r) { return r.json(); })
|
||||
.then(function(data) {
|
||||
self.events = data.events || [];
|
||||
});
|
||||
}
|
||||
|
||||
ReplayPlayer.prototype.play = function() {
|
||||
if (this.playing) return;
|
||||
if (this.events.length === 0) return;
|
||||
this.playing = true;
|
||||
this._schedule();
|
||||
};
|
||||
|
||||
ReplayPlayer.prototype.pause = function() {
|
||||
this.playing = false;
|
||||
for (var i = 0; i < this.timers.length; i++) {
|
||||
clearTimeout(this.timers[i]);
|
||||
}
|
||||
this.timers = [];
|
||||
};
|
||||
|
||||
ReplayPlayer.prototype.reset = function() {
|
||||
this.pause();
|
||||
this.index = 0;
|
||||
this.terminal.reset();
|
||||
};
|
||||
|
||||
ReplayPlayer.prototype.setSpeed = function(speed) {
|
||||
this.speed = speed;
|
||||
if (this.playing) {
|
||||
this.pause();
|
||||
this.play();
|
||||
}
|
||||
};
|
||||
|
||||
ReplayPlayer.prototype._schedule = function() {
|
||||
var self = this;
|
||||
var baseT = this.index < this.events.length ? this.events[this.index].t : 0;
|
||||
|
||||
for (var i = this.index; i < this.events.length; i++) {
|
||||
(function(idx) {
|
||||
var evt = self.events[idx];
|
||||
var delay = (evt.t - baseT) / self.speed;
|
||||
var timer = setTimeout(function() {
|
||||
if (!self.playing) return;
|
||||
// Only write output events (d=1) to terminal; input is echoed in output.
|
||||
if (evt.d === 1) {
|
||||
var raw = atob(evt.data);
|
||||
self.terminal.write(raw);
|
||||
}
|
||||
self.index = idx + 1;
|
||||
if (self.index >= self.events.length) {
|
||||
self.playing = false;
|
||||
}
|
||||
}, delay);
|
||||
self.timers.push(timer);
|
||||
})(i);
|
||||
}
|
||||
};
|
||||
1
internal/web/static/world.svg
Normal file
1
internal/web/static/world.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 55 KiB |
209
internal/web/static/xterm.css
Normal file
209
internal/web/static/xterm.css
Normal file
@@ -0,0 +1,209 @@
|
||||
/**
|
||||
* Copyright (c) 2014 The xterm.js authors. All rights reserved.
|
||||
* Copyright (c) 2012-2013, Christopher Jeffrey (MIT License)
|
||||
* https://github.com/chjj/term.js
|
||||
* @license MIT
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
* THE SOFTWARE.
|
||||
*
|
||||
* Originally forked from (with the author's permission):
|
||||
* Fabrice Bellard's javascript vt100 for jslinux:
|
||||
* http://bellard.org/jslinux/
|
||||
* Copyright (c) 2011 Fabrice Bellard
|
||||
* The original design remains. The terminal itself
|
||||
* has been extended to include xterm CSI codes, among
|
||||
* other features.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Default styles for xterm.js
|
||||
*/
|
||||
|
||||
.xterm {
|
||||
cursor: text;
|
||||
position: relative;
|
||||
user-select: none;
|
||||
-ms-user-select: none;
|
||||
-webkit-user-select: none;
|
||||
}
|
||||
|
||||
.xterm.focus,
|
||||
.xterm:focus {
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.xterm .xterm-helpers {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
/**
|
||||
* The z-index of the helpers must be higher than the canvases in order for
|
||||
* IMEs to appear on top.
|
||||
*/
|
||||
z-index: 5;
|
||||
}
|
||||
|
||||
.xterm .xterm-helper-textarea {
|
||||
padding: 0;
|
||||
border: 0;
|
||||
margin: 0;
|
||||
/* Move textarea out of the screen to the far left, so that the cursor is not visible */
|
||||
position: absolute;
|
||||
opacity: 0;
|
||||
left: -9999em;
|
||||
top: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
z-index: -5;
|
||||
/** Prevent wrapping so the IME appears against the textarea at the correct position */
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
resize: none;
|
||||
}
|
||||
|
||||
.xterm .composition-view {
|
||||
/* TODO: Composition position got messed up somewhere */
|
||||
background: #000;
|
||||
color: #FFF;
|
||||
display: none;
|
||||
position: absolute;
|
||||
white-space: nowrap;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.xterm .composition-view.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.xterm .xterm-viewport {
|
||||
/* On OS X this is required in order for the scroll bar to appear fully opaque */
|
||||
background-color: #000;
|
||||
overflow-y: scroll;
|
||||
cursor: default;
|
||||
position: absolute;
|
||||
right: 0;
|
||||
left: 0;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
}
|
||||
|
||||
.xterm .xterm-screen {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.xterm .xterm-screen canvas {
|
||||
position: absolute;
|
||||
left: 0;
|
||||
top: 0;
|
||||
}
|
||||
|
||||
.xterm .xterm-scroll-area {
|
||||
visibility: hidden;
|
||||
}
|
||||
|
||||
.xterm-char-measure-element {
|
||||
display: inline-block;
|
||||
visibility: hidden;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: -9999em;
|
||||
line-height: normal;
|
||||
}
|
||||
|
||||
.xterm.enable-mouse-events {
|
||||
/* When mouse events are enabled (eg. tmux), revert to the standard pointer cursor */
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.xterm.xterm-cursor-pointer,
|
||||
.xterm .xterm-cursor-pointer {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.xterm.column-select.focus {
|
||||
/* Column selection mode */
|
||||
cursor: crosshair;
|
||||
}
|
||||
|
||||
.xterm .xterm-accessibility,
|
||||
.xterm .xterm-message {
|
||||
position: absolute;
|
||||
left: 0;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
right: 0;
|
||||
z-index: 10;
|
||||
color: transparent;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.xterm .live-region {
|
||||
position: absolute;
|
||||
left: -9999px;
|
||||
width: 1px;
|
||||
height: 1px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.xterm-dim {
|
||||
/* Dim should not apply to background, so the opacity of the foreground color is applied
|
||||
* explicitly in the generated class and reset to 1 here */
|
||||
opacity: 1 !important;
|
||||
}
|
||||
|
||||
.xterm-underline-1 { text-decoration: underline; }
|
||||
.xterm-underline-2 { text-decoration: double underline; }
|
||||
.xterm-underline-3 { text-decoration: wavy underline; }
|
||||
.xterm-underline-4 { text-decoration: dotted underline; }
|
||||
.xterm-underline-5 { text-decoration: dashed underline; }
|
||||
|
||||
.xterm-overline {
|
||||
text-decoration: overline;
|
||||
}
|
||||
|
||||
.xterm-overline.xterm-underline-1 { text-decoration: overline underline; }
|
||||
.xterm-overline.xterm-underline-2 { text-decoration: overline double underline; }
|
||||
.xterm-overline.xterm-underline-3 { text-decoration: overline wavy underline; }
|
||||
.xterm-overline.xterm-underline-4 { text-decoration: overline dotted underline; }
|
||||
.xterm-overline.xterm-underline-5 { text-decoration: overline dashed underline; }
|
||||
|
||||
.xterm-strikethrough {
|
||||
text-decoration: line-through;
|
||||
}
|
||||
|
||||
.xterm-screen .xterm-decoration-container .xterm-decoration {
|
||||
z-index: 6;
|
||||
position: absolute;
|
||||
}
|
||||
|
||||
.xterm-screen .xterm-decoration-container .xterm-decoration.xterm-decoration-top-layer {
|
||||
z-index: 7;
|
||||
}
|
||||
|
||||
.xterm-decoration-overview-ruler {
|
||||
z-index: 8;
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.xterm-decoration-top {
|
||||
z-index: 2;
|
||||
position: relative;
|
||||
}
|
||||
8
internal/web/static/xterm.min.js
vendored
Normal file
8
internal/web/static/xterm.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -2,6 +2,7 @@ package web
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"time"
|
||||
)
|
||||
@@ -9,8 +10,13 @@ import (
|
||||
//go:embed templates/*.html templates/fragments/*.html
|
||||
var templateFS embed.FS
|
||||
|
||||
func loadTemplates() (*template.Template, error) {
|
||||
funcMap := template.FuncMap{
|
||||
type templateSet struct {
|
||||
dashboard *template.Template
|
||||
sessionDetail *template.Template
|
||||
}
|
||||
|
||||
func templateFuncMap() template.FuncMap {
|
||||
return template.FuncMap{
|
||||
"formatTime": func(t time.Time) string {
|
||||
return t.Format("2006-01-02 15:04:05 UTC")
|
||||
},
|
||||
@@ -26,12 +32,71 @@ func loadTemplates() (*template.Template, error) {
|
||||
}
|
||||
return *t
|
||||
},
|
||||
"derefFloat": func(f *float64) float64 {
|
||||
if f == nil {
|
||||
return 0
|
||||
}
|
||||
return *f
|
||||
},
|
||||
"formatScore": func(f *float64) string {
|
||||
if f == nil {
|
||||
return "-"
|
||||
}
|
||||
return fmt.Sprintf("%.0f%%", *f*100)
|
||||
},
|
||||
"derefString": func(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
},
|
||||
"truncateCommand": func(s string) string {
|
||||
if len(s) > 50 {
|
||||
return s[:50] + "..."
|
||||
}
|
||||
return s
|
||||
},
|
||||
"formatBytes": func(b int64) string {
|
||||
const (
|
||||
kb = 1024
|
||||
mb = 1024 * kb
|
||||
)
|
||||
switch {
|
||||
case b >= mb:
|
||||
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
|
||||
case b >= kb:
|
||||
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return template.New("").Funcs(funcMap).ParseFS(templateFS,
|
||||
func loadTemplates() (*templateSet, error) {
|
||||
funcMap := templateFuncMap()
|
||||
|
||||
dashboard, err := template.New("").Funcs(funcMap).ParseFS(templateFS,
|
||||
"templates/layout.html",
|
||||
"templates/dashboard.html",
|
||||
"templates/fragments/stats.html",
|
||||
"templates/fragments/active_sessions.html",
|
||||
"templates/fragments/recent_sessions.html",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing dashboard templates: %w", err)
|
||||
}
|
||||
|
||||
sessionDetail, err := template.New("").Funcs(funcMap).ParseFS(templateFS,
|
||||
"templates/layout.html",
|
||||
"templates/session_detail.html",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing session detail templates: %w", err)
|
||||
}
|
||||
|
||||
return &templateSet{
|
||||
dashboard: dashboard,
|
||||
sessionDetail: sessionDetail,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,86 @@
|
||||
{{template "stats" .Stats}}
|
||||
</section>
|
||||
|
||||
<details>
|
||||
<summary>Filters</summary>
|
||||
<form id="filter-form">
|
||||
<div class="grid">
|
||||
<label>Since <input type="date" name="since"></label>
|
||||
<label>Until <input type="date" name="until"></label>
|
||||
<label>IP <input type="text" name="ip" placeholder="10.0.0.1"></label>
|
||||
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
|
||||
<label>Username <input type="text" name="username" placeholder="root"></label>
|
||||
</div>
|
||||
<div class="grid">
|
||||
<label><input type="checkbox" name="human_score" value="1"> Human score > 0</label>
|
||||
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
|
||||
</div>
|
||||
<button type="submit">Apply</button>
|
||||
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
|
||||
</form>
|
||||
</details>
|
||||
|
||||
<section>
|
||||
<h3>Attack Trends</h3>
|
||||
<div class="grid">
|
||||
<article>
|
||||
<header>Attempts Over Time</header>
|
||||
<canvas id="chart-attempts"></canvas>
|
||||
</article>
|
||||
<article>
|
||||
<header>Hourly Pattern (UTC)</header>
|
||||
<canvas id="chart-hourly"></canvas>
|
||||
</article>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Attack Origins</h3>
|
||||
<article>
|
||||
<div id="world-map"></div>
|
||||
</article>
|
||||
</section>
|
||||
|
||||
<div id="dashboard-content">
|
||||
{{template "dashboard_content" .}}
|
||||
</div>
|
||||
|
||||
<section>
|
||||
<h3>Active Sessions</h3>
|
||||
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
|
||||
{{template "active_sessions" .ActiveSessions}}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Recent Sessions</h3>
|
||||
<table id="recent-sessions-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>IP</th>
|
||||
<th>Country</th>
|
||||
<th>Username</th>
|
||||
<th>Type</th>
|
||||
<th>Score</th>
|
||||
<th>Input</th>
|
||||
<th>Connected</th>
|
||||
<th>Disconnected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{template "recent_sessions" .RecentSessions}}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
{{end}}
|
||||
|
||||
{{define "scripts"}}
|
||||
<script src="/static/chart.min.js"></script>
|
||||
<script src="/static/dashboard.js"></script>
|
||||
{{end}}
|
||||
|
||||
{{define "dashboard_content"}}
|
||||
<section>
|
||||
<h3>Top Credentials & IPs</h3>
|
||||
<div class="top-grid">
|
||||
@@ -40,10 +120,25 @@
|
||||
<header>Top IPs</header>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>IP</th><th>Attempts</th></tr>
|
||||
<tr><th>IP</th><th>Country</th><th>Attempts</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .TopIPs}}
|
||||
<tr><td>{{.Value}}</td><td>{{.Country}}</td><td>{{.Count}}</td></tr>
|
||||
{{else}}
|
||||
<tr><td colspan="3">No data</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
<article>
|
||||
<header>Top Countries</header>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Country</th><th>Attempts</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .TopCountries}}
|
||||
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
|
||||
{{else}}
|
||||
<tr><td colspan="2">No data</td></tr>
|
||||
@@ -51,43 +146,21 @@
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
<article>
|
||||
<header>Top Exec Commands</header>
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>Command</th><th>Count</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .TopExecCommands}}
|
||||
<tr><td><code>{{truncateCommand .Value}}</code></td><td>{{.Count}}</td></tr>
|
||||
{{else}}
|
||||
<tr><td colspan="2">No data</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Active Sessions</h3>
|
||||
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
|
||||
{{template "active_sessions" .ActiveSessions}}
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h3>Recent Sessions</h3>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>IP</th>
|
||||
<th>Username</th>
|
||||
<th>Shell</th>
|
||||
<th>Connected</th>
|
||||
<th>Disconnected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .RecentSessions}}
|
||||
<tr>
|
||||
<td><code>{{truncateID .ID}}</code></td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{.ShellName}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="6">No sessions</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
{{end}}
|
||||
|
||||
@@ -4,22 +4,28 @@
|
||||
<tr>
|
||||
<th>ID</th>
|
||||
<th>IP</th>
|
||||
<th>Country</th>
|
||||
<th>Username</th>
|
||||
<th>Shell</th>
|
||||
<th>Type</th>
|
||||
<th>Score</th>
|
||||
<th>Input</th>
|
||||
<th>Connected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .}}
|
||||
<tr>
|
||||
<td><code>{{truncateID .ID}}</code></td>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{.ShellName}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatBytes .InputBytes}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="5">No active sessions</td></tr>
|
||||
<tr><td colspan="8">No active sessions</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
17
internal/web/templates/fragments/recent_sessions.html
Normal file
17
internal/web/templates/fragments/recent_sessions.html
Normal file
@@ -0,0 +1,17 @@
|
||||
{{define "recent_sessions"}}
|
||||
{{range .}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatBytes .InputBytes}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="9">No sessions</td></tr>
|
||||
{{end}}
|
||||
{{end}}
|
||||
@@ -29,9 +29,16 @@
|
||||
}
|
||||
.top-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
|
||||
grid-template-columns: repeat(auto-fit, minmax(380px, 1fr));
|
||||
gap: 1rem;
|
||||
}
|
||||
.top-grid article {
|
||||
overflow: hidden;
|
||||
min-width: 0;
|
||||
}
|
||||
#world-map svg { width: 100%; height: auto; }
|
||||
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
|
||||
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
|
||||
nav h1 {
|
||||
margin: 0;
|
||||
}
|
||||
@@ -52,5 +59,6 @@
|
||||
<main class="container">
|
||||
{{block "content" .}}{{end}}
|
||||
</main>
|
||||
{{block "scripts" .}}{{end}}
|
||||
</body>
|
||||
</html>
|
||||
|
||||
81
internal/web/templates/session_detail.html
Normal file
81
internal/web/templates/session_detail.html
Normal file
@@ -0,0 +1,81 @@
|
||||
{{define "content"}}
|
||||
<section>
|
||||
<h3>Session {{.Session.ID}}</h3>
|
||||
<div class="top-grid">
|
||||
<article>
|
||||
<header>Session Info</header>
|
||||
<table>
|
||||
<tbody>
|
||||
<tr><td><strong>IP</strong></td><td>{{.Session.IP}}</td></tr>
|
||||
<tr><td><strong>Country</strong></td><td>{{.Session.Country}}</td></tr>
|
||||
<tr><td><strong>Username</strong></td><td>{{.Session.Username}}</td></tr>
|
||||
<tr><td><strong>Shell</strong></td><td>{{.Session.ShellName}}</td></tr>
|
||||
{{if .Session.ExecCommand}}<tr><td><strong>Exec Command</strong></td><td><code>{{derefString .Session.ExecCommand}}</code></td></tr>{{end}}
|
||||
<tr><td><strong>Score</strong></td><td>{{formatScore .Session.HumanScore}}</td></tr>
|
||||
<tr><td><strong>Connected</strong></td><td>{{formatTime .Session.ConnectedAt}}</td></tr>
|
||||
<tr>
|
||||
<td><strong>Disconnected</strong></td>
|
||||
<td>{{if .Session.DisconnectedAt}}{{formatTime (derefTime .Session.DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</article>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
{{if gt .EventCount 0}}
|
||||
<section>
|
||||
<h3>Session Replay</h3>
|
||||
<div style="margin-bottom: 1rem;">
|
||||
<button id="btn-play" onclick="replayPlayer.play()">Play</button>
|
||||
<button id="btn-pause" onclick="replayPlayer.pause()">Pause</button>
|
||||
<button id="btn-reset" onclick="replayPlayer.reset()">Reset</button>
|
||||
<label for="speed-select" style="margin-left: 1rem;">Speed:</label>
|
||||
<select id="speed-select" onchange="replayPlayer.setSpeed(parseFloat(this.value))">
|
||||
<option value="0.5">0.5x</option>
|
||||
<option value="1" selected>1x</option>
|
||||
<option value="2">2x</option>
|
||||
<option value="5">5x</option>
|
||||
<option value="10">10x</option>
|
||||
</select>
|
||||
</div>
|
||||
<div id="terminal" style="background: #000; padding: 4px; border-radius: 4px;"></div>
|
||||
</section>
|
||||
<link rel="stylesheet" href="/static/xterm.css">
|
||||
<script src="/static/xterm.min.js"></script>
|
||||
<script src="/static/replay.js"></script>
|
||||
<script>
|
||||
var replayPlayer = new ReplayPlayer("terminal", "{{.Session.ID}}");
|
||||
</script>
|
||||
{{else}}
|
||||
<section>
|
||||
<p>No recorded events for this session.</p>
|
||||
</section>
|
||||
{{end}}
|
||||
|
||||
{{if .Logs}}
|
||||
<section>
|
||||
<h3>Command Log</h3>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Time</th>
|
||||
<th>Input</th>
|
||||
<th>Output</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .Logs}}
|
||||
<tr>
|
||||
<td>{{formatTime .Timestamp}}</td>
|
||||
<td><code>{{.Input}}</code></td>
|
||||
<td><pre style="margin:0; white-space:pre-wrap;">{{.Output}}</pre></td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
{{end}}
|
||||
|
||||
<p><a href="/">← Back to dashboard</a></p>
|
||||
{{end}}
|
||||
@@ -1,12 +1,13 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"embed"
|
||||
"html/template"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
@@ -17,11 +18,13 @@ type Server struct {
|
||||
store storage.Store
|
||||
logger *slog.Logger
|
||||
mux *http.ServeMux
|
||||
tmpl *template.Template
|
||||
tmpl *templateSet
|
||||
}
|
||||
|
||||
// NewServer creates a new web Server with routes registered.
|
||||
func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
// If metricsHandler is non-nil, a /metrics endpoint is registered.
|
||||
// If metricsToken is non-empty, the metrics endpoint requires Bearer token auth.
|
||||
func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Handler, metricsToken string) (*Server, error) {
|
||||
tmpl, err := loadTemplates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -35,9 +38,24 @@ func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
}
|
||||
|
||||
s.mux.Handle("GET /static/", http.FileServerFS(staticFS))
|
||||
s.mux.HandleFunc("GET /sessions/{id}", s.handleSessionDetail)
|
||||
s.mux.HandleFunc("GET /api/sessions/{id}/events", s.handleAPISessionEvents)
|
||||
s.mux.HandleFunc("GET /api/charts/attempts-over-time", s.handleAPIAttemptsOverTime)
|
||||
s.mux.HandleFunc("GET /api/charts/hourly-pattern", s.handleAPIHourlyPattern)
|
||||
s.mux.HandleFunc("GET /api/charts/country-stats", s.handleAPICountryStats)
|
||||
s.mux.HandleFunc("GET /", s.handleDashboard)
|
||||
s.mux.HandleFunc("GET /fragments/stats", s.handleFragmentStats)
|
||||
s.mux.HandleFunc("GET /fragments/active-sessions", s.handleFragmentActiveSessions)
|
||||
s.mux.HandleFunc("GET /fragments/dashboard-content", s.handleFragmentDashboardContent)
|
||||
s.mux.HandleFunc("GET /fragments/recent-sessions", s.handleFragmentRecentSessions)
|
||||
|
||||
if metricsHandler != nil {
|
||||
h := metricsHandler
|
||||
if metricsToken != "" {
|
||||
h = requireBearerToken(metricsToken, h)
|
||||
}
|
||||
s.mux.Handle("GET /metrics", h)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@@ -46,3 +64,20 @@ func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// requireBearerToken wraps a handler to require a valid Bearer token.
|
||||
func requireBearerToken(token string, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(auth, "Bearer ") {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
provided := auth[len("Bearer "):]
|
||||
if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,20 +2,23 @@ package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T) *Server {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
logger := slog.Default()
|
||||
srv, err := NewServer(store, logger)
|
||||
srv, err := NewServer(store, logger, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating server: %v", err)
|
||||
}
|
||||
@@ -27,34 +30,58 @@ func newSeededTestServer(t *testing.T) *Server {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2"); err != nil {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
|
||||
t.Fatalf("seeding attempt: %v", err)
|
||||
}
|
||||
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash"); err != nil {
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash"); err != nil {
|
||||
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
|
||||
logger := slog.Default()
|
||||
srv, err := NewServer(store, logger)
|
||||
srv, err := NewServer(store, logger, nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating server: %v", err)
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestDbContextNotCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
dbCtx, dbCancel := dbContext(req)
|
||||
defer dbCancel()
|
||||
|
||||
// Cancel the original request context.
|
||||
cancel()
|
||||
|
||||
// The DB context should still be usable.
|
||||
select {
|
||||
case <-dbCtx.Done():
|
||||
t.Fatal("dbContext should not be canceled when request context is canceled")
|
||||
default:
|
||||
}
|
||||
|
||||
// Verify the DB context has a deadline (from the timeout).
|
||||
if _, ok := dbCtx.Deadline(); !ok {
|
||||
t.Error("dbContext should have a deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardHandler(t *testing.T) {
|
||||
t.Run("empty store", func(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
@@ -73,7 +100,7 @@ func TestDashboardHandler(t *testing.T) {
|
||||
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
@@ -93,7 +120,7 @@ func TestDashboardHandler(t *testing.T) {
|
||||
|
||||
func TestFragmentStats(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest("GET", "/fragments/stats", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/fragments/stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
@@ -113,7 +140,7 @@ func TestFragmentStats(t *testing.T) {
|
||||
|
||||
func TestFragmentActiveSessions(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest("GET", "/fragments/active-sessions", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, "/fragments/active-sessions", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
@@ -131,6 +158,396 @@ func TestFragmentActiveSessions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionDetailHandler(t *testing.T) {
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/sessions/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/sessions/"+id, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "10.0.0.1") {
|
||||
t.Error("response should contain IP")
|
||||
}
|
||||
if !strings.Contains(body, "root") {
|
||||
t.Error("response should contain username")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPISessionEvents(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
events := []storage.SessionEvent{
|
||||
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")},
|
||||
{SessionID: id, Timestamp: now.Add(500 * time.Millisecond), Direction: 1, Data: []byte("file1\n")},
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/"+id+"/events", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
ct := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(ct, "application/json") {
|
||||
t.Errorf("Content-Type = %q, want application/json", ct)
|
||||
}
|
||||
|
||||
var resp apiEventsResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
if len(resp.Events) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(resp.Events))
|
||||
}
|
||||
// First event should have t=0 (relative).
|
||||
if resp.Events[0].T != 0 {
|
||||
t.Errorf("events[0].T = %d, want 0", resp.Events[0].T)
|
||||
}
|
||||
// Second event should have t=500 (500ms later).
|
||||
if resp.Events[1].T != 500 {
|
||||
t.Errorf("events[1].T = %d, want 500", resp.Events[1].T)
|
||||
}
|
||||
if resp.Events[0].D != 0 {
|
||||
t.Errorf("events[0].D = %d, want 0", resp.Events[0].D)
|
||||
}
|
||||
if resp.Events[1].D != 1 {
|
||||
t.Errorf("events[1].D = %d, want 1", resp.Events[1].D)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetricsEndpoint(t *testing.T) {
|
||||
t.Run("enabled", func(t *testing.T) {
|
||||
m := metrics.New("test")
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, `oubliette_build_info{version="test"} 1`) {
|
||||
t.Errorf("response should contain build_info metric, got:\n%s", body)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disabled", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
// Without a metrics handler, /metrics falls through to the dashboard.
|
||||
body := w.Body.String()
|
||||
if strings.Contains(body, "oubliette_build_info") {
|
||||
t.Error("response should not contain prometheus metrics when disabled")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMetricsBearerToken(t *testing.T) {
|
||||
m := metrics.New("test")
|
||||
|
||||
t.Run("valid token", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
req.Header.Set("Authorization", "Bearer secret")
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrong token", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
req.Header.Set("Authorization", "Bearer wrong")
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing header", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no token configured", func(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
srv, err := NewServer(store, slog.Default(), m.Handler(), "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTruncateCommand(t *testing.T) {
|
||||
funcMap := templateFuncMap()
|
||||
fn := funcMap["truncateCommand"].(func(string) string)
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"short", "short"},
|
||||
{"exactly fifty characters long! that is what it i.", "exactly fifty characters long! that is what it i."},
|
||||
{"this string is definitely longer than fifty characters and should be truncated", "this string is definitely longer than fifty charac..."},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := fn(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("truncateCommand(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardExecCommands(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("creating session: %v", err)
|
||||
}
|
||||
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
|
||||
t.Fatalf("setting exec command: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, "Top Exec Commands") {
|
||||
t.Error("response should contain 'Top Exec Commands'")
|
||||
}
|
||||
if !strings.Contains(body, "uname -a") {
|
||||
t.Error("response should contain exec command 'uname -a'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIAttemptsOverTime(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/attempts-over-time", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
ct := w.Header().Get("Content-Type")
|
||||
if !strings.Contains(ct, "application/json") {
|
||||
t.Errorf("Content-Type = %q, want application/json", ct)
|
||||
}
|
||||
|
||||
var resp apiAttemptsOverTimeResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
// Seeded data inserted today -> at least 1 point.
|
||||
if len(resp.Points) == 0 {
|
||||
t.Error("expected at least one data point")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIHourlyPattern(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/hourly-pattern", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp apiHourlyPatternResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
if len(resp.Hours) == 0 {
|
||||
t.Error("expected at least one hourly data point")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPICountryStats(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/charts/country-stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
|
||||
var resp apiCountryStatsResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decoding response: %v", err)
|
||||
}
|
||||
if len(resp.Countries) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(resp.Countries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFragmentDashboardContent(t *testing.T) {
|
||||
srv := newSeededTestServer(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if strings.Contains(body, "<!DOCTYPE html>") {
|
||||
t.Error("dashboard content fragment should not contain full HTML document")
|
||||
}
|
||||
if !strings.Contains(body, "Top Usernames") {
|
||||
t.Error("dashboard content fragment should contain 'Top Usernames'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFragmentDashboardContentWithFilter(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
ctx := context.Background()
|
||||
for range 5 {
|
||||
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
}
|
||||
for range 3 {
|
||||
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
|
||||
t.Fatalf("seeding: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
srv, err := NewServer(store, slog.Default(), nil, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewServer: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content?country=CN", nil)
|
||||
w := httptest.NewRecorder()
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", w.Code)
|
||||
}
|
||||
body := w.Body.String()
|
||||
// When filtered by CN, should show root but not admin.
|
||||
if !strings.Contains(body, "root") {
|
||||
t.Error("response should contain 'root' when filtered by CN")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticAssets(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
|
||||
@@ -140,11 +557,14 @@ func TestStaticAssets(t *testing.T) {
|
||||
}{
|
||||
{"/static/pico.min.css", "text/css"},
|
||||
{"/static/htmx.min.js", "text/javascript"},
|
||||
{"/static/chart.min.js", "text/javascript"},
|
||||
{"/static/dashboard.js", "text/javascript"},
|
||||
{"/static/world.svg", "image/svg+xml"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", tt.path, nil)
|
||||
req := httptest.NewRequest(http.MethodGet, tt.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
srv.ServeHTTP(w, req)
|
||||
|
||||
@@ -18,6 +18,32 @@ password = "toor"
|
||||
username = "admin"
|
||||
password = "admin"
|
||||
|
||||
# Route specific credentials to a named shell (optional).
|
||||
# [[auth.static_credentials]]
|
||||
# username = "samsung"
|
||||
# password = "fridge"
|
||||
# shell = "fridge"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "teller"
|
||||
# password = "banking"
|
||||
# shell = "banking"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "admin"
|
||||
# password = "cisco"
|
||||
# shell = "cisco"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "irobot"
|
||||
# password = "roomba"
|
||||
# shell = "roomba"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "player"
|
||||
# password = "tetris"
|
||||
# shell = "tetris"
|
||||
|
||||
[storage]
|
||||
db_path = "oubliette.db"
|
||||
retention_days = 90
|
||||
@@ -26,8 +52,51 @@ retention_interval = "1h"
|
||||
# [web]
|
||||
# enabled = true
|
||||
# listen_addr = ":8080"
|
||||
# metrics_enabled = true
|
||||
# metrics_token = "" # bearer token for /metrics; empty = no auth
|
||||
|
||||
[shell]
|
||||
hostname = "ubuntu-server"
|
||||
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
|
||||
# fake_user = "" # override username in prompt; empty = use authenticated user
|
||||
|
||||
# Map usernames to specific shells (regardless of how auth succeeded).
|
||||
# Credential-specific shell overrides take priority over username routes.
|
||||
# [shell.username_routes]
|
||||
# postgres = "psql"
|
||||
# admin = "bash"
|
||||
|
||||
# Per-shell configuration (optional).
|
||||
# [shell.banking]
|
||||
# bank_name = "SECUREBANK"
|
||||
# terminal_id = "SB-0001" # random if not set
|
||||
# region = "NORTHEAST"
|
||||
|
||||
# [shell.adventure]
|
||||
# dungeon_name = "THE OUBLIETTE"
|
||||
|
||||
# [shell.cisco]
|
||||
# hostname = "Router"
|
||||
# model = "C2960"
|
||||
# ios_version = "15.0(2)SE11"
|
||||
# enable_password = "" # empty = accept after 1 failed attempt
|
||||
|
||||
# [shell.psql]
|
||||
# db_name = "postgres"
|
||||
# pg_version = "15.4"
|
||||
|
||||
# [shell.roomba]
|
||||
# No configuration options currently.
|
||||
|
||||
# [shell.tetris]
|
||||
# difficulty = "normal" # "easy" (slower start), "normal" (standard), "hard" (start at level 5)
|
||||
|
||||
# [detection]
|
||||
# enabled = true
|
||||
# threshold = 0.6 # 0.0–1.0, sessions above this trigger notifications
|
||||
# update_interval = "5s" # how often to recompute the score during a session
|
||||
|
||||
# [[notify.webhooks]]
|
||||
# url = "https://ntfy.example.com/honeypot"
|
||||
# headers = { Authorization = "Bearer your-token" }
|
||||
# events = ["human_detected", "session_started"] # empty = all events
|
||||
|
||||
18
scripts/fetch-geoip.sh
Executable file
18
scripts/fetch-geoip.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
# Downloads the DB-IP Lite country MMDB database for development.
|
||||
# The Nix build fetches this automatically; this script is for local dev only.
|
||||
set -euo pipefail
|
||||
|
||||
URL="https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz"
|
||||
DEST="internal/geoip/dbip-country-lite.mmdb"
|
||||
|
||||
cd "$(git rev-parse --show-toplevel)"
|
||||
|
||||
if [ -f "$DEST" ]; then
|
||||
echo "GeoIP database already exists at $DEST"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Downloading DB-IP Lite country database..."
|
||||
curl -fSL "$URL" | gunzip > "$DEST"
|
||||
echo "Saved to $DEST"
|
||||
Reference in New Issue
Block a user