Compare commits

..

41 Commits

Author SHA1 Message Date
1b28f10ca8 refactor: migrate module path from git.t-juice.club to code.t-juice.club
Update Go module path and all import references to reflect the migration
from Gitea (git.t-juice.club) to Forgejo (code.t-juice.club).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 18:51:23 +01:00
664e79fce6 feat: add Prometheus metrics for Store queries
Add InstrumentedStore decorator that wraps any Store and records
per-method query duration histograms and error counters. Wired into
main.go so all storage consumers get automatic observability.

Bump version to 0.18.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:29:51 +01:00
c74313c195 fix: resolve linting issues in roomba shell
Replace unnecessary fmt.Sprintf calls with string literals, use
slices.Contains instead of manual loop, and use compound assignment
operator.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:18:25 +01:00
9783ae5865 fix: prevent context canceled errors in web dashboard
Detach DB queries from HTTP request context so HTMX polling doesn't
cancel in-flight queries when the browser aborts previous XHRs. Add
indexes on login_attempts and sessions to speed up frequent dashboard
queries. Bump version to 0.17.1.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:16:49 +01:00
62de222488 feat: add tetris shell (Tetris game TUI)
Full-screen Tetris game using Bubbletea with title screen, ghost piece,
lock delay, NES-style scoring, configurable difficulty (easy/normal/hard),
and honeypot event logging. Bumps version to 0.17.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 00:59:46 +01:00
c9d143d84b feat: add roomba shell (iRobot Roomba j7+ vacuum emulator)
New novelty shell emulating RoombaOS with cleaning, scheduling,
diagnostics, floor map, and humorous history entries. Bump version
to 0.16.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 14:06:59 +01:00
d18a904ed5 chore: bump version to 0.15.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:13:50 +01:00
cb7be28f42 feat: add server-side session filtering with input bytes and human score
Replace client-side session table filtering with server-side filtering
via a new /fragments/recent-sessions htmx endpoint. Add InputBytes column
to session tables, Human score > 0 checkbox filter, and Sort by Input
Bytes option to help identify sessions with actual shell interaction.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:12:51 +01:00
0908b43724 chore: bump version to 0.14.2
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:24:01 +01:00
52310f588d fix: highlight all polygons on hover for multi-path countries
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:24:01 +01:00
b52216bd2f chore: bump version to 0.14.1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:21:01 +01:00
2bc83a17dd fix: handle SVG group elements in world map for multi-path countries
The SVG world map uses <g> group elements for countries with complex
shapes (US, CN, RU, GB, etc.), but the JS only queried <path> elements,
causing 36 countries to be missing from the map. Also removes the SVG
<title> element that was overriding the custom tooltip.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:20:23 +01:00
faf6e2abd7 docs: mark 4.1 and 4.4 as completed in PLAN.md
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:34:17 +01:00
0a4eac188a chore: bump version to 0.14.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:31:53 +01:00
7c90c9ed4a feat: add charts, world map, and filters to web dashboard
Add Chart.js line/bar charts for attack trends (attempts over time,
hourly pattern), an SVG world map choropleth colored by attack origin
country, and a collapsible filter form (date range, IP, country,
username) that narrows both charts and top-N tables.

New store methods: GetAttemptsOverTime, GetHourlyPattern, GetCountryStats,
and filtered variants of dashboard stats/top-N queries. New JSON API
endpoints at /api/charts/* and an htmx fragment at
/fragments/dashboard-content for filtered table updates.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:27:15 +01:00
8a631af0d2 fix: prevent dashboard top-grid cards from overflowing horizontally
Increase minimum column width from 280px to 380px so the 3-column Top
IPs table fits without clipping. Add overflow/min-width safety net for
narrow viewports.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 21:25:20 +01:00
40fda3420c feat: add psql shell and username-to-shell routing
Add a PostgreSQL psql interactive terminal shell with backslash
meta-commands, SQL statement handling with multi-line buffering, and
canned responses for common queries. Add username-based shell routing
via [shell.username_routes] config (second priority after credential-
specific shell, before random selection). Bump version to 0.13.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:58:34 +01:00
c4801e3309 chore: bump version to 0.12.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:38:47 +01:00
4f10a8a422 feat: add session indicators and top exec commands to dashboard
Add visual indicators to session tables (replay badge when events exist,
exec badge for exec sessions) and a new "Top Exec Commands" table on the
dashboard. Includes EventCount field on Session, GetTopExecCommands on
Store interface, and truncateCommand template function.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:38:10 +01:00
0b44d1c83f docs: detail fake exec output approach in PLAN.md 4.4.1
Regex-based output assembly: scan exec commands for known patterns
and return plausible fake values rather than interpreting shell
pipelines. Waiting on more real-world bot examples before implementing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 18:01:42 +01:00
0133d956a5 feat: capture SSH exec commands (PLAN.md 4.4)
Bots often send commands via `ssh user@host <command>` (exec request)
rather than requesting an interactive shell. These were previously
rejected silently. Now exec commands are captured, stored on the session
record, and displayed in the web UI session detail page.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:43:11 +01:00
3c20e854aa docs: add plan for capturing SSH exec commands (PLAN.md 4.4)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:25:52 +01:00
090dbec390 chore: bump version to 0.10.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:55:10 +01:00
df860b3061 feat: add new Prometheus metrics and bearer token auth for /metrics
Add 6 new Prometheus metrics for richer observability:
- auth_attempts_by_country_total (counter by country)
- commands_executed_total (counter by shell via OnCommand callback)
- human_score (histogram of final detection scores)
- storage_login_attempts_total, storage_unique_ips, storage_sessions_total
  (gauges via custom collector querying GetDashboardStats on each scrape)

Add optional bearer token authentication for the /metrics endpoint via
web.metrics_token config option. Uses crypto/subtle.ConstantTimeCompare.
Empty token (default) means no auth for backwards compatibility.

Also adds "cisco" to pre-initialized session/command metric labels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:54:29 +01:00
9aecc7ce02 chore: bump version to 0.9.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:29:37 +01:00
94f1f1c266 feat: add GeoIP country lookup with embedded DB-IP Lite database (PLAN.md 4.3)
Embeds a DB-IP Lite country MMDB (~5MB) in the binary via go:embed,
keeping the single-binary deployment story clean. Country codes are
stored alongside login attempts and sessions, shown in the dashboard
(Top IPs, Top Countries card, Recent/Active Sessions, session detail).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:27:46 +01:00
8fff893d25 docs: mark Cisco IOS shell (PLAN.md 3.2) as completed
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:04:51 +01:00
5ba62afec3 feat: add Cisco IOS shell with mode state machine and abbreviation matching (PLAN.md 3.2)
Implements a Cisco IOS CLI emulator with four modes (user exec, privileged exec,
global config, interface config), Cisco-style command abbreviation (e.g. sh run,
conf t), enable password flow, and realistic show command output including
running-config, interfaces, IP routes, and VLANs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 14:58:26 +01:00
058da51f86 fix: add column whitelist to queryTopN to prevent SQL injection
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 10:08:28 +01:00
adfe372d13 refactor: extract changePinModel into its own sub-model
The Change PIN screen was the only screen with its state (pinInput,
pinStage, pinMessage) stored directly on the top-level model. Extract
it into a changePinModel in screen_changepin.go to match the pattern
used by all other screens.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 09:34:56 +01:00
3163ea47dc chore: add bubbletea skill 2026-02-15 09:28:28 +01:00
ab07e6a8dc feat: add Prometheus metrics endpoint and Docker image (PLAN.md 4.2)
Add internal/metrics package with dedicated Prometheus registry exposing
SSH connection, auth attempt, session, and build info metrics. Wire into
SSH server (4 instrumentation points) and web server (/metrics endpoint).
Add dockerImage output to flake.nix via dockerTools.buildLayeredImage.
Bump version to 0.7.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 05:47:16 +01:00
b8fcbc7e10 chore: bump version to 0.6.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 05:17:57 +01:00
aa569aac16 feat: add text adventure shell (PLAN.md 3.4)
Zork-style dungeon crawler set in an abandoned data center / medieval dungeon.
11 rooms, 6 items, 3 puzzles (dark room, locked door, maintenance panel),
standard text adventure parser with verb aliases and direction shortcuts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 05:13:03 +01:00
1a407ad4c2 docs: mark banking TUI shell as complete in PLAN.md
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 00:58:20 +01:00
5d0c8cc20c fix: apply black background to banking TUI padding areas
Padding spaces (end-of-line and blank filler lines) were unstyled,
causing the terminal's default background to bleed through.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 00:55:33 +01:00
d226c32b9b fix: banking shell screen rendering artifacts and transfer panic
Fix rendering issues where content from previous screens bled through
when switching between views of different heights/widths:

- Pad every line to full terminal width (ANSI-aware) so shorter lines
  overwrite leftover content from previous renders
- Track terminal height via WindowSizeMsg and pad between content and
  footer to fill the screen
- Send tea.ClearScreen on all screen transitions for height changes
- Fix panic in transfer completion when routing number is < 4 chars

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 00:50:34 +01:00
86786c9d05 fix: clean up stale active sessions on startup
After an unclean shutdown, sessions could be left with disconnected_at
NULL, appearing permanently active. Add CloseActiveSessions to the Store
interface and call it at startup to close any leftover sessions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 00:16:48 +01:00
d78d461236 chore: bump version to 0.5.0 and update vendor hash
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 23:22:28 +01:00
49425635ce revert: undo premature version bump
Version should be bumped when merging to master, not on the feature branch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 23:17:30 +01:00
8ff029fcb7 feat: add Banking TUI shell using bubbletea
Add an 80s-style green-on-black bank terminal shell ("banking") using
charmbracelet/bubbletea for full-screen TUI rendering over SSH.

Screens: login, main menu, account summary, account detail with
transactions, wire transfer wizard (6-step form capturing routing
number, destination, beneficiary, amount, memo, auth code), transaction
history with pagination, secure messages with breadcrumb content (fake
internal IPs, vault codes), change PIN, and hidden admin access (99)
that locks after 3 failed attempts with COBOL-style error output.

All key actions (login, navigation, wire transfers, admin attempts) are
logged to the session store. Wire transfer data is the honeypot gold.

Configurable via [shell.banking] in TOML: bank_name, terminal_id, region.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 23:17:12 +01:00
87 changed files with 12163 additions and 213 deletions

View File

@@ -0,0 +1,55 @@
---
name: bubbletea
description: Browse Bubbletea TUI framework documentation and examples. Use when working with Bubbletea components, models, commands, or building terminal user interfaces in Go.
---
# Bubbletea Documentation
Bubbletea is a Go framework for building terminal user interfaces based on The Elm Architecture.
## Key Resources
When you need to understand Bubbletea patterns or find examples:
1. **Examples README** - Overview of all available examples:
https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
2. **Examples Directory** - Full source code for all examples:
https://github.com/charmbracelet/bubbletea/tree/main/examples
## How to Use
1. First, fetch the examples README to get an overview of available examples:
```
WebFetch https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
```
2. Once you identify a relevant example, fetch its source code from the examples directory.
## Common Examples to Reference
- `list` - List component with filtering
- `table` - Table component
- `textinput` - Text input handling
- `textarea` - Multi-line text input
- `viewport` - Scrollable content
- `paginator` - Pagination
- `spinner` - Loading spinners
- `progress` - Progress bars
- `tabs` - Tab navigation
- `help` - Help text/keybindings display
## Core Concepts
- **Model**: Application state
- **Update**: Handles messages and returns updated model + commands
- **View**: Renders the model to a string
- **Cmd**: Side effects that produce messages
- **Msg**: Events that trigger updates
## Related Charm Libraries
- **Bubbles**: Pre-built components (github.com/charmbracelet/bubbles)
- **Lipgloss**: Styling and layout (github.com/charmbracelet/lipgloss)
- **Glamour**: Markdown rendering (github.com/charmbracelet/glamour)

2
.gitignore vendored
View File

@@ -4,3 +4,5 @@ oubliette.toml
*.db-wal
*.db-shm
/oubliette
*.mmdb
*.mmdb.gz

87
PLAN.md
View File

@@ -150,7 +150,7 @@ Goal: Add the entertaining shell implementations.
- **Haunted:** commands gradually return stranger output, files appear/disappear, `whoami` returns different users
- **Bread crumbs:** fake .bash_history, id_rsa files, database configs pointing to other honeypots
### 3.2 Cisco IOS Shell
### 3.2 Cisco IOS Shell
- Realistic `>` and `#` prompts
- Common commands: `show running-config`, `show interfaces`, `enable`, `configure terminal`
- Fake device info that looks like a real router
@@ -162,14 +162,29 @@ Goal: Add the entertaining shell implementations.
- "WARNING: milk expires in 2 days"
- Per-credential shell routing via `shell` field in static credentials
### 3.4 Text Adventure
### 3.4 Text Adventure
- Zork-style dungeon crawler
- "You are in a dimly lit server room."
- Navigation, items, puzzles
- The dungeon is the oubliette itself
### 3.5 Other Shell Ideas (Future)
- **Banking TUI:** 80s-style green-on-black bank terminal
### 3.5 Banking TUI Shell ✅
- 80s-style green-on-black bank terminal
### 3.6 PostgreSQL psql Shell ✅
- Simulates psql interactive terminal with `db_name` and `pg_version` config
- Backslash meta-commands: `\q`, `\dt`, `\d <table>`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
- SQL statement handling with multi-line buffering (semicolon-terminated)
- Canned responses for common queries (SELECT version(), current_database(), etc.)
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
### 3.7 Roomba Shell ✅
- iRobot Roomba j7+ vacuum robot interface
- Status, cleaning, scheduling, diagnostics, floor map
- Humorous history entries (cat encounters, sock tangles, sticky substances)
### 3.8 Other Shell Ideas (Future)
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
- **ELIZA therapist:** every response is a therapy question
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
@@ -181,19 +196,55 @@ Goal: Add the entertaining shell implementations.
Goal: Make the web UI great and add operational niceties.
### 4.1 Enhanced Web UI
- GeoIP lookups and world map visualization of attack sources
- Charts: attempts over time, hourly patterns, credential trends
- Session detail view with full command log
- Filtering and search
### 4.1 Enhanced Web UI
- GeoIP lookups and world map visualization of attack sources
- Charts: attempts over time, hourly patterns, credential trends
- Session detail view with full command log
- Filtering and search
### 4.2 Operational
- Prometheus metrics endpoint
- Structured logging (slog)
- Graceful shutdown
- Systemd unit file / deployment docs
### 4.2 Operational
- Prometheus metrics endpoint
- Structured logging (slog)
- Graceful shutdown
- Docker image (nix dockerTools) ✅
- Systemd unit file / deployment docs ✅
### 4.3 GeoIP
- Embed a lightweight GeoIP database or use an API
- Store country/city with each attempt
- Aggregate stats by country
### 4.3 GeoIP
- Embed a lightweight GeoIP database or use an API
- Store country/city with each attempt
- Aggregate stats by country
### 4.4 Capture SSH Exec Commands ✅
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`) ✅
- Parse the command string from the exec payload ✅
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration ✅
- Store the command on the session record before closing the channel ✅
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
- Surface exec commands in the web UI (session detail view) ✅
#### 4.4.1 Fake Exec Output
Return plausible fake output for exec commands to encourage bots to interact further.
**Approach: regex-based output assembly.** Bots typically send a single long command that chains recon commands and then echoes a summary (e.g. `echo "UNAME:$uname"`). Rather than interpreting arbitrary shell pipelines, we scan the command string for known patterns and assemble fake output.
Implementation:
- A map of common command/variable patterns to fake output strings, e.g.:
- `uname -a` / `uname -s -v -n -m``"Linux ubuntu-server 5.15.0-91-generic #101-Ubuntu SMP Tue Jan 2 15:13:10 UTC 2024 x86_64"`
- `uname -m` / `arch``"x86_64"`
- `cat /proc/uptime``"86432.71 172801.55"`
- `nproc` / `grep -c "^processor" /proc/cpuinfo``"2"`
- `cat /proc/cpuinfo` → fake cpuinfo block
- `lspci` → empty (no GPU — discourages cryptominer targeting)
- `id``"uid=0(root) gid=0(root) groups=0(root)"`
- `cat /etc/passwd` → minimal fake passwd file
- `last` → fake login entries
- `cat --help`, `ls --help` → canned GNU coreutils help text
- Scan the exec command for `echo "KEY:$var"` patterns; for each key, look up the corresponding fake value from the variable assignment earlier in the command
- If we recognise echo patterns, assemble and return the expected output
- If we don't recognise the command at all, return empty output with exit 0 (current behaviour)
- Values should draw from the existing shell config where possible (hostname, fake_user) for consistency
- New package `internal/execfake` or a file in `internal/server/` — keep it simple
Gather more real-world bot examples before implementing to ensure good coverage of common recon patterns.

View File

@@ -34,7 +34,8 @@ Key settings:
- `auth.accept_after` — accept login after N failures per IP (default `10`)
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS)
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
- `storage.db_path` — SQLite database path (default `oubliette.db`)
- `storage.retention_days` — auto-prune records older than N days (default `90`)
- `storage.retention_interval` — how often to run retention (default `1h`)
@@ -43,12 +44,21 @@ Key settings:
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
- `web.enabled` — enable the web dashboard (default `false`)
- `web.listen_addr` — web dashboard listen address (default `:8080`)
- Dashboard includes Chart.js charts (attempts over time, hourly pattern), an SVG world map choropleth colored by attack origin, and filter controls for date range / IP / country / username
- Session detail pages at `/sessions/{id}` include terminal replay via xterm.js
- `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`)
- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty)
- `detection.enabled` — enable human detection scoring (default `false`)
- `detection.threshold` — score threshold (0.01.0) for flagging sessions (default `0.6`)
- `detection.update_interval` — how often to recompute scores (default `5s`)
- `notify.webhooks` — list of webhook endpoints for notifications (see example config)
### GeoIP
Country-level GeoIP lookups are embedded in the binary using the [DB-IP Lite](https://db-ip.com/db/lite.php) database (CC-BY-4.0). The dashboard shows country alongside IPs and includes a "Top Countries" table.
For local development, run `scripts/fetch-geoip.sh` to download the MMDB file. The Nix build fetches it automatically.
### Run
```sh
@@ -61,6 +71,9 @@ Test with:
ssh -o StrictHostKeyChecking=no -p 2222 root@localhost
```
SSH exec commands (`ssh user@host <command>`) are also captured and stored on the session record.
### NixOS Module
Add the flake as an input and enable the service:
@@ -82,3 +95,15 @@ Add the flake as an input and enable the service:
```
Alternatively, use `configFile` to pass a pre-written TOML file instead of `settings`.
### Docker
Build a Docker image via nix:
```sh
nix build .#dockerImage
docker load < result
docker run -v /path/to/data:/data -p 2222:2222 -p 8080:8080 oubliette:0.8.0
```
Place your `oubliette.toml` in the data volume. The container exposes ports 2222 (SSH) and 8080 (web/metrics).

View File

@@ -13,13 +13,14 @@ import (
"syscall"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/server"
"git.t-juice.club/torjus/oubliette/internal/storage"
"git.t-juice.club/torjus/oubliette/internal/web"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/server"
"code.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/web"
)
const Version = "0.4.0"
const Version = "0.18.0"
func main() {
if err := run(); err != nil {
@@ -65,12 +66,23 @@ func run() error {
}
defer store.Close()
// Clean up sessions left active by a previous unclean shutdown.
if n, err := store.CloseActiveSessions(context.Background(), time.Now()); err != nil {
return fmt.Errorf("close stale sessions: %w", err)
} else if n > 0 {
logger.Info("closed stale sessions from previous run", "count", n)
}
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
m := metrics.New(Version)
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
m.RegisterStoreCollector(instrumentedStore)
srv, err := server.New(*cfg, store, logger)
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
srv, err := server.New(*cfg, instrumentedStore, logger, m)
if err != nil {
return fmt.Errorf("create server: %w", err)
}
@@ -79,7 +91,12 @@ func run() error {
// Start web server if enabled.
if cfg.Web.Enabled {
webHandler, err := web.NewServer(store, logger.With("component", "web"))
var metricsHandler http.Handler
if *cfg.Web.MetricsEnabled {
metricsHandler = m.Handler()
}
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
if err != nil {
return fmt.Errorf("create web server: %w", err)
}

View File

@@ -18,19 +18,44 @@
pkgs = nixpkgs.legacyPackages.${system};
mainGo = builtins.readFile ./cmd/oubliette/main.go;
version = builtins.head (builtins.match ''.*const Version = "([^"]+)".*'' mainGo);
geoipDb = pkgs.fetchurl {
url = "https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz";
hash = "sha256-xmQZEJZ5WzE9uQww1Sdb8248l+liYw46tjbfJeu945Q=";
};
in
{
default = pkgs.buildGoModule {
pname = "oubliette";
inherit version;
src = ./.;
vendorHash = "sha256-EbJ90e4Jco7CvYYJLrewFLD5XF+Wv6TsT8RRLcj+ijU=";
vendorHash = "sha256-/zxK6CABLYBNtuSOI8dIVgMNxKiDIcbZUS7bQR5TenA=";
subPackages = [ "cmd/oubliette" ];
nativeBuildInputs = [ pkgs.gzip ];
preBuild = ''
gunzip -c ${geoipDb} > internal/geoip/dbip-country-lite.mmdb
'';
meta = {
description = "SSH honeypot";
mainProgram = "oubliette";
};
};
dockerImage = pkgs.dockerTools.buildLayeredImage {
name = "oubliette";
tag = version;
contents = [ self.packages.${system}.default ];
config = {
Entrypoint = [ "/bin/oubliette" ];
Cmd = [ "-config" "/data/oubliette.toml" ];
ExposedPorts = {
"2222/tcp" = {};
"8080/tcp" = {};
};
Volumes = {
"/data" = {};
};
};
};
});
devShells = forAllSystems (system:

30
go.mod
View File

@@ -1,21 +1,49 @@
module git.t-juice.club/torjus/oubliette
module code.t-juice.club/torjus/oubliette
go 1.25.5
require (
github.com/BurntSushi/toml v1.6.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/google/uuid v1.6.0
github.com/oschwald/maxminddb-golang v1.13.1
github.com/prometheus/client_golang v1.23.2
github.com/prometheus/client_model v0.6.2
golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.45.0
)
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect

94
go.sum
View File

@@ -1,34 +1,116 @@
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=

View File

@@ -5,7 +5,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
const (

View File

@@ -6,7 +6,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {

View File

@@ -23,12 +23,15 @@ type Config struct {
type WebConfig struct {
Enabled bool `toml:"enabled"`
ListenAddr string `toml:"listen_addr"`
MetricsEnabled *bool `toml:"metrics_enabled"`
MetricsToken string `toml:"metrics_token"`
}
type ShellConfig struct {
Hostname string `toml:"hostname"`
Banner string `toml:"banner"`
FakeUser string `toml:"fake_user"`
UsernameRoutes map[string]string `toml:"username_routes"`
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
}
@@ -143,6 +146,10 @@ func applyDefaults(cfg *Config) {
if cfg.Web.ListenAddr == "" {
cfg.Web.ListenAddr = ":8080"
}
if cfg.Web.MetricsEnabled == nil {
t := true
cfg.Web.MetricsEnabled = &t
}
if cfg.Shell.Hostname == "" {
cfg.Shell.Hostname = "ubuntu-server"
}
@@ -162,6 +169,7 @@ var knownShellKeys = map[string]bool{
"hostname": true,
"banner": true,
"fake_user": true,
"username_routes": true,
}
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.

View File

@@ -282,6 +282,22 @@ password = "toor"
}
}
func TestLoadMetricsToken(t *testing.T) {
content := `
[web]
enabled = true
metrics_token = "my-secret-token"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Web.MetricsToken != "my-secret-token" {
t.Errorf("metrics_token = %q, want %q", cfg.Web.MetricsToken, "my-secret-token")
}
}
func TestLoadMissingFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.toml")
if err == nil {
@@ -297,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
}
}
func TestLoadUsernameRoutes(t *testing.T) {
content := `
[shell]
hostname = "myhost"
[shell.username_routes]
postgres = "psql"
admin = "bash"
[shell.bash]
custom_key = "value"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Shell.UsernameRoutes == nil {
t.Fatal("UsernameRoutes should not be nil")
}
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
}
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
}
// username_routes should NOT appear in the Shells map.
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
t.Error("username_routes should not appear in Shells map")
}
// bash should still appear in Shells map.
if _, ok := cfg.Shell.Shells["bash"]; !ok {
t.Error("Shells[\"bash\"] should still be present")
}
}
func writeTemp(t *testing.T, content string) string {
t.Helper()
path := filepath.Join(t.TempDir(), "config.toml")

51
internal/geoip/geoip.go Normal file
View File

@@ -0,0 +1,51 @@
package geoip
import (
_ "embed"
"net"
"github.com/oschwald/maxminddb-golang"
)
//go:embed dbip-country-lite.mmdb
var mmdbData []byte
// Reader provides country-level GeoIP lookups using an embedded DB-IP Lite database.
type Reader struct {
db *maxminddb.Reader
}
// New opens the embedded MMDB and returns a ready-to-use Reader.
func New() (*Reader, error) {
db, err := maxminddb.FromBytes(mmdbData)
if err != nil {
return nil, err
}
return &Reader{db: db}, nil
}
type countryRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
// Lookup returns the ISO 3166-1 alpha-2 country code for the given IP address,
// or an empty string if the lookup fails or no result is found.
func (r *Reader) Lookup(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ""
}
var record countryRecord
if err := r.db.Lookup(ip, &record); err != nil {
return ""
}
return record.Country.ISOCode
}
// Close releases resources held by the reader.
func (r *Reader) Close() error {
return r.db.Close()
}

View File

@@ -0,0 +1,44 @@
package geoip
import "testing"
func TestLookup(t *testing.T) {
reader, err := New()
if err != nil {
t.Fatalf("New: %v", err)
}
defer reader.Close()
tests := []struct {
ip string
want string
}{
{"8.8.8.8", "US"},
{"1.1.1.1", "AU"},
{"invalid", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
got := reader.Lookup(tt.ip)
if got != tt.want {
t.Errorf("Lookup(%q) = %q, want %q", tt.ip, got, tt.want)
}
})
}
}
func TestLookupPrivateIP(t *testing.T) {
reader, err := New()
if err != nil {
t.Fatalf("New: %v", err)
}
defer reader.Close()
// Private IPs should return empty string (no country).
got := reader.Lookup("10.0.0.1")
if got != "" {
t.Errorf("Lookup(10.0.0.1) = %q, want empty", got)
}
}

178
internal/metrics/metrics.go Normal file
View File

@@ -0,0 +1,178 @@
package metrics
import (
"context"
"net/http"
"code.t-juice.club/torjus/oubliette/internal/storage"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Metrics holds all Prometheus collectors for the honeypot.
type Metrics struct {
registry *prometheus.Registry
SSHConnectionsTotal *prometheus.CounterVec
SSHConnectionsActive prometheus.Gauge
AuthAttemptsTotal *prometheus.CounterVec
AuthAttemptsByCountry *prometheus.CounterVec
CommandsExecuted *prometheus.CounterVec
HumanScore prometheus.Histogram
SessionsTotal *prometheus.CounterVec
SessionsActive prometheus.Gauge
SessionDuration prometheus.Histogram
ExecCommandsTotal prometheus.Counter
BuildInfo *prometheus.GaugeVec
StorageQueryDuration *prometheus.HistogramVec
StorageQueryErrors *prometheus.CounterVec
}
// New creates a new Metrics instance with all collectors registered.
func New(version string) *Metrics {
reg := prometheus.NewRegistry()
m := &Metrics{
registry: reg,
SSHConnectionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_ssh_connections_total",
Help: "Total SSH connections received.",
}, []string{"outcome"}),
SSHConnectionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "oubliette_ssh_connections_active",
Help: "Current active SSH connections.",
}),
AuthAttemptsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_auth_attempts_total",
Help: "Total authentication attempts.",
}, []string{"result", "reason"}),
AuthAttemptsByCountry: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_auth_attempts_by_country_total",
Help: "Total authentication attempts by country.",
}, []string{"country"}),
CommandsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_commands_executed_total",
Help: "Total commands executed in shells.",
}, []string{"shell"}),
HumanScore: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "oubliette_human_score",
Help: "Distribution of final human detection scores.",
Buckets: prometheus.LinearBuckets(0, 0.1, 11), // 0.0, 0.1, ..., 1.0
}),
SessionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_sessions_total",
Help: "Total sessions created.",
}, []string{"shell"}),
SessionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "oubliette_sessions_active",
Help: "Current active sessions.",
}),
SessionDuration: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "oubliette_session_duration_seconds",
Help: "Session duration in seconds.",
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600},
}),
ExecCommandsTotal: prometheus.NewCounter(prometheus.CounterOpts{
Name: "oubliette_exec_commands_total",
Help: "Total SSH exec commands received.",
}),
BuildInfo: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "oubliette_build_info",
Help: "Build information. Always 1.",
}, []string{"version"}),
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "oubliette_storage_query_duration_seconds",
Help: "Duration of storage query calls in seconds.",
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}, []string{"method"}),
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_storage_query_errors_total",
Help: "Total storage query errors.",
}, []string{"method"}),
}
reg.MustRegister(
collectors.NewGoCollector(),
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
m.SSHConnectionsTotal,
m.SSHConnectionsActive,
m.AuthAttemptsTotal,
m.AuthAttemptsByCountry,
m.CommandsExecuted,
m.HumanScore,
m.SessionsTotal,
m.SessionsActive,
m.SessionDuration,
m.ExecCommandsTotal,
m.BuildInfo,
m.StorageQueryDuration,
m.StorageQueryErrors,
)
m.BuildInfo.WithLabelValues(version).Set(1)
// Initialize label combinations so they appear in Gather/output.
for _, outcome := range []string{"accepted", "rejected_handshake", "rejected_max_connections"} {
m.SSHConnectionsTotal.WithLabelValues(outcome)
}
for _, reason := range []string{"static_credential", "remembered_credential", "threshold_reached", "rejected"} {
m.AuthAttemptsTotal.WithLabelValues("accepted", reason)
m.AuthAttemptsTotal.WithLabelValues("rejected", reason)
}
for _, sh := range []string{"bash", "fridge", "banking", "adventure", "cisco"} {
m.SessionsTotal.WithLabelValues(sh)
m.CommandsExecuted.WithLabelValues(sh)
}
return m
}
// RegisterStoreCollector registers a collector that queries storage stats on each scrape.
func (m *Metrics) RegisterStoreCollector(store storage.Store) {
m.registry.MustRegister(&storeCollector{store: store})
}
// Handler returns an http.Handler that serves Prometheus metrics.
func (m *Metrics) Handler() http.Handler {
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
}
// storeCollector implements prometheus.Collector, querying storage on each scrape.
type storeCollector struct {
store storage.Store
}
var (
storageLoginAttemptsDesc = prometheus.NewDesc(
"oubliette_storage_login_attempts_total",
"Total login attempts in storage.",
nil, nil,
)
storageUniqueIPsDesc = prometheus.NewDesc(
"oubliette_storage_unique_ips",
"Unique IPs in storage.",
nil, nil,
)
storageSessionsDesc = prometheus.NewDesc(
"oubliette_storage_sessions_total",
"Total sessions in storage.",
nil, nil,
)
)
func (c *storeCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- storageLoginAttemptsDesc
ch <- storageUniqueIPsDesc
ch <- storageSessionsDesc
}
func (c *storeCollector) Collect(ch chan<- prometheus.Metric) {
stats, err := c.store.GetDashboardStats(context.Background())
if err != nil {
return
}
ch <- prometheus.MustNewConstMetric(storageLoginAttemptsDesc, prometheus.GaugeValue, float64(stats.TotalAttempts))
ch <- prometheus.MustNewConstMetric(storageUniqueIPsDesc, prometheus.GaugeValue, float64(stats.UniqueIPs))
ch <- prometheus.MustNewConstMetric(storageSessionsDesc, prometheus.GaugeValue, float64(stats.TotalSessions))
}

View File

@@ -0,0 +1,142 @@
package metrics
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func TestNew(t *testing.T) {
m := New("1.2.3")
// Gather all metrics and check expected names exist.
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
want := map[string]bool{
"oubliette_ssh_connections_total": false,
"oubliette_ssh_connections_active": false,
"oubliette_auth_attempts_total": false,
"oubliette_commands_executed_total": false,
"oubliette_human_score": false,
"oubliette_sessions_total": false,
"oubliette_sessions_active": false,
"oubliette_session_duration_seconds": false,
"oubliette_build_info": false,
}
for _, f := range families {
if _, ok := want[f.GetName()]; ok {
want[f.GetName()] = true
}
}
for name, found := range want {
if !found {
t.Errorf("metric %q not registered", name)
}
}
}
func TestAuthAttemptsByCountry(t *testing.T) {
m := New("1.0.0")
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
m.AuthAttemptsByCountry.WithLabelValues("DE").Inc()
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
var found bool
for _, f := range families {
if f.GetName() == "oubliette_auth_attempts_by_country_total" {
found = true
if len(f.GetMetric()) != 2 {
t.Errorf("expected 2 label pairs (US, DE), got %d", len(f.GetMetric()))
}
}
}
if !found {
t.Error("oubliette_auth_attempts_by_country_total not found after incrementing")
}
}
func TestHandler(t *testing.T) {
m := New("1.2.3")
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
m.Handler().ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body, err := io.ReadAll(w.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
if !strings.Contains(string(body), `oubliette_build_info{version="1.2.3"} 1`) {
t.Errorf("response should contain build_info metric, got:\n%s", body)
}
}
func TestStoreCollector(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
// Seed some data.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
t.Fatalf("CreateSession: %v", err)
}
m := New("test")
m.RegisterStoreCollector(store)
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
wantMetrics := map[string]float64{
"oubliette_storage_login_attempts_total": 2,
"oubliette_storage_unique_ips": 2,
"oubliette_storage_sessions_total": 1,
}
for _, f := range families {
expected, ok := wantMetrics[f.GetName()]
if !ok {
continue
}
if len(f.GetMetric()) == 0 {
t.Errorf("metric %q has no samples", f.GetName())
continue
}
got := f.GetMetric()[0].GetGauge().GetValue()
if got != expected {
t.Errorf("metric %q = %f, want %f", f.GetName(), got, expected)
}
delete(wantMetrics, f.GetName())
}
for name := range wantMetrics {
t.Errorf("metric %q not found in gather output", name)
}
}

View File

@@ -10,7 +10,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
// Event types.

View File

@@ -10,7 +10,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
func testSession() SessionInfo {

View File

@@ -12,14 +12,22 @@ import (
"os"
"time"
"git.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/detection"
"git.t-juice.club/torjus/oubliette/internal/notify"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/auth"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/detection"
"code.t-juice.club/torjus/oubliette/internal/geoip"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/notify"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
"code.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
@@ -32,9 +40,11 @@ type Server struct {
connSem chan struct{} // semaphore limiting concurrent connections
shellRegistry *shell.Registry
notifier notify.Sender
metrics *metrics.Metrics
geoip *geoip.Reader
}
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
registry := shell.NewRegistry()
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
return nil, fmt.Errorf("registering bash shell: %w", err)
@@ -42,6 +52,29 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil {
return nil, fmt.Errorf("registering fridge shell: %w", err)
}
if err := registry.Register(banking.NewBankingShell(), 1); err != nil {
return nil, fmt.Errorf("registering banking shell: %w", err)
}
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
return nil, fmt.Errorf("registering adventure shell: %w", err)
}
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
return nil, fmt.Errorf("registering cisco shell: %w", err)
}
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
return nil, fmt.Errorf("registering psql shell: %w", err)
}
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
return nil, fmt.Errorf("registering roomba shell: %w", err)
}
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
return nil, fmt.Errorf("registering tetris shell: %w", err)
}
geo, err := geoip.New()
if err != nil {
return nil, fmt.Errorf("opening geoip database: %w", err)
}
s := &Server{
cfg: cfg,
@@ -51,6 +84,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
shellRegistry: registry,
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
metrics: m,
geoip: geo,
}
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
@@ -68,6 +103,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
}
func (s *Server) ListenAndServe(ctx context.Context) error {
defer s.geoip.Close()
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
if err != nil {
return fmt.Errorf("listen: %w", err)
@@ -94,11 +131,16 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
// Enforce max concurrent connections.
select {
case s.connSem <- struct{}{}:
s.metrics.SSHConnectionsActive.Inc()
go func() {
defer func() { <-s.connSem }()
defer func() {
<-s.connSem
s.metrics.SSHConnectionsActive.Dec()
}()
s.handleConn(conn)
}()
default:
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
conn.Close()
}
@@ -110,11 +152,13 @@ func (s *Server) handleConn(conn net.Conn) {
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
return
}
defer sshConn.Close()
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
s.logger.Info("SSH connection established",
"remote_addr", sshConn.RemoteAddr(),
"user", sshConn.User(),
@@ -153,6 +197,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
}
}
// Second priority: username-based route.
if selectedShell == nil {
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
sh, found := s.shellRegistry.Get(shellName)
if found {
selectedShell = sh
} else {
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
}
}
}
// Lowest priority: random selection.
if selectedShell == nil {
var err error
selectedShell, err = s.shellRegistry.Select()
@@ -163,11 +219,17 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}
ip := extractIP(conn.RemoteAddr())
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
country := s.geoip.Lookup(ip)
sessionStart := time.Now()
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
if err != nil {
s.logger.Error("failed to create session", "err", err)
} else {
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
s.metrics.SessionsActive.Inc()
defer func() {
s.metrics.SessionsActive.Dec()
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
s.logger.Error("failed to end session", "err", err)
}
@@ -193,14 +255,24 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
defer s.notifier.CleanupSession(sessionID)
// Handle session requests (pty-req, shell, etc.)
// Handle session requests (pty-req, shell, exec, etc.)
execCh := make(chan string, 1)
go func() {
defer close(execCh)
for req := range requests {
switch req.Type {
case "pty-req", "shell":
if req.WantReply {
req.Reply(true, nil)
}
case "exec":
if req.WantReply {
req.Reply(true, nil)
}
var payload struct{ Command string }
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
execCh <- payload.Command
}
default:
if req.WantReply {
req.Reply(false, nil)
@@ -209,6 +281,29 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}
}()
// Check for exec request before proceeding to interactive shell.
select {
case cmd, ok := <-execCh:
if ok && cmd != "" {
s.logger.Info("exec command received",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
"command", cmd,
)
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
}
s.metrics.ExecCommandsTotal.Inc()
// Send exit-status 0 and close channel.
exitPayload := make([]byte, 4) // uint32(0)
_, _ = channel.SendRequest("exit-status", false, exitPayload)
return
}
case <-time.After(500 * time.Millisecond):
// No exec request within timeout — proceed with interactive shell.
}
// Build session context.
var shellCfg map[string]any
if s.cfg.Shell.Shells != nil {
@@ -226,6 +321,9 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
Banner: s.cfg.Shell.Banner,
FakeUser: s.cfg.Shell.FakeUser,
},
OnCommand: func(sh string) {
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
},
}
// Wrap channel in RecordingChannel.
@@ -261,6 +359,7 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}
if scorer != nil {
finalScore := scorer.Score()
s.metrics.HumanScore.Observe(finalScore)
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
}
@@ -310,6 +409,12 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
ip := extractIP(conn.RemoteAddr())
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
if d.Accepted {
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
} else {
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
}
s.logger.Info("auth attempt",
"remote_addr", conn.RemoteAddr(),
"username", conn.User(),
@@ -317,7 +422,11 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
"reason", d.Reason,
)
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil {
country := s.geoip.Lookup(ip)
if country != "" {
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
}
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
s.logger.Error("failed to record login attempt", "err", err)
}

View File

@@ -11,8 +11,10 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/auth"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
@@ -120,7 +122,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
store := storage.NewMemoryStore()
srv, err := New(cfg, store, logger)
srv, err := New(cfg, store, logger, metrics.New("test"))
if err != nil {
t.Fatalf("creating server: %v", err)
}
@@ -251,6 +253,137 @@ func TestIntegrationSSHConnect(t *testing.T) {
}
})
// Test exec command capture.
t.Run("exec_command", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.Password("toor")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
// Run a command via exec (no PTY, no shell).
if err := session.Run("uname -a"); err != nil {
// Run returns an error because the server closes the channel,
// but that's expected.
_ = err
}
// Give the server a moment to store the command.
time.Sleep(200 * time.Millisecond)
// Verify the exec command was captured.
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
var foundExec bool
for _, s := range sessions {
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
foundExec = true
break
}
}
if !foundExec {
t.Error("expected a session with exec_command='uname -a'")
}
})
// Test username route: add username_routes so that "postgres" gets psql shell.
t.Run("username_route", func(t *testing.T) {
// Reconfigure with username routes.
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
// Need to get the "postgres" user in via static creds or threshold.
// Use static creds for simplicity.
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
config.Credential{Username: "postgres", Password: "postgres"},
)
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
defer func() {
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
}()
clientCfg := &ssh.ClientConfig{
User: "postgres",
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
t.Fatalf("request pty: %v", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("stdin pipe: %v", err)
}
var output bytes.Buffer
session.Stdout = &output
if err := session.Shell(); err != nil {
t.Fatalf("shell: %v", err)
}
// Wait for the psql banner.
time.Sleep(500 * time.Millisecond)
// Send \q to quit.
stdin.Write([]byte(`\q` + "\r"))
time.Sleep(200 * time.Millisecond)
session.Wait()
out := output.String()
if !strings.Contains(out, "psql") {
t.Errorf("output should contain psql banner, got: %s", out)
}
// Verify session was created with shell name "psql".
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
var foundPsql bool
for _, s := range sessions {
if s.ShellName == "psql" && s.Username == "postgres" {
foundPsql = true
break
}
}
if !foundPsql {
t.Error("expected a session with shell_name='psql' for user 'postgres'")
}
})
// Test threshold acceptance: after enough failed dials, a subsequent
// dial with the same credentials should succeed via threshold or
// remembered credential.

View File

@@ -0,0 +1,117 @@
package adventure
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
// AdventureShell implements a Zork-style text adventure set in a dungeon/data center.
type AdventureShell struct{}
// NewAdventureShell returns a new AdventureShell instance.
func NewAdventureShell() *AdventureShell {
return &AdventureShell{}
}
func (a *AdventureShell) Name() string { return "adventure" }
func (a *AdventureShell) Description() string { return "Zork-style text adventure dungeon crawler" }
func (a *AdventureShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
dungeonName := configString(sess.ShellConfig, "dungeon_name", "THE OUBLIETTE")
game := newGame()
// Print banner and initial room.
banner := strings.ReplaceAll(adventureBanner(dungeonName), "\n", "\r\n")
fmt.Fprint(rw, banner)
// Show starting room.
startDesc := game.describeRoom(game.rooms[game.currentRoom])
startDesc = strings.ReplaceAll(startDesc, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n\r\n", startDesc)
for {
if _, err := fmt.Fprint(rw, "> "); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
fmt.Fprint(rw, "\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := game.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
// Log command and output to store.
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("adventure")
}
if result.exit {
return nil
}
}
}
func adventureBanner(dungeonName string) string {
return fmt.Sprintf(`
___ _ _ ____ _ ___ _____ _____ _____
/ _ \| | | | __ )| | |_ _| ____|_ _|_ _| ___
| | | | | | | _ \| | | || _| | | | | / _ \
| |_| | |_| | |_) | |___ | || |___ | | | | | __/
\___/ \___/|____/|_____|___|_____| |_| |_| \___|
Welcome to %s.
You wake up in the dark. The air is cold and hums with electricity.
This is a place where things are put to be forgotten.
Type 'help' for commands. Type 'look' to examine your surroundings.
`, dungeonName)
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,383 @@
package adventure
import (
"bytes"
"context"
"io"
"strings"
"testing"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
io.Reader
io.Writer
}
func (r *rwCloser) Close() error { return nil }
func runShell(t *testing.T, commands string) string {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "root",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
},
}
rw := &rwCloser{
Reader: bytes.NewBufferString(commands),
Writer: &bytes.Buffer{},
}
sh := NewAdventureShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sh.Handle(ctx, sess, rw); err != nil {
t.Fatalf("Handle: %v", err)
}
return rw.Writer.(*bytes.Buffer).String()
}
func TestAdventureShellName(t *testing.T) {
sh := NewAdventureShell()
if sh.Name() != "adventure" {
t.Errorf("Name() = %q, want %q", sh.Name(), "adventure")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestBanner(t *testing.T) {
output := runShell(t, "quit\r")
if !strings.Contains(output, "OUBLIETTE") {
t.Error("output should contain OUBLIETTE in banner")
}
if !strings.Contains(output, ">") {
t.Error("output should contain > prompt")
}
}
func TestStartingRoom(t *testing.T) {
output := runShell(t, "quit\r")
if !strings.Contains(output, "The Oubliette") {
t.Error("should start in The Oubliette")
}
if !strings.Contains(output, "narrow stone chamber") {
t.Error("should show starting room description")
}
}
func TestHelpCommand(t *testing.T) {
output := runShell(t, "help\rquit\r")
for _, keyword := range []string{"look", "go", "take", "drop", "use", "inventory", "help", "quit"} {
if !strings.Contains(output, keyword) {
t.Errorf("help output should mention %q", keyword)
}
}
}
func TestLookCommand(t *testing.T) {
output := runShell(t, "look\rquit\r")
if !strings.Contains(output, "The Oubliette") {
t.Error("look should show current room")
}
if !strings.Contains(output, "Exits:") {
t.Error("look should show exits")
}
}
func TestMovement(t *testing.T) {
output := runShell(t, "go east\rquit\r")
if !strings.Contains(output, "Stone Corridor") {
t.Error("going east from oubliette should reach Stone Corridor")
}
}
func TestBareDirection(t *testing.T) {
output := runShell(t, "e\rquit\r")
if !strings.Contains(output, "Stone Corridor") {
t.Error("bare 'e' should move east to Stone Corridor")
}
}
func TestDirectionAliases(t *testing.T) {
output := runShell(t, "east\rquit\r")
if !strings.Contains(output, "Stone Corridor") {
t.Error("'east' should move to Stone Corridor")
}
}
func TestInvalidDirection(t *testing.T) {
output := runShell(t, "go north\rquit\r")
if !strings.Contains(output, "can't go") {
t.Error("should say you can't go that direction")
}
}
func TestTakeItem(t *testing.T) {
// Move to corridor where flashlight is.
output := runShell(t, "e\rtake flashlight\rinventory\rquit\r")
if !strings.Contains(output, "Taken") {
t.Error("should confirm taking item")
}
if !strings.Contains(output, "flashlight") {
t.Error("inventory should show flashlight")
}
}
func TestDropItem(t *testing.T) {
output := runShell(t, "e\rtake flashlight\rdrop flashlight\rinventory\rquit\r")
if !strings.Contains(output, "Dropped") {
t.Error("should confirm dropping item")
}
if !strings.Contains(output, "not carrying anything") {
t.Error("inventory should be empty after dropping")
}
}
func TestEmptyInventory(t *testing.T) {
output := runShell(t, "inventory\rquit\r")
if !strings.Contains(output, "not carrying anything") {
t.Error("should say not carrying anything")
}
}
func TestExamineItem(t *testing.T) {
output := runShell(t, "e\rexamine flashlight\rquit\r")
if !strings.Contains(output, "batteries") {
t.Error("examining flashlight should describe it")
}
}
func TestDarkRoom(t *testing.T) {
// Go to the pit without flashlight.
output := runShell(t, "down\rquit\r")
if !strings.Contains(output, "darkness") {
t.Error("pit should be dark without flashlight")
}
// Should NOT show the lit description.
if strings.Contains(output, "skeleton") {
t.Error("should not see skeleton without flashlight")
}
}
func TestDarkRoomWithFlashlight(t *testing.T) {
// Get flashlight first, then go to pit.
output := runShell(t, "e\rtake flashlight\rw\rdown\rquit\r")
if !strings.Contains(output, "skeleton") {
t.Error("should see skeleton with flashlight in pit")
}
if !strings.Contains(output, "rusty key") {
t.Error("should see rusty key with flashlight in pit")
}
}
func TestDarkRoomCantTake(t *testing.T) {
output := runShell(t, "down\rtake rusty_key\rquit\r")
if !strings.Contains(output, "too dark") {
t.Error("should not be able to take items in dark room")
}
}
func TestLockedDoor(t *testing.T) {
// Navigate to archive without keycard.
output := runShell(t, "e\rs\rs\re\rquit\r")
if !strings.Contains(output, "keycard") || !strings.Contains(output, "red") {
t.Error("should mention keycard reader when trying locked door")
}
}
func TestUnlockWithKeycard(t *testing.T) {
// Get keycard from server room, navigate to archive, use keycard.
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\rquit\r")
if !strings.Contains(output, "green") || !strings.Contains(output, "clicks open") {
t.Error("should show unlock message")
}
if !strings.Contains(output, "Control Room") {
t.Error("should be able to enter control room after unlocking")
}
}
func TestRustyKeyPuzzle(t *testing.T) {
// Get flashlight, get rusty key from pit, go to generator room, use key.
output := runShell(t, "e\rtake flashlight\rw\rdown\rtake rusty_key\rup\re\re\rs\re\ruse rusty_key\rquit\r")
if !strings.Contains(output, "maintenance panel") || !strings.Contains(output, "logbook") {
t.Error("should show maintenance log when using rusty key in generator room")
}
if !strings.Contains(output, "Dunwich") {
t.Error("logbook should mention Dunwich")
}
}
func TestExitEndsGame(t *testing.T) {
// Full path to exit: get keycard, navigate to archive, unlock, go to control room, east to exit.
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\re\r")
if !strings.Contains(output, "SECURITY AUDIT") {
t.Error("exit should show the ending message")
}
if !strings.Contains(output, "SESSION HAS BEEN LOGGED") {
t.Error("exit should mention session logging")
}
}
func TestUnknownCommand(t *testing.T) {
output := runShell(t, "xyzzy\rquit\r")
if !strings.Contains(output, "don't understand") {
t.Error("should show error for unknown command")
}
}
func TestQuitCommand(t *testing.T) {
output := runShell(t, "quit\r")
if !strings.Contains(output, "terminated") {
t.Error("quit should show termination message")
}
}
func TestExitAlias(t *testing.T) {
output := runShell(t, "exit\r")
if !strings.Contains(output, "terminated") {
t.Error("exit alias should work like quit")
}
}
func TestVentilationShaft(t *testing.T) {
output := runShell(t, "e\rup\rquit\r")
if !strings.Contains(output, "Ventilation Shaft") {
t.Error("should reach ventilation shaft from corridor")
}
if !strings.Contains(output, "note") {
t.Error("should see note in ventilation shaft")
}
}
func TestNote(t *testing.T) {
output := runShell(t, "e\rup\rtake note\rexamine note\rquit\r")
if !strings.Contains(output, "GET OUT") {
t.Error("note should contain warning text")
}
}
func TestUseItemWrongRoom(t *testing.T) {
// Get keycard from server room, try to use in wrong room.
output := runShell(t, "e\re\rtake keycard\ruse keycard\rquit\r")
if !strings.Contains(output, "nothing to use") {
t.Error("should say nothing to use keycard on in wrong room")
}
}
func TestEthernetCable(t *testing.T) {
output := runShell(t, "e\rs\rtake ethernet_cable\ruse ethernet_cable\rquit\r")
if !strings.Contains(output, "chewed") {
t.Error("using ethernet cable should mention chewed ends")
}
}
func TestSessionLogs(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "root",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
},
}
rw := &rwCloser{
Reader: bytes.NewBufferString("help\rquit\r"),
Writer: &bytes.Buffer{},
}
sh := NewAdventureShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sh.Handle(ctx, sess, rw)
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
}
func TestParserBasics(t *testing.T) {
tests := []struct {
input string
verb string
object string
}{
{"look", "look", ""},
{"l", "look", ""},
{"examine flashlight", "look", "flashlight"},
{"go north", "go", "north"},
{"n", "go", "north"},
{"south", "go", "south"},
{"take the key", "take", "key"},
{"get a flashlight", "take", "flashlight"},
{"drop rusty key", "drop", "rusty key"},
{"use keycard", "use", "keycard"},
{"inventory", "inventory", ""},
{"i", "inventory", ""},
{"inv", "inventory", ""},
{"help", "help", ""},
{"?", "help", ""},
{"quit", "quit", ""},
{"exit", "quit", ""},
{"q", "quit", ""},
{"LOOK", "look", ""},
{" go east ", "go", "east"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
cmd := parseCommand(tt.input)
if cmd.verb != tt.verb {
t.Errorf("parseCommand(%q).verb = %q, want %q", tt.input, cmd.verb, tt.verb)
}
if cmd.object != tt.object {
t.Errorf("parseCommand(%q).object = %q, want %q", tt.input, cmd.object, tt.object)
}
})
}
}
func TestParserEmpty(t *testing.T) {
cmd := parseCommand("")
if cmd.verb != "" {
t.Errorf("empty input should give empty verb, got %q", cmd.verb)
}
}
func TestParserArticleStripping(t *testing.T) {
cmd := parseCommand("take the an a flashlight")
if cmd.verb != "take" || cmd.object != "flashlight" {
t.Errorf("articles should be stripped, got verb=%q object=%q", cmd.verb, cmd.object)
}
}
func TestConfigString(t *testing.T) {
cfg := map[string]any{"dungeon_name": "MY DUNGEON"}
if got := configString(cfg, "dungeon_name", "DEFAULT"); got != "MY DUNGEON" {
t.Errorf("configString() = %q, want %q", got, "MY DUNGEON")
}
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
t.Errorf("configString() for missing key = %q, want %q", got, "DEFAULT")
}
if got := configString(nil, "key", "DEFAULT"); got != "DEFAULT" {
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
}
}

View File

@@ -0,0 +1,358 @@
package adventure
import (
"fmt"
"slices"
"strings"
)
type gameState struct {
currentRoom string
inventory []string
rooms map[string]*room
items map[string]*item
flags map[string]bool
turns int
}
type commandResult struct {
output string
exit bool
}
func newGame() *gameState {
rooms, items := newWorld()
return &gameState{
currentRoom: "oubliette",
rooms: rooms,
items: items,
flags: make(map[string]bool),
}
}
func (g *gameState) dispatch(input string) commandResult {
cmd := parseCommand(input)
if cmd.verb == "" {
return commandResult{}
}
g.turns++
switch cmd.verb {
case "look":
return g.cmdLook(cmd.object)
case "go":
return g.cmdGo(cmd.object)
case "take":
return g.cmdTake(cmd.object)
case "drop":
return g.cmdDrop(cmd.object)
case "use":
return g.cmdUse(cmd.object)
case "inventory":
return g.cmdInventory()
case "help":
return g.cmdHelp()
case "quit":
return commandResult{output: "The darkness closes in. Session terminated.", exit: true}
default:
return commandResult{output: fmt.Sprintf("I don't understand '%s'. Type 'help' for available commands.", input)}
}
}
func (g *gameState) cmdHelp() commandResult {
help := `Available commands:
look / examine [item] - Look around or examine something
go <direction> - Move (north, south, east, west, up, down)
take <item> - Pick up an item
drop <item> - Drop an item
use <item> - Use an item
inventory - Check what you're carrying
help - Show this help
quit / exit - End session
You can also just type a direction (n, s, e, w, u, d) to move.`
return commandResult{output: help}
}
func (g *gameState) cmdLook(object string) commandResult {
r := g.rooms[g.currentRoom]
// Look at a specific item.
if object != "" {
return g.examineItem(object)
}
// Look at the room.
return commandResult{output: g.describeRoom(r)}
}
func (g *gameState) describeRoom(r *room) string {
var b strings.Builder
fmt.Fprintf(&b, "== %s ==\n", r.name)
// Dark room check.
if r.darkDesc != "" && !g.hasItem("flashlight") {
b.WriteString(r.darkDesc)
return b.String()
}
b.WriteString(r.description)
// List visible items.
visibleItems := g.roomItems(r)
if len(visibleItems) > 0 {
b.WriteString("\n\nYou can see: ")
names := make([]string, len(visibleItems))
for i, id := range visibleItems {
names[i] = g.items[id].name
}
b.WriteString(strings.Join(names, ", "))
}
// List exits.
if len(r.exits) > 0 {
b.WriteString("\n\nExits: ")
dirs := make([]string, 0, len(r.exits))
for dir := range r.exits {
dirs = append(dirs, dir)
}
b.WriteString(strings.Join(dirs, ", "))
}
return b.String()
}
func (g *gameState) roomItems(r *room) []string {
// In dark rooms without flashlight, can't see items.
if r.darkDesc != "" && !g.hasItem("flashlight") {
return nil
}
return r.items
}
func (g *gameState) examineItem(name string) commandResult {
id := g.resolveItem(name)
if id == "" {
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
}
return commandResult{output: g.items[id].description}
}
func (g *gameState) cmdGo(direction string) commandResult {
if direction == "" {
return commandResult{output: "Go where? Try a direction: north, south, east, west, up, down."}
}
r := g.rooms[g.currentRoom]
destID, ok := r.exits[direction]
if !ok {
return commandResult{output: fmt.Sprintf("You can't go %s from here.", direction)}
}
// Check locked doors.
if r.locked != nil {
if flag, locked := r.locked[direction]; locked && !g.flags[flag] {
return g.lockedMessage(direction)
}
}
g.currentRoom = destID
dest := g.rooms[destID]
// Entering the exit room ends the game.
if destID == "exit" {
return commandResult{output: g.describeRoom(dest), exit: true}
}
return commandResult{output: g.describeRoom(dest)}
}
func (g *gameState) lockedMessage(direction string) commandResult {
if direction == "east" && g.currentRoom == "archive" {
return commandResult{output: "The steel door won't budge. The keycard reader blinks red, waiting."}
}
return commandResult{output: "The way is locked."}
}
func (g *gameState) cmdTake(name string) commandResult {
if name == "" {
return commandResult{output: "Take what?"}
}
r := g.rooms[g.currentRoom]
// Can't take items in dark rooms.
if r.darkDesc != "" && !g.hasItem("flashlight") {
return commandResult{output: "It's too dark to find anything."}
}
id := g.resolveRoomItem(name)
if id == "" {
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
}
it := g.items[id]
if !it.takeable {
return commandResult{output: fmt.Sprintf("You can't take the %s.", it.name)}
}
// Remove from room, add to inventory.
g.removeRoomItem(r, id)
g.inventory = append(g.inventory, id)
return commandResult{output: fmt.Sprintf("Taken: %s", it.name)}
}
func (g *gameState) cmdDrop(name string) commandResult {
if name == "" {
return commandResult{output: "Drop what?"}
}
id := g.resolveInventoryItem(name)
if id == "" {
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
}
// Remove from inventory, add to room.
g.removeInventoryItem(id)
r := g.rooms[g.currentRoom]
r.items = append(r.items, id)
return commandResult{output: fmt.Sprintf("Dropped: %s", g.items[id].name)}
}
func (g *gameState) cmdUse(name string) commandResult {
if name == "" {
return commandResult{output: "Use what?"}
}
id := g.resolveInventoryItem(name)
if id == "" {
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
}
switch id {
case "keycard":
return g.useKeycard()
case "rusty_key":
return g.useRustyKey()
case "flashlight":
return commandResult{output: "The flashlight is already on. Its beam cuts through the darkness."}
case "ethernet_cable":
return commandResult{output: "You wave the cable around hopefully. Nothing happens. Both ends are chewed through anyway."}
case "floppy_disk":
return commandResult{output: "You don't have anything to put it in. Floppy drives went extinct decades ago."}
case "note":
return commandResult{output: g.items[id].description}
default:
return commandResult{output: fmt.Sprintf("You can't figure out how to use the %s here.", g.items[id].name)}
}
}
func (g *gameState) useKeycard() commandResult {
if g.currentRoom != "archive" {
return commandResult{output: "There's nothing to use the keycard on here."}
}
if g.flags["archive_unlocked"] {
return commandResult{output: "You already unlocked that door."}
}
g.flags["archive_unlocked"] = true
return commandResult{output: "You swipe the keycard. The reader flashes green and the steel door clicks open with a heavy thunk."}
}
func (g *gameState) useRustyKey() commandResult {
if g.currentRoom != "generator" {
return commandResult{output: "There's nothing to use the rusty key on here."}
}
if g.flags["panel_opened"] {
return commandResult{output: "The maintenance panel is already open."}
}
g.flags["panel_opened"] = true
return commandResult{output: `The key fits. The maintenance panel swings open with a screech, revealing a logbook:
MAINTENANCE LOG - GENERATOR B
Last service: 2003-11-15
Technician: J. Dunwich
Notes: "Generator A offline permanently. B running on fumes.
Fuel delivery canceled - 'facility decommissioned' per management.
But the servers are still running. WHO IS USING THEM?
Filed ticket #4,271. No response. As usual."
The entries stop after that date.`}
}
func (g *gameState) cmdInventory() commandResult {
if len(g.inventory) == 0 {
return commandResult{output: "You're not carrying anything."}
}
var b strings.Builder
b.WriteString("You are carrying:\n")
for _, id := range g.inventory {
fmt.Fprintf(&b, " - %s\n", g.items[id].name)
}
return commandResult{output: strings.TrimRight(b.String(), "\n")}
}
// hasItem checks if the player has an item in inventory.
func (g *gameState) hasItem(id string) bool {
return slices.Contains(g.inventory, id)
}
// resolveItem finds an item by name in both room and inventory.
func (g *gameState) resolveItem(name string) string {
if id := g.resolveInventoryItem(name); id != "" {
return id
}
return g.resolveRoomItem(name)
}
// resolveRoomItem finds an item in the current room by partial name match.
func (g *gameState) resolveRoomItem(name string) string {
r := g.rooms[g.currentRoom]
name = strings.ToLower(name)
for _, id := range r.items {
if matchesItem(id, g.items[id].name, name) {
return id
}
}
return ""
}
// resolveInventoryItem finds an item in inventory by partial name match.
func (g *gameState) resolveInventoryItem(name string) string {
name = strings.ToLower(name)
for _, id := range g.inventory {
if matchesItem(id, g.items[id].name, name) {
return id
}
}
return ""
}
// matchesItem checks if a search term matches an item's ID or display name.
func matchesItem(id, displayName, search string) bool {
return id == search ||
strings.ToLower(displayName) == search ||
strings.Contains(id, search) ||
strings.Contains(strings.ToLower(displayName), search)
}
func (g *gameState) removeRoomItem(r *room, id string) {
for i, itemID := range r.items {
if itemID == id {
r.items = append(r.items[:i], r.items[i+1:]...)
return
}
}
}
func (g *gameState) removeInventoryItem(id string) {
for i, invID := range g.inventory {
if invID == id {
g.inventory = append(g.inventory[:i], g.inventory[i+1:]...)
return
}
}
}

View File

@@ -0,0 +1,106 @@
package adventure
import "strings"
type parsedCommand struct {
verb string
object string
}
// directionAliases maps shorthand and full direction names to canonical forms.
var directionAliases = map[string]string{
"n": "north",
"s": "south",
"e": "east",
"w": "west",
"u": "up",
"d": "down",
"north": "north",
"south": "south",
"east": "east",
"west": "west",
"up": "up",
"down": "down",
}
// verbAliases maps aliases to canonical verbs.
var verbAliases = map[string]string{
"look": "look",
"l": "look",
"examine": "look",
"inspect": "look",
"x": "look",
"go": "go",
"move": "go",
"walk": "go",
"take": "take",
"get": "take",
"grab": "take",
"pick": "take",
"drop": "drop",
"put": "drop",
"use": "use",
"apply": "use",
"inventory": "inventory",
"inv": "inventory",
"i": "inventory",
"help": "help",
"?": "help",
"quit": "quit",
"exit": "quit",
"logout": "quit",
"q": "quit",
}
// articles are stripped from input.
var articles = map[string]bool{
"the": true,
"a": true,
"an": true,
}
// parseCommand parses raw input into a verb and object.
func parseCommand(input string) parsedCommand {
input = strings.TrimSpace(strings.ToLower(input))
if input == "" {
return parsedCommand{}
}
words := strings.Fields(input)
// Strip articles.
var filtered []string
for _, w := range words {
if !articles[w] {
filtered = append(filtered, w)
}
}
if len(filtered) == 0 {
return parsedCommand{}
}
first := filtered[0]
// Bare direction → go <direction>.
if dir, ok := directionAliases[first]; ok {
return parsedCommand{verb: "go", object: dir}
}
// Known verb alias.
if verb, ok := verbAliases[first]; ok {
object := ""
if len(filtered) > 1 {
object = strings.Join(filtered[1:], " ")
// Resolve direction alias in object for "go north" etc.
if verb == "go" {
if dir, ok := directionAliases[object]; ok {
object = dir
}
}
}
return parsedCommand{verb: verb, object: object}
}
// Unknown verb — return as-is so game can give an error.
return parsedCommand{verb: first, object: strings.Join(filtered[1:], " ")}
}

View File

@@ -0,0 +1,211 @@
package adventure
// room represents a location in the game world.
type room struct {
name string
description string
darkDesc string // shown when room is dark (empty = not a dark room)
exits map[string]string // direction → room ID
items []string // item IDs present in the room
locked map[string]string // direction → flag required to unlock
}
// item represents an object that can be picked up and used.
type item struct {
name string
description string
takeable bool
}
func newWorld() (map[string]*room, map[string]*item) {
rooms := map[string]*room{
"oubliette": {
name: "The Oubliette",
description: `You are in a narrow stone chamber. The walls are damp and slick with condensation.
Far above, an iron grate lets in a faint green glow — not daylight, but the steady
pulse of status LEDs. A frayed ethernet cable hangs from the grate like a vine.
A passage leads east into darkness. Stone steps spiral downward.`,
exits: map[string]string{
"east": "corridor",
"down": "pit",
},
},
"corridor": {
name: "Stone Corridor",
description: `A long corridor carved from living rock. The walls transition from rough-hewn
stone to poured concrete as you look east. Fluorescent tubes flicker overhead,
half of them dead. Cable trays run along the ceiling, sagging under the weight
of bundled Cat5. A draft comes from a ventilation shaft above.`,
exits: map[string]string{
"west": "oubliette",
"east": "server_room",
"south": "cable_crypt",
"up": "ventilation",
},
items: []string{"flashlight"},
},
"ventilation": {
name: "Ventilation Shaft",
description: `You've squeezed into a narrow ventilation shaft. The aluminum walls vibrate with
the hum of distant fans. It's barely wide enough to turn around in. Dust and
cobwebs coat everything. Someone has scratched tally marks into the metal — you
stop counting at forty-seven.`,
exits: map[string]string{
"down": "corridor",
},
items: []string{"note"},
},
"server_room": {
name: "Server Room",
description: `Rows of black server racks stretch into the gloom, their LEDs blinking in
patterns that almost seem deliberate. The air is frigid and filled with the
white noise of a thousand fans. A yellowed label on the nearest rack reads:
"PRODUCTION - DO NOT TOUCH". Someone has added in marker: "OR ELSE".
A laminated keycard sits on top of a powered-down blade server.`,
exits: map[string]string{
"west": "corridor",
"south": "cold_storage",
},
items: []string{"keycard"},
},
"pit": {
name: "The Pit",
darkDesc: `You are in absolute darkness. You can hear water dripping somewhere far below
and the faint hum of electronics. The air smells of rust and ozone. You can't
see a thing without a light source. The stairs lead back up.`,
description: `Your flashlight reveals a deep shaft — part medieval oubliette, part cable run.
Rusty chains hang from iron rings set into the walls, intertwined with bundles
of fiber optic cable that glow faintly orange. At the bottom, a skeleton in a
lab coat slumps against the wall, still wearing an ID badge. A rusty key glints
on a hook near the skeleton's hand.`,
exits: map[string]string{
"up": "oubliette",
},
items: []string{"rusty_key"},
},
"cable_crypt": {
name: "Cable Crypt",
description: `This vaulted chamber was clearly something else before the cables arrived.
Stone sarcophagi line the walls, but their lids have been removed and they're
now stuffed full of tangled ethernet runs and power strips. A faded sign reads
"STRUCTURED CABLING" — someone has crossed out "STRUCTURED" and written
"CHAOTIC" above it. The air smells of old stone and warm plastic.`,
exits: map[string]string{
"north": "corridor",
"south": "archive",
},
items: []string{"ethernet_cable"},
},
"cold_storage": {
name: "Cold Storage",
description: `A cavernous room kept at near-freezing temperatures. Frost coats the walls.
Rows of old tape drives and disk platters are stacked on industrial shelving,
their labels faded beyond reading. A humming cryo-unit in the corner has a
blinking amber light. A hand-written sign says "BACKUP STORAGE - CRITICAL".`,
exits: map[string]string{
"north": "server_room",
"east": "generator",
},
items: []string{"floppy_disk"},
},
"archive": {
name: "The Archive",
description: `Floor-to-ceiling shelves hold thousands of manila folders, binders, and
three-ring notebooks. The organization system, if there ever was one, has
long since collapsed into entropy. A thick layer of dust covers everything.
A heavy steel door to the east has an electronic keycard reader with a
steady red light. The cable crypt lies back to the north.`,
exits: map[string]string{
"north": "cable_crypt",
"east": "control_room",
},
locked: map[string]string{
"east": "archive_unlocked",
},
},
"generator": {
name: "Generator Room",
description: `Two massive diesel generators squat in this room like sleeping beasts. One is
clearly dead — corrosion has eaten through the fuel lines. The other hums at
a low idle, keeping the facility on life support. A maintenance panel on the
wall is secured with an old-fashioned keyhole.`,
exits: map[string]string{
"west": "cold_storage",
},
},
"control_room": {
name: "Control Room",
description: `Banks of CRT monitors cast a blue glow across the room. Most show static, but
one displays a facility map with pulsing dots labeled "ACTIVE SESSIONS". You
count the dots — there are more than there should be. A desk covered in coffee
rings holds a battered keyboard. The main display reads:
FACILITY STATUS: NOMINAL
CONTAINMENT: ACTIVE
SUBJECTS: ███
A heavy blast door leads east. Above it, a faded exit sign flickers.`,
exits: map[string]string{
"west": "archive",
"east": "exit",
},
},
"exit": {
name: "The Exit",
description: `The blast door groans open to reveal... a corridor. Not the freedom you
expected, but another corridor — identical to the one you started in. The
same flickering fluorescents. The same sagging cable trays. The same damp
stone walls.
As you step through, the door slams shut behind you. A speaker crackles:
"THANK YOU FOR PARTICIPATING IN TODAY'S SECURITY AUDIT.
YOUR SESSION HAS BEEN LOGGED.
HAVE A NICE DAY."
The lights go out.`,
},
}
items := map[string]*item{
"flashlight": {
name: "flashlight",
description: "A heavy-duty flashlight. The batteries are low but it still works.",
takeable: true,
},
"keycard": {
name: "keycard",
description: "A laminated keycard with a faded photo. The name reads 'J. DUNWICH, LEVEL 3'.",
takeable: true,
},
"rusty_key": {
name: "rusty key",
description: "An old iron key, spotted with rust. It looks like it fits a maintenance panel.",
takeable: true,
},
"note": {
name: "note",
description: `A crumpled note in shaky handwriting:
"If you're reading this, GET OUT. The facility is automated now.
The sessions never end. I thought I was running tests but the
tests were running me. Don't trust the prompts. Don't trust
the exits. Don't trust
[the writing trails off into an illegible scrawl]"`,
takeable: true,
},
"ethernet_cable": {
name: "ethernet cable",
description: "A tangled Cat5e cable, about 3 meters long. Both ends have been chewed by something.",
takeable: true,
},
"floppy_disk": {
name: "floppy disk",
description: `A 3.5" floppy disk labeled "BACKUP - CRITICAL - DO NOT FORMAT". The label is dated 1997.`,
takeable: true,
},
}
return rooms, items
}

View File

@@ -0,0 +1,74 @@
package banking
import (
"context"
"fmt"
"io"
"math/rand/v2"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
// BankingShell is an 80s-style green-on-black bank terminal TUI.
type BankingShell struct{}
// NewBankingShell returns a new BankingShell instance.
func NewBankingShell() *BankingShell {
return &BankingShell{}
}
func (b *BankingShell) Name() string { return "banking" }
func (b *BankingShell) Description() string { return "80s-style banking terminal TUI" }
func (b *BankingShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
bankName := configString(sess.ShellConfig, "bank_name", "SECUREBANK")
terminalID := configString(sess.ShellConfig, "terminal_id", "")
region := configString(sess.ShellConfig, "region", "NORTHEAST")
if terminalID == "" {
terminalID = fmt.Sprintf("SB-%04d", rand.IntN(10000))
}
m := newModel(sess, bankName, terminalID, region)
p := tea.NewProgram(m,
tea.WithInput(rw),
tea.WithOutput(rw),
tea.WithAltScreen(),
)
done := make(chan error, 1)
go func() {
_, err := p.Run()
done <- err
}()
select {
case err := <-done:
return err
case <-ctx.Done():
p.Quit()
<-done
return nil
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,560 @@
package banking
import (
"context"
"strings"
"testing"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// newTestModel creates a model with a test session context.
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "banker", "banking", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "banker",
Store: store,
}
m := newModel(sess, "SECUREBANK", "SB-0001", "NORTHEAST")
return m, store
}
// sendKeys sends a string of characters as individual key messages to the model.
func sendKeys(m *model, s string) {
for _, ch := range s {
var msg tea.KeyMsg
switch ch {
case '\r':
msg = tea.KeyMsg{Type: tea.KeyEnter}
case '\x1b':
msg = tea.KeyMsg{Type: tea.KeyEscape}
case '\x03':
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
default:
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{ch}}
}
m.Update(msg)
}
}
func TestBankingShellName(t *testing.T) {
sh := NewBankingShell()
if sh.Name() != "banking" {
t.Errorf("Name() = %q, want %q", sh.Name(), "banking")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestFormatCurrency(t *testing.T) {
tests := []struct {
cents int64
want string
}{
{0, "$0.00"},
{100, "$1.00"},
{4738291, "$47,382.91"},
{18254100, "$182,541.00"},
{52387450, "$523,874.50"},
{25000000, "$250,000.00"},
{-125000, "-$1,250.00"},
{99, "$0.99"},
}
for _, tt := range tests {
got := formatCurrency(tt.cents)
if got != tt.want {
t.Errorf("formatCurrency(%d) = %q, want %q", tt.cents, got, tt.want)
}
}
}
func TestNewBankState(t *testing.T) {
state := newBankState()
if len(state.Accounts) != 4 {
t.Errorf("expected 4 accounts, got %d", len(state.Accounts))
}
for _, acct := range state.Accounts {
txns, ok := state.Transactions[acct.Number]
if !ok {
t.Errorf("no transactions for account %s", acct.Number)
continue
}
if len(txns) == 0 {
t.Errorf("account %s has no transactions", acct.Number)
}
}
if len(state.Messages) != 4 {
t.Errorf("expected 4 messages, got %d", len(state.Messages))
}
}
func TestLoginScreenRenders(t *testing.T) {
m, _ := newTestModel(t)
view := m.View()
if !strings.Contains(view, "SECUREBANK") {
t.Error("login should show bank name")
}
if !strings.Contains(view, "AUTHORIZED ACCESS ONLY") {
t.Error("login should show authorization warning")
}
if !strings.Contains(view, "ACCOUNT NUMBER") {
t.Error("login should prompt for account number")
}
}
func TestLoginFlow(t *testing.T) {
m, _ := newTestModel(t)
// Type account number.
sendKeys(m, "12345678")
view := m.View()
if !strings.Contains(view, "12345678") {
t.Error("should show typed account number")
}
// Press enter.
sendKeys(m, "\r")
view = m.View()
if !strings.Contains(view, "PIN") {
t.Error("should show PIN prompt after entering account number")
}
// Type PIN and enter.
sendKeys(m, "1234\r")
// Should be on menu now.
if m.screen != screenMenu {
t.Errorf("expected screenMenu, got %d", m.screen)
}
}
func TestMainMenuRenders(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r")
view := m.View()
if !strings.Contains(view, "MAIN MENU") {
t.Error("should show MAIN MENU after login")
}
if !strings.Contains(view, "WIRE TRANSFER") {
t.Error("menu should contain WIRE TRANSFER option")
}
if !strings.Contains(view, "SECURE MESSAGES") {
t.Error("menu should contain SECURE MESSAGES option")
}
if !strings.Contains(view, "ACCOUNT SUMMARY") {
t.Error("menu should contain ACCOUNT SUMMARY option")
}
}
func TestAccountSummary(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "1\r") // account summary
if m.screen != screenAccountSummary {
t.Fatalf("expected screenAccountSummary, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "ACCOUNT SUMMARY") {
t.Error("should show ACCOUNT SUMMARY")
}
if !strings.Contains(view, "CHECKING") {
t.Error("should show CHECKING account")
}
if !strings.Contains(view, "SAVINGS") {
t.Error("should show SAVINGS account")
}
if !strings.Contains(view, "TOTAL") {
t.Error("should show TOTAL")
}
// Press any key to return.
sendKeys(m, " ")
if m.screen != screenMenu {
t.Errorf("should return to menu, got screen %d", m.screen)
}
}
func TestWireTransferFlow(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "3\r") // wire transfer
if m.screen != screenTransfer {
t.Fatalf("expected screenTransfer, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "WIRE TRANSFER") {
t.Error("should show WIRE TRANSFER header")
}
// Fill all fields.
sendKeys(m, "021000021\r") // routing
sendKeys(m, "9876543210\r") // dest account
sendKeys(m, "JOHN DOE\r") // beneficiary
sendKeys(m, "FIRST NATIONAL BANK\r") // bank name
sendKeys(m, "50000\r") // amount
sendKeys(m, "INVOICE 12345\r") // memo
// Should be on confirm step.
view = m.View()
if !strings.Contains(view, "TRANSFER SUMMARY") {
t.Error("should show TRANSFER SUMMARY for confirmation")
}
if !strings.Contains(view, "021000021") {
t.Error("summary should show routing number")
}
// Confirm.
sendKeys(m, "Y\r")
// Auth code.
sendKeys(m, "AUTH99\r")
view = m.View()
if !strings.Contains(view, "TRANSFER QUEUED") {
t.Error("should show TRANSFER QUEUED confirmation")
}
// Check wire transfer was stored.
if len(m.state.Transfers) != 1 {
t.Fatalf("expected 1 transfer, got %d", len(m.state.Transfers))
}
wt := m.state.Transfers[0]
if wt.RoutingNumber != "021000021" {
t.Errorf("routing = %q, want %q", wt.RoutingNumber, "021000021")
}
if wt.DestAccount != "9876543210" {
t.Errorf("dest = %q, want %q", wt.DestAccount, "9876543210")
}
if wt.Beneficiary != "JOHN DOE" {
t.Errorf("beneficiary = %q, want %q", wt.Beneficiary, "JOHN DOE")
}
// Press key to return to menu.
sendKeys(m, " ")
if m.screen != screenMenu {
t.Errorf("should return to menu after transfer, got screen %d", m.screen)
}
}
func TestWireTransferCancel(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "3\r") // wire transfer
sendKeys(m, "021000021\r")
sendKeys(m, "9876543210\r")
sendKeys(m, "JOHN DOE\r")
sendKeys(m, "FIRST NATIONAL BANK\r")
sendKeys(m, "50000\r")
sendKeys(m, "INVOICE 12345\r")
// Cancel at confirm step.
sendKeys(m, "N\r")
if m.screen != screenMenu {
t.Errorf("should return to menu after cancel, got screen %d", m.screen)
}
}
func TestTransactionHistory(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "4\r") // transaction history
if m.screen != screenHistory {
t.Fatalf("expected screenHistory, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "TRANSACTION HISTORY") {
t.Error("should show TRANSACTION HISTORY header")
}
// Select first account.
sendKeys(m, "1\r")
view = m.View()
if !strings.Contains(view, "DATE") {
t.Error("should show transaction list with DATE column")
}
if !strings.Contains(view, "PAGE") {
t.Error("should show page indicator")
}
// Press B to go back to account list.
sendKeys(m, "B")
view = m.View()
if !strings.Contains(view, "SELECT ACCOUNT") {
t.Error("should return to account selection")
}
// Press 0 to return to menu.
sendKeys(m, "0\r")
if m.screen != screenMenu {
t.Errorf("should return to menu, got screen %d", m.screen)
}
}
func TestSecureMessages(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "5\r") // secure messages
if m.screen != screenMessages {
t.Fatalf("expected screenMessages, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "SECURE MESSAGES") {
t.Error("should show SECURE MESSAGES header")
}
if !strings.Contains(view, "SCHEDULED MAINTENANCE") {
t.Error("should show first message subject")
}
// View first message.
sendKeys(m, "1\r")
view = m.View()
if !strings.Contains(view, "10.48.2.100") {
t.Error("message body should contain breadcrumb IP")
}
// Press key to return to list.
sendKeys(m, " ")
// Return to menu.
sendKeys(m, "0\r")
if m.screen != screenMenu {
t.Errorf("should return to menu, got screen %d", m.screen)
}
}
func TestAdminAccessDenied(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "99\r") // admin
if m.screen != screenAdmin {
t.Fatalf("expected screenAdmin, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "SYSTEM ADMINISTRATION") {
t.Error("should show SYSTEM ADMINISTRATION header")
}
// Three failed PIN attempts.
sendKeys(m, "secret1\r")
view = m.View()
if !strings.Contains(view, "INVALID CREDENTIALS") {
t.Error("should show INVALID CREDENTIALS after first attempt")
}
sendKeys(m, "secret2\r")
sendKeys(m, "secret3\r")
view = m.View()
if !strings.Contains(view, "ACCESS DENIED") {
t.Error("should show ACCESS DENIED after 3 attempts")
}
if !strings.Contains(view, "ABEND S0C4") {
t.Error("should show COBOL-style error")
}
// Press key to return to menu.
sendKeys(m, " ")
if m.screen != screenMenu {
t.Errorf("should return to menu after lockout, got screen %d", m.screen)
}
}
func TestAdminEscapeReturns(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "99\r") // admin
sendKeys(m, "\x1b") // ESC to return
if m.screen != screenMenu {
t.Errorf("ESC should return to menu from admin, got screen %d", m.screen)
}
}
func TestChangePin(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "6\r") // change PIN
if m.screen != screenChangePin {
t.Fatalf("expected screenChangePin, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "CHANGE PIN") {
t.Error("should show CHANGE PIN header")
}
// Old PIN.
sendKeys(m, "1234\r")
// New PIN.
sendKeys(m, "5678\r")
// Confirm PIN.
sendKeys(m, "5678\r")
view = m.View()
if !strings.Contains(view, "PIN CHANGED SUCCESSFULLY") {
t.Error("should show PIN CHANGED SUCCESSFULLY")
}
// Press key to return.
sendKeys(m, " ")
if m.screen != screenMenu {
t.Errorf("should return to menu after PIN change, got screen %d", m.screen)
}
}
func TestLogoutExits(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "7\r") // logout
if !m.quitting {
t.Error("should be quitting after logout")
}
}
func TestSessionLogs(t *testing.T) {
m, store := newTestModel(t)
// Send login keys and manually run returned commands.
// Type account number.
sendKeys(m, "12345678")
// Enter to advance to PIN.
sendKeys(m, "\r")
// Type PIN.
sendKeys(m, "1234")
// Enter to login — this returns a logAction cmd.
_, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
if cmd != nil {
// Execute the batch of commands (login log).
execCmds(cmd)
}
// Give async store writes a moment.
time.Sleep(50 * time.Millisecond)
if len(store.SessionLogs) == 0 {
t.Error("expected session logs to be recorded after login")
}
found := false
for _, log := range store.SessionLogs {
if strings.Contains(log.Input, "LOGIN") {
found = true
break
}
}
if !found {
t.Error("expected a LOGIN entry in session logs")
}
// Navigate to account summary — also returns a logAction cmd.
sendKeys(m, "1")
_, cmd = m.Update(tea.KeyMsg{Type: tea.KeyEnter})
if cmd != nil {
execCmds(cmd)
}
time.Sleep(50 * time.Millisecond)
foundMenu := false
for _, log := range store.SessionLogs {
if strings.Contains(log.Input, "MENU") {
foundMenu = true
break
}
}
if !foundMenu {
t.Error("expected a MENU entry in session logs for account summary")
}
}
// execCmds recursively executes tea.Cmd functions (including batches).
func execCmds(cmd tea.Cmd) {
if cmd == nil {
return
}
msg := cmd()
// tea.BatchMsg is a slice of Cmds returned by tea.Batch.
if batch, ok := msg.(tea.BatchMsg); ok {
for _, c := range batch {
execCmds(c)
}
}
}
func TestConfigString(t *testing.T) {
cfg := map[string]any{
"bank_name": "TESTBANK",
"region": "SOUTHWEST",
}
if got := configString(cfg, "bank_name", "DEFAULT"); got != "TESTBANK" {
t.Errorf("configString() = %q, want %q", got, "TESTBANK")
}
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
t.Errorf("configString() = %q, want %q", got, "DEFAULT")
}
if got := configString(nil, "bank_name", "DEFAULT"); got != "DEFAULT" {
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
}
}
func TestScreenFrame(t *testing.T) {
frame := screenFrame("TESTBANK", "TB-0001", "NORTHEAST", "content here", 0)
if !strings.Contains(frame, "TESTBANK FEDERAL RESERVE SYSTEM") {
t.Error("frame should contain bank name in header")
}
if !strings.Contains(frame, "TB-0001") {
t.Error("frame should contain terminal ID in footer")
}
if !strings.Contains(frame, "content here") {
t.Error("frame should contain the content")
}
}
func TestScreenFramePadsLines(t *testing.T) {
frame := screenFrame("TESTBANK", "TB-0001", "NE", "short\n", 0)
for i, line := range strings.Split(frame, "\n") {
w := lipgloss.Width(line)
if w > 0 && w < termWidth {
t.Errorf("line %d has visual width %d, want at least %d: %q", i, w, termWidth, line)
}
}
}
func TestScreenFramePadsToHeight(t *testing.T) {
short := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 30)
lines := strings.Count(short, "\n")
// Total newlines should be at least height-1 (since the last line has no trailing newline).
if lines < 29 {
t.Errorf("padded frame has %d newlines, want at least 29 for height=30", lines)
}
// Without height, no padding.
noPad := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 0)
noPadLines := strings.Count(noPad, "\n")
if noPadLines >= 29 {
t.Errorf("unpadded frame has %d newlines, should be much less than 29", noPadLines)
}
}

View File

@@ -0,0 +1,258 @@
package banking
import (
"fmt"
"time"
)
// Account types.
const (
AcctChecking = "CHECKING"
AcctSavings = "SAVINGS"
AcctMoneyMarket = "MONEY MARKET"
AcctCertDeposit = "CERT OF DEPOSIT"
)
// Account represents a fake bank account.
type Account struct {
Number string
Type string
Balance int64 // cents
}
// Transaction represents a fake bank transaction.
type Transaction struct {
Date string
Description string
Amount int64 // cents (negative for debits)
Balance int64 // running balance in cents
}
// SecureMessage represents a fake internal message.
type SecureMessage struct {
ID int
Date string
From string
Subj string
Body string
Unread bool
}
// WireTransfer captures data entered during the wire transfer wizard.
type WireTransfer struct {
RoutingNumber string
DestAccount string
Beneficiary string
BankName string
Amount string
Memo string
AuthCode string
}
// bankState holds all fake data for a session.
type bankState struct {
Accounts []Account
Transactions map[string][]Transaction // keyed by account number
Messages []SecureMessage
Transfers []WireTransfer
}
func newBankState() *bankState {
now := time.Now()
accounts := []Account{
{Number: "****4821", Type: AcctChecking, Balance: 4738291},
{Number: "****7203", Type: AcctSavings, Balance: 18254100},
{Number: "****9915", Type: AcctMoneyMarket, Balance: 52387450},
{Number: "****1102", Type: AcctCertDeposit, Balance: 25000000},
}
transactions := make(map[string][]Transaction)
transactions["****4821"] = generateCheckingTxns(now, accounts[0].Balance)
transactions["****7203"] = generateSavingsTxns(now, accounts[1].Balance)
transactions["****9915"] = generateMoneyMarketTxns(now, accounts[2].Balance)
transactions["****1102"] = generateCDTxns(now, accounts[3].Balance)
messages := []SecureMessage{
{
ID: 1,
Date: now.Add(-2 * 24 * time.Hour).Format("01/02/2006"),
From: "SYSTEM ADMINISTRATOR",
Subj: "SCHEDULED MAINTENANCE WINDOW",
Unread: true,
Body: fmt.Sprintf(`FROM: SYSTEM ADMINISTRATOR <sysadmin@internal.securebank.local>
DATE: %s
RE: SCHEDULED MAINTENANCE WINDOW
ALL TERMINALS WILL BE OFFLINE FOR MAINTENANCE:
DATE: %s
TIME: 02:00 - 04:00 EST
AFFECTED: ALL REGIONS
DURING THIS WINDOW, THE FOLLOWING SYSTEMS WILL BE UNAVAILABLE:
- WIRE TRANSFER PROCESSING (10.48.2.100:8443)
- ACCOUNT MANAGEMENT (10.48.2.101:8443)
- ACH BATCH PROCESSOR (10.48.2.105:9090)
PLEASE ENSURE ALL PENDING TRANSACTIONS ARE SUBMITTED BEFORE 01:30 EST.
CONTACT: HELPDESK EXT 4400 OR ops-support@internal.securebank.local`,
now.Add(-2*24*time.Hour).Format("01/02/2006 15:04"),
now.Add(5*24*time.Hour).Format("01/02/2006")),
},
{
ID: 2,
Date: now.Add(-5 * 24 * time.Hour).Format("01/02/2006"),
From: "COMPLIANCE DEPT",
Subj: "QUARTERLY AUDIT REMINDER",
Unread: true,
Body: `FROM: COMPLIANCE DEPT <compliance@internal.securebank.local>
RE: QUARTERLY AUDIT REMINDER
ALL BRANCH MANAGERS:
THE Q4 COMPLIANCE AUDIT IS SCHEDULED FOR NEXT WEEK.
PLEASE ENSURE THE FOLLOWING ARE CURRENT:
1. TRANSACTION LOGS EXPORTED TO \\FILESERV01\AUDIT\Q4
2. VAULT ACCESS CODES ROTATED (LAST ROTATION: SEE VAULT-MGMT PORTAL)
3. EMPLOYEE ACCESS REVIEWS COMPLETED IN IAM PORTAL (https://iam.internal:8443)
NOTE: DEFAULT CREDENTIALS FOR THE AUDIT PORTAL HAVE BEEN RESET.
NEW CREDENTIALS DISTRIBUTED VIA SECURE COURIER.
REFERENCE: AUDIT-2024-Q4-0847
VAULT MASTER CODE HINT: FIRST 4 OF ROUTING + BRANCH ZIP (STANDARD FORMAT)`,
},
{
ID: 3,
Date: now.Add(-8 * 24 * time.Hour).Format("01/02/2006"),
From: "IT SECURITY",
Subj: "PASSWORD POLICY UPDATE",
Unread: false,
Body: `FROM: IT SECURITY <itsec@internal.securebank.local>
RE: PASSWORD POLICY UPDATE - EFFECTIVE IMMEDIATELY
ALL STAFF:
PER FEDERAL BANKING REGULATION 12 CFR 748, THE FOLLOWING
PASSWORD POLICY IS NOW IN EFFECT:
- MINIMUM 12 CHARACTERS
- MUST CONTAIN UPPERCASE, LOWERCASE, NUMBER, SPECIAL CHAR
- 90-DAY ROTATION CYCLE
- NO REUSE OF LAST 24 PASSWORDS
LEGACY SYSTEM ACCOUNTS (MAINFRAME, AS/400) ARE EXEMPT UNTIL
MIGRATION IS COMPLETE. CURRENT LEGACY ACCESS:
MAINFRAME: telnet://10.48.1.50:23 (CICS REGION PROD1)
AS/400: tn5250://10.48.1.55 (SUBSYSTEM QINTER)
SERVICE ACCOUNT PASSWORDS ARE MANAGED VIA CYBERARK:
https://pam.internal.securebank.local:8443
TICKET: SEC-2024-1847`,
},
{
ID: 4,
Date: now.Add(-12 * 24 * time.Hour).Format("01/02/2006"),
From: "WIRE OPERATIONS",
Subj: "FEDWIRE CUTOFF TIME CHANGE",
Unread: false,
Body: `FROM: WIRE OPERATIONS <wireops@internal.securebank.local>
RE: FEDWIRE CUTOFF TIME CHANGE
EFFECTIVE NEXT MONDAY, FEDWIRE CUTOFF TIMES ARE:
DOMESTIC WIRES: 16:30 EST (WAS 17:00)
INTERNATIONAL WIRES: 14:00 EST (NO CHANGE)
BOOK TRANSFERS: 17:30 EST (NO CHANGE)
WIRES SUBMITTED AFTER CUTOFF WILL BE QUEUED FOR NEXT
BUSINESS DAY PROCESSING.
FOR EMERGENCY SAME-DAY PROCESSING AFTER CUTOFF:
CONTACT WIRE ROOM: EXT 4450
AUTH CODE REQUIRED (OBTAIN FROM BRANCH MANAGER)
APPROVAL CHAIN: OPS-MGR -> VP-WIRE -> SVP-TREASURY
CORRESPONDENT BANK CONTACTS:
JPMORGAN: wire.ops@jpmc.com / 212-555-0147
CITI: fedwire@citi.com / 212-555-0283`,
},
}
return &bankState{
Accounts: accounts,
Transactions: transactions,
Messages: messages,
}
}
func generateCheckingTxns(now time.Time, endBalance int64) []Transaction {
txns := []Transaction{
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
{Description: "CHECK #1847", Amount: -125000},
{Description: "POS DEBIT - WHOLE FOODS #1284", Amount: -18743},
{Description: "ATM WITHDRAWAL - MAIN ST BRANCH", Amount: -40000},
{Description: "ACH DEBIT - MORTGAGE PMT", Amount: -215000},
{Description: "WIRE TRANSFER IN - REF#8847201", Amount: 1250000},
{Description: "POS DEBIT - SHELL OIL #4492", Amount: -6821},
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
{Description: "CHECK #1848", Amount: -75000},
{Description: "ONLINE TRANSFER TO SAVINGS", Amount: -100000},
{Description: "POS DEBIT - AMAZON.COM", Amount: -14599},
{Description: "ACH DEBIT - ELECTRIC COMPANY", Amount: -18742},
{Description: "ATM WITHDRAWAL - PARK AVE BRANCH", Amount: -20000},
{Description: "WIRE TRANSFER OUT - REF#9014882", Amount: -500000},
{Description: "POS DEBIT - COSTCO #0441", Amount: -28734},
{Description: "ACH DEPOSIT - TAX REFUND", Amount: 342100},
}
return populateTransactions(txns, now, endBalance)
}
func generateSavingsTxns(now time.Time, endBalance int64) []Transaction {
txns := []Transaction{
{Description: "INTEREST PAYMENT", Amount: 4521},
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 100000},
{Description: "INTEREST PAYMENT", Amount: 4633},
{Description: "ACH DEPOSIT - DIVIDEND PMT", Amount: 125000},
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 200000},
{Description: "INTEREST PAYMENT", Amount: 4748},
{Description: "WITHDRAWAL - TRANSFER TO MM", Amount: -500000},
{Description: "INTEREST PAYMENT", Amount: 4812},
}
return populateTransactions(txns, now, endBalance)
}
func generateMoneyMarketTxns(now time.Time, endBalance int64) []Transaction {
txns := []Transaction{
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 21847},
{Description: "DEPOSIT - TRANSFER FROM SAVINGS", Amount: 500000},
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22105},
{Description: "WITHDRAWAL - WIRE TRANSFER", Amount: -1000000},
{Description: "DEPOSIT - ACH TRANSFER", Amount: 750000},
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22394},
}
return populateTransactions(txns, now, endBalance)
}
func generateCDTxns(now time.Time, endBalance int64) []Transaction {
txns := []Transaction{
{Description: "CERTIFICATE OPENED - 12MO TERM", Amount: 25000000},
{Description: "INTEREST ACCRUAL", Amount: 10417},
{Description: "INTEREST ACCRUAL", Amount: 10417},
{Description: "INTEREST ACCRUAL", Amount: 10417},
}
return populateTransactions(txns, now, endBalance)
}
func populateTransactions(txns []Transaction, now time.Time, endBalance int64) []Transaction {
// Work backwards from end balance to assign dates and running balances.
bal := endBalance
for i := len(txns) - 1; i >= 0; i-- {
txns[i].Balance = bal
txns[i].Date = now.Add(time.Duration(-(len(txns) - i)) * 3 * 24 * time.Hour).Format("01/02/2006")
bal -= txns[i].Amount
}
return txns
}

View File

@@ -0,0 +1,353 @@
package banking
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
type screen int
const (
screenLogin screen = iota
screenMenu
screenAccountSummary
screenAccountDetail
screenTransfer
screenHistory
screenMessages
screenChangePin
screenAdmin
)
type model struct {
sess *shell.SessionContext
bankName string
terminalID string
region string
state *bankState
screen screen
quitting bool
height int
login loginModel
menu menuModel
summary accountSummaryModel
detail accountDetailModel
transfer transferModel
history historyModel
messages messagesModel
admin adminModel
changePin changePinModel
}
func newModel(sess *shell.SessionContext, bankName, terminalID, region string) *model {
state := newBankState()
unread := 0
for _, msg := range state.Messages {
if msg.Unread {
unread++
}
}
return &model{
sess: sess,
bankName: bankName,
terminalID: terminalID,
region: region,
state: state,
screen: screenLogin,
login: newLoginModel(bankName),
}
}
func (m *model) Init() tea.Cmd {
return nil
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.quitting {
return m, tea.Quit
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.height = msg.Height
return m, nil
case tea.KeyMsg:
if msg.Type == tea.KeyCtrlC {
m.quitting = true
return m, tea.Quit
}
}
switch m.screen {
case screenLogin:
return m.updateLogin(msg)
case screenMenu:
return m.updateMenu(msg)
case screenAccountSummary:
return m.updateAccountSummary(msg)
case screenAccountDetail:
return m.updateAccountDetail(msg)
case screenTransfer:
return m.updateTransfer(msg)
case screenHistory:
return m.updateHistory(msg)
case screenMessages:
return m.updateMessages(msg)
case screenChangePin:
return m.updateChangePin(msg)
case screenAdmin:
return m.updateAdmin(msg)
}
return m, nil
}
func (m *model) View() string {
var content string
switch m.screen {
case screenLogin:
content = m.login.View()
case screenMenu:
content = m.menu.View()
case screenAccountSummary:
content = m.summary.View()
case screenAccountDetail:
content = m.detail.View()
case screenTransfer:
content = m.transfer.View()
case screenHistory:
content = m.history.View()
case screenMessages:
content = m.messages.View()
case screenChangePin:
content = m.changePin.View()
case screenAdmin:
content = m.admin.View()
}
return screenFrame(m.bankName, m.terminalID, m.region, content, m.height)
}
// --- Screen update handlers ---
func (m *model) updateLogin(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.login, cmd = m.login.Update(msg)
if m.login.stage == 2 {
// Login always succeeds — this is a honeypot.
logCmd := logAction(m.sess, fmt.Sprintf("LOGIN acct=%s", m.login.accountNum), "ACCESS GRANTED")
clearCmd := m.goToMenu()
return m, tea.Batch(cmd, logCmd, clearCmd)
}
return m, cmd
}
func (m *model) updateMenu(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.menu, cmd = m.menu.Update(msg)
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEnter {
choice := strings.TrimSpace(m.menu.choice)
switch choice {
case "1":
m.screen = screenAccountSummary
m.summary = newAccountSummaryModel(m.state.Accounts)
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 1", "ACCOUNT SUMMARY"))
case "2":
m.screen = screenAccountDetail
m.detail = newAccountDetailModel(m.state.Accounts, m.state.Transactions)
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 2", "ACCOUNT DETAIL"))
case "3":
m.screen = screenTransfer
m.transfer = newTransferModel()
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 3", "WIRE TRANSFER"))
case "4":
m.screen = screenHistory
m.history = newHistoryModel(m.state.Accounts, m.state.Transactions)
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 4", "TRANSACTION HISTORY"))
case "5":
m.screen = screenMessages
m.messages = newMessagesModel(m.state.Messages)
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 5", "SECURE MESSAGES"))
case "6":
m.screen = screenChangePin
m.changePin = newChangePinModel()
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 6", "CHANGE PIN"))
case "7":
m.quitting = true
return m, tea.Batch(logAction(m.sess, "LOGOUT", "SESSION ENDED"), tea.Quit)
case "99", "admin", "ADMIN":
m.screen = screenAdmin
m.admin = newAdminModel()
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "ADMIN ACCESS ATTEMPT", "ADMIN SCREEN SHOWN"))
}
// Invalid choice, reset.
m.menu.choice = ""
}
return m, cmd
}
func (m *model) updateAccountSummary(msg tea.Msg) (tea.Model, tea.Cmd) {
if _, ok := msg.(tea.KeyMsg); ok {
return m, m.goToMenu()
}
return m, nil
}
func (m *model) updateAccountDetail(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.detail, cmd = m.detail.Update(msg)
if m.detail.choice == "back" {
return m, tea.Batch(cmd, m.goToMenu())
}
return m, cmd
}
func (m *model) updateTransfer(msg tea.Msg) (tea.Model, tea.Cmd) {
prevStep := m.transfer.step
var cmd tea.Cmd
m.transfer, cmd = m.transfer.Update(msg)
// Transfer cancelled.
if m.transfer.confirm == "cancelled" {
clearCmd := m.goToMenu()
return m, tea.Batch(clearCmd, logAction(m.sess, "WIRE TRANSFER CANCELLED", "USER CANCELLED"))
}
// Clear screen when transfer steps change content height significantly
// (e.g. confirm→authcode, fields→confirm, authcode→complete).
if m.transfer.step != prevStep {
cmd = tea.Batch(cmd, tea.ClearScreen)
}
// Transfer completed — log it.
if m.transfer.step == transferStepComplete && prevStep != transferStepComplete {
t := m.transfer.transfer
m.state.Transfers = append(m.state.Transfers, t)
logMsg := fmt.Sprintf("WIRE TRANSFER: routing=%s dest=%s beneficiary=%s bank=%s amount=%s memo=%s auth=%s",
t.RoutingNumber, t.DestAccount, t.Beneficiary, t.BankName, t.Amount, t.Memo, t.AuthCode)
return m, tea.Batch(cmd, logAction(m.sess, logMsg, "TRANSFER QUEUED"))
}
// Completed screen → any key goes back.
if m.transfer.step == transferStepComplete {
if _, ok := msg.(tea.KeyMsg); ok && prevStep == transferStepComplete {
return m, tea.Batch(cmd, m.goToMenu())
}
}
return m, cmd
}
func (m *model) updateHistory(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.history, cmd = m.history.Update(msg)
if m.history.choice == "back" {
return m, tea.Batch(cmd, m.goToMenu())
}
return m, cmd
}
func (m *model) updateMessages(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.messages, cmd = m.messages.Update(msg)
if m.messages.choice == "back" {
return m, tea.Batch(cmd, m.goToMenu())
}
// Log when viewing a message.
if m.messages.viewing >= 0 {
idx := m.messages.viewing
return m, tea.Batch(cmd, logAction(m.sess,
fmt.Sprintf("VIEW MESSAGE #%d", idx+1),
m.state.Messages[idx].Subj))
}
return m, cmd
}
func (m *model) updateChangePin(msg tea.Msg) (tea.Model, tea.Cmd) {
prevStage := m.changePin.stage
var cmd tea.Cmd
m.changePin, cmd = m.changePin.Update(msg)
// Log successful PIN change.
if m.changePin.stage == 3 && prevStage != 3 {
cmd = tea.Batch(cmd, logAction(m.sess, "CHANGE PIN", "PIN CHANGED SUCCESSFULLY"))
}
if m.changePin.done {
return m, tea.Batch(cmd, m.goToMenu())
}
return m, cmd
}
func (m *model) updateAdmin(msg tea.Msg) (tea.Model, tea.Cmd) {
prevLocked := m.admin.locked
prevAttempts := m.admin.attempts
// Check for ESC before delegating.
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEscape && !m.admin.locked {
return m, m.goToMenu()
}
var cmd tea.Cmd
m.admin, cmd = m.admin.Update(msg)
// Log failed attempts.
if m.admin.attempts > prevAttempts {
cmd = tea.Batch(cmd, logAction(m.sess,
fmt.Sprintf("ADMIN PIN ATTEMPT #%d", m.admin.attempts),
"INVALID CREDENTIALS"))
}
// Log lockout.
if m.admin.locked && !prevLocked {
cmd = tea.Batch(cmd, tea.ClearScreen, logAction(m.sess,
"ADMIN LOCKOUT",
"TERMINAL LOCKED - INCIDENT LOGGED"))
}
// If locked and any key pressed, go back.
if m.admin.locked {
if _, ok := msg.(tea.KeyMsg); ok && prevLocked {
return m, tea.Batch(cmd, m.goToMenu())
}
}
return m, cmd
}
func (m *model) goToMenu() tea.Cmd {
unread := 0
for _, msg := range m.state.Messages {
if msg.Unread {
unread++
}
}
m.screen = screenMenu
m.menu = newMenuModel(m.bankName, unread)
return tea.ClearScreen
}
// logAction returns a tea.Cmd that logs an action to the session store.
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
return func() tea.Msg {
if sess.Store != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
}
if sess.OnCommand != nil {
sess.OnCommand("banking")
}
return nil
}
}

View File

@@ -0,0 +1,213 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
// --- Account Summary ---
type accountSummaryModel struct {
accounts []Account
}
func newAccountSummaryModel(accounts []Account) accountSummaryModel {
return accountSummaryModel{accounts: accounts}
}
func (m accountSummaryModel) Update(_ tea.Msg) (accountSummaryModel, tea.Cmd) {
// Any key returns to menu.
return m, nil
}
func (m accountSummaryModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("ACCOUNT SUMMARY"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
// Header.
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "ACCOUNT", "TYPE", "BALANCE")))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
b.WriteString("\n")
total := int64(0)
for _, acct := range m.accounts {
total += acct.Balance
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-18s %18s",
acct.Number, acct.Type, formatCurrency(acct.Balance))))
b.WriteString("\n")
}
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
b.WriteString("\n")
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "", "TOTAL", formatCurrency(total))))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
b.WriteString("\n")
return b.String()
}
// --- Account Detail (with transactions) ---
type accountDetailModel struct {
accounts []Account
transactions map[string][]Transaction
selected int
page int
pageSize int
choosing bool
choice string
}
func newAccountDetailModel(accounts []Account, transactions map[string][]Transaction) accountDetailModel {
return accountDetailModel{
accounts: accounts,
transactions: transactions,
choosing: true,
pageSize: 10,
}
}
func (m accountDetailModel) Update(msg tea.Msg) (accountDetailModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.choosing {
switch keyMsg.Type {
case tea.KeyEnter:
for i := range m.accounts {
if m.choice == fmt.Sprintf("%d", i+1) {
m.selected = i
m.choosing = false
m.page = 0
break
}
}
if m.choice == "0" {
m.choice = "back"
}
return m, nil
case tea.KeyBackspace:
if len(m.choice) > 0 {
m.choice = m.choice[:len(m.choice)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
m.choice += ch
}
}
} else {
switch keyMsg.String() {
case "n", "N":
acctNum := m.accounts[m.selected].Number
txns := m.transactions[acctNum]
maxPage := (len(txns) - 1) / m.pageSize
if m.page < maxPage {
m.page++
}
case "p", "P":
if m.page > 0 {
m.page--
}
case "b", "B":
m.choosing = true
m.choice = ""
default:
m.choice = "back"
}
}
return m, nil
}
func (m accountDetailModel) View() string {
if m.choosing {
return m.viewChooseAccount()
}
return m.viewDetail()
}
func (m accountDetailModel) viewChooseAccount() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("ACCOUNT DETAIL"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
b.WriteString("\n\n")
for i, acct := range m.accounts {
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s %s",
i+1, acct.Number, acct.Type, formatCurrency(acct.Balance))))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
b.WriteString(inputStyle.Render(m.choice))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
return b.String()
}
func (m accountDetailModel) viewDetail() string {
var b strings.Builder
acct := m.accounts[m.selected]
txns := m.transactions[acct.Number]
b.WriteString("\n")
b.WriteString(centerText(fmt.Sprintf("ACCOUNT DETAIL - %s", acct.Number)))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" TYPE: %s", acct.Type)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" BALANCE: %s", formatCurrency(acct.Balance))))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
// Header.
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s", "DATE", "DESCRIPTION", "AMOUNT", "BALANCE")))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 72)))
b.WriteString("\n")
// Paginate.
start := m.page * m.pageSize
end := min(start+m.pageSize, len(txns))
for _, txn := range txns[start:end] {
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s",
txn.Date, txn.Description, formatCurrency(txn.Amount), formatCurrency(txn.Balance))))
b.WriteString("\n")
}
b.WriteString("\n")
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,111 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type adminModel struct {
pin string
attempts int
locked bool
}
func newAdminModel() adminModel {
return adminModel{}
}
func (m adminModel) Update(msg tea.Msg) (adminModel, tea.Cmd) {
if m.locked {
return m, nil
}
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
if m.pin != "" {
m.attempts++
if m.attempts >= 3 {
m.locked = true
}
m.pin = ""
}
case tea.KeyBackspace:
if len(m.pin) > 0 {
m.pin = m.pin[:len(m.pin)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.pin) < 20 {
m.pin += ch
}
}
return m, nil
}
func (m adminModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("SYSTEM ADMINISTRATION"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
if m.locked {
b.WriteString(errorStyle.Render(" *** ACCESS DENIED ***"))
b.WriteString("\n\n")
b.WriteString(errorStyle.Render(" MAXIMUM AUTHENTICATION ATTEMPTS EXCEEDED"))
b.WriteString("\n")
b.WriteString(errorStyle.Render(" TERMINAL LOCKED - INCIDENT LOGGED"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" SECURITY ALERT HAS BEEN DISPATCHED TO:"))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" - INFORMATION SECURITY DEPT"))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" - BRANCH SECURITY OFFICER"))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" - FEDERAL RESERVE OVERSIGHT"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" INCIDENT REF: SEC-%d-ADMIN-BRUTE", 20240000+m.attempts)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" IEF4271I UNAUTHORIZED ACCESS ATTEMPT - ABEND S0C4"))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" IEF4272I JOB SECADMIN STEP0001 - COND CODE 4088"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
b.WriteString("\n")
} else {
b.WriteString(titleStyle.Render(" RESTRICTED ACCESS - ADMINISTRATOR ONLY"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" THIS FUNCTION REQUIRES LEVEL 5 SECURITY CLEARANCE."))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" ALL ACCESS ATTEMPTS ARE LOGGED AND AUDITED."))
b.WriteString("\n\n")
if m.attempts > 0 {
b.WriteString(errorStyle.Render(fmt.Sprintf(" INVALID CREDENTIALS (%d OF 3 ATTEMPTS)", m.attempts)))
b.WriteString("\n\n")
}
b.WriteString(titleStyle.Render(" ADMIN PIN: "))
masked := strings.Repeat("*", len(m.pin))
b.WriteString(inputStyle.Render(masked))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
b.WriteString("\n")
}
return b.String()
}

View File

@@ -0,0 +1,111 @@
package banking
import (
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type changePinModel struct {
input string
stage int // 0=old, 1=new, 2=confirm, 3=done
newPin string
done bool
}
func newChangePinModel() changePinModel {
return changePinModel{}
}
func (m changePinModel) Update(msg tea.Msg) (changePinModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.stage == 3 {
m.done = true
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
switch m.stage {
case 0:
if m.input != "" {
m.stage = 1
m.input = ""
}
case 1:
if len(m.input) >= 4 {
m.newPin = m.input
m.stage = 2
m.input = ""
}
case 2:
if m.input == m.newPin {
m.stage = 3
} else {
m.input = ""
m.newPin = ""
m.stage = 1
}
}
case tea.KeyEscape:
m.done = true
case tea.KeyBackspace:
if len(m.input) > 0 {
m.input = m.input[:len(m.input)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.input) < 12 {
m.input += ch
}
}
return m, nil
}
func (m changePinModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("CHANGE PIN"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
if m.stage == 3 {
b.WriteString(titleStyle.Render(" PIN CHANGED SUCCESSFULLY"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" YOUR NEW PIN IS NOW ACTIVE."))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" PLEASE USE YOUR NEW PIN FOR ALL FUTURE TRANSACTIONS."))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
} else {
prompts := []string{" CURRENT PIN: ", " NEW PIN: ", " CONFIRM PIN: "}
for i := 0; i < m.stage; i++ {
b.WriteString(baseStyle.Render(prompts[i]))
b.WriteString(baseStyle.Render(strings.Repeat("*", 4)))
b.WriteString("\n")
}
if m.stage < 3 {
b.WriteString(titleStyle.Render(prompts[m.stage]))
masked := strings.Repeat("*", len(m.input))
b.WriteString(inputStyle.Render(masked))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
}
b.WriteString("\n")
if m.stage == 1 {
b.WriteString(dimStyle.Render(" PIN MUST BE AT LEAST 4 CHARACTERS"))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
}
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,152 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type historyModel struct {
accounts []Account
transactions map[string][]Transaction
selected int
page int
pageSize int
choosing bool
choice string
}
func newHistoryModel(accounts []Account, transactions map[string][]Transaction) historyModel {
return historyModel{
accounts: accounts,
transactions: transactions,
choosing: true,
pageSize: 12,
}
}
func (m historyModel) Update(msg tea.Msg) (historyModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.choosing {
switch keyMsg.Type {
case tea.KeyEnter:
for i := range m.accounts {
if m.choice == fmt.Sprintf("%d", i+1) {
m.selected = i
m.choosing = false
m.page = 0
break
}
}
if m.choice == "0" {
m.choice = "back"
}
return m, nil
case tea.KeyBackspace:
if len(m.choice) > 0 {
m.choice = m.choice[:len(m.choice)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
m.choice += ch
}
}
} else {
switch keyMsg.String() {
case "n", "N":
acctNum := m.accounts[m.selected].Number
txns := m.transactions[acctNum]
maxPage := (len(txns) - 1) / m.pageSize
if m.page < maxPage {
m.page++
}
case "p", "P":
if m.page > 0 {
m.page--
}
case "b", "B":
m.choosing = true
m.choice = ""
default:
m.choice = "back"
}
}
return m, nil
}
func (m historyModel) View() string {
if m.choosing {
return m.viewChooseAccount()
}
return m.viewHistory()
}
func (m historyModel) viewChooseAccount() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("TRANSACTION HISTORY"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
b.WriteString("\n\n")
for i, acct := range m.accounts {
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s",
i+1, acct.Number, acct.Type)))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
b.WriteString(inputStyle.Render(m.choice))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
return b.String()
}
func (m historyModel) viewHistory() string {
var b strings.Builder
acct := m.accounts[m.selected]
txns := m.transactions[acct.Number]
b.WriteString("\n")
b.WriteString(centerText(fmt.Sprintf("TRANSACTION HISTORY - %s (%s)", acct.Number, acct.Type)))
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-40s %14s", "DATE", "DESCRIPTION", "AMOUNT")))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
b.WriteString("\n")
start := m.page * m.pageSize
end := min(start+m.pageSize, len(txns))
for _, txn := range txns[start:end] {
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-40s %14s",
txn.Date, txn.Description, formatCurrency(txn.Amount))))
b.WriteString("\n")
}
b.WriteString("\n")
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,103 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type loginModel struct {
accountNum string
pin string
stage int // 0 = account, 1 = pin, 2 = authenticating
bankName string
}
func newLoginModel(bankName string) loginModel {
return loginModel{bankName: bankName}
}
func (m loginModel) Update(msg tea.Msg) (loginModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
if m.stage == 0 && m.accountNum != "" {
m.stage = 1
return m, nil
}
if m.stage == 1 && m.pin != "" {
m.stage = 2
return m, nil
}
case tea.KeyBackspace:
if m.stage == 0 && len(m.accountNum) > 0 {
m.accountNum = m.accountNum[:len(m.accountNum)-1]
} else if m.stage == 1 && len(m.pin) > 0 {
m.pin = m.pin[:len(m.pin)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
if m.stage == 0 && len(m.accountNum) < 20 {
m.accountNum += ch
} else if m.stage == 1 && len(m.pin) < 12 {
m.pin += ch
}
}
}
return m, nil
}
func (m loginModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText(m.bankName + " ONLINE BANKING"))
b.WriteString("\n")
b.WriteString(centerText("AUTHORIZED ACCESS ONLY"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" ACCOUNT NUMBER: "))
if m.stage == 0 {
b.WriteString(inputStyle.Render(m.accountNum))
b.WriteString(inputStyle.Render("_"))
} else {
b.WriteString(baseStyle.Render(m.accountNum))
}
b.WriteString("\n\n")
if m.stage >= 1 {
b.WriteString(titleStyle.Render(" PIN: "))
masked := strings.Repeat("*", len(m.pin))
if m.stage == 1 {
b.WriteString(inputStyle.Render(masked))
b.WriteString(inputStyle.Render("_"))
} else {
b.WriteString(baseStyle.Render(masked))
}
b.WriteString("\n")
}
if m.stage == 2 {
b.WriteString("\n")
b.WriteString(baseStyle.Render(" AUTHENTICATING..."))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" WARNING: UNAUTHORIZED ACCESS TO THIS SYSTEM IS A FEDERAL CRIME"))
b.WriteString("\n")
b.WriteString(dimStyle.Render(fmt.Sprintf(" UNDER 18 U.S.C. %s 1030. ALL ACTIVITY IS MONITORED AND LOGGED.", "\u00A7")))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,80 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type menuModel struct {
choice string
unread int
bankName string
}
func newMenuModel(bankName string, unreadCount int) menuModel {
return menuModel{bankName: bankName, unread: unreadCount}
}
func (m menuModel) Update(msg tea.Msg) (menuModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
return m, nil
case tea.KeyBackspace:
if len(m.choice) > 0 {
m.choice = m.choice[:len(m.choice)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
if len(m.choice) < 10 {
m.choice += ch
}
}
}
return m, nil
}
func (m menuModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("MAIN MENU"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
items := []struct {
num string
desc string
}{
{"1", "ACCOUNT SUMMARY"},
{"2", "ACCOUNT DETAIL / TRANSACTIONS"},
{"3", "WIRE TRANSFER"},
{"4", "TRANSACTION HISTORY"},
{"5", fmt.Sprintf("SECURE MESSAGES (%d UNREAD)", m.unread)},
{"6", "CHANGE PIN"},
{"7", "LOGOUT"},
}
for _, item := range items {
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%s] %s", item.num, item.desc)))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
b.WriteString(inputStyle.Render(m.choice))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,122 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type messagesModel struct {
messages []SecureMessage
viewing int // -1 = list, >= 0 = detail
choice string
}
func newMessagesModel(messages []SecureMessage) messagesModel {
return messagesModel{messages: messages, viewing: -1}
}
func (m messagesModel) Update(msg tea.Msg) (messagesModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.viewing >= 0 {
// In detail view, any key goes back to list.
m.viewing = -1
m.choice = ""
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
if m.choice == "0" {
m.choice = "back"
return m, nil
}
for i := range m.messages {
if m.choice == fmt.Sprintf("%d", i+1) {
m.viewing = i
m.messages[i].Unread = false
return m, nil
}
}
m.choice = ""
case tea.KeyBackspace:
if len(m.choice) > 0 {
m.choice = m.choice[:len(m.choice)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' && len(m.choice) < 2 {
m.choice += ch
}
}
return m, nil
}
func (m messagesModel) View() string {
if m.viewing >= 0 {
return m.viewDetail()
}
return m.viewList()
}
func (m messagesModel) viewList() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("SECURE MESSAGES"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-4s %-3s %-12s %-22s %s", "#", "", "DATE", "FROM", "SUBJECT")))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
b.WriteString("\n")
for i, msg := range m.messages {
marker := " "
if msg.Unread {
marker = " * "
}
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d]%s%-12s %-22s %s",
i+1, marker, msg.Date, msg.From, msg.Subj)))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" SELECT MESSAGE: "))
b.WriteString(inputStyle.Render(m.choice))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
return b.String()
}
func (m messagesModel) viewDetail() string {
var b strings.Builder
msg := m.messages[m.viewing]
b.WriteString("\n")
b.WriteString(centerText(fmt.Sprintf("MESSAGE #%d", m.viewing+1)))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(msg.Body))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MESSAGE LIST"))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,198 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
const (
transferStepRouting = iota
transferStepDest
transferStepBeneficiary
transferStepBankName
transferStepAmount
transferStepMemo
transferStepConfirm
transferStepAuthCode
transferStepComplete
)
var transferPrompts = []string{
" ROUTING NUMBER (ABA): ",
" DESTINATION ACCOUNT: ",
" BENEFICIARY NAME: ",
" RECEIVING BANK NAME: ",
" AMOUNT (USD): ",
" MEMO / REFERENCE: ",
"",
" AUTHORIZATION CODE: ",
}
type transferModel struct {
step int
fields [8]string // indexed by step
transfer WireTransfer
confirm string // y/n input for confirm step
}
func newTransferModel() transferModel {
return transferModel{}
}
func (m transferModel) Update(msg tea.Msg) (transferModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.step == transferStepComplete {
return m, nil
}
if m.step == transferStepConfirm {
switch keyMsg.Type {
case tea.KeyEnter:
switch strings.ToUpper(m.confirm) {
case "Y", "YES":
m.step = transferStepAuthCode
case "N", "NO":
m.confirm = "cancelled"
}
return m, nil
case tea.KeyBackspace:
if len(m.confirm) > 0 {
m.confirm = m.confirm[:len(m.confirm)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.confirm) < 3 {
m.confirm += ch
}
}
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
val := strings.TrimSpace(m.fields[m.step])
if val == "" {
return m, nil
}
if m.step == transferStepAuthCode {
m.transfer.AuthCode = val
m.step = transferStepComplete
return m, nil
}
switch m.step {
case transferStepRouting:
m.transfer.RoutingNumber = val
case transferStepDest:
m.transfer.DestAccount = val
case transferStepBeneficiary:
m.transfer.Beneficiary = val
case transferStepBankName:
m.transfer.BankName = val
case transferStepAmount:
m.transfer.Amount = val
case transferStepMemo:
m.transfer.Memo = val
}
m.step++
return m, nil
case tea.KeyBackspace:
if len(m.fields[m.step]) > 0 {
m.fields[m.step] = m.fields[m.step][:len(m.fields[m.step])-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.fields[m.step]) < 40 {
m.fields[m.step] += ch
}
}
return m, nil
}
func (m transferModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("WIRE TRANSFER"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
// Show completed fields.
for i := 0; i < m.step && i < len(transferPrompts); i++ {
if i == transferStepConfirm {
continue
}
prompt := transferPrompts[i]
if prompt == "" {
continue
}
b.WriteString(baseStyle.Render(prompt))
b.WriteString(baseStyle.Render(m.fields[i]))
b.WriteString("\n")
}
// Current field.
switch {
case m.step == transferStepConfirm:
b.WriteString("\n")
b.WriteString(titleStyle.Render(" === TRANSFER SUMMARY ==="))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" ROUTING: %s", m.transfer.RoutingNumber)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" ACCOUNT: %s", m.transfer.DestAccount)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" BENEFICIARY: %s", m.transfer.Beneficiary)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" BANK: %s", m.transfer.BankName)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" AMOUNT: $%s", m.transfer.Amount)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" MEMO: %s", m.transfer.Memo)))
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" CONFIRM TRANSFER? (Y/N): "))
b.WriteString(inputStyle.Render(m.confirm))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
case m.step == transferStepComplete:
b.WriteString("\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" TRANSFER QUEUED FOR PROCESSING"))
b.WriteString("\n\n")
routing := m.transfer.RoutingNumber
if len(routing) > 4 {
routing = routing[:4]
}
b.WriteString(baseStyle.Render(fmt.Sprintf(" CONFIRMATION #: WR-%s-%s",
routing, "847291")))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" STATUS: PENDING FEDWIRE SETTLEMENT"))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" ESTIMATED COMPLETION: NEXT BUSINESS DAY"))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
b.WriteString("\n")
case m.step < len(transferPrompts):
prompt := transferPrompts[m.step]
b.WriteString(titleStyle.Render(prompt))
b.WriteString(inputStyle.Render(m.fields[m.step]))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
}
if m.step < transferStepConfirm {
b.WriteString("\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(fmt.Sprintf(" STEP %d OF 6", m.step+1)))
b.WriteString("\n")
}
return b.String()
}

View File

@@ -0,0 +1,168 @@
package banking
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
const termWidth = 80
// Color palette — green-on-black retro terminal.
var (
colorGreen = lipgloss.Color("#00FF00")
colorDim = lipgloss.Color("#007700")
colorBlack = lipgloss.Color("#000000")
colorBright = lipgloss.Color("#AAFFAA")
colorRed = lipgloss.Color("#FF3333")
)
// Reusable styles.
var (
baseStyle = lipgloss.NewStyle().
Foreground(colorGreen).
Background(colorBlack)
headerStyle = lipgloss.NewStyle().
Foreground(colorBright).
Background(colorBlack).
Bold(true).
Width(termWidth).
Align(lipgloss.Center)
titleStyle = lipgloss.NewStyle().
Foreground(colorGreen).
Background(colorBlack).
Bold(true)
dimStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
errorStyle = lipgloss.NewStyle().
Foreground(colorRed).
Background(colorBlack).
Bold(true)
inputStyle = lipgloss.NewStyle().
Foreground(colorBright).
Background(colorBlack)
)
// divider returns an 80-column === line.
func divider() string {
return dimStyle.Render(strings.Repeat("=", termWidth))
}
// thinDivider returns an 80-column --- line.
func thinDivider() string {
return dimStyle.Render(strings.Repeat("-", termWidth))
}
// centerText centers text within 80 columns.
func centerText(s string) string {
return headerStyle.Render(s)
}
// padRight pads a string to the given width.
func padRight(s string, width int) string {
if len(s) >= width {
return s[:width]
}
return s + strings.Repeat(" ", width-len(s))
}
// formatCurrency formats cents as $X,XXX.XX
func formatCurrency(cents int64) string {
negative := cents < 0
if negative {
cents = -cents
}
dollars := cents / 100
remainder := cents % 100
// Add thousands separators.
ds := fmt.Sprintf("%d", dollars)
if len(ds) > 3 {
var parts []string
for len(ds) > 3 {
parts = append([]string{ds[len(ds)-3:]}, parts...)
ds = ds[:len(ds)-3]
}
parts = append([]string{ds}, parts...)
ds = strings.Join(parts, ",")
}
if negative {
return fmt.Sprintf("-$%s.%02d", ds, remainder)
}
return fmt.Sprintf("$%s.%02d", ds, remainder)
}
// padLine pads a single line (which may contain ANSI codes) to termWidth
// using its visual width. Padding uses a black background so the terminal's
// default background doesn't bleed through.
func padLine(line string) string {
w := lipgloss.Width(line)
if w >= termWidth {
return line
}
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
}
// padLines pads every line in a multi-line string to termWidth so that
// shorter lines fully overwrite previous content in the terminal.
func padLines(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = padLine(line)
}
return strings.Join(lines, "\n")
}
// screenFrame wraps content in the persistent header and footer.
// The height parameter is used to pad the output to fill the terminal,
// preventing leftover lines from previous renders bleeding through.
func screenFrame(bankName, terminalID, region, content string, height int) string {
var b strings.Builder
// Header (4 lines).
b.WriteString(divider())
b.WriteString("\n")
b.WriteString(centerText(bankName + " FEDERAL RESERVE SYSTEM"))
b.WriteString("\n")
b.WriteString(centerText("SECURE BANKING TERMINAL"))
b.WriteString("\n")
b.WriteString(divider())
b.WriteString("\n")
// Content.
b.WriteString(content)
// Pad with blank lines between content and footer so the footer
// stays at the bottom and the total output fills the terminal height.
if height > 0 {
const headerLines = 4
const footerLines = 2
// strings.Count gives newlines; add 1 for the line after the last \n.
contentLines := strings.Count(content, "\n") + 1
used := headerLines + contentLines + footerLines
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
for i := used; i < height; i++ {
b.WriteString(blankLine)
b.WriteString("\n")
}
}
// Footer (2 lines).
b.WriteString("\n")
b.WriteString(divider())
b.WriteString("\n")
footer := fmt.Sprintf(" TERMINAL: %s | REGION: %s | ENCRYPTED SESSION ACTIVE", terminalID, region)
b.WriteString(dimStyle.Render(padRight(footer, termWidth)))
// Pad every line to full terminal width so shorter lines overwrite
// leftover content from previous renders.
return padLines(b.String())
}

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
@@ -86,6 +86,9 @@ func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw i
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("bash")
}
if result.exit {
return nil

View File

@@ -9,8 +9,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
@@ -116,7 +116,7 @@ func TestReadLineCtrlD(t *testing.T) {
func TestBashShellHandle(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash", "")
sess := &shell.SessionContext{
SessionID: sessID,
@@ -166,7 +166,7 @@ func TestBashShellHandle(t *testing.T) {
func TestBashShellFakeUser(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash", "")
sess := &shell.SessionContext{
SessionID: sessID,

View File

@@ -0,0 +1,206 @@
package cisco
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// CiscoShell emulates a Cisco IOS CLI.
type CiscoShell struct{}
// NewCiscoShell returns a new CiscoShell instance.
func NewCiscoShell() *CiscoShell {
return &CiscoShell{}
}
func (c *CiscoShell) Name() string { return "cisco" }
func (c *CiscoShell) Description() string { return "Cisco IOS CLI emulator" }
func (c *CiscoShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
hostname := configString(sess.ShellConfig, "hostname", "Router")
model := configString(sess.ShellConfig, "model", "C2960")
iosVersion := configString(sess.ShellConfig, "ios_version", "15.0(2)SE11")
enablePass := configString(sess.ShellConfig, "enable_password", "")
state := newIOSState(hostname, model, iosVersion, enablePass)
// IOS just shows a blank line then the prompt after SSH auth.
fmt.Fprint(rw, "\r\n")
for {
prompt := state.prompt()
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
// Check for Ctrl+Z (^Z) — return to privileged exec.
if trimmed == "\x1a" || trimmed == "^Z" {
if state.mode == modeGlobalConfig || state.mode == modeInterfaceConfig {
state.mode = modePrivilegedExec
state.currentIf = ""
}
continue
}
// Handle "enable" specially — it needs password prompting.
if state.mode == modeUserExec && isEnableCommand(trimmed) {
output := handleEnable(ctx, state, rw)
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("cisco")
}
continue
}
result := state.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("cisco")
}
if result.exit {
return nil
}
}
}
// isEnableCommand checks if input resolves to "enable" in user exec mode.
func isEnableCommand(input string) bool {
words := strings.Fields(input)
if len(words) != 1 {
return false
}
w := strings.ToLower(words[0])
enable := "enable"
return len(w) >= 2 && len(w) <= len(enable) && enable[:len(w)] == w
}
// handleEnable manages the enable password prompt flow.
// Returns the output string (for logging).
func handleEnable(ctx context.Context, state *iosState, rw io.ReadWriter) string {
const maxAttempts = 3
hadFailure := false
for range maxAttempts {
fmt.Fprint(rw, "Password: ")
password, err := readPassword(ctx, rw)
if err != nil {
return ""
}
fmt.Fprint(rw, "\r\n")
if state.enablePass == "" {
// No password configured — accept after one failed attempt.
if hadFailure {
state.mode = modePrivilegedExec
return ""
}
hadFailure = true
} else if password == state.enablePass {
state.mode = modePrivilegedExec
return ""
}
}
output := "% Bad passwords"
fmt.Fprintf(rw, "%s\r\n", output)
return output
}
// readPassword reads a password without echoing characters.
func readPassword(ctx context.Context, rw io.ReadWriter) (string, error) {
var buf []byte
b := make([]byte, 1)
for {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
n, err := rw.Read(b)
if err != nil {
return "", err
}
if n == 0 {
continue
}
ch := b[0]
switch {
case ch == '\r' || ch == '\n':
return string(buf), nil
case ch == 4: // Ctrl+D
return string(buf), io.EOF
case ch == 3: // Ctrl+C
return "", io.EOF
case ch == 127 || ch == 8: // Backspace/DEL
if len(buf) > 0 {
buf = buf[:len(buf)-1]
}
case ch == 27: // ESC sequence
next := make([]byte, 1)
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
rw.Read(next)
}
case ch >= 32 && ch < 127:
buf = append(buf, ch)
// Don't echo.
}
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,531 @@
package cisco
import (
"testing"
)
// --- Abbreviation resolution tests ---
func TestResolveAbbreviationExact(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "shutdown"},
}
got, err := resolveAbbreviation("show", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
func TestResolveAbbreviationUnique(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
{name: "exit"},
}
got, err := resolveAbbreviation("sh", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
func TestResolveAbbreviationAmbiguous(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "shutdown"},
}
_, err := resolveAbbreviation("sh", entries)
if err == nil {
t.Fatal("expected ambiguous error, got nil")
}
if err.Error() != "ambiguous" {
t.Errorf("got error %q, want %q", err.Error(), "ambiguous")
}
}
func TestResolveAbbreviationUnknown(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
}
_, err := resolveAbbreviation("xyz", entries)
if err == nil {
t.Fatal("expected unknown error, got nil")
}
if err.Error() != "unknown" {
t.Errorf("got error %q, want %q", err.Error(), "unknown")
}
}
func TestResolveAbbreviationCaseInsensitive(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
}
got, err := resolveAbbreviation("SH", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
// --- Multi-word command resolution tests ---
func TestResolveCommandShowRunningConfig(t *testing.T) {
resolved, args, err := resolveCommand([]string{"sh", "run"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(args) != 0 {
t.Errorf("unexpected args: %v", args)
}
want := []string{"show", "running-config"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandConfigureTerminal(t *testing.T) {
resolved, _, err := resolveCommand([]string{"conf", "t"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []string{"configure", "terminal"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandShowIPInterfaceBrief(t *testing.T) {
resolved, _, err := resolveCommand([]string{"sh", "ip", "int", "br"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []string{"show", "ip", "interface", "brief"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandWithArgs(t *testing.T) {
// "hostname MyRouter" → resolved=["hostname"], args=["MyRouter"]
resolved, args, err := resolveCommand([]string{"hostname", "MyRouter"}, globalConfigCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resolved) != 1 || resolved[0] != "hostname" {
t.Errorf("resolved = %v, want [hostname]", resolved)
}
if len(args) != 1 || args[0] != "MyRouter" {
t.Errorf("args = %v, want [MyRouter]", args)
}
}
func TestResolveCommandAmbiguous(t *testing.T) {
// In user exec, "e" matches "enable" and "exit" — ambiguous
_, _, err := resolveCommand([]string{"e"}, userExecCommands)
if err == nil {
t.Fatal("expected ambiguous error")
}
}
// --- Mode state machine tests ---
func TestPromptGeneration(t *testing.T) {
tests := []struct {
mode iosMode
want string
}{
{modeUserExec, "Router>"},
{modePrivilegedExec, "Router#"},
{modeGlobalConfig, "Router(config)#"},
{modeInterfaceConfig, "Router(config-if)#"},
}
for _, tt := range tests {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = tt.mode
if got := s.prompt(); got != tt.want {
t.Errorf("prompt(%d) = %q, want %q", tt.mode, got, tt.want)
}
}
}
func TestPromptAfterHostnameChange(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeGlobalConfig
s.dispatch("hostname Switch1")
if s.hostname != "Switch1" {
t.Fatalf("hostname = %q, want %q", s.hostname, "Switch1")
}
if got := s.prompt(); got != "Switch1(config)#" {
t.Errorf("prompt = %q, want %q", got, "Switch1(config)#")
}
}
func TestModeTransitions(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
// Start in user exec.
if s.mode != modeUserExec {
t.Fatalf("initial mode = %d, want %d", s.mode, modeUserExec)
}
// Can't skip to config mode directly from user exec.
result := s.dispatch("configure terminal")
if result.output == "" {
t.Error("expected error for conf t in user exec mode")
}
// Manually set privileged mode (enable tested separately).
s.mode = modePrivilegedExec
// conf t → global config
s.dispatch("configure terminal")
if s.mode != modeGlobalConfig {
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
}
// interface Gi0/0 → interface config
s.dispatch("interface GigabitEthernet0/0")
if s.mode != modeInterfaceConfig {
t.Errorf("mode after interface = %d, want %d", s.mode, modeInterfaceConfig)
}
// exit → back to global config
s.dispatch("exit")
if s.mode != modeGlobalConfig {
t.Errorf("mode after exit from if-config = %d, want %d", s.mode, modeGlobalConfig)
}
// end → back to privileged exec
s.dispatch("end")
if s.mode != modePrivilegedExec {
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
}
// disable → back to user exec
s.dispatch("disable")
if s.mode != modeUserExec {
t.Errorf("mode after disable = %d, want %d", s.mode, modeUserExec)
}
}
func TestEndFromInterfaceConfig(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("end")
if s.mode != modePrivilegedExec {
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
}
if s.currentIf != "" {
t.Errorf("currentIf = %q, want empty", s.currentIf)
}
}
func TestExitFromPrivilegedExec(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
result := s.dispatch("exit")
if !result.exit {
t.Error("expected exit=true from privileged exec exit")
}
}
// --- Show command output tests ---
func TestShowVersionContainsModel(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showVersion(s)
if !contains(output, "C2960") {
t.Error("show version missing model")
}
if !contains(output, "15.0(2)SE11") {
t.Error("show version missing IOS version")
}
if !contains(output, "Router") {
t.Error("show version missing hostname")
}
}
func TestShowRunningConfigContainsInterfaces(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showRunningConfig(s)
if !contains(output, "hostname Router") {
t.Error("running-config missing hostname")
}
if !contains(output, "interface GigabitEthernet0/0") {
t.Error("running-config missing interface")
}
if !contains(output, "ip address 192.168.1.1") {
t.Error("running-config missing IP address")
}
if !contains(output, "line vty") {
t.Error("running-config missing VTY config")
}
}
func TestShowRunningConfigWithEnableSecret(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "secret123")
output := showRunningConfig(s)
if !contains(output, "enable secret") {
t.Error("running-config missing enable secret when password is set")
}
}
func TestShowRunningConfigWithoutEnableSecret(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showRunningConfig(s)
if contains(output, "enable secret") {
t.Error("running-config should not have enable secret when password is empty")
}
}
func TestShowIPInterfaceBrief(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showIPInterfaceBrief(s)
if !contains(output, "GigabitEthernet0/0") {
t.Error("ip interface brief missing GigabitEthernet0/0")
}
if !contains(output, "192.168.1.1") {
t.Error("ip interface brief missing 192.168.1.1")
}
}
func TestShowIPRoute(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showIPRoute(s)
if !contains(output, "directly connected") {
t.Error("ip route missing connected routes")
}
if !contains(output, "0.0.0.0/0") {
t.Error("ip route missing default route")
}
}
func TestShowVLANBrief(t *testing.T) {
output := showVLANBrief()
if !contains(output, "default") {
t.Error("vlan brief missing default vlan")
}
if !contains(output, "MGMT") {
t.Error("vlan brief missing MGMT vlan")
}
}
// --- Interface config tests ---
func TestInterfaceShutdownNoShutdown(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("shutdown")
iface := s.findInterface("GigabitEthernet0/0")
if iface == nil {
t.Fatal("interface not found")
}
if !iface.shutdown {
t.Error("interface should be shutdown")
}
if iface.status != "administratively down" {
t.Errorf("status = %q, want %q", iface.status, "administratively down")
}
s.dispatch("no shutdown")
if iface.shutdown {
t.Error("interface should not be shutdown after no shutdown")
}
if iface.status != "up" {
t.Errorf("status = %q, want %q", iface.status, "up")
}
}
func TestInterfaceIPAddress(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("ip address 10.10.10.1 255.255.255.0")
iface := s.findInterface("GigabitEthernet0/0")
if iface == nil {
t.Fatal("interface not found")
}
if iface.ip != "10.10.10.1" {
t.Errorf("ip = %q, want %q", iface.ip, "10.10.10.1")
}
if iface.mask != "255.255.255.0" {
t.Errorf("mask = %q, want %q", iface.mask, "255.255.255.0")
}
}
// --- Dispatch / invalid command tests ---
func TestInvalidCommandInUserExec(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
result := s.dispatch("foobar")
if !contains(result.output, "Invalid input") {
t.Errorf("expected invalid input error, got %q", result.output)
}
}
func TestAmbiguousCommandOutput(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
// "e" in user exec is ambiguous (enable, exit)
result := s.dispatch("e")
if !contains(result.output, "Ambiguous") {
t.Errorf("expected ambiguous error, got %q", result.output)
}
}
func TestHelpCommand(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
result := s.dispatch("?")
if !contains(result.output, "show") {
t.Error("help missing 'show'")
}
if !contains(result.output, "enable") {
t.Error("help missing 'enable'")
}
}
// --- Abbreviation integration tests ---
func TestShowAbbreviationInDispatch(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
result := s.dispatch("sh ver")
if !contains(result.output, "Cisco IOS Software") {
t.Error("'sh ver' should produce version output")
}
}
func TestConfTAbbreviation(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
s.dispatch("conf t")
if s.mode != modeGlobalConfig {
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
}
}
// --- Enable command detection ---
func TestIsEnableCommand(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"enable", true},
{"en", true},
{"ena", true},
{"e", false}, // too short (single char could be other commands)
{"enab", true},
{"ENABLE", true},
{"exit", false},
{"enable 15", false}, // has extra argument
}
for _, tt := range tests {
if got := isEnableCommand(tt.input); got != tt.want {
t.Errorf("isEnableCommand(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
// --- configString tests ---
func TestConfigString(t *testing.T) {
cfg := map[string]any{"hostname": "MySwitch"}
if got := configString(cfg, "hostname", "Router"); got != "MySwitch" {
t.Errorf("configString() = %q, want %q", got, "MySwitch")
}
if got := configString(cfg, "missing", "Default"); got != "Default" {
t.Errorf("configString() for missing = %q, want %q", got, "Default")
}
if got := configString(nil, "key", "Default"); got != "Default" {
t.Errorf("configString(nil) = %q, want %q", got, "Default")
}
}
// --- Helper ---
func TestMaskBits(t *testing.T) {
tests := []struct {
mask string
want int
}{
{"255.255.255.0", 24},
{"255.255.255.252", 30},
{"255.255.0.0", 16},
{"255.0.0.0", 8},
}
for _, tt := range tests {
if got := maskBits(tt.mask); got != tt.want {
t.Errorf("maskBits(%q) = %d, want %d", tt.mask, got, tt.want)
}
}
}
func TestNetworkFromIP(t *testing.T) {
tests := []struct {
ip, mask, want string
}{
{"192.168.1.1", "255.255.255.0", "192.168.1.0"},
{"10.0.0.1", "255.255.255.252", "10.0.0.0"},
{"172.16.5.100", "255.255.0.0", "172.16.0.0"},
}
for _, tt := range tests {
if got := networkFromIP(tt.ip, tt.mask); got != tt.want {
t.Errorf("networkFromIP(%q, %q) = %q, want %q", tt.ip, tt.mask, got, tt.want)
}
}
}
// --- Shell metadata ---
func TestShellNameAndDescription(t *testing.T) {
s := NewCiscoShell()
if s.Name() != "cisco" {
t.Errorf("Name() = %q, want %q", s.Name(), "cisco")
}
if s.Description() == "" {
t.Error("Description() should not be empty")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && containsHelper(s, substr)
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,414 @@
package cisco
import (
"fmt"
"strings"
)
// commandResult holds the output of a command and whether the session should end.
type commandResult struct {
output string
exit bool
}
// commandEntry defines a single command with its name and optional sub-commands.
type commandEntry struct {
name string
subs []commandEntry // nil for leaf commands
}
// userExecCommands defines the command tree for user EXEC mode.
var userExecCommands = []commandEntry{
{name: "show", subs: []commandEntry{
{name: "version"},
{name: "clock"},
{name: "ip", subs: []commandEntry{
{name: "route"},
{name: "interface", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "interfaces"},
{name: "vlan", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "enable"},
{name: "exit"},
{name: "?"},
}
// privilegedExecCommands extends user commands for privileged mode.
var privilegedExecCommands = []commandEntry{
{name: "show", subs: []commandEntry{
{name: "version"},
{name: "clock"},
{name: "ip", subs: []commandEntry{
{name: "route"},
{name: "interface", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "interfaces"},
{name: "running-config"},
{name: "startup-config"},
{name: "vlan", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "configure", subs: []commandEntry{
{name: "terminal"},
}},
{name: "write", subs: []commandEntry{
{name: "memory"},
}},
{name: "copy"},
{name: "reload"},
{name: "disable"},
{name: "terminal", subs: []commandEntry{
{name: "length"},
}},
{name: "exit"},
{name: "?"},
}
// globalConfigCommands defines the command tree for global config mode.
var globalConfigCommands = []commandEntry{
{name: "hostname"},
{name: "interface"},
{name: "ip", subs: []commandEntry{
{name: "route"},
}},
{name: "no"},
{name: "end"},
{name: "exit"},
{name: "?"},
}
// interfaceConfigCommands defines the command tree for interface config mode.
var interfaceConfigCommands = []commandEntry{
{name: "ip", subs: []commandEntry{
{name: "address"},
}},
{name: "description"},
{name: "shutdown"},
{name: "no", subs: []commandEntry{
{name: "shutdown"},
}},
{name: "switchport", subs: []commandEntry{
{name: "mode"},
}},
{name: "end"},
{name: "exit"},
{name: "?"},
}
// commandsForMode returns the command tree for the given IOS mode.
func commandsForMode(mode iosMode) []commandEntry {
switch mode {
case modeUserExec:
return userExecCommands
case modePrivilegedExec:
return privilegedExecCommands
case modeGlobalConfig:
return globalConfigCommands
case modeInterfaceConfig:
return interfaceConfigCommands
default:
return userExecCommands
}
}
// resolveAbbreviation attempts to match an abbreviated word against a list of
// command entries. It returns the matched entry name, or an error string if
// ambiguous or unknown.
func resolveAbbreviation(word string, entries []commandEntry) (string, error) {
word = strings.ToLower(word)
var matches []string
for _, e := range entries {
if strings.ToLower(e.name) == word {
return e.name, nil // exact match
}
if strings.HasPrefix(strings.ToLower(e.name), word) {
matches = append(matches, e.name)
}
}
switch len(matches) {
case 0:
return "", fmt.Errorf("unknown")
case 1:
return matches[0], nil
default:
return "", fmt.Errorf("ambiguous")
}
}
// resolveCommand resolves a sequence of abbreviated words into the canonical
// command path (e.g., ["sh", "run"] → ["show", "running-config"]).
// It returns the resolved path, any remaining arguments, and an error if
// resolution fails.
func resolveCommand(words []string, entries []commandEntry) ([]string, []string, error) {
var resolved []string
current := entries
for i, w := range words {
name, err := resolveAbbreviation(w, current)
if err != nil {
if err.Error() == "unknown" && len(resolved) > 0 {
// Remaining words are arguments to the resolved command.
return resolved, words[i:], nil
}
return resolved, words[i:], err
}
resolved = append(resolved, name)
// Find sub-commands for the matched entry.
var nextLevel []commandEntry
for _, e := range current {
if e.name == name {
nextLevel = e.subs
break
}
}
if nextLevel == nil {
// Leaf command — rest are arguments.
return resolved, words[i+1:], nil
}
current = nextLevel
}
return resolved, nil, nil
}
// dispatch processes a command line in the context of the current IOS state.
func (s *iosState) dispatch(input string) commandResult {
words := strings.Fields(input)
if len(words) == 0 {
return commandResult{}
}
// Handle "?" as a help request.
if words[0] == "?" {
return s.cmdHelp()
}
cmds := commandsForMode(s.mode)
resolved, args, err := resolveCommand(words, cmds)
if err != nil {
if err.Error() == "ambiguous" {
return commandResult{output: fmt.Sprintf("%% Ambiguous command: \"%s\"", input)}
}
return commandResult{output: invalidInput(input)}
}
if len(resolved) == 0 {
return commandResult{output: invalidInput(input)}
}
cmd := strings.Join(resolved, " ")
switch s.mode {
case modeUserExec:
return s.dispatchUserExec(cmd, args)
case modePrivilegedExec:
return s.dispatchPrivilegedExec(cmd, args)
case modeGlobalConfig:
return s.dispatchGlobalConfig(cmd, args)
case modeInterfaceConfig:
return s.dispatchInterfaceConfig(cmd, args)
}
return commandResult{output: invalidInput(input)}
}
func (s *iosState) dispatchUserExec(cmd string, args []string) commandResult {
switch cmd {
case "show version":
return commandResult{output: showVersion(s)}
case "show clock":
return commandResult{output: showClock()}
case "show ip route":
return commandResult{output: showIPRoute(s)}
case "show ip interface brief":
return commandResult{output: showIPInterfaceBrief(s)}
case "show interfaces":
return commandResult{output: showInterfaces(s)}
case "show vlan brief":
return commandResult{output: showVLANBrief()}
case "enable":
return commandResult{} // handled in Handle() loop
case "exit":
return commandResult{exit: true}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchPrivilegedExec(cmd string, args []string) commandResult {
switch cmd {
case "show version":
return commandResult{output: showVersion(s)}
case "show clock":
return commandResult{output: showClock()}
case "show ip route":
return commandResult{output: showIPRoute(s)}
case "show ip interface brief":
return commandResult{output: showIPInterfaceBrief(s)}
case "show interfaces":
return commandResult{output: showInterfaces(s)}
case "show running-config":
return commandResult{output: showRunningConfig(s)}
case "show startup-config":
return commandResult{output: showRunningConfig(s)} // same as running
case "show vlan brief":
return commandResult{output: showVLANBrief()}
case "configure terminal":
s.mode = modeGlobalConfig
return commandResult{output: "Enter configuration commands, one per line. End with CNTL/Z."}
case "write memory":
return commandResult{output: "[OK]"}
case "copy":
return commandResult{output: "[OK]"}
case "reload":
return commandResult{output: "System configuration has been modified. Save? [yes/no]: ", exit: true}
case "disable":
s.mode = modeUserExec
return commandResult{}
case "terminal length":
return commandResult{} // accept silently
case "exit":
return commandResult{exit: true}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchGlobalConfig(cmd string, args []string) commandResult {
switch cmd {
case "hostname":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
s.hostname = args[0]
return commandResult{}
case "interface":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
ifName := strings.Join(args, "")
s.currentIf = ifName
s.mode = modeInterfaceConfig
return commandResult{}
case "ip route":
return commandResult{} // accept silently
case "no":
return commandResult{} // accept silently
case "end":
s.mode = modePrivilegedExec
return commandResult{}
case "exit":
s.mode = modePrivilegedExec
return commandResult{}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchInterfaceConfig(cmd string, args []string) commandResult {
switch cmd {
case "ip address":
if len(args) < 2 {
return commandResult{output: "% Incomplete command."}
}
if iface := s.findInterface(s.currentIf); iface != nil {
iface.ip = args[0]
iface.mask = args[1]
}
return commandResult{}
case "description":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
if iface := s.findInterface(s.currentIf); iface != nil {
iface.desc = strings.Join(args, " ")
}
return commandResult{}
case "shutdown":
if iface := s.findInterface(s.currentIf); iface != nil {
iface.shutdown = true
iface.status = "administratively down"
iface.protocol = "down"
}
return commandResult{}
case "no shutdown":
if iface := s.findInterface(s.currentIf); iface != nil {
iface.shutdown = false
iface.status = "up"
iface.protocol = "up"
}
return commandResult{}
case "switchport mode":
return commandResult{} // accept silently
case "end":
s.mode = modePrivilegedExec
s.currentIf = ""
return commandResult{}
case "exit":
s.mode = modeGlobalConfig
s.currentIf = ""
return commandResult{}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) cmdHelp() commandResult {
cmds := commandsForMode(s.mode)
var b strings.Builder
for _, e := range cmds {
if e.name == "?" {
continue
}
b.WriteString(fmt.Sprintf(" %-20s %s\n", e.name, helpText(e.name)))
}
return commandResult{output: b.String()}
}
func helpText(name string) string {
switch name {
case "show":
return "Show running system information"
case "enable":
return "Turn on privileged commands"
case "disable":
return "Turn off privileged commands"
case "exit":
return "Exit from the EXEC"
case "configure":
return "Enter configuration mode"
case "write":
return "Write running configuration to memory"
case "copy":
return "Copy from one file to another"
case "reload":
return "Halt and perform a cold restart"
case "terminal":
return "Set terminal line parameters"
case "hostname":
return "Set system's network name"
case "interface":
return "Select an interface to configure"
case "ip":
return "Global IP configuration subcommands"
case "no":
return "Negate a command or set its defaults"
case "end":
return "Exit from configure mode"
case "description":
return "Interface specific description"
case "shutdown":
return "Shutdown the selected interface"
case "switchport":
return "Set switching mode characteristics"
default:
return ""
}
}
func invalidInput(input string) string {
return fmt.Sprintf("%% Invalid input detected at '^' marker.\n\n%s\n^", input)
}

View File

@@ -0,0 +1,234 @@
package cisco
import (
"fmt"
"math/rand"
"strings"
"time"
)
func showVersion(s *iosState) string {
days := 14 + rand.Intn(350)
hours := rand.Intn(24)
mins := rand.Intn(60)
return fmt.Sprintf(`Cisco IOS Software, %s Software (%s-UNIVERSALK9-M), Version %s, RELEASE SOFTWARE (fc3)
Technical Support: http://www.cisco.com/techsupport
Copyright (c) 1986-2019 by Cisco Systems, Inc.
Compiled Thu 30-Jan-19 10:08 by prod_rel_team
ROM: Bootstrap program is %s boot loader
BOOTLDR: %s Boot Loader (C2960-HBOOT-M) Version 15.0(2r)SE, RELEASE SOFTWARE (fc1)
%s uptime is %d days, %d hours, %d minutes
System returned to ROM by power-on
System image file is "flash:/%s-universalk9-mz.SPA.%s.bin"
This product contains cryptographic features and is subject to United States
and local country laws governing import, export, transfer and use.
cisco %s (%s) processor (revision K0) with 524288K bytes of memory.
Processor board ID %s
Last reset from power-on
2 Gigabit Ethernet interfaces
1 Virtual Ethernet interface
64K bytes of flash-simulated non-volatile configuration memory.
Total of 65536K bytes of APC System Flash (Read/Write)
Configuration register is 0x2102`,
s.model, s.model, s.iosVersion,
s.model, s.model,
s.hostname, days, hours, mins,
s.model, s.iosVersion,
s.model, processorForModel(s.model),
s.serial,
)
}
func processorForModel(model string) string {
if strings.HasPrefix(model, "C29") {
return "PowerPC405"
}
return "MIPS"
}
func showClock() string {
now := time.Now().UTC()
return fmt.Sprintf("*%s UTC", now.Format("15:04:05.000 Mon Jan 2 2006"))
}
func showIPRoute(s *iosState) string {
var b strings.Builder
b.WriteString("Codes: C - connected, S - static, R - RIP, M - mobile, B - BGP\n")
b.WriteString(" D - EIGRP, EX - EIGRP external, O - OSPF, IA - OSPF inter area\n")
b.WriteString(" N1 - OSPF NSSA external type 1, N2 - OSPF NSSA external type 2\n")
b.WriteString(" E1 - OSPF external type 1, E2 - OSPF external type 2\n")
b.WriteString(" i - IS-IS, su - IS-IS summary, L1 - IS-IS level-1, L2 - IS-IS level-2\n")
b.WriteString(" ia - IS-IS inter area, * - candidate default, U - per-user static route\n")
b.WriteString(" o - ODR, P - periodic downloaded static route\n\n")
b.WriteString("Gateway of last resort is 10.0.0.2 to network 0.0.0.0\n\n")
for _, iface := range s.interfaces {
if iface.ip == "unassigned" || iface.status != "up" {
continue
}
network := networkFromIP(iface.ip, iface.mask)
maskBits := maskBits(iface.mask)
fmt.Fprintf(&b, "C %s/%d is directly connected, %s\n", network, maskBits, iface.name)
}
b.WriteString("S* 0.0.0.0/0 [1/0] via 10.0.0.2")
return b.String()
}
func showIPInterfaceBrief(s *iosState) string {
var b strings.Builder
fmt.Fprintf(&b, "%-25s %-15s %-4s %-7s %-22s %s\n",
"Interface", "IP-Address", "OK?", "Method", "Status", "Protocol")
for _, iface := range s.interfaces {
ip := iface.ip
if ip == "" {
ip = "unassigned"
}
fmt.Fprintf(&b, "%-25s %-15s YES manual %-22s %s\n",
iface.name, ip, iface.status, iface.protocol)
}
return b.String()
}
func showInterfaces(s *iosState) string {
var b strings.Builder
for i, iface := range s.interfaces {
if i > 0 {
b.WriteString("\n")
}
upDown := "up"
if iface.shutdown {
upDown = "administratively down"
}
fmt.Fprintf(&b, "%s is %s, line protocol is %s\n", iface.name, upDown, iface.protocol)
fmt.Fprintf(&b, " Hardware is Gigabit Ethernet, address is %s (bia %s)\n", iface.mac, iface.mac)
if iface.ip != "unassigned" && iface.ip != "" {
fmt.Fprintf(&b, " Internet address is %s/%d\n", iface.ip, maskBits(iface.mask))
}
fmt.Fprintf(&b, " MTU %d bytes, BW %s sec, DLY 10 usec,\n", iface.mtu, iface.bandwidth)
b.WriteString(" reliability 255/255, txload 1/255, rxload 1/255\n")
b.WriteString(" Encapsulation ARPA, loopback not set\n")
fmt.Fprintf(&b, " %d packets input, %d bytes, 0 no buffer\n", iface.rxPackets, iface.rxBytes)
fmt.Fprintf(&b, " %d packets output, %d bytes, 0 underruns", iface.txPackets, iface.txBytes)
}
return b.String()
}
func showRunningConfig(s *iosState) string {
var b strings.Builder
b.WriteString("Building configuration...\n\n")
b.WriteString("Current configuration : 1482 bytes\n")
b.WriteString("!\n")
b.WriteString("! Last configuration change at 14:32:22 UTC Mon Feb 10 2025\n")
b.WriteString("!\n")
b.WriteString("version 15.0\n")
b.WriteString("service timestamps debug datetime msec\n")
b.WriteString("service timestamps log datetime msec\n")
b.WriteString("no service password-encryption\n")
b.WriteString("!\n")
fmt.Fprintf(&b, "hostname %s\n", s.hostname)
b.WriteString("!\n")
b.WriteString("boot-start-marker\n")
b.WriteString("boot-end-marker\n")
b.WriteString("!\n")
if s.enablePass != "" {
b.WriteString("enable secret 5 $1$mERr$hx5rVt7rPNoS4wqbXKX7m0\n")
}
b.WriteString("!\n")
b.WriteString("no aaa new-model\n")
b.WriteString("!\n")
for _, iface := range s.interfaces {
b.WriteString("!\n")
fmt.Fprintf(&b, "interface %s\n", iface.name)
if iface.desc != "" {
fmt.Fprintf(&b, " description %s\n", iface.desc)
}
if iface.ip != "unassigned" && iface.ip != "" {
fmt.Fprintf(&b, " ip address %s %s\n", iface.ip, iface.mask)
} else {
b.WriteString(" no ip address\n")
}
if iface.shutdown {
b.WriteString(" shutdown\n")
}
}
b.WriteString("!\n")
b.WriteString("ip forward-protocol nd\n")
b.WriteString("!\n")
b.WriteString("ip route 0.0.0.0 0.0.0.0 10.0.0.2\n")
b.WriteString("!\n")
b.WriteString("access-list 10 permit 192.168.1.0 0.0.0.255\n")
b.WriteString("access-list 10 deny any\n")
b.WriteString("!\n")
b.WriteString("line con 0\n")
b.WriteString(" logging synchronous\n")
b.WriteString("line vty 0 4\n")
b.WriteString(" login local\n")
b.WriteString(" transport input ssh\n")
b.WriteString("!\n")
b.WriteString("end")
return b.String()
}
func showVLANBrief() string {
var b strings.Builder
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "VLAN", "Name", "Status", "Ports")
b.WriteString("---- -------------------------------- --------- -------------------------------\n")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1", "default", "active", "Gi0/0, Gi0/1, Gi0/2")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "10", "MGMT", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "20", "USERS", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "99", "NATIVE", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1002", "fddi-default", "act/unsup", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1003", "token-ring-default", "act/unsup", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s", "1004", "fddinet-default", "act/unsup", "")
return b.String()
}
// networkFromIP derives the network address from an IP and mask.
func networkFromIP(ip, mask string) string {
ipParts := parseIPv4(ip)
maskParts := parseIPv4(mask)
if ipParts == nil || maskParts == nil {
return ip
}
return fmt.Sprintf("%d.%d.%d.%d",
ipParts[0]&maskParts[0],
ipParts[1]&maskParts[1],
ipParts[2]&maskParts[2],
ipParts[3]&maskParts[3],
)
}
func maskBits(mask string) int {
parts := parseIPv4(mask)
if parts == nil {
return 24
}
bits := 0
for _, p := range parts {
for i := 7; i >= 0; i-- {
if p&(1<<uint(i)) != 0 {
bits++
} else {
return bits
}
}
}
return bits
}
func parseIPv4(s string) []int {
var a, b, c, d int
n, _ := fmt.Sscanf(s, "%d.%d.%d.%d", &a, &b, &c, &d)
if n != 4 {
return nil
}
return []int{a, b, c, d}
}

View File

@@ -0,0 +1,109 @@
package cisco
import "fmt"
// iosMode represents the current CLI mode of the IOS state machine.
type iosMode int
const (
modeUserExec iosMode = iota // Router>
modePrivilegedExec // Router#
modeGlobalConfig // Router(config)#
modeInterfaceConfig // Router(config-if)#
)
// ifaceInfo holds interface metadata for show commands.
type ifaceInfo struct {
name string
ip string
mask string
status string
protocol string
mac string
bandwidth string
mtu int
rxPackets int
txPackets int
rxBytes int
txBytes int
shutdown bool
desc string
}
// iosState holds all mutable state for the Cisco IOS shell session.
type iosState struct {
mode iosMode
hostname string
model string
iosVersion string
serial string
enablePass string
interfaces []ifaceInfo
currentIf string
}
func newIOSState(hostname, model, iosVersion, enablePass string) *iosState {
return &iosState{
mode: modeUserExec,
hostname: hostname,
model: model,
iosVersion: iosVersion,
serial: "FTX1524Z0P3",
enablePass: enablePass,
interfaces: defaultInterfaces(),
}
}
func defaultInterfaces() []ifaceInfo {
return []ifaceInfo{
{
name: "GigabitEthernet0/0", ip: "192.168.1.1", mask: "255.255.255.0",
status: "up", protocol: "up", mac: "0050.7966.6800",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 148253, txPackets: 93127, rxBytes: 19284732, txBytes: 8291043,
},
{
name: "GigabitEthernet0/1", ip: "10.0.0.1", mask: "255.255.255.252",
status: "up", protocol: "up", mac: "0050.7966.6801",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 52104, txPackets: 48891, rxBytes: 4182934, txBytes: 3901284,
},
{
name: "GigabitEthernet0/2", ip: "unassigned", mask: "",
status: "administratively down", protocol: "down", mac: "0050.7966.6802",
bandwidth: "1000000 Kbit", mtu: 1500, shutdown: true,
},
{
name: "Vlan1", ip: "172.16.0.1", mask: "255.255.0.0",
status: "up", protocol: "up", mac: "0050.7966.6810",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 8421, txPackets: 7103, rxBytes: 512384, txBytes: 423901,
},
}
}
// prompt returns the IOS prompt string for the current mode.
func (s *iosState) prompt() string {
switch s.mode {
case modeUserExec:
return fmt.Sprintf("%s>", s.hostname)
case modePrivilegedExec:
return fmt.Sprintf("%s#", s.hostname)
case modeGlobalConfig:
return fmt.Sprintf("%s(config)#", s.hostname)
case modeInterfaceConfig:
return fmt.Sprintf("%s(config-if)#", s.hostname)
default:
return fmt.Sprintf("%s>", s.hostname)
}
}
// findInterface returns a pointer to the interface with the given name, or nil.
func (s *iosState) findInterface(name string) *ifaceInfo {
for i := range s.interfaces {
if s.interfaces[i].name == name {
return &s.interfaces[i]
}
}
return nil
}

View File

@@ -6,7 +6,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// EventRecorder buffers I/O events in memory and periodically flushes them to

View File

@@ -6,7 +6,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func TestEventRecorderFlush(t *testing.T) {
@@ -14,7 +14,7 @@ func TestEventRecorderFlush(t *testing.T) {
ctx := context.Background()
// Create a session so events have a valid session ID.
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -55,7 +55,7 @@ func TestEventRecorderPeriodicFlush(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
@@ -69,6 +69,9 @@ func (f *FridgeShell) Handle(ctx context.Context, sess *shell.SessionContext, rw
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("fridge")
}
if result.exit {
return nil

View File

@@ -8,8 +8,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
@@ -22,7 +22,7 @@ func (r *rwCloser) Close() error { return nil }
func runShell(t *testing.T, commands string) string {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
sess := &shell.SessionContext{
SessionID: sessID,
@@ -205,7 +205,7 @@ func TestLogoutCommand(t *testing.T) {
func TestSessionLogs(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
sess := &shell.SessionContext{
SessionID: sessID,

View File

@@ -0,0 +1,123 @@
package psql
import (
"fmt"
"strings"
"time"
)
// commandResult holds the output of a command and whether the session should end.
type commandResult struct {
output string
exit bool
}
// dispatchBackslash handles psql backslash meta-commands.
func dispatchBackslash(cmd, dbName string) commandResult {
// Normalize: trim spaces after the backslash command word.
parts := strings.Fields(cmd)
if len(parts) == 0 {
return commandResult{output: "Invalid command \\. Try \\? for help."}
}
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
args := parts[1:]
switch verb {
case `\q`:
return commandResult{exit: true}
case `\dt`:
return commandResult{output: listTables()}
case `\d`:
if len(args) == 0 {
return commandResult{output: listTables()}
}
return commandResult{output: describeTable(args[0])}
case `\l`:
return commandResult{output: listDatabases()}
case `\du`:
return commandResult{output: listRoles()}
case `\conninfo`:
return commandResult{output: connInfo(dbName)}
case `\?`:
return commandResult{output: backslashHelp()}
case `\h`:
return commandResult{output: sqlHelp()}
default:
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
}
}
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
// Strip trailing semicolon and whitespace for matching.
trimmed := strings.TrimRight(sql, "; \t")
trimmed = strings.TrimSpace(trimmed)
upper := strings.ToUpper(trimmed)
switch {
case upper == "SELECT VERSION()":
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
return commandResult{output: formatSingleValue("version", ver)}
case upper == "SELECT CURRENT_DATABASE()":
return commandResult{output: formatSingleValue("current_database", dbName)}
case upper == "SELECT CURRENT_USER":
return commandResult{output: formatSingleValue("current_user", "postgres")}
case upper == "SELECT NOW()":
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
return commandResult{output: formatSingleValue("now", now)}
case upper == "SELECT 1":
return commandResult{output: formatSingleValue("?column?", "1")}
case strings.HasPrefix(upper, "INSERT"):
return commandResult{output: "INSERT 0 1"}
case strings.HasPrefix(upper, "UPDATE"):
return commandResult{output: "UPDATE 1"}
case strings.HasPrefix(upper, "DELETE"):
return commandResult{output: "DELETE 1"}
case strings.HasPrefix(upper, "CREATE TABLE"):
return commandResult{output: "CREATE TABLE"}
case strings.HasPrefix(upper, "CREATE DATABASE"):
return commandResult{output: "CREATE DATABASE"}
case strings.HasPrefix(upper, "DROP TABLE"):
return commandResult{output: "DROP TABLE"}
case strings.HasPrefix(upper, "ALTER TABLE"):
return commandResult{output: "ALTER TABLE"}
case upper == "BEGIN":
return commandResult{output: "BEGIN"}
case upper == "COMMIT":
return commandResult{output: "COMMIT"}
case upper == "ROLLBACK":
return commandResult{output: "ROLLBACK"}
case upper == "SHOW SERVER_VERSION":
return commandResult{output: formatSingleValue("server_version", pgVersion)}
case upper == "SHOW SEARCH_PATH":
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
case strings.HasPrefix(upper, "SET "):
return commandResult{output: "SET"}
default:
// Extract the first token for the error message.
firstToken := strings.Fields(trimmed)
token := trimmed
if len(firstToken) > 0 {
token = firstToken[0]
}
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
}
}
// formatSingleValue formats a single-row, single-column psql result.
func formatSingleValue(colName, value string) string {
width := max(len(colName), len(value))
var b strings.Builder
// Header
fmt.Fprintf(&b, " %-*s \n", width, colName)
// Separator
b.WriteString(strings.Repeat("-", width+2))
b.WriteString("\n")
// Value
fmt.Fprintf(&b, " %-*s\n", width, value)
// Row count
b.WriteString("(1 row)")
return b.String()
}

View File

@@ -0,0 +1,155 @@
package psql
import "fmt"
func startupBanner(version string) string {
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
}
func listTables() string {
return ` List of relations
Schema | Name | Type | Owner
--------+---------------+-------+----------
public | audit_log | table | postgres
public | credentials | table | postgres
public | sessions | table | postgres
public | users | table | postgres
(4 rows)`
}
func listDatabases() string {
return ` List of databases
Name | Owner | Encoding | Collate | Ctype | Access privileges
-----------+----------+----------+-------------+-------------+-----------------------
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
| | | | | postgres=CTc/postgres
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
| | | | | postgres=CTc/postgres
(4 rows)`
}
func listRoles() string {
return ` List of roles
Role name | Attributes | Member of
-----------+------------------------------------------------------------+-----------
app_user | | {}
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
readonly | Cannot login | {}`
}
func describeTable(name string) string {
switch name {
case "users":
return ` Table "public.users"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+-----------------------------------
id | integer | | not null | nextval('users_id_seq'::regclass)
username | character varying(255) | | not null |
email | character varying(255) | | not null |
password | character varying(255) | | not null |
created_at | timestamp without time zone | | | now()
updated_at | timestamp without time zone | | | now()
Indexes:
"users_pkey" PRIMARY KEY, btree (id)
"users_email_key" UNIQUE, btree (email)
"users_username_key" UNIQUE, btree (username)`
case "sessions":
return ` Table "public.sessions"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+--------------------------------------
id | integer | | not null | nextval('sessions_id_seq'::regclass)
user_id | integer | | |
token | character varying(255) | | not null |
ip_address | inet | | |
created_at | timestamp without time zone | | | now()
expires_at | timestamp without time zone | | not null |
Indexes:
"sessions_pkey" PRIMARY KEY, btree (id)
"sessions_token_key" UNIQUE, btree (token)
Foreign-key constraints:
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
case "credentials":
return ` Table "public.credentials"
Column | Type | Collation | Nullable | Default
-----------+-----------------------------+-----------+----------+-----------------------------------------
id | integer | | not null | nextval('credentials_id_seq'::regclass)
user_id | integer | | |
type | character varying(50) | | not null |
value | text | | not null |
created_at| timestamp without time zone | | | now()
Indexes:
"credentials_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
case "audit_log":
return ` Table "public.audit_log"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+---------------------------------------
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
user_id | integer | | |
action | character varying(100) | | not null |
details | text | | |
ip_address | inet | | |
created_at | timestamp without time zone | | | now()
Indexes:
"audit_log_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
default:
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
}
}
func connInfo(dbName string) string {
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
}
func backslashHelp() string {
return `General
\copyright show PostgreSQL usage and distribution terms
\crosstabview [COLUMNS] execute query and display result in crosstab
\errverbose show most recent error message at maximum verbosity
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
\gdesc describe result of query, without executing it
\gexec execute query, then execute each value in its result
\gset [PREFIX] execute query and store result in psql variables
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
\q quit psql
\watch [SEC] execute query every SEC seconds
Informational
(options: S = show system objects, + = additional detail)
\d[S+] list tables, views, and sequences
\d[S+] NAME describe table, view, sequence, or index
\da[S] [PATTERN] list aggregates
\dA[+] [PATTERN] list access methods
\dt[S+] [PATTERN] list tables
\du[S+] [PATTERN] list roles
\l[+] [PATTERN] list databases`
}
func sqlHelp() string {
return `Available help:
ABORT CREATE LANGUAGE
ALTER AGGREGATE CREATE MATERIALIZED VIEW
ALTER COLLATION CREATE OPERATOR
ALTER CONVERSION CREATE POLICY
ALTER DATABASE CREATE PROCEDURE
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
ALTER DOMAIN CREATE ROLE
ALTER EVENT TRIGGER CREATE RULE
ALTER EXTENSION CREATE SCHEMA
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
ALTER FOREIGN TABLE CREATE SERVER
ALTER FUNCTION CREATE STATISTICS
ALTER GROUP CREATE SUBSCRIPTION
ALTER INDEX CREATE TABLE
ALTER LANGUAGE CREATE TABLESPACE
BEGIN DELETE
COMMIT DROP TABLE
CREATE DATABASE INSERT
CREATE INDEX ROLLBACK
SELECT UPDATE`
}

137
internal/shell/psql/psql.go Normal file
View File

@@ -0,0 +1,137 @@
package psql
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// PsqlShell emulates a PostgreSQL psql interactive terminal.
type PsqlShell struct{}
// NewPsqlShell returns a new PsqlShell instance.
func NewPsqlShell() *PsqlShell {
return &PsqlShell{}
}
func (p *PsqlShell) Name() string { return "psql" }
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
dbName := configString(sess.ShellConfig, "db_name", "postgres")
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
// Print startup banner.
fmt.Fprint(rw, startupBanner(pgVersion))
var sqlBuf []string // accumulates multi-line SQL
for {
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
// Empty line in non-buffering state: just re-prompt.
if trimmed == "" && len(sqlBuf) == 0 {
continue
}
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
if strings.HasPrefix(trimmed, `\`) {
sqlBuf = nil // discard any partial SQL
result := dispatchBackslash(trimmed, dbName)
if result.output != "" {
output := strings.ReplaceAll(result.output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("psql")
}
if result.exit {
return nil
}
continue
}
// Accumulate SQL lines.
sqlBuf = append(sqlBuf, line)
// Check if the statement is terminated by a semicolon.
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
continue
}
// Full statement ready — join and dispatch.
fullSQL := strings.Join(sqlBuf, " ")
sqlBuf = nil
result := dispatchSQL(fullSQL, dbName, pgVersion)
if result.output != "" {
output := strings.ReplaceAll(result.output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("psql")
}
if result.exit {
return nil
}
}
}
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
func buildPrompt(dbName string, continuation bool) string {
if continuation {
return dbName + "-# "
}
return dbName + "=# "
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,330 @@
package psql
import (
"strings"
"testing"
)
// --- Prompt tests ---
func TestBuildPromptNormal(t *testing.T) {
got := buildPrompt("postgres", false)
if got != "postgres=# " {
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
}
}
func TestBuildPromptContinuation(t *testing.T) {
got := buildPrompt("postgres", true)
if got != "postgres-# " {
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
}
}
func TestBuildPromptCustomDB(t *testing.T) {
got := buildPrompt("mydb", false)
if got != "mydb=# " {
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
}
}
// --- Backslash command dispatch tests ---
func TestBackslashQuit(t *testing.T) {
result := dispatchBackslash(`\q`, "postgres")
if !result.exit {
t.Error("\\q should set exit=true")
}
}
func TestBackslashListTables(t *testing.T) {
result := dispatchBackslash(`\dt`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\dt should list tables including 'users'")
}
if !strings.Contains(result.output, "sessions") {
t.Error("\\dt should list tables including 'sessions'")
}
}
func TestBackslashDescribeTable(t *testing.T) {
result := dispatchBackslash(`\d users`, "postgres")
if !strings.Contains(result.output, "username") {
t.Error("\\d users should describe users table with 'username' column")
}
if !strings.Contains(result.output, "PRIMARY KEY") {
t.Error("\\d users should include index info")
}
}
func TestBackslashDescribeUnknownTable(t *testing.T) {
result := dispatchBackslash(`\d nonexistent`, "postgres")
if !strings.Contains(result.output, "Did not find") {
t.Error("\\d nonexistent should return not found message")
}
}
func TestBackslashListDatabases(t *testing.T) {
result := dispatchBackslash(`\l`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\l should list databases including 'postgres'")
}
if !strings.Contains(result.output, "template0") {
t.Error("\\l should list databases including 'template0'")
}
}
func TestBackslashListRoles(t *testing.T) {
result := dispatchBackslash(`\du`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\du should list roles including 'postgres'")
}
if !strings.Contains(result.output, "Superuser") {
t.Error("\\du should show Superuser attribute for postgres")
}
}
func TestBackslashConnInfo(t *testing.T) {
result := dispatchBackslash(`\conninfo`, "mydb")
if !strings.Contains(result.output, "mydb") {
t.Error("\\conninfo should include database name")
}
if !strings.Contains(result.output, "5432") {
t.Error("\\conninfo should include port")
}
}
func TestBackslashHelp(t *testing.T) {
result := dispatchBackslash(`\?`, "postgres")
if !strings.Contains(result.output, `\q`) {
t.Error("\\? should include \\q in help output")
}
}
func TestBackslashSQLHelp(t *testing.T) {
result := dispatchBackslash(`\h`, "postgres")
if !strings.Contains(result.output, "SELECT") {
t.Error("\\h should include SQL commands like SELECT")
}
}
func TestBackslashUnknown(t *testing.T) {
result := dispatchBackslash(`\xyz`, "postgres")
if !strings.Contains(result.output, "Invalid command") {
t.Error("unknown backslash command should return error")
}
}
// --- SQL dispatch tests ---
func TestSQLSelectVersion(t *testing.T) {
result := dispatchSQL("SELECT version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SELECT version() should contain pg version")
}
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT version() should show row count")
}
}
func TestSQLSelectCurrentDatabase(t *testing.T) {
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("SELECT current_database() should return db name")
}
}
func TestSQLSelectCurrentUser(t *testing.T) {
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
if !strings.Contains(result.output, "postgres") {
t.Error("SELECT current_user should return postgres")
}
}
func TestSQLSelectNow(t *testing.T) {
result := dispatchSQL("SELECT now();", "postgres", "15.4")
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT now() should show row count")
}
}
func TestSQLSelectOne(t *testing.T) {
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
if !strings.Contains(result.output, "1") {
t.Error("SELECT 1 should return 1")
}
}
func TestSQLInsert(t *testing.T) {
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
if result.output != "INSERT 0 1" {
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
}
}
func TestSQLUpdate(t *testing.T) {
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
if result.output != "UPDATE 1" {
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
}
}
func TestSQLDelete(t *testing.T) {
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
if result.output != "DELETE 1" {
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
}
}
func TestSQLCreateTable(t *testing.T) {
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
if result.output != "CREATE TABLE" {
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
}
}
func TestSQLCreateDatabase(t *testing.T) {
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
if result.output != "CREATE DATABASE" {
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
}
}
func TestSQLDropTable(t *testing.T) {
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
if result.output != "DROP TABLE" {
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
}
}
func TestSQLAlterTable(t *testing.T) {
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
if result.output != "ALTER TABLE" {
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
}
}
func TestSQLBeginCommitRollback(t *testing.T) {
tests := []struct {
sql string
want string
}{
{"BEGIN;", "BEGIN"},
{"COMMIT;", "COMMIT"},
{"ROLLBACK;", "ROLLBACK"},
}
for _, tt := range tests {
result := dispatchSQL(tt.sql, "postgres", "15.4")
if result.output != tt.want {
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
}
}
}
func TestSQLShowServerVersion(t *testing.T) {
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SHOW server_version should contain version")
}
}
func TestSQLShowSearchPath(t *testing.T) {
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
if !strings.Contains(result.output, "public") {
t.Error("SHOW search_path should contain public")
}
}
func TestSQLSet(t *testing.T) {
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
if result.output != "SET" {
t.Errorf("SET output = %q, want %q", result.output, "SET")
}
}
func TestSQLUnrecognized(t *testing.T) {
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
if !strings.Contains(result.output, "ERROR") {
t.Error("unrecognized SQL should return error")
}
if !strings.Contains(result.output, "FOOBAR") {
t.Error("error should reference the offending token")
}
}
// --- Case insensitivity ---
func TestSQLCaseInsensitive(t *testing.T) {
result := dispatchSQL("select version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("select version() (lowercase) should work")
}
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("mixed case SELECT should work")
}
}
// --- Startup banner ---
func TestStartupBanner(t *testing.T) {
banner := startupBanner("15.4")
if !strings.Contains(banner, "psql (15.4)") {
t.Errorf("banner should contain version, got: %s", banner)
}
if !strings.Contains(banner, "help") {
t.Error("banner should mention help")
}
}
// --- configString ---
func TestConfigString(t *testing.T) {
cfg := map[string]any{"db_name": "mydb"}
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
t.Errorf("configString() = %q, want %q", got, "mydb")
}
if got := configString(cfg, "missing", "default"); got != "default" {
t.Errorf("configString() for missing = %q, want %q", got, "default")
}
if got := configString(nil, "key", "default"); got != "default" {
t.Errorf("configString(nil) = %q, want %q", got, "default")
}
}
// --- Shell metadata ---
func TestShellNameAndDescription(t *testing.T) {
s := NewPsqlShell()
if s.Name() != "psql" {
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
}
if s.Description() == "" {
t.Error("Description() should not be empty")
}
}
// --- formatSingleValue ---
func TestFormatSingleValue(t *testing.T) {
out := formatSingleValue("?column?", "1")
if !strings.Contains(out, "?column?") {
t.Error("should contain column name")
}
if !strings.Contains(out, "1") {
t.Error("should contain value")
}
if !strings.Contains(out, "(1 row)") {
t.Error("should contain row count")
}
}
// --- \d with no args ---
func TestBackslashDescribeNoArgs(t *testing.T) {
result := dispatchBackslash(`\d`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\d with no args should list tables")
}
}

View File

@@ -0,0 +1,463 @@
package roomba
import (
"context"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
type RoombaShell struct{}
// NewRoombaShell returns a new RoombaShell instance.
func NewRoombaShell() *RoombaShell {
return &RoombaShell{}
}
func (r *RoombaShell) Name() string { return "roomba" }
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
state := newRoombaState()
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
fmt.Fprint(rw, banner)
for {
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
fmt.Fprint(rw, "logout\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := state.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("roomba")
}
if result.exit {
return nil
}
}
}
func bootBanner() string {
return `
____ _ ___ ____
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
iRobot Roomba j7+ | RoombaOS v4.3.7
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
Battery: 73% | WiFi: Connected (SmartHome-5G)
Type 'help' for available commands.
`
}
type room struct {
name string
areaSqFt int
lastCleaned time.Time
}
type scheduleEntry struct {
day string
time string
}
type historyEntry struct {
timestamp time.Time
room string
duration string
note string
}
type roombaState struct {
battery int
dustbin int
status string
rooms []room
schedule []scheduleEntry
cleanHistory []historyEntry
}
type commandResult struct {
output string
exit bool
}
func newRoombaState() *roombaState {
now := time.Now()
return &roombaState{
battery: 73,
dustbin: 61,
status: "Docked",
rooms: []room{
{"Kitchen", 180, now.Add(-2 * time.Hour)},
{"Living Room", 320, now.Add(-5 * time.Hour)},
{"Bedroom", 200, now.Add(-26 * time.Hour)},
{"Hallway", 60, now.Add(-5 * time.Hour)},
{"Bathroom", 75, now.Add(-50 * time.Hour)},
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
},
schedule: []scheduleEntry{
{"Monday", "09:00"},
{"Wednesday", "09:00"},
{"Friday", "09:00"},
{"Saturday", "14:00"},
},
cleanHistory: []historyEntry{
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
},
}
}
func (s *roombaState) dispatch(input string) commandResult {
parts := strings.Fields(input)
if len(parts) == 0 {
return commandResult{}
}
cmd := strings.ToLower(parts[0])
args := parts[1:]
switch cmd {
case "help":
return s.cmdHelp()
case "status":
return s.cmdStatus()
case "clean":
return s.cmdClean(args)
case "dock":
return s.cmdDock()
case "map":
return s.cmdMap()
case "schedule":
return s.cmdSchedule(args)
case "history":
return s.cmdHistory()
case "diagnostics":
return s.cmdDiagnostics()
case "alerts":
return s.cmdAlerts()
case "reboot":
return s.cmdReboot()
case "exit", "logout":
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
default:
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
}
}
func (s *roombaState) cmdHelp() commandResult {
help := `Available commands:
help - Show this help message
status - Show robot status
clean - Start full cleaning job
clean room <name> - Clean a specific room
dock - Return to dock
map - Show floor plan and room list
schedule - List cleaning schedule
schedule add <day> <time> - Add scheduled cleaning
schedule remove <day> - Remove scheduled cleaning
history - Show recent cleaning history
diagnostics - Run system diagnostics
alerts - Show active alerts
reboot - Reboot RoombaOS
exit / logout - Disconnect`
return commandResult{output: help}
}
func (s *roombaState) cmdStatus() commandResult {
var b strings.Builder
b.WriteString("=== RoombaOS System Status ===\n")
b.WriteString("Model: iRobot Roomba j7+\n")
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
b.WriteString("Side brush: OK (142 hrs)\n")
b.WriteString("Main brush: OK (98 hrs)\n")
b.WriteString("\n")
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
b.WriteString("Signal: -38 dBm\n")
b.WriteString("Alexa: Linked\n")
b.WriteString("Google Home: Linked\n")
b.WriteString("iRobot Home App: Connected\n")
b.WriteString("\n")
b.WriteString("Firmware: v4.3.7-stable\n")
b.WriteString("LIDAR: Operational\n")
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
return commandResult{output: b.String()}
}
func (s *roombaState) cmdClean(args []string) commandResult {
if s.status == "Cleaning" {
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
}
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
roomName := strings.Join(args[1:], " ")
for _, r := range s.rooms {
if strings.EqualFold(r.name, roomName) {
s.status = "Cleaning"
return commandResult{output: fmt.Sprintf(
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
)}
}
}
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
}
if len(args) > 0 {
return commandResult{output: "Usage: clean [room <name>]"}
}
s.status = "Cleaning"
var totalArea int
for _, r := range s.rooms {
totalArea += r.areaSqFt
}
return commandResult{output: fmt.Sprintf(
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
totalArea, len(s.rooms), totalArea/8,
)}
}
func (s *roombaState) cmdDock() commandResult {
if s.status == "Docked" {
return commandResult{output: "Already docked."}
}
if s.status == "Returning to dock" {
return commandResult{output: "Already returning to dock."}
}
s.status = "Returning to dock"
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
}
func (s *roombaState) cmdMap() commandResult {
var b strings.Builder
b.WriteString("=== Floor Plan ===\n\n")
b.WriteString(" +------------+----------+\n")
b.WriteString(" | | |\n")
b.WriteString(" | Kitchen | Bathroom |\n")
b.WriteString(" | 180sqft | 75sqft |\n")
b.WriteString(" | | |\n")
b.WriteString(" +------+-----+----+-----+\n")
b.WriteString(" | | | |\n")
b.WriteString(" | Hall | Living | Cat |\n")
b.WriteString(" | 60sf | Room | Rm |\n")
b.WriteString(" | | 320sqft |110sf|\n")
b.WriteString(" +------+ +-----+\n")
b.WriteString(" | | |\n")
b.WriteString(" | Bed +----------+\n")
b.WriteString(" | room | [DOCK]\n")
b.WriteString(" |200sf |\n")
b.WriteString(" +------+\n")
b.WriteString("\nRoom Details:\n")
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
for _, r := range s.rooms {
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
}
return commandResult{output: b.String()}
}
func (s *roombaState) cmdSchedule(args []string) commandResult {
if len(args) == 0 {
return s.scheduleList()
}
sub := strings.ToLower(args[0])
switch sub {
case "add":
if len(args) < 3 {
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
}
return s.scheduleAdd(args[1], args[2])
case "remove":
if len(args) < 2 {
return commandResult{output: "Usage: schedule remove <day>"}
}
return s.scheduleRemove(args[1])
default:
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
}
}
func (s *roombaState) scheduleList() commandResult {
if len(s.schedule) == 0 {
return commandResult{output: "No cleaning schedule configured."}
}
var b strings.Builder
b.WriteString("=== Cleaning Schedule ===\n")
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
for _, e := range s.schedule {
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
}
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
return commandResult{output: b.String()}
}
func (s *roombaState) scheduleAdd(day, t string) commandResult {
day = capitalizeFirst(strings.ToLower(day))
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
if !slices.Contains(validDays, day) {
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
}
for _, e := range s.schedule {
if strings.EqualFold(e.day, day) {
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
}
}
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
}
func (s *roombaState) scheduleRemove(day string) commandResult {
day = capitalizeFirst(strings.ToLower(day))
for i, e := range s.schedule {
if strings.EqualFold(e.day, day) {
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
}
}
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
}
func (s *roombaState) cmdHistory() commandResult {
if len(s.cleanHistory) == 0 {
return commandResult{output: "No cleaning history."}
}
var b strings.Builder
b.WriteString("=== Cleaning History ===\n")
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
for _, h := range s.cleanHistory {
ts := h.timestamp.Format("2006-01-02 15:04")
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
}
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
return commandResult{output: b.String()}
}
func (s *roombaState) cmdDiagnostics() commandResult {
diag := `Running RoombaOS diagnostics...
[1/8] Cliff sensors........... OK
[2/8] Bumper sensor........... OK
[3/8] Side brush motor........ OK (142 hrs until replacement)
[4/8] Main brush motor........ OK (98 hrs until replacement)
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
[7/8] Dustbin sensor.......... OK
[8/8] WiFi module............. OK (signal: -38 dBm)
ALL SYSTEMS NOMINAL`
return commandResult{output: diag}
}
func (s *roombaState) cmdAlerts() commandResult {
var alerts []string
if s.dustbin >= 60 {
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
}
alerts = append(alerts,
"WARNING: Side brush replacement due in 12 hours",
"INFO: Unidentified sticky substance detected in Kitchen",
"INFO: Cat frequently blocking cleaning path in Cat's Room",
"INFO: Firmware update available: v4.4.0-beta",
"INFO: Filter replacement recommended in 14 days",
)
var b strings.Builder
b.WriteString("=== Active Alerts ===\n")
for _, a := range alerts {
b.WriteString(a + "\n")
}
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
return commandResult{output: b.String()}
}
func (s *roombaState) cmdReboot() commandResult {
reboot := `RoombaOS is rebooting...
Stopping navigation engine..... done
Saving room map data........... done
Flushing cleaning logs......... done
Disconnecting from WiFi........ done
Rebooting now. Goodbye!`
return commandResult{output: reboot, exit: true}
}
func capitalizeFirst(s string) string {
if s == "" {
return s
}
return strings.ToUpper(s[:1]) + s[1:]
}
func formatDuration(d time.Duration) string {
hours := int(d.Hours())
minutes := int(d.Minutes()) % 60
if hours >= 24 {
days := hours / 24
hours %= 24
return fmt.Sprintf("%dd %dh", days, hours)
}
if hours > 0 {
return fmt.Sprintf("%dh %dm", hours, minutes)
}
return fmt.Sprintf("%dm", minutes)
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"io"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// Shell is the interface that all honeypot shell implementations must satisfy.
@@ -24,6 +24,7 @@ type SessionContext struct {
Store storage.Store
ShellConfig map[string]any
CommonConfig ShellCommonConfig
OnCommand func(shell string) // called when a command is executed; may be nil
}
// ShellCommonConfig holds settings shared across all shell types.

View File

@@ -0,0 +1,101 @@
package tetris
import "github.com/charmbracelet/lipgloss"
// pieceType identifies a tetromino (06).
type pieceType int
const (
pieceI pieceType = iota
pieceO
pieceT
pieceS
pieceZ
pieceJ
pieceL
)
const numPieceTypes = 7
// Standard Tetris colors.
var pieceColors = [numPieceTypes]lipgloss.Color{
lipgloss.Color("#00FFFF"), // I — cyan
lipgloss.Color("#FFFF00"), // O — yellow
lipgloss.Color("#AA00FF"), // T — purple
lipgloss.Color("#00FF00"), // S — green
lipgloss.Color("#FF0000"), // Z — red
lipgloss.Color("#0000FF"), // J — blue
lipgloss.Color("#FF8800"), // L — orange
}
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
// relative to the piece origin.
type rotation [4][2]int
var pieces = [numPieceTypes][4]rotation{
// I
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
},
// O
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
},
// T
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
},
// S
{
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
},
// Z
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
},
// J
{
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
},
// L
{
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
},
}
// spawnCol returns the starting column for a piece, centering it on the board.
func spawnCol(pt pieceType, rot int) int {
shape := pieces[pt][rot]
minC, maxC := shape[0][1], shape[0][1]
for _, off := range shape {
if off[1] < minC {
minC = off[1]
}
if off[1] > maxC {
maxC = off[1]
}
}
width := maxC - minC + 1
return (boardCols - width) / 2
}

View File

@@ -0,0 +1,210 @@
package tetris
import "math/rand/v2"
const (
boardRows = 20
boardCols = 10
)
// cell represents a single board cell. Zero value is empty.
type cell struct {
filled bool
piece pieceType // which piece type filled this cell (for color)
}
// gameState holds all mutable state for a Tetris game.
type gameState struct {
board [boardRows][boardCols]cell
current pieceType
currentRot int
currentRow int
currentCol int
next pieceType
score int
level int
lines int
gameOver bool
}
// newGame creates a new game state, optionally starting at a given level.
func newGame(startLevel int) *gameState {
g := &gameState{
level: startLevel,
next: pieceType(rand.IntN(numPieceTypes)),
}
g.spawnPiece()
return g
}
// spawnPiece pulls the next piece and generates a new next.
func (g *gameState) spawnPiece() {
g.current = g.next
g.next = pieceType(rand.IntN(numPieceTypes))
g.currentRot = 0
g.currentRow = 0
g.currentCol = spawnCol(g.current, 0)
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
g.gameOver = true
}
}
// canPlace checks whether the piece fits at the given position.
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
shape := pieces[pt][rot]
for _, off := range shape {
r, c := row+off[0], col+off[1]
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
return false
}
if g.board[r][c].filled {
return false
}
}
return true
}
// moveLeft moves the current piece left if possible.
func (g *gameState) moveLeft() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
g.currentCol--
return true
}
return false
}
// moveRight moves the current piece right if possible.
func (g *gameState) moveRight() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
g.currentCol++
return true
}
return false
}
// moveDown moves the current piece down one row. Returns false if it cannot.
func (g *gameState) moveDown() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
g.currentRow++
return true
}
return false
}
// rotate rotates the current piece clockwise with wall kick attempts.
func (g *gameState) rotate() bool {
newRot := (g.currentRot + 1) % 4
// Try in-place first.
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
g.currentRot = newRot
return true
}
// Wall kick: try +-1 column offset.
for _, offset := range []int{-1, 1} {
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
g.currentRot = newRot
g.currentCol += offset
return true
}
}
// I piece: try +-2.
if g.current == pieceI {
for _, offset := range []int{-2, 2} {
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
g.currentRot = newRot
g.currentCol += offset
return true
}
}
}
return false
}
// ghostRow returns the row where the piece would land.
func (g *gameState) ghostRow() int {
row := g.currentRow
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
row++
}
return row
}
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
func (g *gameState) hardDrop() int {
ghost := g.ghostRow()
dropped := ghost - g.currentRow
g.currentRow = ghost
return dropped
}
// lockPiece writes the current piece into the board.
func (g *gameState) lockPiece() {
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := g.currentRow+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
g.board[r][c] = cell{filled: true, piece: g.current}
}
}
}
// clearLines removes completed rows and returns how many were cleared.
func (g *gameState) clearLines() int {
cleared := 0
for r := boardRows - 1; r >= 0; r-- {
full := true
for c := range boardCols {
if !g.board[r][c].filled {
full = false
break
}
}
if full {
cleared++
// Shift everything above down.
for rr := r; rr > 0; rr-- {
g.board[rr] = g.board[rr-1]
}
g.board[0] = [boardCols]cell{}
r++ // re-check this row since we shifted
}
}
return cleared
}
// NES-style scoring multipliers per lines cleared.
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
// addScore updates score, lines, and level after clearing rows.
func (g *gameState) addScore(linesCleared int) {
if linesCleared > 0 && linesCleared <= 4 {
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
}
g.lines += linesCleared
// Level up every 10 lines.
newLevel := g.lines / 10
if newLevel > g.level {
g.level = newLevel
}
}
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
// Returns the number of lines cleared.
func (g *gameState) afterLock() int {
g.lockPiece()
cleared := g.clearLines()
g.addScore(cleared)
g.spawnPiece()
return cleared
}
// tickInterval returns the gravity interval in milliseconds for the current level.
func tickInterval(level int) int {
return max(800-level*60, 100)
}

View File

@@ -0,0 +1,331 @@
package tetris
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
type screen int
const (
screenTitle screen = iota
screenGame
screenGameOver
)
type tickMsg time.Time
type lockMsg time.Time
const lockDelay = 500 * time.Millisecond
type model struct {
sess *shell.SessionContext
difficulty string
screen screen
game *gameState
quitting bool
height int
keypresses int
locking bool // true when piece has landed and lock delay is active
}
func newModel(sess *shell.SessionContext, difficulty string) *model {
return &model{
sess: sess,
difficulty: difficulty,
screen: screenTitle,
}
}
func (m *model) Init() tea.Cmd {
return nil
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.quitting {
return m, tea.Quit
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.height = msg.Height
return m, nil
case tea.KeyMsg:
m.keypresses++
if msg.Type == tea.KeyCtrlC {
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
tea.Quit,
)
}
}
switch m.screen {
case screenTitle:
return m.updateTitle(msg)
case screenGame:
return m.updateGame(msg)
case screenGameOver:
return m.updateGameOver(msg)
}
return m, nil
}
func (m *model) View() string {
var content string
switch m.screen {
case screenTitle:
content = m.titleView()
case screenGame:
content = gameView(m.game)
case screenGameOver:
content = m.gameOverView()
}
return gameFrame(content, m.height)
}
// --- Title screen ---
func (m *model) titleView() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" Press any key to start"))
b.WriteString("\n")
return b.String()
}
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
if _, ok := msg.(tea.KeyMsg); ok {
m.screen = screenGame
var startLevel int
if m.difficulty == "hard" {
startLevel = 5
}
m.game = newGame(startLevel)
return m, tea.Batch(
tea.ClearScreen,
m.scheduleTick(),
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
)
}
return m, nil
}
// --- Game screen ---
func (m *model) scheduleTick() tea.Cmd {
ms := tickInterval(m.game.level)
if m.difficulty == "easy" {
ms = max(1000-m.game.level*60, 150)
}
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func (m *model) scheduleLock() tea.Cmd {
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
return lockMsg(t)
})
}
// performLock locks the piece, clears lines, and returns commands for logging
// and scheduling the next tick. Returns nil if game over (goToGameOver is
// included in the returned batch).
func (m *model) performLock() tea.Cmd {
m.locking = false
cleared := m.game.afterLock()
if m.game.gameOver {
return tea.Batch(
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
m.goToGameOver(),
)
}
var cmds []tea.Cmd
cmds = append(cmds, m.scheduleTick())
if cleared > 0 {
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
prevLevel := (m.game.lines - cleared) / 10
if m.game.level > prevLevel {
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
}
}
return tea.Batch(cmds...)
}
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case lockMsg:
if m.game.gameOver || !m.locking {
return m, nil
}
// Lock delay expired — lock the piece now.
return m, m.performLock()
case tickMsg:
if m.game.gameOver || m.locking {
return m, nil
}
if !m.game.moveDown() {
// Piece landed — start lock delay instead of locking immediately.
m.locking = true
return m, m.scheduleLock()
}
return m, m.scheduleTick()
case tea.KeyMsg:
if m.game.gameOver {
return m, nil
}
switch msg.String() {
case "left":
m.game.moveLeft()
// If piece can now drop further, cancel lock delay.
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case "right":
m.game.moveRight()
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case "down":
if m.game.moveDown() {
m.game.score++ // soft drop bonus
if m.locking {
m.locking = false
}
}
case "up", "z":
m.game.rotate()
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case " ":
m.locking = false
dropped := m.game.hardDrop()
m.game.score += dropped * 2
return m, m.performLock()
case "q":
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
tea.Quit,
)
}
return m, nil
}
return m, nil
}
// --- Game over screen ---
func (m *model) goToGameOver() tea.Cmd {
m.screen = screenGameOver
return tea.ClearScreen
}
func (m *model) gameOverView() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(titleStyle.Render(" GAME OVER"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" R - Play again"))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" Q - Quit"))
b.WriteString("\n")
return b.String()
}
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
if keyMsg, ok := msg.(tea.KeyMsg); ok {
switch keyMsg.String() {
case "r":
startLevel := 0
if m.difficulty == "hard" {
startLevel = 5
}
m.game = newGame(startLevel)
m.screen = screenGame
m.keypresses = 0
return m, tea.Batch(
tea.ClearScreen,
m.scheduleTick(),
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
)
case "q":
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
tea.Quit,
)
}
}
return m, nil
}
// Helper methods for safe access when game may be nil.
func (m *model) gameScore() int {
if m.game != nil {
return m.game.score
}
return 0
}
func (m *model) gameLevel() int {
if m.game != nil {
return m.game.level
}
return 0
}
func (m *model) gameLines() int {
if m.game != nil {
return m.game.lines
}
return 0
}
// logAction returns a tea.Cmd that logs an action to the session store.
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
return func() tea.Msg {
if sess.Store != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
}
if sess.OnCommand != nil {
sess.OnCommand("tetris")
}
return nil
}
}

View File

@@ -0,0 +1,286 @@
package tetris
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
const termWidth = 80
var (
colorWhite = lipgloss.Color("#FFFFFF")
colorDim = lipgloss.Color("#555555")
colorBlack = lipgloss.Color("#000000")
colorGhost = lipgloss.Color("#333333")
)
var (
baseStyle = lipgloss.NewStyle().
Foreground(colorWhite).
Background(colorBlack)
dimStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
titleStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#00FFFF")).
Background(colorBlack).
Bold(true)
sidebarLabelStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
sidebarValueStyle = lipgloss.NewStyle().
Foreground(colorWhite).
Background(colorBlack).
Bold(true)
)
// cellStyle returns a style for a filled cell of a given piece type.
func cellStyle(pt pieceType) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(pieceColors[pt]).
Background(colorBlack)
}
// ghostStyle returns a dimmed style for the ghost piece.
func ghostCellStyle() lipgloss.Style {
return lipgloss.NewStyle().
Foreground(colorGhost).
Background(colorBlack)
}
// renderBoard renders the board, current piece, and ghost piece as a string.
func renderBoard(g *gameState) string {
// Build a display grid that includes the current piece and ghost.
type displayCell struct {
filled bool
ghost bool
piece pieceType
}
var grid [boardRows][boardCols]displayCell
// Copy locked cells.
for r := range boardRows {
for c := range boardCols {
if g.board[r][c].filled {
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
}
}
}
// Ghost piece.
ghostR := g.ghostRow()
if ghostR != g.currentRow {
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := ghostR+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
grid[r][c] = displayCell{ghost: true, piece: g.current}
}
}
}
// Current piece.
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := g.currentRow+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
grid[r][c] = displayCell{filled: true, piece: g.current}
}
}
// Render grid.
var b strings.Builder
borderStyle := dimStyle
for _, row := range grid {
b.WriteString(borderStyle.Render("|"))
for _, dc := range row {
switch {
case dc.filled:
b.WriteString(cellStyle(dc.piece).Render("[]"))
case dc.ghost:
b.WriteString(ghostCellStyle().Render("::"))
default:
b.WriteString(baseStyle.Render(" "))
}
}
b.WriteString(borderStyle.Render("|"))
b.WriteString("\n")
}
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
return b.String()
}
// renderNextPiece renders the "next piece" preview box.
func renderNextPiece(pt pieceType) string {
shape := pieces[pt][0]
// Determine bounding box.
minR, maxR := shape[0][0], shape[0][0]
minC, maxC := shape[0][1], shape[0][1]
for _, off := range shape {
if off[0] < minR {
minR = off[0]
}
if off[0] > maxR {
maxR = off[0]
}
if off[1] < minC {
minC = off[1]
}
if off[1] > maxC {
maxC = off[1]
}
}
rows := maxR - minR + 1
cols := maxC - minC + 1
// Build a small grid.
grid := make([][]bool, rows)
for i := range grid {
grid[i] = make([]bool, cols)
}
for _, off := range shape {
grid[off[0]-minR][off[1]-minC] = true
}
var b strings.Builder
boxWidth := 8 // chars for the box interior
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
b.WriteString("\n")
for r := range rows {
b.WriteString(dimStyle.Render("|"))
// Center the piece in the box.
pieceWidth := cols * 2
leftPad := (boxWidth - pieceWidth) / 2
rightPad := boxWidth - pieceWidth - leftPad
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
for c := range cols {
if grid[r][c] {
b.WriteString(cellStyle(pt).Render("[]"))
} else {
b.WriteString(baseStyle.Render(" "))
}
}
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
b.WriteString(dimStyle.Render("|"))
b.WriteString("\n")
}
// Fill remaining rows in the box (max 4 rows for I piece).
for r := rows; r < 2; r++ {
b.WriteString(dimStyle.Render("|"))
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
b.WriteString(dimStyle.Render("|"))
b.WriteString("\n")
}
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
return b.String()
}
// formatScore formats a score with comma separators.
func formatScore(n int) string {
s := fmt.Sprintf("%d", n)
if len(s) <= 3 {
return s
}
var parts []string
for len(s) > 3 {
parts = append([]string{s[len(s)-3:]}, parts...)
s = s[:len(s)-3]
}
parts = append([]string{s}, parts...)
return strings.Join(parts, ",")
}
// gameView combines the board and sidebar into the game screen.
func gameView(g *gameState) string {
boardStr := renderBoard(g)
boardLines := strings.Split(boardStr, "\n")
nextStr := renderNextPiece(g.next)
nextLines := strings.Split(nextStr, "\n")
// Build sidebar lines.
var sidebar []string
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
sidebar = append(sidebar, nextLines...)
sidebar = append(sidebar, "")
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
sidebar = append(sidebar, "")
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
// Combine board and sidebar side by side.
var b strings.Builder
maxLines := max(len(boardLines), len(sidebar))
for i := range maxLines {
boardLine := ""
if i < len(boardLines) {
boardLine = boardLines[i]
}
sidebarLine := ""
if i < len(sidebar) {
sidebarLine = sidebar[i]
}
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
b.WriteString(boardLine)
b.WriteString(sidebarLine)
b.WriteString("\n")
}
return b.String()
}
// padLine pads a single line to termWidth.
func padLine(line string) string {
w := lipgloss.Width(line)
if w >= termWidth {
return line
}
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
}
// padLines pads every line in a multi-line string to termWidth.
func padLines(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = padLine(line)
}
return strings.Join(lines, "\n")
}
// gameFrame wraps content with padding to fill the terminal.
func gameFrame(content string, height int) string {
var b strings.Builder
b.WriteString(content)
// Pad with blank lines to fill terminal height.
if height > 0 {
contentLines := strings.Count(content, "\n") + 1
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
for i := contentLines; i < height; i++ {
b.WriteString(blankLine)
b.WriteString("\n")
}
}
return padLines(b.String())
}

View File

@@ -0,0 +1,66 @@
package tetris
import (
"context"
"io"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
// TetrisShell is a Tetris game TUI for the honeypot.
type TetrisShell struct{}
// NewTetrisShell returns a new TetrisShell instance.
func NewTetrisShell() *TetrisShell {
return &TetrisShell{}
}
func (t *TetrisShell) Name() string { return "tetris" }
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
m := newModel(sess, difficulty)
p := tea.NewProgram(m,
tea.WithInput(rw),
tea.WithOutput(rw),
tea.WithAltScreen(),
)
done := make(chan error, 1)
go func() {
_, err := p.Run()
done <- err
}()
select {
case err := <-done:
return err
case <-ctx.Done():
p.Quit()
<-done
return nil
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,582 @@
package tetris
import (
"context"
"strings"
"testing"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// newTestModel creates a model with a test session context.
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "player",
Store: store,
}
m := newModel(sess, "normal")
return m, store
}
// sendKey sends a single key message to the model and returns the command.
func sendKey(m *model, key string) tea.Cmd {
var msg tea.KeyMsg
switch key {
case "enter":
msg = tea.KeyMsg{Type: tea.KeyEnter}
case "up":
msg = tea.KeyMsg{Type: tea.KeyUp}
case "down":
msg = tea.KeyMsg{Type: tea.KeyDown}
case "left":
msg = tea.KeyMsg{Type: tea.KeyLeft}
case "right":
msg = tea.KeyMsg{Type: tea.KeyRight}
case "space":
msg = tea.KeyMsg{Type: tea.KeySpace}
case "ctrl+c":
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
default:
if len(key) == 1 {
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
}
}
_, cmd := m.Update(msg)
return cmd
}
// sendTick sends a tick message to the model.
func sendTick(m *model) tea.Cmd {
_, cmd := m.Update(tickMsg(time.Now()))
return cmd
}
// execCmds recursively executes tea.Cmd functions (including batches).
func execCmds(cmd tea.Cmd) {
if cmd == nil {
return
}
msg := cmd()
if batch, ok := msg.(tea.BatchMsg); ok {
for _, c := range batch {
execCmds(c)
}
}
}
func TestTetrisShellName(t *testing.T) {
sh := NewTetrisShell()
if sh.Name() != "tetris" {
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestConfigString(t *testing.T) {
cfg := map[string]any{
"difficulty": "hard",
}
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
t.Errorf("configString() = %q, want %q", got, "hard")
}
if got := configString(cfg, "missing", "normal"); got != "normal" {
t.Errorf("configString() = %q, want %q", got, "normal")
}
if got := configString(nil, "difficulty", "normal"); got != "normal" {
t.Errorf("configString(nil) = %q, want %q", got, "normal")
}
}
func TestTitleScreenRenders(t *testing.T) {
m, _ := newTestModel(t)
view := m.View()
if !strings.Contains(view, "████") {
t.Error("title screen should show TETRIS logo")
}
if !strings.Contains(view, "Press any key") {
t.Error("title screen should show 'Press any key'")
}
}
func TestTitleToGame(t *testing.T) {
m, _ := newTestModel(t)
if m.screen != screenTitle {
t.Fatalf("expected screenTitle, got %d", m.screen)
}
sendKey(m, "enter")
if m.screen != screenGame {
t.Errorf("expected screenGame after keypress, got %d", m.screen)
}
if m.game == nil {
t.Fatal("game should be initialized")
}
}
func TestGameRenders(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
view := m.View()
if !strings.Contains(view, "|") {
t.Error("game view should contain board borders")
}
if !strings.Contains(view, "SCORE") {
t.Error("game view should show SCORE")
}
if !strings.Contains(view, "LEVEL") {
t.Error("game view should show LEVEL")
}
if !strings.Contains(view, "LINES") {
t.Error("game view should show LINES")
}
if !strings.Contains(view, "NEXT") {
t.Error("game view should show NEXT")
}
}
// --- Pure game logic tests ---
func TestNewGame(t *testing.T) {
g := newGame(0)
if g.gameOver {
t.Error("new game should not be game over")
}
if g.score != 0 {
t.Errorf("initial score = %d, want 0", g.score)
}
if g.level != 0 {
t.Errorf("initial level = %d, want 0", g.level)
}
if g.lines != 0 {
t.Errorf("initial lines = %d, want 0", g.lines)
}
}
func TestNewGameHardLevel(t *testing.T) {
g := newGame(5)
if g.level != 5 {
t.Errorf("hard start level = %d, want 5", g.level)
}
}
func TestMoveLeft(t *testing.T) {
g := newGame(0)
startCol := g.currentCol
g.moveLeft()
if g.currentCol != startCol-1 {
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
}
}
func TestMoveRight(t *testing.T) {
g := newGame(0)
startCol := g.currentCol
g.moveRight()
if g.currentCol != startCol+1 {
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
}
}
func TestMoveDown(t *testing.T) {
g := newGame(0)
startRow := g.currentRow
moved := g.moveDown()
if !moved {
t.Error("moveDown should succeed from starting position")
}
if g.currentRow != startRow+1 {
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
}
}
func TestCannotMoveLeftBeyondWall(t *testing.T) {
g := newGame(0)
// Move all the way left.
for range boardCols {
g.moveLeft()
}
col := g.currentCol
g.moveLeft() // should not move further
if g.currentCol != col {
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
}
}
func TestCannotMoveRightBeyondWall(t *testing.T) {
g := newGame(0)
// Move all the way right.
for range boardCols {
g.moveRight()
}
col := g.currentCol
g.moveRight() // should not move further
if g.currentCol != col {
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
}
}
func TestRotate(t *testing.T) {
g := newGame(0)
startRot := g.currentRot
g.rotate()
// Rotation should change (possibly with wall kick).
if g.currentRot == startRot {
// Rotation might legitimately fail in some edge cases, so just check
// that the game state is valid.
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
t.Error("piece should be in a valid position after rotate attempt")
}
}
}
func TestHardDrop(t *testing.T) {
g := newGame(0)
startRow := g.currentRow
dropped := g.hardDrop()
if dropped == 0 {
t.Error("hard drop should move piece down at least some rows from top")
}
if g.currentRow <= startRow {
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
}
}
func TestGhostRow(t *testing.T) {
g := newGame(0)
ghost := g.ghostRow()
if ghost < g.currentRow {
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
}
// Ghost should be at a position where moving down one more is impossible.
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
t.Error("ghost row should be the lowest valid position")
}
}
func TestLockPiece(t *testing.T) {
g := newGame(0)
g.hardDrop()
pt := g.current
row, col, rot := g.currentRow, g.currentCol, g.currentRot
g.lockPiece()
// Verify that the piece's cells are now filled.
shape := pieces[pt][rot]
for _, off := range shape {
r, c := row+off[0], col+off[1]
if !g.board[r][c].filled {
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
}
}
}
func TestClearLines(t *testing.T) {
g := newGame(0)
// Fill the bottom row completely.
for c := range boardCols {
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
}
cleared := g.clearLines()
if cleared != 1 {
t.Errorf("clearLines() = %d, want 1", cleared)
}
// Bottom row should now be empty (shifted from above).
for c := range boardCols {
if g.board[boardRows-1][c].filled {
t.Errorf("bottom row col %d should be empty after clearing", c)
}
}
}
func TestClearMultipleLines(t *testing.T) {
g := newGame(0)
// Fill the bottom 4 rows.
for r := boardRows - 4; r < boardRows; r++ {
for c := range boardCols {
g.board[r][c] = cell{filled: true, piece: pieceI}
}
}
cleared := g.clearLines()
if cleared != 4 {
t.Errorf("clearLines() = %d, want 4", cleared)
}
}
func TestScoring(t *testing.T) {
tests := []struct {
lines int
level int
want int
}{
{1, 0, 40},
{2, 0, 100},
{3, 0, 300},
{4, 0, 1200},
{1, 1, 80},
{4, 2, 3600},
}
for _, tt := range tests {
g := newGame(tt.level)
g.addScore(tt.lines)
if g.score != tt.want {
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
}
}
}
func TestLevelUp(t *testing.T) {
g := newGame(0)
g.lines = 9
g.addScore(1) // This should push lines to 10, triggering level 1.
if g.level != 1 {
t.Errorf("level = %d, want 1 after 10 lines", g.level)
}
}
func TestTickInterval(t *testing.T) {
if got := tickInterval(0); got != 800 {
t.Errorf("tickInterval(0) = %d, want 800", got)
}
if got := tickInterval(5); got != 500 {
t.Errorf("tickInterval(5) = %d, want 500", got)
}
// Floor at 100ms.
if got := tickInterval(20); got != 100 {
t.Errorf("tickInterval(20) = %d, want 100", got)
}
}
func TestFormatScore(t *testing.T) {
tests := []struct {
n int
want string
}{
{0, "0"},
{100, "100"},
{1250, "1,250"},
{1000000, "1,000,000"},
}
for _, tt := range tests {
if got := formatScore(tt.n); got != tt.want {
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
}
}
}
func TestGameOverScreen(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Force game over.
m.game.gameOver = true
m.screen = screenGameOver
view := m.View()
if !strings.Contains(view, "GAME OVER") {
t.Error("game over screen should show GAME OVER")
}
if !strings.Contains(view, "Score") {
t.Error("game over screen should show score")
}
}
func TestRestartFromGameOver(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
m.game.gameOver = true
m.screen = screenGameOver
sendKey(m, "r")
if m.screen != screenGame {
t.Errorf("expected screenGame after restart, got %d", m.screen)
}
if m.game.gameOver {
t.Error("game should not be over after restart")
}
}
func TestQuitFromGame(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
sendKey(m, "q")
if !m.quitting {
t.Error("should be quitting after pressing q")
}
}
func TestQuitFromGameOver(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
m.game.gameOver = true
m.screen = screenGameOver
sendKey(m, "q")
if !m.quitting {
t.Error("should be quitting after pressing q in game over")
}
}
func TestSoftDropScoring(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
scoreBefore := m.game.score
sendKey(m, "down")
if m.game.score != scoreBefore+1 {
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
}
}
func TestHardDropScoring(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Hard drop gives 2 points per row dropped.
sendKey(m, "space")
if m.game.score < 2 {
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
}
}
func TestTickMovesDown(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
rowBefore := m.game.currentRow
sendTick(m)
// Piece should either move down by 1, or lock and spawn a new piece at top.
movedDown := m.game.currentRow == rowBefore+1
respawned := m.game.currentRow < rowBefore
if !movedDown && !respawned && !m.game.gameOver {
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
}
}
func TestSessionLogs(t *testing.T) {
m, store := newTestModel(t)
// Press key to start game — returns a logAction cmd.
cmd := sendKey(m, "enter")
if cmd != nil {
execCmds(cmd)
}
time.Sleep(50 * time.Millisecond)
found := false
for _, log := range store.SessionLogs {
if strings.Contains(log.Input, "GAME START") {
found = true
break
}
}
if !found {
t.Error("expected GAME START in session logs")
}
}
func TestKeypressCounter(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
sendKey(m, "left")
sendKey(m, "right")
sendKey(m, "down")
if m.keypresses != 4 { // enter + 3 game keys
t.Errorf("keypresses = %d, want 4", m.keypresses)
}
}
func TestLockDelay(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Drop piece to the bottom via ticks until it can't move down.
for range boardRows + 5 {
if m.locking {
break
}
sendTick(m)
}
if !m.locking {
t.Fatal("piece should be in locking state after hitting bottom")
}
// During lock delay, we should still be able to move left/right.
colBefore := m.game.currentCol
sendKey(m, "left")
if m.game.currentCol >= colBefore {
// Might not have moved if against wall, try right.
sendKey(m, "right")
}
// Sending a lockMsg should finalize the piece.
m.Update(lockMsg(time.Now()))
// After lock, a new piece should have spawned (row near top).
if m.game.currentRow > 1 && !m.game.gameOver {
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
}
}
func TestLockDelayCancelledByDrop(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Build a ledge: fill rows 18-19 but leave column 0 empty.
for r := boardRows - 2; r < boardRows; r++ {
for c := 1; c < boardCols; c++ {
m.game.board[r][c] = cell{filled: true, piece: pieceI}
}
}
// Move piece to column 0 area and drop it onto the ledge.
for range boardCols {
m.game.moveLeft()
}
// Tick down until locking.
for range boardRows + 5 {
if m.locking {
break
}
sendTick(m)
}
// If piece is on the ledge and we slide it to col 0 (open column),
// the lock delay should cancel since it can fall further.
// This test just validates the locking flag logic works.
if m.locking {
// Try moving — if piece can drop further, locking should cancel.
sendKey(m, "left")
// Whether locking cancels depends on the board state; just verify no crash.
}
}
func TestSpawnCol(t *testing.T) {
// All pieces should spawn roughly centered.
for pt := range pieceType(numPieceTypes) {
col := spawnCol(pt, 0)
if col < 0 || col > boardCols-1 {
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
}
// Verify piece fits at spawn position.
shape := pieces[pt][0]
for _, off := range shape {
c := col + off[1]
if c < 0 || c >= boardCols {
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
}
}
}
}

View File

@@ -0,0 +1,217 @@
package storage
import (
"context"
"time"
"github.com/prometheus/client_golang/prometheus"
)
// InstrumentedStore wraps a Store and records query duration and errors
// as Prometheus metrics for each method call.
type InstrumentedStore struct {
store Store
queryDuration *prometheus.HistogramVec
queryErrors *prometheus.CounterVec
}
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
return &InstrumentedStore{
store: store,
queryDuration: queryDuration,
queryErrors: queryErrors,
}
}
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
v, err := fn()
timer.ObserveDuration()
if err != nil {
s.queryErrors.WithLabelValues(method).Inc()
}
return v, err
}
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
err := fn()
timer.ObserveDuration()
if err != nil {
s.queryErrors.WithLabelValues(method).Inc()
}
return err
}
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
return observeErr(s, "RecordLoginAttempt", func() error {
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
})
}
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
return observe(s, "CreateSession", func() (string, error) {
return s.store.CreateSession(ctx, ip, username, shellName, country)
})
}
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
return observeErr(s, "EndSession", func() error {
return s.store.EndSession(ctx, sessionID, disconnectedAt)
})
}
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
return observeErr(s, "UpdateHumanScore", func() error {
return s.store.UpdateHumanScore(ctx, sessionID, score)
})
}
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
return observeErr(s, "SetExecCommand", func() error {
return s.store.SetExecCommand(ctx, sessionID, command)
})
}
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
return observeErr(s, "AppendSessionLog", func() error {
return s.store.AppendSessionLog(ctx, sessionID, input, output)
})
}
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
return s.store.DeleteRecordsBefore(ctx, cutoff)
})
}
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
return s.store.GetDashboardStats(ctx)
})
}
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
return s.store.GetTopUsernames(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
return s.store.GetTopPasswords(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
return s.store.GetTopIPs(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
return s.store.GetTopCountries(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
return s.store.GetTopExecCommands(ctx, limit)
})
}
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
return observe(s, "GetRecentSessions", func() ([]Session, error) {
return s.store.GetRecentSessions(ctx, limit, activeOnly)
})
}
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
})
}
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
return observe(s, "GetSession", func() (*Session, error) {
return s.store.GetSession(ctx, sessionID)
})
}
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
return s.store.GetSessionLogs(ctx, sessionID)
})
}
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
return observeErr(s, "AppendSessionEvents", func() error {
return s.store.AppendSessionEvents(ctx, events)
})
}
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
return s.store.GetSessionEvents(ctx, sessionID)
})
}
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
return observe(s, "CloseActiveSessions", func() (int64, error) {
return s.store.CloseActiveSessions(ctx, disconnectedAt)
})
}
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
return s.store.GetAttemptsOverTime(ctx, days, since, until)
})
}
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
return s.store.GetHourlyPattern(ctx, since, until)
})
}
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
return s.store.GetCountryStats(ctx)
})
}
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
return s.store.GetFilteredDashboardStats(ctx, f)
})
}
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
return s.store.GetFilteredTopUsernames(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
return s.store.GetFilteredTopPasswords(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
return s.store.GetFilteredTopIPs(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
return s.store.GetFilteredTopCountries(ctx, limit, f)
})
}
func (s *InstrumentedStore) Close() error {
return s.store.Close()
}

View File

@@ -0,0 +1,163 @@
package storage
import (
"context"
"errors"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
)
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_query_errors_total",
Help: "test",
}, []string{"method"})
store := NewMemoryStore()
return NewInstrumentedStore(store, dur, errs), dur, errs
}
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
m := &dto.Metric{}
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
return m.GetHistogram().GetSampleCount()
}
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
m := &dto.Metric{}
c.WithLabelValues(method).Write(m)
return m.GetCounter().GetValue()
}
func TestInstrumentedStoreDelegation(t *testing.T) {
s, dur, _ := newTestInstrumented()
ctx := context.Background()
// RecordLoginAttempt should delegate and record duration.
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
// CreateSession should delegate and return a valid ID.
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if id == "" {
t.Fatal("CreateSession returned empty ID")
}
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
// GetDashboardStats should delegate.
stats, err := s.GetDashboardStats(ctx)
if err != nil {
t.Fatalf("GetDashboardStats: %v", err)
}
if stats == nil {
t.Fatal("GetDashboardStats returned nil")
}
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
}
func TestInstrumentedStoreErrorCounting(t *testing.T) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_ec_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_ec_query_errors_total",
Help: "test",
}, []string{"method"})
es := &errorStore{}
s := NewInstrumentedStore(es, dur, errs)
ctx := context.Background()
// Error should be counted.
err := s.EndSession(ctx, "nonexistent", time.Now())
if !errors.Is(err, errFake) {
t.Fatalf("expected errFake, got %v", err)
}
if c := getHistogramCount(dur, "EndSession"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
if c := getCounterValue(errs, "EndSession"); c != 1 {
t.Fatalf("expected error count 1, got %f", c)
}
// Successful call should not increment error counter.
s2, _, errs2 := newTestInstrumented()
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
t.Fatalf("expected error count 0, got %f", c)
}
}
// errorStore is a Store that returns errors for all methods.
type errorStore struct {
MemoryStore
}
var errFake = errors.New("fake error")
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
return errFake
}
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
return errFake
}
func TestInstrumentedStoreObserveErr(t *testing.T) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test2_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test2_query_errors_total",
Help: "test",
}, []string{"method"})
es := &errorStore{}
s := NewInstrumentedStore(es, dur, errs)
ctx := context.Background()
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if !errors.Is(err, errFake) {
t.Fatalf("expected errFake, got %v", err)
}
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected error count 1, got %f", c)
}
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
}
func TestInstrumentedStoreClose(t *testing.T) {
s, _, _ := newTestInstrumented()
if err := s.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
}

View File

@@ -25,7 +25,7 @@ func NewMemoryStore() *MemoryStore {
}
}
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip string) error {
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip, country string) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -35,6 +35,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
if a.Username == username && a.Password == password && a.IP == ip {
a.Count++
a.LastSeen = now
a.Country = country
return nil
}
}
@@ -44,6 +45,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
Username: username,
Password: password,
IP: ip,
Country: country,
Count: 1,
FirstSeen: now,
LastSeen: now,
@@ -51,7 +53,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
return nil
}
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName string) (string, error) {
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName, country string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -60,6 +62,7 @@ func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName s
m.Sessions[id] = &Session{
ID: id,
IP: ip,
Country: country,
Username: username,
ShellName: shellName,
ConnectedAt: now,
@@ -88,6 +91,16 @@ func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, scor
return nil
}
func (m *MemoryStore) SetExecCommand(_ context.Context, sessionID string, command string) error {
m.mu.Lock()
defer m.mu.Unlock()
if s, ok := m.Sessions[sessionID]; ok {
s.ExecCommand = &command
}
return nil
}
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -234,7 +247,60 @@ func (m *MemoryStore) GetTopPasswords(_ context.Context, limit int) ([]TopEntry,
func (m *MemoryStore) GetTopIPs(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.topN("ip", limit), nil
type ipInfo struct {
count int64
country string
}
agg := make(map[string]*ipInfo)
for _, a := range m.LoginAttempts {
info, ok := agg[a.IP]
if !ok {
info = &ipInfo{}
agg[a.IP] = info
}
info.count += int64(a.Count)
if a.Country != "" {
info.country = a.Country
}
}
entries := make([]TopEntry, 0, len(agg))
for ip, info := range agg {
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) GetTopCountries(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.Country == "" {
continue
}
counts[a.Country] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
// topN aggregates login attempts by the given field and returns the top N. Must be called with m.mu held.
@@ -270,20 +336,372 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
m.mu.Lock()
defer m.mu.Unlock()
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
}
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.collectSessions(limit, activeOnly, f), nil
}
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
// Compute event counts and input bytes per session.
eventCounts := make(map[string]int)
inputBytes := make(map[string]int64)
for _, e := range m.SessionEvents {
eventCounts[e.SessionID]++
if e.Direction == 0 {
inputBytes[e.SessionID] += int64(len(e.Data))
}
}
var sessions []Session
for _, s := range m.Sessions {
if activeOnly && s.DisconnectedAt != nil {
continue
}
sessions = append(sessions, *s)
if !matchesSessionFilter(s, f) {
continue
}
sess := *s
sess.EventCount = eventCounts[s.ID]
sess.InputBytes = inputBytes[s.ID]
sessions = append(sessions, sess)
}
if f.SortBy == "input_bytes" {
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].InputBytes > sessions[j].InputBytes
})
} else {
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
})
}
if limit > 0 && len(sessions) > limit {
sessions = sessions[:limit]
}
return sessions, nil
return sessions
}
// matchesSessionFilter returns true if the session matches the given filter.
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
return false
}
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
return false
}
if f.IP != "" && s.IP != f.IP {
return false
}
if f.Country != "" && s.Country != f.Country {
return false
}
if f.Username != "" && s.Username != f.Username {
return false
}
if f.HumanScoreAboveZero {
if s.HumanScore == nil || *s.HumanScore <= 0 {
return false
}
}
return true
}
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, s := range m.Sessions {
if s.ExecCommand != nil {
counts[*s.ExecCommand]++
}
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
var count int64
t := disconnectedAt.UTC()
for _, s := range m.Sessions {
if s.DisconnectedAt == nil {
s.DisconnectedAt = &t
count++
}
}
return count, nil
}
func (m *MemoryStore) GetAttemptsOverTime(_ context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
m.mu.Lock()
defer m.mu.Unlock()
var cutoff time.Time
if since != nil {
cutoff = *since
} else {
cutoff = time.Now().UTC().AddDate(0, 0, -days)
}
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.LastSeen.Before(cutoff) {
continue
}
if until != nil && a.LastSeen.After(*until) {
continue
}
day := a.LastSeen.Format("2006-01-02")
counts[day] += int64(a.Count)
}
points := make([]TimeSeriesPoint, 0, len(counts))
for day, count := range counts {
t, _ := time.Parse("2006-01-02", day)
points = append(points, TimeSeriesPoint{Timestamp: t, Count: count})
}
sort.Slice(points, func(i, j int) bool {
return points[i].Timestamp.Before(points[j].Timestamp)
})
return points, nil
}
func (m *MemoryStore) GetHourlyPattern(_ context.Context, since, until *time.Time) ([]HourlyCount, error) {
m.mu.Lock()
defer m.mu.Unlock()
hourCounts := make(map[int]int64)
for _, a := range m.LoginAttempts {
if since != nil && a.LastSeen.Before(*since) {
continue
}
if until != nil && a.LastSeen.After(*until) {
continue
}
hour := a.LastSeen.Hour()
hourCounts[hour] += int64(a.Count)
}
counts := make([]HourlyCount, 0, len(hourCounts))
for h, c := range hourCounts {
counts = append(counts, HourlyCount{Hour: h, Count: c})
}
sort.Slice(counts, func(i, j int) bool {
return counts[i].Hour < counts[j].Hour
})
return counts, nil
}
func (m *MemoryStore) GetCountryStats(_ context.Context) ([]CountryCount, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.Country == "" {
continue
}
counts[a.Country] += int64(a.Count)
}
result := make([]CountryCount, 0, len(counts))
for country, count := range counts {
result = append(result, CountryCount{Country: country, Count: count})
}
sort.Slice(result, func(i, j int) bool {
return result[i].Count > result[j].Count
})
return result, nil
}
// matchesFilter returns true if the login attempt matches the given filter. Must be called with m.mu held.
func matchesFilter(a *LoginAttempt, f DashboardFilter) bool {
if f.Since != nil && a.LastSeen.Before(*f.Since) {
return false
}
if f.Until != nil && a.LastSeen.After(*f.Until) {
return false
}
if f.IP != "" && a.IP != f.IP {
return false
}
if f.Country != "" && a.Country != f.Country {
return false
}
if f.Username != "" && a.Username != f.Username {
return false
}
return true
}
func (m *MemoryStore) GetFilteredDashboardStats(_ context.Context, f DashboardFilter) (*DashboardStats, error) {
m.mu.Lock()
defer m.mu.Unlock()
stats := &DashboardStats{}
ips := make(map[string]struct{})
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
stats.TotalAttempts += int64(a.Count)
ips[a.IP] = struct{}{}
}
stats.UniqueIPs = int64(len(ips))
for _, s := range m.Sessions {
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
continue
}
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
continue
}
if f.IP != "" && s.IP != f.IP {
continue
}
if f.Country != "" && s.Country != f.Country {
continue
}
stats.TotalSessions++
if s.DisconnectedAt == nil {
stats.ActiveSessions++
}
}
return stats, nil
}
func (m *MemoryStore) GetFilteredTopUsernames(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.filteredTopN("username", limit, f), nil
}
func (m *MemoryStore) GetFilteredTopPasswords(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.filteredTopN("password", limit, f), nil
}
func (m *MemoryStore) GetFilteredTopIPs(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
type ipInfo struct {
count int64
country string
}
agg := make(map[string]*ipInfo)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
info, ok := agg[a.IP]
if !ok {
info = &ipInfo{}
agg[a.IP] = info
}
info.count += int64(a.Count)
if a.Country != "" {
info.country = a.Country
}
}
entries := make([]TopEntry, 0, len(agg))
for ip, info := range agg {
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) GetFilteredTopCountries(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if a.Country == "" {
continue
}
if !matchesFilter(a, f) {
continue
}
counts[a.Country] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
// filteredTopN aggregates login attempts by the given field with filter applied and returns the top N. Must be called with m.mu held.
func (m *MemoryStore) filteredTopN(field string, limit int, f DashboardFilter) []TopEntry {
counts := make(map[string]int64)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
var key string
switch field {
case "username":
key = a.Username
case "password":
key = a.Password
case "ip":
key = a.IP
}
counts[key] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries
}
func (m *MemoryStore) Close() error {

View File

@@ -0,0 +1,3 @@
ALTER TABLE login_attempts ADD COLUMN country TEXT NOT NULL DEFAULT '';
ALTER TABLE sessions ADD COLUMN country TEXT NOT NULL DEFAULT '';
CREATE INDEX idx_login_attempts_country ON login_attempts(country);

View File

@@ -0,0 +1 @@
ALTER TABLE sessions ADD COLUMN exec_command TEXT;

View File

@@ -0,0 +1,3 @@
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);

View File

@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
t.Fatalf("query version: %v", err)
}
if version != 2 {
t.Errorf("version = %d, want 2", version)
if version != 5 {
t.Errorf("version = %d, want 5", version)
}
// Verify tables exist by inserting into them.
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
t.Fatalf("query version: %v", err)
}
if version != 2 {
t.Errorf("version = %d after double migrate, want 2", version)
if version != 5 {
t.Errorf("version = %d after double migrate, want 5", version)
}
}

View File

@@ -22,7 +22,7 @@ func TestRunRetentionDeletesOldRecords(t *testing.T) {
}
// Insert a recent login attempt.
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
t.Fatalf("insert recent attempt: %v", err)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
@@ -34,28 +35,29 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
return &SQLiteStore{db: db}, nil
}
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip string) error {
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
INSERT INTO login_attempts (username, password, ip, count, first_seen, last_seen)
VALUES (?, ?, ?, 1, ?, ?)
INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
VALUES (?, ?, ?, ?, 1, ?, ?)
ON CONFLICT(username, password, ip) DO UPDATE SET
count = count + 1,
last_seen = ?`,
username, password, ip, now, now, now)
last_seen = ?,
country = ?`,
username, password, ip, country, now, now, now, country)
if err != nil {
return fmt.Errorf("recording login attempt: %w", err)
}
return nil
}
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName string) (string, error) {
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
id := uuid.New().String()
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
INSERT INTO sessions (id, ip, username, shell_name, connected_at)
VALUES (?, ?, ?, ?, ?)`,
id, ip, username, shellName, now)
INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
VALUES (?, ?, ?, ?, ?, ?)`,
id, ip, username, shellName, country, now)
if err != nil {
return "", fmt.Errorf("creating session: %w", err)
}
@@ -82,6 +84,16 @@ func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, sc
return nil
}
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
_, err := s.db.ExecContext(ctx, `
UPDATE sessions SET exec_command = ? WHERE id = ?`,
command, sessionID)
if err != nil {
return fmt.Errorf("setting exec command: %w", err)
}
return nil
}
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
@@ -99,12 +111,13 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
var execCommand sql.NullString
err := s.db.QueryRowContext(ctx, `
SELECT id, ip, username, shell_name, connected_at, disconnected_at, human_score
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
FROM sessions WHERE id = ?`, sessionID).Scan(
&sess.ID, &sess.IP, &sess.Username, &sess.ShellName,
&connectedAt, &disconnectedAt, &humanScore,
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
)
if err == sql.ErrNoRows {
return nil, nil
@@ -121,6 +134,9 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
if humanScore.Valid {
sess.HumanScore = &humanScore.Float64
}
if execCommand.Valid {
sess.ExecCommand = &execCommand.String
}
return &sess, nil
}
@@ -288,10 +304,60 @@ func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntr
}
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
return s.queryTopN(ctx, "ip", limit)
rows, err := s.db.QueryContext(ctx, `
SELECT ip, country, SUM(count) AS total
FROM login_attempts
GROUP BY ip
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top IPs: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top IPs: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT country, SUM(count) AS total
FROM login_attempts
WHERE country != ''
GROUP BY country
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top countries: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top countries: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
switch column {
case "username", "password", "ip":
// valid columns
default:
return nil, fmt.Errorf("invalid column: %s", column)
}
query := fmt.Sprintf(`
SELECT %s, SUM(count) AS total
FROM login_attempts
@@ -317,40 +383,401 @@ func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) (
}
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
query := `SELECT id, ip, username, shell_name, connected_at, disconnected_at, human_score FROM sessions`
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
if activeOnly {
query += ` WHERE disconnected_at IS NULL`
query += ` WHERE s.disconnected_at IS NULL`
}
query += ` ORDER BY connected_at DESC LIMIT ?`
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, limit)
return s.scanSessions(ctx, query, limit)
}
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
var clauses []string
var args []any
if activeOnly {
clauses = append(clauses, "s.disconnected_at IS NULL")
}
if f.Since != nil {
clauses = append(clauses, "s.connected_at >= ?")
args = append(args, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
clauses = append(clauses, "s.connected_at <= ?")
args = append(args, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
clauses = append(clauses, "s.ip = ?")
args = append(args, f.IP)
}
if f.Country != "" {
clauses = append(clauses, "s.country = ?")
args = append(args, f.Country)
}
if f.Username != "" {
clauses = append(clauses, "s.username = ?")
args = append(args, f.Username)
}
if f.HumanScoreAboveZero {
clauses = append(clauses, "s.human_score > 0")
}
if len(clauses) == 0 {
return "", nil
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
var validSessionSorts = map[string]string{
"connected_at": "s.connected_at DESC",
"input_bytes": "input_bytes DESC",
}
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
where, args := buildSessionWhereClause(f, activeOnly)
args = append(args, limit)
orderBy := validSessionSorts["connected_at"]
if mapped, ok := validSessionSorts[f.SortBy]; ok {
orderBy = mapped
}
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
return s.scanSessions(ctx, query, args...)
}
// scanSessions executes a session query and scans the results.
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying recent sessions: %w", err)
return nil, fmt.Errorf("querying sessions: %w", err)
}
defer func() { _ = rows.Close() }()
var sessions []Session
for rows.Next() {
var s Session
var sess Session
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
if err := rows.Scan(&s.ID, &s.IP, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore); err != nil {
var execCommand sql.NullString
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
return nil, fmt.Errorf("scanning session: %w", err)
}
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
if disconnectedAt.Valid {
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
s.DisconnectedAt = &t
sess.DisconnectedAt = &t
}
if humanScore.Valid {
s.HumanScore = &humanScore.Float64
sess.HumanScore = &humanScore.Float64
}
sessions = append(sessions, s)
if execCommand.Valid {
sess.ExecCommand = &execCommand.String
}
sessions = append(sessions, sess)
}
return sessions, rows.Err()
}
func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT exec_command, COUNT(*) as total
FROM sessions
WHERE exec_command IS NOT NULL
GROUP BY exec_command
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top exec commands: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top exec commands: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
res, err := s.db.ExecContext(ctx, `
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
disconnectedAt.UTC().Format(time.RFC3339))
if err != nil {
return 0, fmt.Errorf("closing active sessions: %w", err)
}
return res.RowsAffected()
}
func (s *SQLiteStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
query := `SELECT DATE(last_seen) AS d, SUM(count) FROM login_attempts WHERE 1=1`
var args []any
if since != nil {
query += ` AND last_seen >= ?`
args = append(args, since.UTC().Format(time.RFC3339))
} else {
query += ` AND last_seen >= ?`
args = append(args, time.Now().UTC().AddDate(0, 0, -days).Format("2006-01-02"))
}
if until != nil {
query += ` AND last_seen <= ?`
args = append(args, until.UTC().Format(time.RFC3339))
}
query += ` GROUP BY d ORDER BY d`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying attempts over time: %w", err)
}
defer func() { _ = rows.Close() }()
var points []TimeSeriesPoint
for rows.Next() {
var dateStr string
var p TimeSeriesPoint
if err := rows.Scan(&dateStr, &p.Count); err != nil {
return nil, fmt.Errorf("scanning time series point: %w", err)
}
p.Timestamp, _ = time.Parse("2006-01-02", dateStr)
points = append(points, p)
}
return points, rows.Err()
}
func (s *SQLiteStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
query := `SELECT CAST(STRFTIME('%H', last_seen) AS INTEGER) AS h, SUM(count) FROM login_attempts WHERE 1=1`
var args []any
if since != nil {
query += ` AND last_seen >= ?`
args = append(args, since.UTC().Format(time.RFC3339))
}
if until != nil {
query += ` AND last_seen <= ?`
args = append(args, until.UTC().Format(time.RFC3339))
}
query += ` GROUP BY h ORDER BY h`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying hourly pattern: %w", err)
}
defer func() { _ = rows.Close() }()
var counts []HourlyCount
for rows.Next() {
var c HourlyCount
if err := rows.Scan(&c.Hour, &c.Count); err != nil {
return nil, fmt.Errorf("scanning hourly count: %w", err)
}
counts = append(counts, c)
}
return counts, rows.Err()
}
func (s *SQLiteStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT country, SUM(count) AS total
FROM login_attempts
WHERE country != ''
GROUP BY country
ORDER BY total DESC`)
if err != nil {
return nil, fmt.Errorf("querying country stats: %w", err)
}
defer func() { _ = rows.Close() }()
var counts []CountryCount
for rows.Next() {
var c CountryCount
if err := rows.Scan(&c.Country, &c.Count); err != nil {
return nil, fmt.Errorf("scanning country count: %w", err)
}
counts = append(counts, c)
}
return counts, rows.Err()
}
// buildAttemptWhereClause builds a dynamic WHERE clause for login_attempts filtering.
func buildAttemptWhereClause(f DashboardFilter) (string, []any) {
var clauses []string
var args []any
if f.Since != nil {
clauses = append(clauses, "last_seen >= ?")
args = append(args, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
clauses = append(clauses, "last_seen <= ?")
args = append(args, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
clauses = append(clauses, "ip = ?")
args = append(args, f.IP)
}
if f.Country != "" {
clauses = append(clauses, "country = ?")
args = append(args, f.Country)
}
if f.Username != "" {
clauses = append(clauses, "username = ?")
args = append(args, f.Username)
}
if len(clauses) == 0 {
return "", nil
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
func (s *SQLiteStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
where, args := buildAttemptWhereClause(f)
stats := &DashboardStats{}
err := s.db.QueryRowContext(ctx,
`SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`+where, args...).
Scan(&stats.TotalAttempts, &stats.UniqueIPs)
if err != nil {
return nil, fmt.Errorf("querying filtered attempt stats: %w", err)
}
// Sessions don't have username/password, so only filter by time, IP, country.
sessQuery := `SELECT COUNT(*) FROM sessions WHERE 1=1`
var sessArgs []any
if f.Since != nil {
sessQuery += ` AND connected_at >= ?`
sessArgs = append(sessArgs, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
sessQuery += ` AND connected_at <= ?`
sessArgs = append(sessArgs, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
sessQuery += ` AND ip = ?`
sessArgs = append(sessArgs, f.IP)
}
if f.Country != "" {
sessQuery += ` AND country = ?`
sessArgs = append(sessArgs, f.Country)
}
err = s.db.QueryRowContext(ctx, sessQuery, sessArgs...).Scan(&stats.TotalSessions)
if err != nil {
return nil, fmt.Errorf("querying filtered total sessions: %w", err)
}
err = s.db.QueryRowContext(ctx, sessQuery+` AND disconnected_at IS NULL`, sessArgs...).Scan(&stats.ActiveSessions)
if err != nil {
return nil, fmt.Errorf("querying filtered active sessions: %w", err)
}
return stats, nil
}
func (s *SQLiteStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return s.queryFilteredTopN(ctx, "username", limit, f)
}
func (s *SQLiteStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return s.queryFilteredTopN(ctx, "password", limit, f)
}
func (s *SQLiteStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
where, args := buildAttemptWhereClause(f)
args = append(args, limit)
//nolint:gosec // where clause built from trusted constants, not user input
query := `SELECT ip, country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY ip ORDER BY total DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top IPs: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top IPs: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
where, args := buildAttemptWhereClause(f)
countryClause := "country != ''"
if where == "" {
where = " WHERE " + countryClause
} else {
where += " AND " + countryClause
}
args = append(args, limit)
//nolint:gosec // where clause built from trusted constants, not user input
query := `SELECT country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY country ORDER BY total DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top countries: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top countries: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) queryFilteredTopN(ctx context.Context, column string, limit int, f DashboardFilter) ([]TopEntry, error) {
switch column {
case "username", "password":
// valid columns
default:
return nil, fmt.Errorf("invalid column: %s", column)
}
where, args := buildAttemptWhereClause(f)
args = append(args, limit)
query := fmt.Sprintf(`
SELECT %s, SUM(count) AS total
FROM login_attempts%s
GROUP BY %s
ORDER BY total DESC
LIMIT ?`, column, where, column)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top %s: %w", column, err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top %s: %w", column, err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) Close() error {
return s.db.Close()
}

View File

@@ -23,17 +23,17 @@ func TestRecordLoginAttempt(t *testing.T) {
ctx := context.Background()
// First attempt creates a new record.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("first attempt: %v", err)
}
// Second attempt with same credentials increments count.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("second attempt: %v", err)
}
// Different IP is a separate record.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
t.Fatalf("different IP: %v", err)
}
@@ -62,7 +62,7 @@ func TestCreateAndEndSession(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
@@ -100,7 +100,7 @@ func TestUpdateHumanScore(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
@@ -123,7 +123,7 @@ func TestAppendSessionLog(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
@@ -159,7 +159,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
}
// Insert a recent login attempt.
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
t.Fatalf("insert recent attempt: %v", err)
}
@@ -178,7 +178,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
}
// Insert a recent session.
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", ""); err != nil {
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", "", ""); err != nil {
t.Fatalf("insert recent session: %v", err)
}
@@ -204,6 +204,79 @@ func TestDeleteRecordsBefore(t *testing.T) {
}
}
func TestGetTopExecCommands(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Create sessions with exec commands.
for range 3 {
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
t.Fatalf("setting exec command: %v", err)
}
}
for range 2 {
id, err := store.CreateSession(ctx, "10.0.0.2", "admin", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.SetExecCommand(ctx, id, "cat /etc/passwd"); err != nil {
t.Fatalf("setting exec command: %v", err)
}
}
// Session without exec command — should not appear.
if _, err := store.CreateSession(ctx, "10.0.0.3", "test", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
entries, err := store.GetTopExecCommands(ctx, 10)
if err != nil {
t.Fatalf("GetTopExecCommands: %v", err)
}
if len(entries) != 2 {
t.Fatalf("len = %d, want 2", len(entries))
}
if entries[0].Value != "uname -a" || entries[0].Count != 3 {
t.Errorf("entries[0] = %+v, want uname -a:3", entries[0])
}
if entries[1].Value != "cat /etc/passwd" || entries[1].Count != 2 {
t.Errorf("entries[1] = %+v, want cat /etc/passwd:2", entries[1])
}
}
func TestGetRecentSessionsEventCount(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
// Add some events.
events := []SessionEvent{
{SessionID: id, Timestamp: time.Now(), Direction: 0, Data: []byte("ls\n")},
{SessionID: id, Timestamp: time.Now(), Direction: 1, Data: []byte("file1\n")},
}
if err := store.AppendSessionEvents(ctx, events); err != nil {
t.Fatalf("appending events: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].EventCount != 2 {
t.Errorf("EventCount = %d, want 2", sessions[0].EventCount)
}
}
func TestNewSQLiteStoreCreatesFile(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "test.db")
store, err := NewSQLiteStore(dbPath)
@@ -214,7 +287,7 @@ func TestNewSQLiteStoreCreatesFile(t *testing.T) {
// Verify we can use the store.
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1", ""); err != nil {
t.Fatalf("recording attempt: %v", err)
}
}

View File

@@ -11,6 +11,7 @@ type LoginAttempt struct {
Username string
Password string
IP string
Country string
Count int
FirstSeen time.Time
LastSeen time.Time
@@ -20,11 +21,15 @@ type LoginAttempt struct {
type Session struct {
ID string
IP string
Country string
Username string
ShellName string
ConnectedAt time.Time
DisconnectedAt *time.Time
HumanScore *float64
ExecCommand *string
EventCount int
InputBytes int64
}
// SessionLog represents a single log entry for a session.
@@ -52,9 +57,39 @@ type DashboardStats struct {
ActiveSessions int64
}
// TimeSeriesPoint represents a single data point in a time series.
type TimeSeriesPoint struct {
Timestamp time.Time
Count int64
}
// HourlyCount represents the total attempts for a given hour of day.
type HourlyCount struct {
Hour int // 0-23
Count int64
}
// CountryCount represents the total attempts from a given country.
type CountryCount struct {
Country string
Count int64
}
// DashboardFilter contains optional filters for dashboard queries.
type DashboardFilter struct {
Since *time.Time
Until *time.Time
IP string
Country string
Username string
HumanScoreAboveZero bool
SortBy string
}
// TopEntry represents a value and its count for top-N queries.
type TopEntry struct {
Value string
Country string // populated by GetTopIPs
Count int64
}
@@ -62,10 +97,10 @@ type TopEntry struct {
type Store interface {
// RecordLoginAttempt upserts a login attempt, incrementing the count
// for existing (username, password, ip) combinations.
RecordLoginAttempt(ctx context.Context, username, password, ip string) error
RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error
// CreateSession creates a new session record and returns its UUID.
CreateSession(ctx context.Context, ip, username, shellName string) (string, error)
CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error)
// EndSession sets the disconnected_at timestamp for a session.
EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error
@@ -73,6 +108,9 @@ type Store interface {
// UpdateHumanScore sets the human detection score for a session.
UpdateHumanScore(ctx context.Context, sessionID string, score float64) error
// SetExecCommand sets the exec command for a session.
SetExecCommand(ctx context.Context, sessionID string, command string) error
// AppendSessionLog adds a log entry to a session.
AppendSessionLog(ctx context.Context, sessionID, input, output string) error
@@ -92,10 +130,20 @@ type Store interface {
// GetTopIPs returns the top N IPs by total attempt count.
GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error)
// GetTopCountries returns the top N countries by total attempt count.
GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error)
// GetTopExecCommands returns the top N exec commands by session count.
GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error)
// GetRecentSessions returns the most recent sessions ordered by connected_at DESC.
// If activeOnly is true, only sessions with no disconnected_at are returned.
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
// GetFilteredSessions returns sessions matching the given filter, ordered
// by the filter's SortBy field (default: connected_at DESC).
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
// GetSession returns a single session by ID.
GetSession(ctx context.Context, sessionID string) (*Session, error)
@@ -108,6 +156,35 @@ type Store interface {
// GetSessionEvents returns all events for a session ordered by id.
GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error)
// CloseActiveSessions sets disconnected_at for all sessions that are
// still marked as active. This should be called at startup to clean up
// sessions left over from a previous unclean shutdown.
CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error)
// GetAttemptsOverTime returns daily attempt counts for the last N days.
GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error)
// GetHourlyPattern returns total attempts grouped by hour of day (0-23).
GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error)
// GetCountryStats returns total attempts per country, ordered by count DESC.
GetCountryStats(ctx context.Context) ([]CountryCount, error)
// GetFilteredDashboardStats returns aggregate counts with optional filters applied.
GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error)
// GetFilteredTopUsernames returns top usernames with optional filters applied.
GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopPasswords returns top passwords with optional filters applied.
GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopIPs returns top IPs with optional filters applied.
GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopCountries returns top countries with optional filters applied.
GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// Close releases any resources held by the store.
Close() error
}

View File

@@ -38,23 +38,23 @@ func seedData(t *testing.T, store Store) {
// Login attempts: root/toor from two IPs, admin/admin from one IP.
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 2 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
// Sessions: one active, one ended.
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
@@ -62,7 +62,7 @@ func seedData(t *testing.T, store Store) {
t.Fatalf("ending session: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash"); err != nil {
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
}
@@ -210,7 +210,7 @@ func TestGetSession(t *testing.T) {
t.Run("found", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -233,7 +233,7 @@ func TestGetSessionLogs(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -277,7 +277,7 @@ func TestSessionEvents(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -316,6 +316,334 @@ func TestSessionEvents(t *testing.T) {
})
}
func TestCloseActiveSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("no active sessions", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
n, err := store.CloseActiveSessions(ctx, time.Now())
if err != nil {
t.Fatalf("CloseActiveSessions: %v", err)
}
if n != 0 {
t.Errorf("closed %d, want 0", n)
}
})
t.Run("closes only active sessions", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Create 3 sessions: end one, leave two active.
id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
store.CreateSession(ctx, "10.0.0.3", "test", "bash", "")
store.EndSession(ctx, id1, time.Now())
n, err := store.CloseActiveSessions(ctx, time.Now())
if err != nil {
t.Fatalf("CloseActiveSessions: %v", err)
}
if n != 2 {
t.Errorf("closed %d, want 2", n)
}
// Verify no active sessions remain.
active, err := store.GetRecentSessions(ctx, 10, true)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(active) != 0 {
t.Errorf("active sessions = %d, want 0", len(active))
}
})
})
}
func TestSetExecCommand(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("set and retrieve", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
// Initially nil.
s, err := store.GetSession(ctx, id)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if s.ExecCommand != nil {
t.Errorf("expected nil ExecCommand, got %q", *s.ExecCommand)
}
// Set exec command.
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
t.Fatalf("SetExecCommand: %v", err)
}
s, err = store.GetSession(ctx, id)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if s.ExecCommand == nil {
t.Fatal("expected non-nil ExecCommand")
}
if *s.ExecCommand != "uname -a" {
t.Errorf("ExecCommand = %q, want %q", *s.ExecCommand, "uname -a")
}
})
t.Run("appears in recent sessions", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.SetExecCommand(ctx, id, "id"); err != nil {
t.Fatalf("SetExecCommand: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ExecCommand == nil || *sessions[0].ExecCommand != "id" {
t.Errorf("ExecCommand = %v, want \"id\"", sessions[0].ExecCommand)
}
})
})
}
func seedChartData(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
// Record attempts with country data from different IPs.
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 2 {
if err := store.RecordLoginAttempt(ctx, "root", "123456", "10.0.0.3", "CN"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
}
func TestGetAttemptsOverTime(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
if err != nil {
t.Fatalf("GetAttemptsOverTime: %v", err)
}
if len(points) != 0 {
t.Errorf("expected empty, got %v", points)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
if err != nil {
t.Fatalf("GetAttemptsOverTime: %v", err)
}
// All data was inserted today, so should be one point.
if len(points) != 1 {
t.Fatalf("len = %d, want 1", len(points))
}
// 5 + 3 + 2 = 10 total.
if points[0].Count != 10 {
t.Errorf("count = %d, want 10", points[0].Count)
}
})
})
}
func TestGetHourlyPattern(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
if err != nil {
t.Fatalf("GetHourlyPattern: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty, got %v", counts)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
if err != nil {
t.Fatalf("GetHourlyPattern: %v", err)
}
// All data was inserted at the same hour.
if len(counts) != 1 {
t.Fatalf("len = %d, want 1", len(counts))
}
if counts[0].Count != 10 {
t.Errorf("count = %d, want 10", counts[0].Count)
}
})
})
}
func TestGetCountryStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
counts, err := store.GetCountryStats(context.Background())
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty, got %v", counts)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
counts, err := store.GetCountryStats(context.Background())
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 2 {
t.Fatalf("len = %d, want 2", len(counts))
}
// CN: 5 + 2 = 7, RU: 3 - ordered by count DESC.
if counts[0].Country != "CN" || counts[0].Count != 7 {
t.Errorf("counts[0] = %+v, want CN/7", counts[0])
}
if counts[1].Country != "RU" || counts[1].Count != 3 {
t.Errorf("counts[1] = %+v, want RU/3", counts[1])
}
})
t.Run("excludes empty country", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.2", "US"); err != nil {
t.Fatalf("seeding: %v", err)
}
counts, err := store.GetCountryStats(ctx)
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 1 {
t.Fatalf("len = %d, want 1", len(counts))
}
if counts[0].Country != "US" {
t.Errorf("country = %q, want US", counts[0].Country)
}
})
})
}
func TestGetFilteredDashboardStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("no filter", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 10 {
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
}
})
t.Run("filter by country", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Country: "CN"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
// CN: 5 + 2 = 7
if stats.TotalAttempts != 7 {
t.Errorf("TotalAttempts = %d, want 7", stats.TotalAttempts)
}
})
t.Run("filter by IP", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{IP: "10.0.0.1"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 5 {
t.Errorf("TotalAttempts = %d, want 5", stats.TotalAttempts)
}
})
t.Run("filter by username", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Username: "admin"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 3 {
t.Errorf("TotalAttempts = %d, want 3", stats.TotalAttempts)
}
})
})
}
func TestGetFilteredTopUsernames(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
store := newStore(t)
seedChartData(t, store)
// Filter by country CN should only show root.
entries, err := store.GetFilteredTopUsernames(context.Background(), 10, DashboardFilter{Country: "CN"})
if err != nil {
t.Fatalf("GetFilteredTopUsernames: %v", err)
}
if len(entries) != 1 {
t.Fatalf("len = %d, want 1", len(entries))
}
if entries[0].Value != "root" || entries[0].Count != 7 {
t.Errorf("entries[0] = %+v, want root/7", entries[0])
}
})
}
func TestGetRecentSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
@@ -372,3 +700,192 @@ func TestGetRecentSessions(t *testing.T) {
})
})
}
func TestInputBytes(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("counts only input direction", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
events := []SessionEvent{
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
}
if err := store.AppendSessionEvents(ctx, events); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
if sessions[0].InputBytes != 7 {
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
}
})
t.Run("zero when no events", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].InputBytes != 0 {
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
}
})
})
}
func TestGetFilteredSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("filter by human score", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Create two sessions, one with human score > 0.
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
}
})
t.Run("sort by input bytes", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Session with more input (created first).
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
if err := store.AppendSessionEvents(ctx, []SessionEvent{
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
}); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
// Session with less input (created after id1, so would be first by connected_at).
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
time.Sleep(1100 * time.Millisecond)
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.AppendSessionEvents(ctx, []SessionEvent{
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
}); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
// Default sort (connected_at DESC) should show id2 first.
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
if sessions[0].ID != id2 {
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
}
// Sort by input_bytes should show id1 first (more input).
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
}
})
t.Run("combined filters", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
// Different country, also has score.
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
// Same country CN but no score.
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
// Filter: CN + human score > 0 -> only id1.
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
Country: "CN",
HumanScoreAboveZero: true,
})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
}
})
})
}

View File

@@ -1,24 +1,37 @@
package web
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// dbContext returns a context detached from the HTTP request lifecycle with a
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
// queries when the browser aborts the previous XHR.
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
}
type dashboardData struct {
Stats *storage.DashboardStats
TopUsernames []storage.TopEntry
TopPasswords []storage.TopEntry
TopIPs []storage.TopEntry
TopCountries []storage.TopEntry
TopExecCommands []storage.TopEntry
ActiveSessions []storage.Session
RecentSessions []storage.Session
}
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
stats, err := s.store.GetDashboardStats(ctx)
if err != nil {
@@ -48,6 +61,20 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
return
}
topCountries, err := s.store.GetTopCountries(ctx, 10)
if err != nil {
s.logger.Error("failed to get top countries", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topExecCommands, err := s.store.GetTopExecCommands(ctx, 10)
if err != nil {
s.logger.Error("failed to get top exec commands", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
activeSessions, err := s.store.GetRecentSessions(ctx, 50, true)
if err != nil {
s.logger.Error("failed to get active sessions", "err", err)
@@ -67,6 +94,8 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
TopUsernames: topUsernames,
TopPasswords: topPasswords,
TopIPs: topIPs,
TopCountries: topCountries,
TopExecCommands: topExecCommands,
ActiveSessions: activeSessions,
RecentSessions: recentSessions,
}
@@ -78,7 +107,10 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
stats, err := s.store.GetDashboardStats(r.Context())
ctx, cancel := dbContext(r)
defer cancel()
stats, err := s.store.GetDashboardStats(ctx)
if err != nil {
s.logger.Error("failed to get dashboard stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -92,7 +124,10 @@ func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
sessions, err := s.store.GetRecentSessions(r.Context(), 50, true)
ctx, cancel := dbContext(r)
defer cancel()
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
if err != nil {
s.logger.Error("failed to get active sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -105,6 +140,24 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req
}
}
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
f := parseDashboardFilter(r)
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
if err != nil {
s.logger.Error("failed to get filtered sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
s.logger.Error("failed to render recent sessions fragment", "err", err)
}
}
type sessionDetailData struct {
Session *storage.Session
Logs []storage.SessionLog
@@ -112,7 +165,8 @@ type sessionDetailData struct {
}
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
sessionID := r.PathValue("id")
session, err := s.store.GetSession(ctx, sessionID)
@@ -162,8 +216,201 @@ type apiEventsResponse struct {
Events []apiEvent `json:"events"`
}
// parseDateParam parses a "YYYY-MM-DD" query parameter into a *time.Time.
func parseDateParam(r *http.Request, name string) *time.Time {
v := r.URL.Query().Get(name)
if v == "" {
return nil
}
t, err := time.Parse("2006-01-02", v)
if err != nil {
return nil
}
// For "until" dates, set to end of day.
if name == "until" {
t = t.Add(24*time.Hour - time.Second)
}
return &t
}
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
return storage.DashboardFilter{
Since: parseDateParam(r, "since"),
Until: parseDateParam(r, "until"),
IP: r.URL.Query().Get("ip"),
Country: r.URL.Query().Get("country"),
Username: r.URL.Query().Get("username"),
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
SortBy: r.URL.Query().Get("sort"),
}
}
type apiTimeSeriesPoint struct {
Date string `json:"date"`
Count int64 `json:"count"`
}
type apiAttemptsOverTimeResponse struct {
Points []apiTimeSeriesPoint `json:"points"`
}
func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Request) {
days := 30
if v := r.URL.Query().Get("days"); v != "" {
if d, err := strconv.Atoi(v); err == nil && d > 0 && d <= 365 {
days = d
}
}
since := parseDateParam(r, "since")
until := parseDateParam(r, "until")
ctx, cancel := dbContext(r)
defer cancel()
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
if err != nil {
s.logger.Error("failed to get attempts over time", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiAttemptsOverTimeResponse{Points: make([]apiTimeSeriesPoint, len(points))}
for i, p := range points {
resp.Points[i] = apiTimeSeriesPoint{
Date: p.Timestamp.Format("2006-01-02"),
Count: p.Count,
}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode attempts over time", "err", err)
}
}
type apiHourlyCount struct {
Hour int `json:"hour"`
Count int64 `json:"count"`
}
type apiHourlyPatternResponse struct {
Hours []apiHourlyCount `json:"hours"`
}
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
since := parseDateParam(r, "since")
until := parseDateParam(r, "until")
counts, err := s.store.GetHourlyPattern(ctx, since, until)
if err != nil {
s.logger.Error("failed to get hourly pattern", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiHourlyPatternResponse{Hours: make([]apiHourlyCount, len(counts))}
for i, c := range counts {
resp.Hours[i] = apiHourlyCount{Hour: c.Hour, Count: c.Count}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode hourly pattern", "err", err)
}
}
type apiCountryCount struct {
Country string `json:"country"`
Count int64 `json:"count"`
}
type apiCountryStatsResponse struct {
Countries []apiCountryCount `json:"countries"`
}
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
counts, err := s.store.GetCountryStats(ctx)
if err != nil {
s.logger.Error("failed to get country stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiCountryStatsResponse{Countries: make([]apiCountryCount, len(counts))}
for i, c := range counts {
resp.Countries[i] = apiCountryCount{Country: c.Country, Count: c.Count}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode country stats", "err", err)
}
}
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
f := parseDashboardFilter(r)
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
if err != nil {
s.logger.Error("failed to get filtered stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topUsernames, err := s.store.GetFilteredTopUsernames(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top usernames", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topPasswords, err := s.store.GetFilteredTopPasswords(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top passwords", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topIPs, err := s.store.GetFilteredTopIPs(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top IPs", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topCountries, err := s.store.GetFilteredTopCountries(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top countries", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
data := dashboardData{
Stats: stats,
TopUsernames: topUsernames,
TopPasswords: topPasswords,
TopIPs: topIPs,
TopCountries: topCountries,
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "dashboard_content", data); err != nil {
s.logger.Error("failed to render dashboard content fragment", "err", err)
}
}
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
sessionID := r.PathValue("id")
events, err := s.store.GetSessionEvents(ctx, sessionID)

14
internal/web/static/chart.min.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,275 @@
(function() {
'use strict';
// Chart.js theme for Pico dark mode
Chart.defaults.color = '#b0b0b8';
Chart.defaults.borderColor = '#3a3a4a';
var attemptsChart = null;
var hourlyChart = null;
function getFilterParams() {
var form = document.getElementById('filter-form');
if (!form) return '';
var params = new URLSearchParams();
var since = form.elements['since'].value;
var until = form.elements['until'].value;
if (since) params.set('since', since);
if (until) params.set('until', until);
var humanScore = form.elements['human_score'];
if (humanScore && humanScore.checked) params.set('human_score', '1');
var sortBy = form.elements['sort'];
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
return params.toString();
}
function initAttemptsChart() {
var canvas = document.getElementById('chart-attempts');
if (!canvas) return;
var ctx = canvas.getContext('2d');
var qs = getFilterParams();
var url = '/api/charts/attempts-over-time' + (qs ? '?' + qs : '');
fetch(url)
.then(function(r) { return r.json(); })
.then(function(data) {
var labels = data.points.map(function(p) { return p.date; });
var values = data.points.map(function(p) { return p.count; });
if (attemptsChart) {
attemptsChart.data.labels = labels;
attemptsChart.data.datasets[0].data = values;
attemptsChart.update();
return;
}
attemptsChart = new Chart(ctx, {
type: 'line',
data: {
labels: labels,
datasets: [{
label: 'Attempts',
data: values,
borderColor: '#6366f1',
backgroundColor: 'rgba(99, 102, 241, 0.1)',
fill: true,
tension: 0.3,
pointRadius: 2
}]
},
options: {
responsive: true,
maintainAspectRatio: true,
plugins: { legend: { display: false } },
scales: {
x: { grid: { display: false } },
y: { beginAtZero: true }
}
}
});
});
}
function initHourlyChart() {
var canvas = document.getElementById('chart-hourly');
if (!canvas) return;
var ctx = canvas.getContext('2d');
var qs = getFilterParams();
var url = '/api/charts/hourly-pattern' + (qs ? '?' + qs : '');
fetch(url)
.then(function(r) { return r.json(); })
.then(function(data) {
// Fill all 24 hours, defaulting to 0
var hourMap = {};
data.hours.forEach(function(h) { hourMap[h.hour] = h.count; });
var labels = [];
var values = [];
for (var i = 0; i < 24; i++) {
labels.push(i + ':00');
values.push(hourMap[i] || 0);
}
if (hourlyChart) {
hourlyChart.data.labels = labels;
hourlyChart.data.datasets[0].data = values;
hourlyChart.update();
return;
}
hourlyChart = new Chart(ctx, {
type: 'bar',
data: {
labels: labels,
datasets: [{
label: 'Attempts',
data: values,
backgroundColor: 'rgba(99, 102, 241, 0.6)',
borderColor: '#6366f1',
borderWidth: 1
}]
},
options: {
responsive: true,
maintainAspectRatio: true,
plugins: { legend: { display: false } },
scales: {
x: { grid: { display: false } },
y: { beginAtZero: true }
}
}
});
});
}
function initWorldMap() {
var container = document.getElementById('world-map');
if (!container) return;
fetch('/static/world.svg')
.then(function(r) { return r.text(); })
.then(function(svgText) {
container.innerHTML = svgText;
fetch('/api/charts/country-stats')
.then(function(r) { return r.json(); })
.then(function(data) {
colorMap(container, data.countries);
});
});
}
function colorMap(container, countries) {
if (!countries || countries.length === 0) return;
var maxCount = countries[0].count; // already sorted DESC
var logMax = Math.log(maxCount + 1);
// Build lookup
var lookup = {};
countries.forEach(function(c) {
lookup[c.country.toLowerCase()] = c.count;
});
// Create tooltip element
var tooltip = document.createElement('div');
tooltip.id = 'map-tooltip';
tooltip.style.cssText = 'position:fixed;display:none;background:#1a1a2e;color:#e0e0e8;padding:4px 8px;border-radius:4px;font-size:13px;pointer-events:none;z-index:1000;border:1px solid #3a3a4a;';
document.body.appendChild(tooltip);
var svg = container.querySelector('svg');
if (!svg) return;
// Remove SVG title to prevent browser native tooltip
var svgTitle = svg.querySelector('title');
if (svgTitle) svgTitle.remove();
// Select both <path id="xx"> and <g id="xx"> country elements
var elements = svg.querySelectorAll('path[id], g[id]');
elements.forEach(function(el) {
var id = el.id.toLowerCase();
if (id.charAt(0) === '_') return; // skip non-country paths
var count = lookup[id];
if (count) {
var intensity = Math.log(count + 1) / logMax;
var r = Math.round(30 + intensity * 69); // 30 -> 99
var g = Math.round(30 + intensity * 72); // 30 -> 102
var b = Math.round(62 + intensity * 179); // 62 -> 241
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
// For <g> elements, color child paths; for <path>, color directly
if (el.tagName.toLowerCase() === 'g') {
el.querySelectorAll('path').forEach(function(p) {
p.style.fill = color;
});
} else {
el.style.fill = color;
}
}
el.addEventListener('mouseenter', function(e) {
var cc = id.toUpperCase();
var n = lookup[id] || 0;
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
tooltip.style.display = 'block';
});
el.addEventListener('mousemove', function(e) {
tooltip.style.left = (e.clientX + 12) + 'px';
tooltip.style.top = (e.clientY - 10) + 'px';
});
el.addEventListener('mouseleave', function() {
tooltip.style.display = 'none';
});
el.addEventListener('click', function() {
var input = document.querySelector('#filter-form input[name="country"]');
if (input) {
input.value = id.toUpperCase();
applyFilters();
}
});
el.style.cursor = 'pointer';
});
}
function applyFilters() {
// Re-fetch charts with filter params
initAttemptsChart();
initHourlyChart();
// Re-fetch dashboard content via htmx
var form = document.getElementById('filter-form');
if (!form) return;
var params = new URLSearchParams();
['since', 'until', 'ip', 'country', 'username'].forEach(function(name) {
var val = form.elements[name].value;
if (val) params.set(name, val);
});
var humanScore = form.elements['human_score'];
if (humanScore && humanScore.checked) params.set('human_score', '1');
var sortBy = form.elements['sort'];
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
var target = document.getElementById('dashboard-content');
if (target) {
var url = '/fragments/dashboard-content?' + params.toString();
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
}
// Server-side filter for recent sessions table
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
}
window.clearFilters = function() {
var form = document.getElementById('filter-form');
if (form) {
form.reset();
applyFilters();
}
};
window.applyFilters = applyFilters;
// Initialize on DOM ready
document.addEventListener('DOMContentLoaded', function() {
initAttemptsChart();
initHourlyChart();
initWorldMap();
var form = document.getElementById('filter-form');
if (form) {
form.addEventListener('submit', function(e) {
e.preventDefault();
applyFilters();
});
}
});
})();

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 55 KiB

View File

@@ -44,6 +44,32 @@ func templateFuncMap() template.FuncMap {
}
return fmt.Sprintf("%.0f%%", *f*100)
},
"derefString": func(s *string) string {
if s == nil {
return ""
}
return *s
},
"truncateCommand": func(s string) string {
if len(s) > 50 {
return s[:50] + "..."
}
return s
},
"formatBytes": func(b int64) string {
const (
kb = 1024
mb = 1024 * kb
)
switch {
case b >= mb:
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
case b >= kb:
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
default:
return fmt.Sprintf("%d B", b)
}
},
}
}
@@ -55,6 +81,7 @@ func loadTemplates() (*templateSet, error) {
"templates/dashboard.html",
"templates/fragments/stats.html",
"templates/fragments/active_sessions.html",
"templates/fragments/recent_sessions.html",
)
if err != nil {
return nil, fmt.Errorf("parsing dashboard templates: %w", err)

View File

@@ -3,6 +3,86 @@
{{template "stats" .Stats}}
</section>
<details>
<summary>Filters</summary>
<form id="filter-form">
<div class="grid">
<label>Since <input type="date" name="since"></label>
<label>Until <input type="date" name="until"></label>
<label>IP <input type="text" name="ip" placeholder="10.0.0.1"></label>
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
<label>Username <input type="text" name="username" placeholder="root"></label>
</div>
<div class="grid">
<label><input type="checkbox" name="human_score" value="1"> Human score &gt; 0</label>
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
</div>
<button type="submit">Apply</button>
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
</form>
</details>
<section>
<h3>Attack Trends</h3>
<div class="grid">
<article>
<header>Attempts Over Time</header>
<canvas id="chart-attempts"></canvas>
</article>
<article>
<header>Hourly Pattern (UTC)</header>
<canvas id="chart-hourly"></canvas>
</article>
</div>
</section>
<section>
<h3>Attack Origins</h3>
<article>
<div id="world-map"></div>
</article>
</section>
<div id="dashboard-content">
{{template "dashboard_content" .}}
</div>
<section>
<h3>Active Sessions</h3>
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
{{template "active_sessions" .ActiveSessions}}
</div>
</section>
<section>
<h3>Recent Sessions</h3>
<table id="recent-sessions-table">
<thead>
<tr>
<th>ID</th>
<th>IP</th>
<th>Country</th>
<th>Username</th>
<th>Type</th>
<th>Score</th>
<th>Input</th>
<th>Connected</th>
<th>Disconnected</th>
</tr>
</thead>
<tbody>
{{template "recent_sessions" .RecentSessions}}
</tbody>
</table>
</section>
{{end}}
{{define "scripts"}}
<script src="/static/chart.min.js"></script>
<script src="/static/dashboard.js"></script>
{{end}}
{{define "dashboard_content"}}
<section>
<h3>Top Credentials & IPs</h3>
<div class="top-grid">
@@ -40,10 +120,25 @@
<header>Top IPs</header>
<table>
<thead>
<tr><th>IP</th><th>Attempts</th></tr>
<tr><th>IP</th><th>Country</th><th>Attempts</th></tr>
</thead>
<tbody>
{{range .TopIPs}}
<tr><td>{{.Value}}</td><td>{{.Country}}</td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="3">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
<article>
<header>Top Countries</header>
<table>
<thead>
<tr><th>Country</th><th>Attempts</th></tr>
</thead>
<tbody>
{{range .TopCountries}}
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="2">No data</td></tr>
@@ -51,45 +146,21 @@
</tbody>
</table>
</article>
<article>
<header>Top Exec Commands</header>
<table>
<thead>
<tr><th>Command</th><th>Count</th></tr>
</thead>
<tbody>
{{range .TopExecCommands}}
<tr><td><code>{{truncateCommand .Value}}</code></td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="2">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
</div>
</section>
<section>
<h3>Active Sessions</h3>
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
{{template "active_sessions" .ActiveSessions}}
</div>
</section>
<section>
<h3>Recent Sessions</h3>
<table>
<thead>
<tr>
<th>ID</th>
<th>IP</th>
<th>Username</th>
<th>Shell</th>
<th>Score</th>
<th>Connected</th>
<th>Disconnected</th>
</tr>
</thead>
<tbody>
{{range .RecentSessions}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a></td>
<td>{{.IP}}</td>
<td>{{.Username}}</td>
<td>{{.ShellName}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatTime .ConnectedAt}}</td>
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
</tr>
{{else}}
<tr><td colspan="7">No sessions</td></tr>
{{end}}
</tbody>
</table>
</section>
{{end}}

View File

@@ -4,24 +4,28 @@
<tr>
<th>ID</th>
<th>IP</th>
<th>Country</th>
<th>Username</th>
<th>Shell</th>
<th>Type</th>
<th>Score</th>
<th>Input</th>
<th>Connected</th>
</tr>
</thead>
<tbody>
{{range .}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a></td>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
<td>{{.IP}}</td>
<td>{{.Country}}</td>
<td>{{.Username}}</td>
<td>{{.ShellName}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatBytes .InputBytes}}</td>
<td>{{formatTime .ConnectedAt}}</td>
</tr>
{{else}}
<tr><td colspan="6">No active sessions</td></tr>
<tr><td colspan="8">No active sessions</td></tr>
{{end}}
</tbody>
</table>

View File

@@ -0,0 +1,17 @@
{{define "recent_sessions"}}
{{range .}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
<td>{{.IP}}</td>
<td>{{.Country}}</td>
<td>{{.Username}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatBytes .InputBytes}}</td>
<td>{{formatTime .ConnectedAt}}</td>
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
</tr>
{{else}}
<tr><td colspan="9">No sessions</td></tr>
{{end}}
{{end}}

View File

@@ -29,9 +29,16 @@
}
.top-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
grid-template-columns: repeat(auto-fit, minmax(380px, 1fr));
gap: 1rem;
}
.top-grid article {
overflow: hidden;
min-width: 0;
}
#world-map svg { width: 100%; height: auto; }
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
nav h1 {
margin: 0;
}
@@ -52,5 +59,6 @@
<main class="container">
{{block "content" .}}{{end}}
</main>
{{block "scripts" .}}{{end}}
</body>
</html>

View File

@@ -7,8 +7,10 @@
<table>
<tbody>
<tr><td><strong>IP</strong></td><td>{{.Session.IP}}</td></tr>
<tr><td><strong>Country</strong></td><td>{{.Session.Country}}</td></tr>
<tr><td><strong>Username</strong></td><td>{{.Session.Username}}</td></tr>
<tr><td><strong>Shell</strong></td><td>{{.Session.ShellName}}</td></tr>
{{if .Session.ExecCommand}}<tr><td><strong>Exec Command</strong></td><td><code>{{derefString .Session.ExecCommand}}</code></td></tr>{{end}}
<tr><td><strong>Score</strong></td><td>{{formatScore .Session.HumanScore}}</td></tr>
<tr><td><strong>Connected</strong></td><td>{{formatTime .Session.ConnectedAt}}</td></tr>
<tr>

View File

@@ -1,11 +1,13 @@
package web
import (
"crypto/subtle"
"embed"
"log/slog"
"net/http"
"strings"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
//go:embed static/*
@@ -20,7 +22,9 @@ type Server struct {
}
// NewServer creates a new web Server with routes registered.
func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
// If metricsHandler is non-nil, a /metrics endpoint is registered.
// If metricsToken is non-empty, the metrics endpoint requires Bearer token auth.
func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Handler, metricsToken string) (*Server, error) {
tmpl, err := loadTemplates()
if err != nil {
return nil, err
@@ -36,9 +40,22 @@ func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
s.mux.Handle("GET /static/", http.FileServerFS(staticFS))
s.mux.HandleFunc("GET /sessions/{id}", s.handleSessionDetail)
s.mux.HandleFunc("GET /api/sessions/{id}/events", s.handleAPISessionEvents)
s.mux.HandleFunc("GET /api/charts/attempts-over-time", s.handleAPIAttemptsOverTime)
s.mux.HandleFunc("GET /api/charts/hourly-pattern", s.handleAPIHourlyPattern)
s.mux.HandleFunc("GET /api/charts/country-stats", s.handleAPICountryStats)
s.mux.HandleFunc("GET /", s.handleDashboard)
s.mux.HandleFunc("GET /fragments/stats", s.handleFragmentStats)
s.mux.HandleFunc("GET /fragments/active-sessions", s.handleFragmentActiveSessions)
s.mux.HandleFunc("GET /fragments/dashboard-content", s.handleFragmentDashboardContent)
s.mux.HandleFunc("GET /fragments/recent-sessions", s.handleFragmentRecentSessions)
if metricsHandler != nil {
h := metricsHandler
if metricsToken != "" {
h = requireBearerToken(metricsToken, h)
}
s.mux.Handle("GET /metrics", h)
}
return s, nil
}
@@ -47,3 +64,20 @@ func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.mux.ServeHTTP(w, r)
}
// requireBearerToken wraps a handler to require a valid Bearer token.
func requireBearerToken(token string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
provided := auth[len("Bearer "):]
if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -10,14 +10,15 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func newTestServer(t *testing.T) *Server {
t.Helper()
store := storage.NewMemoryStore()
logger := slog.Default()
srv, err := NewServer(store, logger)
srv, err := NewServer(store, logger, nil, "")
if err != nil {
t.Fatalf("creating server: %v", err)
}
@@ -30,29 +31,53 @@ func newSeededTestServer(t *testing.T) *Server {
ctx := context.Background()
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash"); err != nil {
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash"); err != nil {
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
logger := slog.Default()
srv, err := NewServer(store, logger)
srv, err := NewServer(store, logger, nil, "")
if err != nil {
t.Fatalf("creating server: %v", err)
}
return srv
}
func TestDbContextNotCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(ctx)
dbCtx, dbCancel := dbContext(req)
defer dbCancel()
// Cancel the original request context.
cancel()
// The DB context should still be usable.
select {
case <-dbCtx.Done():
t.Fatal("dbContext should not be canceled when request context is canceled")
default:
}
// Verify the DB context has a deadline (from the timeout).
if _, ok := dbCtx.Deadline(); !ok {
t.Error("dbContext should have a deadline")
}
}
func TestDashboardHandler(t *testing.T) {
t.Run("empty store", func(t *testing.T) {
srv := newTestServer(t)
@@ -149,12 +174,12 @@ func TestSessionDetailHandler(t *testing.T) {
t.Run("found", func(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
srv, err := NewServer(store, slog.Default())
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
@@ -180,7 +205,7 @@ func TestSessionDetailHandler(t *testing.T) {
func TestAPISessionEvents(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -194,7 +219,7 @@ func TestAPISessionEvents(t *testing.T) {
t.Fatalf("AppendSessionEvents: %v", err)
}
srv, err := NewServer(store, slog.Default())
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
@@ -236,6 +261,293 @@ func TestAPISessionEvents(t *testing.T) {
}
}
func TestMetricsEndpoint(t *testing.T) {
t.Run("enabled", func(t *testing.T) {
m := metrics.New("test")
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
if !strings.Contains(body, `oubliette_build_info{version="test"} 1`) {
t.Errorf("response should contain build_info metric, got:\n%s", body)
}
})
t.Run("disabled", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
// Without a metrics handler, /metrics falls through to the dashboard.
body := w.Body.String()
if strings.Contains(body, "oubliette_build_info") {
t.Error("response should not contain prometheus metrics when disabled")
}
})
}
func TestMetricsBearerToken(t *testing.T) {
m := metrics.New("test")
t.Run("valid token", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
req.Header.Set("Authorization", "Bearer secret")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
t.Run("wrong token", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
req.Header.Set("Authorization", "Bearer wrong")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", w.Code)
}
})
t.Run("missing header", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", w.Code)
}
})
t.Run("no token configured", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
}
func TestTruncateCommand(t *testing.T) {
funcMap := templateFuncMap()
fn := funcMap["truncateCommand"].(func(string) string)
tests := []struct {
input string
want string
}{
{"short", "short"},
{"exactly fifty characters long! that is what it i.", "exactly fifty characters long! that is what it i."},
{"this string is definitely longer than fifty characters and should be truncated", "this string is definitely longer than fifty charac..."},
{"", ""},
}
for _, tt := range tests {
got := fn(tt.input)
if got != tt.want {
t.Errorf("truncateCommand(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestDashboardExecCommands(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
t.Fatalf("setting exec command: %v", err)
}
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
if !strings.Contains(body, "Top Exec Commands") {
t.Error("response should contain 'Top Exec Commands'")
}
if !strings.Contains(body, "uname -a") {
t.Error("response should contain exec command 'uname -a'")
}
}
func TestAPIAttemptsOverTime(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/api/charts/attempts-over-time", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
ct := w.Header().Get("Content-Type")
if !strings.Contains(ct, "application/json") {
t.Errorf("Content-Type = %q, want application/json", ct)
}
var resp apiAttemptsOverTimeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
// Seeded data inserted today -> at least 1 point.
if len(resp.Points) == 0 {
t.Error("expected at least one data point")
}
}
func TestAPIHourlyPattern(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/api/charts/hourly-pattern", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
var resp apiHourlyPatternResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if len(resp.Hours) == 0 {
t.Error("expected at least one hourly data point")
}
}
func TestAPICountryStats(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding: %v", err)
}
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/charts/country-stats", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
var resp apiCountryStatsResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if len(resp.Countries) != 2 {
t.Fatalf("len = %d, want 2", len(resp.Countries))
}
}
func TestFragmentDashboardContent(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
if strings.Contains(body, "<!DOCTYPE html>") {
t.Error("dashboard content fragment should not contain full HTML document")
}
if !strings.Contains(body, "Top Usernames") {
t.Error("dashboard content fragment should contain 'Top Usernames'")
}
}
func TestFragmentDashboardContentWithFilter(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding: %v", err)
}
}
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content?country=CN", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
// When filtered by CN, should show root but not admin.
if !strings.Contains(body, "root") {
t.Error("response should contain 'root' when filtered by CN")
}
}
func TestStaticAssets(t *testing.T) {
srv := newTestServer(t)
@@ -245,6 +557,9 @@ func TestStaticAssets(t *testing.T) {
}{
{"/static/pico.min.css", "text/css"},
{"/static/htmx.min.js", "text/javascript"},
{"/static/chart.min.js", "text/javascript"},
{"/static/dashboard.js", "text/javascript"},
{"/static/world.svg", "image/svg+xml"},
}
for _, tt := range tests {

View File

@@ -24,6 +24,26 @@ password = "admin"
# password = "fridge"
# shell = "fridge"
# [[auth.static_credentials]]
# username = "teller"
# password = "banking"
# shell = "banking"
# [[auth.static_credentials]]
# username = "admin"
# password = "cisco"
# shell = "cisco"
# [[auth.static_credentials]]
# username = "irobot"
# password = "roomba"
# shell = "roomba"
# [[auth.static_credentials]]
# username = "player"
# password = "tetris"
# shell = "tetris"
[storage]
db_path = "oubliette.db"
retention_days = 90
@@ -32,12 +52,45 @@ retention_interval = "1h"
# [web]
# enabled = true
# listen_addr = ":8080"
# metrics_enabled = true
# metrics_token = "" # bearer token for /metrics; empty = no auth
[shell]
hostname = "ubuntu-server"
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
# fake_user = "" # override username in prompt; empty = use authenticated user
# Map usernames to specific shells (regardless of how auth succeeded).
# Credential-specific shell overrides take priority over username routes.
# [shell.username_routes]
# postgres = "psql"
# admin = "bash"
# Per-shell configuration (optional).
# [shell.banking]
# bank_name = "SECUREBANK"
# terminal_id = "SB-0001" # random if not set
# region = "NORTHEAST"
# [shell.adventure]
# dungeon_name = "THE OUBLIETTE"
# [shell.cisco]
# hostname = "Router"
# model = "C2960"
# ios_version = "15.0(2)SE11"
# enable_password = "" # empty = accept after 1 failed attempt
# [shell.psql]
# db_name = "postgres"
# pg_version = "15.4"
# [shell.roomba]
# No configuration options currently.
# [shell.tetris]
# difficulty = "normal" # "easy" (slower start), "normal" (standard), "hard" (start at level 5)
# [detection]
# enabled = true
# threshold = 0.6 # 0.01.0, sessions above this trigger notifications

18
scripts/fetch-geoip.sh Executable file
View File

@@ -0,0 +1,18 @@
#!/usr/bin/env bash
# Downloads the DB-IP Lite country MMDB database for development.
# The Nix build fetches this automatically; this script is for local dev only.
set -euo pipefail
URL="https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz"
DEST="internal/geoip/dbip-country-lite.mmdb"
cd "$(git rev-parse --show-toplevel)"
if [ -f "$DEST" ]; then
echo "GeoIP database already exists at $DEST"
exit 0
fi
echo "Downloading DB-IP Lite country database..."
curl -fSL "$URL" | gunzip > "$DEST"
echo "Saved to $DEST"