Compare commits

...

32 Commits

Author SHA1 Message Date
1b28f10ca8 refactor: migrate module path from git.t-juice.club to code.t-juice.club
Update Go module path and all import references to reflect the migration
from Gitea (git.t-juice.club) to Forgejo (code.t-juice.club).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 18:51:23 +01:00
664e79fce6 feat: add Prometheus metrics for Store queries
Add InstrumentedStore decorator that wraps any Store and records
per-method query duration histograms and error counters. Wired into
main.go so all storage consumers get automatic observability.

Bump version to 0.18.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:29:51 +01:00
c74313c195 fix: resolve linting issues in roomba shell
Replace unnecessary fmt.Sprintf calls with string literals, use
slices.Contains instead of manual loop, and use compound assignment
operator.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:18:25 +01:00
9783ae5865 fix: prevent context canceled errors in web dashboard
Detach DB queries from HTTP request context so HTMX polling doesn't
cancel in-flight queries when the browser aborts previous XHRs. Add
indexes on login_attempts and sessions to speed up frequent dashboard
queries. Bump version to 0.17.1.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 22:16:49 +01:00
62de222488 feat: add tetris shell (Tetris game TUI)
Full-screen Tetris game using Bubbletea with title screen, ghost piece,
lock delay, NES-style scoring, configurable difficulty (easy/normal/hard),
and honeypot event logging. Bumps version to 0.17.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 00:59:46 +01:00
c9d143d84b feat: add roomba shell (iRobot Roomba j7+ vacuum emulator)
New novelty shell emulating RoombaOS with cleaning, scheduling,
diagnostics, floor map, and humorous history entries. Bump version
to 0.16.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 14:06:59 +01:00
d18a904ed5 chore: bump version to 0.15.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:13:50 +01:00
cb7be28f42 feat: add server-side session filtering with input bytes and human score
Replace client-side session table filtering with server-side filtering
via a new /fragments/recent-sessions htmx endpoint. Add InputBytes column
to session tables, Human score > 0 checkbox filter, and Sort by Input
Bytes option to help identify sessions with actual shell interaction.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 09:12:51 +01:00
0908b43724 chore: bump version to 0.14.2
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:24:01 +01:00
52310f588d fix: highlight all polygons on hover for multi-path countries
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:24:01 +01:00
b52216bd2f chore: bump version to 0.14.1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:21:01 +01:00
2bc83a17dd fix: handle SVG group elements in world map for multi-path countries
The SVG world map uses <g> group elements for countries with complex
shapes (US, CN, RU, GB, etc.), but the JS only queried <path> elements,
causing 36 countries to be missing from the map. Also removes the SVG
<title> element that was overriding the custom tooltip.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:20:23 +01:00
faf6e2abd7 docs: mark 4.1 and 4.4 as completed in PLAN.md
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:34:17 +01:00
0a4eac188a chore: bump version to 0.14.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:31:53 +01:00
7c90c9ed4a feat: add charts, world map, and filters to web dashboard
Add Chart.js line/bar charts for attack trends (attempts over time,
hourly pattern), an SVG world map choropleth colored by attack origin
country, and a collapsible filter form (date range, IP, country,
username) that narrows both charts and top-N tables.

New store methods: GetAttemptsOverTime, GetHourlyPattern, GetCountryStats,
and filtered variants of dashboard stats/top-N queries. New JSON API
endpoints at /api/charts/* and an htmx fragment at
/fragments/dashboard-content for filtered table updates.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 20:27:15 +01:00
8a631af0d2 fix: prevent dashboard top-grid cards from overflowing horizontally
Increase minimum column width from 280px to 380px so the 3-column Top
IPs table fits without clipping. Add overflow/min-width safety net for
narrow viewports.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 21:25:20 +01:00
40fda3420c feat: add psql shell and username-to-shell routing
Add a PostgreSQL psql interactive terminal shell with backslash
meta-commands, SQL statement handling with multi-line buffering, and
canned responses for common queries. Add username-based shell routing
via [shell.username_routes] config (second priority after credential-
specific shell, before random selection). Bump version to 0.13.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:58:34 +01:00
c4801e3309 chore: bump version to 0.12.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:38:47 +01:00
4f10a8a422 feat: add session indicators and top exec commands to dashboard
Add visual indicators to session tables (replay badge when events exist,
exec badge for exec sessions) and a new "Top Exec Commands" table on the
dashboard. Includes EventCount field on Session, GetTopExecCommands on
Store interface, and truncateCommand template function.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:38:10 +01:00
0b44d1c83f docs: detail fake exec output approach in PLAN.md 4.4.1
Regex-based output assembly: scan exec commands for known patterns
and return plausible fake values rather than interpreting shell
pipelines. Waiting on more real-world bot examples before implementing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 18:01:42 +01:00
0133d956a5 feat: capture SSH exec commands (PLAN.md 4.4)
Bots often send commands via `ssh user@host <command>` (exec request)
rather than requesting an interactive shell. These were previously
rejected silently. Now exec commands are captured, stored on the session
record, and displayed in the web UI session detail page.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:43:11 +01:00
3c20e854aa docs: add plan for capturing SSH exec commands (PLAN.md 4.4)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:25:52 +01:00
090dbec390 chore: bump version to 0.10.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:55:10 +01:00
df860b3061 feat: add new Prometheus metrics and bearer token auth for /metrics
Add 6 new Prometheus metrics for richer observability:
- auth_attempts_by_country_total (counter by country)
- commands_executed_total (counter by shell via OnCommand callback)
- human_score (histogram of final detection scores)
- storage_login_attempts_total, storage_unique_ips, storage_sessions_total
  (gauges via custom collector querying GetDashboardStats on each scrape)

Add optional bearer token authentication for the /metrics endpoint via
web.metrics_token config option. Uses crypto/subtle.ConstantTimeCompare.
Empty token (default) means no auth for backwards compatibility.

Also adds "cisco" to pre-initialized session/command metric labels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:54:29 +01:00
9aecc7ce02 chore: bump version to 0.9.0
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:29:37 +01:00
94f1f1c266 feat: add GeoIP country lookup with embedded DB-IP Lite database (PLAN.md 4.3)
Embeds a DB-IP Lite country MMDB (~5MB) in the binary via go:embed,
keeping the single-binary deployment story clean. Country codes are
stored alongside login attempts and sessions, shown in the dashboard
(Top IPs, Top Countries card, Recent/Active Sessions, session detail).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:27:46 +01:00
8fff893d25 docs: mark Cisco IOS shell (PLAN.md 3.2) as completed
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 15:04:51 +01:00
5ba62afec3 feat: add Cisco IOS shell with mode state machine and abbreviation matching (PLAN.md 3.2)
Implements a Cisco IOS CLI emulator with four modes (user exec, privileged exec,
global config, interface config), Cisco-style command abbreviation (e.g. sh run,
conf t), enable password flow, and realistic show command output including
running-config, interfaces, IP routes, and VLANs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 14:58:26 +01:00
058da51f86 fix: add column whitelist to queryTopN to prevent SQL injection
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 10:08:28 +01:00
adfe372d13 refactor: extract changePinModel into its own sub-model
The Change PIN screen was the only screen with its state (pinInput,
pinStage, pinMessage) stored directly on the top-level model. Extract
it into a changePinModel in screen_changepin.go to match the pattern
used by all other screens.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 09:34:56 +01:00
3163ea47dc chore: add bubbletea skill 2026-02-15 09:28:28 +01:00
ab07e6a8dc feat: add Prometheus metrics endpoint and Docker image (PLAN.md 4.2)
Add internal/metrics package with dedicated Prometheus registry exposing
SSH connection, auth attempt, session, and build info metrics. Wire into
SSH server (4 instrumentation points) and web server (/metrics endpoint).
Add dockerImage output to flake.nix via dockerTools.buildLayeredImage.
Bump version to 0.7.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 05:47:16 +01:00
75 changed files with 8464 additions and 314 deletions

View File

@@ -0,0 +1,55 @@
---
name: bubbletea
description: Browse Bubbletea TUI framework documentation and examples. Use when working with Bubbletea components, models, commands, or building terminal user interfaces in Go.
---
# Bubbletea Documentation
Bubbletea is a Go framework for building terminal user interfaces based on The Elm Architecture.
## Key Resources
When you need to understand Bubbletea patterns or find examples:
1. **Examples README** - Overview of all available examples:
https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
2. **Examples Directory** - Full source code for all examples:
https://github.com/charmbracelet/bubbletea/tree/main/examples
## How to Use
1. First, fetch the examples README to get an overview of available examples:
```
WebFetch https://github.com/charmbracelet/bubbletea/blob/main/examples/README.md
```
2. Once you identify a relevant example, fetch its source code from the examples directory.
## Common Examples to Reference
- `list` - List component with filtering
- `table` - Table component
- `textinput` - Text input handling
- `textarea` - Multi-line text input
- `viewport` - Scrollable content
- `paginator` - Pagination
- `spinner` - Loading spinners
- `progress` - Progress bars
- `tabs` - Tab navigation
- `help` - Help text/keybindings display
## Core Concepts
- **Model**: Application state
- **Update**: Handles messages and returns updated model + commands
- **View**: Renders the model to a string
- **Cmd**: Side effects that produce messages
- **Msg**: Events that trigger updates
## Related Charm Libraries
- **Bubbles**: Pre-built components (github.com/charmbracelet/bubbles)
- **Lipgloss**: Styling and layout (github.com/charmbracelet/lipgloss)
- **Glamour**: Markdown rendering (github.com/charmbracelet/glamour)

2
.gitignore vendored
View File

@@ -4,3 +4,5 @@ oubliette.toml
*.db-wal
*.db-shm
/oubliette
*.mmdb
*.mmdb.gz

81
PLAN.md
View File

@@ -150,7 +150,7 @@ Goal: Add the entertaining shell implementations.
- **Haunted:** commands gradually return stranger output, files appear/disappear, `whoami` returns different users
- **Bread crumbs:** fake .bash_history, id_rsa files, database configs pointing to other honeypots
### 3.2 Cisco IOS Shell
### 3.2 Cisco IOS Shell
- Realistic `>` and `#` prompts
- Common commands: `show running-config`, `show interfaces`, `enable`, `configure terminal`
- Fake device info that looks like a real router
@@ -171,7 +171,20 @@ Goal: Add the entertaining shell implementations.
### 3.5 Banking TUI Shell ✅
- 80s-style green-on-black bank terminal
### 3.6 Other Shell Ideas (Future)
### 3.6 PostgreSQL psql Shell ✅
- Simulates psql interactive terminal with `db_name` and `pg_version` config
- Backslash meta-commands: `\q`, `\dt`, `\d <table>`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
- SQL statement handling with multi-line buffering (semicolon-terminated)
- Canned responses for common queries (SELECT version(), current_database(), etc.)
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
### 3.7 Roomba Shell ✅
- iRobot Roomba j7+ vacuum robot interface
- Status, cleaning, scheduling, diagnostics, floor map
- Humorous history entries (cat encounters, sock tangles, sticky substances)
### 3.8 Other Shell Ideas (Future)
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
- **ELIZA therapist:** every response is a therapy question
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
@@ -183,19 +196,55 @@ Goal: Add the entertaining shell implementations.
Goal: Make the web UI great and add operational niceties.
### 4.1 Enhanced Web UI
- GeoIP lookups and world map visualization of attack sources
- Charts: attempts over time, hourly patterns, credential trends
- Session detail view with full command log
- Filtering and search
### 4.1 Enhanced Web UI
- GeoIP lookups and world map visualization of attack sources
- Charts: attempts over time, hourly patterns, credential trends
- Session detail view with full command log
- Filtering and search
### 4.2 Operational
- Prometheus metrics endpoint
- Structured logging (slog)
- Graceful shutdown
- Systemd unit file / deployment docs
### 4.2 Operational
- Prometheus metrics endpoint
- Structured logging (slog)
- Graceful shutdown
- Docker image (nix dockerTools) ✅
- Systemd unit file / deployment docs ✅
### 4.3 GeoIP
- Embed a lightweight GeoIP database or use an API
- Store country/city with each attempt
- Aggregate stats by country
### 4.3 GeoIP
- Embed a lightweight GeoIP database or use an API
- Store country/city with each attempt
- Aggregate stats by country
### 4.4 Capture SSH Exec Commands ✅
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`) ✅
- Parse the command string from the exec payload ✅
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration ✅
- Store the command on the session record before closing the channel ✅
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
- Surface exec commands in the web UI (session detail view) ✅
#### 4.4.1 Fake Exec Output
Return plausible fake output for exec commands to encourage bots to interact further.
**Approach: regex-based output assembly.** Bots typically send a single long command that chains recon commands and then echoes a summary (e.g. `echo "UNAME:$uname"`). Rather than interpreting arbitrary shell pipelines, we scan the command string for known patterns and assemble fake output.
Implementation:
- A map of common command/variable patterns to fake output strings, e.g.:
- `uname -a` / `uname -s -v -n -m``"Linux ubuntu-server 5.15.0-91-generic #101-Ubuntu SMP Tue Jan 2 15:13:10 UTC 2024 x86_64"`
- `uname -m` / `arch``"x86_64"`
- `cat /proc/uptime``"86432.71 172801.55"`
- `nproc` / `grep -c "^processor" /proc/cpuinfo``"2"`
- `cat /proc/cpuinfo` → fake cpuinfo block
- `lspci` → empty (no GPU — discourages cryptominer targeting)
- `id``"uid=0(root) gid=0(root) groups=0(root)"`
- `cat /etc/passwd` → minimal fake passwd file
- `last` → fake login entries
- `cat --help`, `ls --help` → canned GNU coreutils help text
- Scan the exec command for `echo "KEY:$var"` patterns; for each key, look up the corresponding fake value from the variable assignment earlier in the command
- If we recognise echo patterns, assemble and return the expected output
- If we don't recognise the command at all, return empty output with exit 0 (current behaviour)
- Values should draw from the existing shell config where possible (hostname, fake_user) for consistency
- New package `internal/execfake` or a file in `internal/server/` — keep it simple
Gather more real-world bot examples before implementing to ensure good coverage of common recon patterns.

View File

@@ -34,7 +34,8 @@ Key settings:
- `auth.accept_after` — accept login after N failures per IP (default `10`)
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon)
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
- `storage.db_path` — SQLite database path (default `oubliette.db`)
- `storage.retention_days` — auto-prune records older than N days (default `90`)
- `storage.retention_interval` — how often to run retention (default `1h`)
@@ -43,12 +44,21 @@ Key settings:
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
- `web.enabled` — enable the web dashboard (default `false`)
- `web.listen_addr` — web dashboard listen address (default `:8080`)
- Dashboard includes Chart.js charts (attempts over time, hourly pattern), an SVG world map choropleth colored by attack origin, and filter controls for date range / IP / country / username
- Session detail pages at `/sessions/{id}` include terminal replay via xterm.js
- `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`)
- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty)
- `detection.enabled` — enable human detection scoring (default `false`)
- `detection.threshold` — score threshold (0.01.0) for flagging sessions (default `0.6`)
- `detection.update_interval` — how often to recompute scores (default `5s`)
- `notify.webhooks` — list of webhook endpoints for notifications (see example config)
### GeoIP
Country-level GeoIP lookups are embedded in the binary using the [DB-IP Lite](https://db-ip.com/db/lite.php) database (CC-BY-4.0). The dashboard shows country alongside IPs and includes a "Top Countries" table.
For local development, run `scripts/fetch-geoip.sh` to download the MMDB file. The Nix build fetches it automatically.
### Run
```sh
@@ -61,6 +71,9 @@ Test with:
ssh -o StrictHostKeyChecking=no -p 2222 root@localhost
```
SSH exec commands (`ssh user@host <command>`) are also captured and stored on the session record.
### NixOS Module
Add the flake as an input and enable the service:
@@ -82,3 +95,15 @@ Add the flake as an input and enable the service:
```
Alternatively, use `configFile` to pass a pre-written TOML file instead of `settings`.
### Docker
Build a Docker image via nix:
```sh
nix build .#dockerImage
docker load < result
docker run -v /path/to/data:/data -p 2222:2222 -p 8080:8080 oubliette:0.8.0
```
Place your `oubliette.toml` in the data volume. The container exposes ports 2222 (SSH) and 8080 (web/metrics).

View File

@@ -13,13 +13,14 @@ import (
"syscall"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/server"
"git.t-juice.club/torjus/oubliette/internal/storage"
"git.t-juice.club/torjus/oubliette/internal/web"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/server"
"code.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/web"
)
const Version = "0.6.0"
const Version = "0.18.0"
func main() {
if err := run(); err != nil {
@@ -75,9 +76,13 @@ func run() error {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
m := metrics.New(Version)
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
m.RegisterStoreCollector(instrumentedStore)
srv, err := server.New(*cfg, store, logger)
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
srv, err := server.New(*cfg, instrumentedStore, logger, m)
if err != nil {
return fmt.Errorf("create server: %w", err)
}
@@ -86,7 +91,12 @@ func run() error {
// Start web server if enabled.
if cfg.Web.Enabled {
webHandler, err := web.NewServer(store, logger.With("component", "web"))
var metricsHandler http.Handler
if *cfg.Web.MetricsEnabled {
metricsHandler = m.Handler()
}
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
if err != nil {
return fmt.Errorf("create web server: %w", err)
}

View File

@@ -18,19 +18,44 @@
pkgs = nixpkgs.legacyPackages.${system};
mainGo = builtins.readFile ./cmd/oubliette/main.go;
version = builtins.head (builtins.match ''.*const Version = "([^"]+)".*'' mainGo);
geoipDb = pkgs.fetchurl {
url = "https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz";
hash = "sha256-xmQZEJZ5WzE9uQww1Sdb8248l+liYw46tjbfJeu945Q=";
};
in
{
default = pkgs.buildGoModule {
pname = "oubliette";
inherit version;
src = ./.;
vendorHash = "sha256-oH92jRD+2niIf7xAX1HeZvhux8lVqj43Qxdef5GjX4Q=";
vendorHash = "sha256-/zxK6CABLYBNtuSOI8dIVgMNxKiDIcbZUS7bQR5TenA=";
subPackages = [ "cmd/oubliette" ];
nativeBuildInputs = [ pkgs.gzip ];
preBuild = ''
gunzip -c ${geoipDb} > internal/geoip/dbip-country-lite.mmdb
'';
meta = {
description = "SSH honeypot";
mainProgram = "oubliette";
};
};
dockerImage = pkgs.dockerTools.buildLayeredImage {
name = "oubliette";
tag = version;
contents = [ self.packages.${system}.default ];
config = {
Entrypoint = [ "/bin/oubliette" ];
Cmd = [ "-config" "/data/oubliette.toml" ];
ExposedPorts = {
"2222/tcp" = {};
"8080/tcp" = {};
};
Volumes = {
"/data" = {};
};
};
};
});
devShells = forAllSystems (system:

13
go.mod
View File

@@ -1,4 +1,4 @@
module git.t-juice.club/torjus/oubliette
module code.t-juice.club/torjus/oubliette
go 1.25.5
@@ -7,18 +7,24 @@ require (
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/google/uuid v1.6.0
github.com/oschwald/maxminddb-golang v1.13.1
github.com/prometheus/client_golang v1.23.2
github.com/prometheus/client_model v0.6.2
golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.45.0
)
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
@@ -26,13 +32,18 @@ require (
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect

46
go.sum
View File

@@ -2,6 +2,10 @@ github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk
github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
@@ -14,16 +18,29 @@ github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0G
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -38,15 +55,37 @@ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELU
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
@@ -65,6 +104,13 @@ golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=

View File

@@ -5,7 +5,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
const (

View File

@@ -6,7 +6,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {

View File

@@ -21,15 +21,18 @@ type Config struct {
}
type WebConfig struct {
Enabled bool `toml:"enabled"`
ListenAddr string `toml:"listen_addr"`
Enabled bool `toml:"enabled"`
ListenAddr string `toml:"listen_addr"`
MetricsEnabled *bool `toml:"metrics_enabled"`
MetricsToken string `toml:"metrics_token"`
}
type ShellConfig struct {
Hostname string `toml:"hostname"`
Banner string `toml:"banner"`
FakeUser string `toml:"fake_user"`
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
Hostname string `toml:"hostname"`
Banner string `toml:"banner"`
FakeUser string `toml:"fake_user"`
UsernameRoutes map[string]string `toml:"username_routes"`
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
}
type StorageConfig struct {
@@ -143,6 +146,10 @@ func applyDefaults(cfg *Config) {
if cfg.Web.ListenAddr == "" {
cfg.Web.ListenAddr = ":8080"
}
if cfg.Web.MetricsEnabled == nil {
t := true
cfg.Web.MetricsEnabled = &t
}
if cfg.Shell.Hostname == "" {
cfg.Shell.Hostname = "ubuntu-server"
}
@@ -159,9 +166,10 @@ func applyDefaults(cfg *Config) {
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
var knownShellKeys = map[string]bool{
"hostname": true,
"banner": true,
"fake_user": true,
"hostname": true,
"banner": true,
"fake_user": true,
"username_routes": true,
}
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.

View File

@@ -282,6 +282,22 @@ password = "toor"
}
}
func TestLoadMetricsToken(t *testing.T) {
content := `
[web]
enabled = true
metrics_token = "my-secret-token"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Web.MetricsToken != "my-secret-token" {
t.Errorf("metrics_token = %q, want %q", cfg.Web.MetricsToken, "my-secret-token")
}
}
func TestLoadMissingFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.toml")
if err == nil {
@@ -297,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
}
}
func TestLoadUsernameRoutes(t *testing.T) {
content := `
[shell]
hostname = "myhost"
[shell.username_routes]
postgres = "psql"
admin = "bash"
[shell.bash]
custom_key = "value"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Shell.UsernameRoutes == nil {
t.Fatal("UsernameRoutes should not be nil")
}
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
}
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
}
// username_routes should NOT appear in the Shells map.
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
t.Error("username_routes should not appear in Shells map")
}
// bash should still appear in Shells map.
if _, ok := cfg.Shell.Shells["bash"]; !ok {
t.Error("Shells[\"bash\"] should still be present")
}
}
func writeTemp(t *testing.T, content string) string {
t.Helper()
path := filepath.Join(t.TempDir(), "config.toml")

51
internal/geoip/geoip.go Normal file
View File

@@ -0,0 +1,51 @@
package geoip
import (
_ "embed"
"net"
"github.com/oschwald/maxminddb-golang"
)
//go:embed dbip-country-lite.mmdb
var mmdbData []byte
// Reader provides country-level GeoIP lookups using an embedded DB-IP Lite database.
type Reader struct {
db *maxminddb.Reader
}
// New opens the embedded MMDB and returns a ready-to-use Reader.
func New() (*Reader, error) {
db, err := maxminddb.FromBytes(mmdbData)
if err != nil {
return nil, err
}
return &Reader{db: db}, nil
}
type countryRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
// Lookup returns the ISO 3166-1 alpha-2 country code for the given IP address,
// or an empty string if the lookup fails or no result is found.
func (r *Reader) Lookup(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ""
}
var record countryRecord
if err := r.db.Lookup(ip, &record); err != nil {
return ""
}
return record.Country.ISOCode
}
// Close releases resources held by the reader.
func (r *Reader) Close() error {
return r.db.Close()
}

View File

@@ -0,0 +1,44 @@
package geoip
import "testing"
func TestLookup(t *testing.T) {
reader, err := New()
if err != nil {
t.Fatalf("New: %v", err)
}
defer reader.Close()
tests := []struct {
ip string
want string
}{
{"8.8.8.8", "US"},
{"1.1.1.1", "AU"},
{"invalid", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
got := reader.Lookup(tt.ip)
if got != tt.want {
t.Errorf("Lookup(%q) = %q, want %q", tt.ip, got, tt.want)
}
})
}
}
func TestLookupPrivateIP(t *testing.T) {
reader, err := New()
if err != nil {
t.Fatalf("New: %v", err)
}
defer reader.Close()
// Private IPs should return empty string (no country).
got := reader.Lookup("10.0.0.1")
if got != "" {
t.Errorf("Lookup(10.0.0.1) = %q, want empty", got)
}
}

178
internal/metrics/metrics.go Normal file
View File

@@ -0,0 +1,178 @@
package metrics
import (
"context"
"net/http"
"code.t-juice.club/torjus/oubliette/internal/storage"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// Metrics holds all Prometheus collectors for the honeypot.
type Metrics struct {
registry *prometheus.Registry
SSHConnectionsTotal *prometheus.CounterVec
SSHConnectionsActive prometheus.Gauge
AuthAttemptsTotal *prometheus.CounterVec
AuthAttemptsByCountry *prometheus.CounterVec
CommandsExecuted *prometheus.CounterVec
HumanScore prometheus.Histogram
SessionsTotal *prometheus.CounterVec
SessionsActive prometheus.Gauge
SessionDuration prometheus.Histogram
ExecCommandsTotal prometheus.Counter
BuildInfo *prometheus.GaugeVec
StorageQueryDuration *prometheus.HistogramVec
StorageQueryErrors *prometheus.CounterVec
}
// New creates a new Metrics instance with all collectors registered.
func New(version string) *Metrics {
reg := prometheus.NewRegistry()
m := &Metrics{
registry: reg,
SSHConnectionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_ssh_connections_total",
Help: "Total SSH connections received.",
}, []string{"outcome"}),
SSHConnectionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "oubliette_ssh_connections_active",
Help: "Current active SSH connections.",
}),
AuthAttemptsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_auth_attempts_total",
Help: "Total authentication attempts.",
}, []string{"result", "reason"}),
AuthAttemptsByCountry: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_auth_attempts_by_country_total",
Help: "Total authentication attempts by country.",
}, []string{"country"}),
CommandsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_commands_executed_total",
Help: "Total commands executed in shells.",
}, []string{"shell"}),
HumanScore: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "oubliette_human_score",
Help: "Distribution of final human detection scores.",
Buckets: prometheus.LinearBuckets(0, 0.1, 11), // 0.0, 0.1, ..., 1.0
}),
SessionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_sessions_total",
Help: "Total sessions created.",
}, []string{"shell"}),
SessionsActive: prometheus.NewGauge(prometheus.GaugeOpts{
Name: "oubliette_sessions_active",
Help: "Current active sessions.",
}),
SessionDuration: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "oubliette_session_duration_seconds",
Help: "Session duration in seconds.",
Buckets: []float64{1, 5, 10, 30, 60, 120, 300, 600, 1800, 3600},
}),
ExecCommandsTotal: prometheus.NewCounter(prometheus.CounterOpts{
Name: "oubliette_exec_commands_total",
Help: "Total SSH exec commands received.",
}),
BuildInfo: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "oubliette_build_info",
Help: "Build information. Always 1.",
}, []string{"version"}),
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "oubliette_storage_query_duration_seconds",
Help: "Duration of storage query calls in seconds.",
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}, []string{"method"}),
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "oubliette_storage_query_errors_total",
Help: "Total storage query errors.",
}, []string{"method"}),
}
reg.MustRegister(
collectors.NewGoCollector(),
collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}),
m.SSHConnectionsTotal,
m.SSHConnectionsActive,
m.AuthAttemptsTotal,
m.AuthAttemptsByCountry,
m.CommandsExecuted,
m.HumanScore,
m.SessionsTotal,
m.SessionsActive,
m.SessionDuration,
m.ExecCommandsTotal,
m.BuildInfo,
m.StorageQueryDuration,
m.StorageQueryErrors,
)
m.BuildInfo.WithLabelValues(version).Set(1)
// Initialize label combinations so they appear in Gather/output.
for _, outcome := range []string{"accepted", "rejected_handshake", "rejected_max_connections"} {
m.SSHConnectionsTotal.WithLabelValues(outcome)
}
for _, reason := range []string{"static_credential", "remembered_credential", "threshold_reached", "rejected"} {
m.AuthAttemptsTotal.WithLabelValues("accepted", reason)
m.AuthAttemptsTotal.WithLabelValues("rejected", reason)
}
for _, sh := range []string{"bash", "fridge", "banking", "adventure", "cisco"} {
m.SessionsTotal.WithLabelValues(sh)
m.CommandsExecuted.WithLabelValues(sh)
}
return m
}
// RegisterStoreCollector registers a collector that queries storage stats on each scrape.
func (m *Metrics) RegisterStoreCollector(store storage.Store) {
m.registry.MustRegister(&storeCollector{store: store})
}
// Handler returns an http.Handler that serves Prometheus metrics.
func (m *Metrics) Handler() http.Handler {
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
}
// storeCollector implements prometheus.Collector, querying storage on each scrape.
type storeCollector struct {
store storage.Store
}
var (
storageLoginAttemptsDesc = prometheus.NewDesc(
"oubliette_storage_login_attempts_total",
"Total login attempts in storage.",
nil, nil,
)
storageUniqueIPsDesc = prometheus.NewDesc(
"oubliette_storage_unique_ips",
"Unique IPs in storage.",
nil, nil,
)
storageSessionsDesc = prometheus.NewDesc(
"oubliette_storage_sessions_total",
"Total sessions in storage.",
nil, nil,
)
)
func (c *storeCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- storageLoginAttemptsDesc
ch <- storageUniqueIPsDesc
ch <- storageSessionsDesc
}
func (c *storeCollector) Collect(ch chan<- prometheus.Metric) {
stats, err := c.store.GetDashboardStats(context.Background())
if err != nil {
return
}
ch <- prometheus.MustNewConstMetric(storageLoginAttemptsDesc, prometheus.GaugeValue, float64(stats.TotalAttempts))
ch <- prometheus.MustNewConstMetric(storageUniqueIPsDesc, prometheus.GaugeValue, float64(stats.UniqueIPs))
ch <- prometheus.MustNewConstMetric(storageSessionsDesc, prometheus.GaugeValue, float64(stats.TotalSessions))
}

View File

@@ -0,0 +1,142 @@
package metrics
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func TestNew(t *testing.T) {
m := New("1.2.3")
// Gather all metrics and check expected names exist.
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
want := map[string]bool{
"oubliette_ssh_connections_total": false,
"oubliette_ssh_connections_active": false,
"oubliette_auth_attempts_total": false,
"oubliette_commands_executed_total": false,
"oubliette_human_score": false,
"oubliette_sessions_total": false,
"oubliette_sessions_active": false,
"oubliette_session_duration_seconds": false,
"oubliette_build_info": false,
}
for _, f := range families {
if _, ok := want[f.GetName()]; ok {
want[f.GetName()] = true
}
}
for name, found := range want {
if !found {
t.Errorf("metric %q not registered", name)
}
}
}
func TestAuthAttemptsByCountry(t *testing.T) {
m := New("1.0.0")
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
m.AuthAttemptsByCountry.WithLabelValues("DE").Inc()
m.AuthAttemptsByCountry.WithLabelValues("US").Inc()
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
var found bool
for _, f := range families {
if f.GetName() == "oubliette_auth_attempts_by_country_total" {
found = true
if len(f.GetMetric()) != 2 {
t.Errorf("expected 2 label pairs (US, DE), got %d", len(f.GetMetric()))
}
}
}
if !found {
t.Error("oubliette_auth_attempts_by_country_total not found after incrementing")
}
}
func TestHandler(t *testing.T) {
m := New("1.2.3")
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
m.Handler().ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body, err := io.ReadAll(w.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
if !strings.Contains(string(body), `oubliette_build_info{version="1.2.3"} 1`) {
t.Errorf("response should contain build_info metric, got:\n%s", body)
}
}
func TestStoreCollector(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
// Seed some data.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
t.Fatalf("CreateSession: %v", err)
}
m := New("test")
m.RegisterStoreCollector(store)
families, err := m.registry.Gather()
if err != nil {
t.Fatalf("gather: %v", err)
}
wantMetrics := map[string]float64{
"oubliette_storage_login_attempts_total": 2,
"oubliette_storage_unique_ips": 2,
"oubliette_storage_sessions_total": 1,
}
for _, f := range families {
expected, ok := wantMetrics[f.GetName()]
if !ok {
continue
}
if len(f.GetMetric()) == 0 {
t.Errorf("metric %q has no samples", f.GetName())
continue
}
got := f.GetMetric()[0].GetGauge().GetValue()
if got != expected {
t.Errorf("metric %q = %f, want %f", f.GetName(), got, expected)
}
delete(wantMetrics, f.GetName())
}
for name := range wantMetrics {
t.Errorf("metric %q not found in gather output", name)
}
}

View File

@@ -10,7 +10,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
// Event types.

View File

@@ -10,7 +10,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/config"
)
func testSession() SessionInfo {

View File

@@ -12,16 +12,22 @@ import (
"os"
"time"
"git.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/detection"
"git.t-juice.club/torjus/oubliette/internal/notify"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/shell/adventure"
"git.t-juice.club/torjus/oubliette/internal/shell/banking"
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/auth"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/detection"
"code.t-juice.club/torjus/oubliette/internal/geoip"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/notify"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
"code.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
@@ -34,9 +40,11 @@ type Server struct {
connSem chan struct{} // semaphore limiting concurrent connections
shellRegistry *shell.Registry
notifier notify.Sender
metrics *metrics.Metrics
geoip *geoip.Reader
}
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
registry := shell.NewRegistry()
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
return nil, fmt.Errorf("registering bash shell: %w", err)
@@ -50,6 +58,23 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
return nil, fmt.Errorf("registering adventure shell: %w", err)
}
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
return nil, fmt.Errorf("registering cisco shell: %w", err)
}
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
return nil, fmt.Errorf("registering psql shell: %w", err)
}
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
return nil, fmt.Errorf("registering roomba shell: %w", err)
}
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
return nil, fmt.Errorf("registering tetris shell: %w", err)
}
geo, err := geoip.New()
if err != nil {
return nil, fmt.Errorf("opening geoip database: %w", err)
}
s := &Server{
cfg: cfg,
@@ -59,6 +84,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
shellRegistry: registry,
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
metrics: m,
geoip: geo,
}
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
@@ -76,6 +103,8 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server,
}
func (s *Server) ListenAndServe(ctx context.Context) error {
defer s.geoip.Close()
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
if err != nil {
return fmt.Errorf("listen: %w", err)
@@ -102,11 +131,16 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
// Enforce max concurrent connections.
select {
case s.connSem <- struct{}{}:
s.metrics.SSHConnectionsActive.Inc()
go func() {
defer func() { <-s.connSem }()
defer func() {
<-s.connSem
s.metrics.SSHConnectionsActive.Dec()
}()
s.handleConn(conn)
}()
default:
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
conn.Close()
}
@@ -118,11 +152,13 @@ func (s *Server) handleConn(conn net.Conn) {
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
return
}
defer sshConn.Close()
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
s.logger.Info("SSH connection established",
"remote_addr", sshConn.RemoteAddr(),
"user", sshConn.User(),
@@ -161,6 +197,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
}
}
// Second priority: username-based route.
if selectedShell == nil {
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
sh, found := s.shellRegistry.Get(shellName)
if found {
selectedShell = sh
} else {
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
}
}
}
// Lowest priority: random selection.
if selectedShell == nil {
var err error
selectedShell, err = s.shellRegistry.Select()
@@ -171,11 +219,17 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}
ip := extractIP(conn.RemoteAddr())
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
country := s.geoip.Lookup(ip)
sessionStart := time.Now()
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
if err != nil {
s.logger.Error("failed to create session", "err", err)
} else {
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
s.metrics.SessionsActive.Inc()
defer func() {
s.metrics.SessionsActive.Dec()
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
s.logger.Error("failed to end session", "err", err)
}
@@ -201,14 +255,24 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
defer s.notifier.CleanupSession(sessionID)
// Handle session requests (pty-req, shell, etc.)
// Handle session requests (pty-req, shell, exec, etc.)
execCh := make(chan string, 1)
go func() {
defer close(execCh)
for req := range requests {
switch req.Type {
case "pty-req", "shell":
if req.WantReply {
req.Reply(true, nil)
}
case "exec":
if req.WantReply {
req.Reply(true, nil)
}
var payload struct{ Command string }
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
execCh <- payload.Command
}
default:
if req.WantReply {
req.Reply(false, nil)
@@ -217,6 +281,29 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}
}()
// Check for exec request before proceeding to interactive shell.
select {
case cmd, ok := <-execCh:
if ok && cmd != "" {
s.logger.Info("exec command received",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
"command", cmd,
)
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
}
s.metrics.ExecCommandsTotal.Inc()
// Send exit-status 0 and close channel.
exitPayload := make([]byte, 4) // uint32(0)
_, _ = channel.SendRequest("exit-status", false, exitPayload)
return
}
case <-time.After(500 * time.Millisecond):
// No exec request within timeout — proceed with interactive shell.
}
// Build session context.
var shellCfg map[string]any
if s.cfg.Shell.Shells != nil {
@@ -234,6 +321,9 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
Banner: s.cfg.Shell.Banner,
FakeUser: s.cfg.Shell.FakeUser,
},
OnCommand: func(sh string) {
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
},
}
// Wrap channel in RecordingChannel.
@@ -269,6 +359,7 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}
if scorer != nil {
finalScore := scorer.Score()
s.metrics.HumanScore.Observe(finalScore)
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
}
@@ -318,6 +409,12 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
ip := extractIP(conn.RemoteAddr())
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
if d.Accepted {
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
} else {
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
}
s.logger.Info("auth attempt",
"remote_addr", conn.RemoteAddr(),
"username", conn.User(),
@@ -325,7 +422,11 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.
"reason", d.Reason,
)
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil {
country := s.geoip.Lookup(ip)
if country != "" {
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
}
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
s.logger.Error("failed to record login attempt", "err", err)
}

View File

@@ -11,8 +11,10 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/auth"
"code.t-juice.club/torjus/oubliette/internal/config"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
@@ -120,7 +122,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
store := storage.NewMemoryStore()
srv, err := New(cfg, store, logger)
srv, err := New(cfg, store, logger, metrics.New("test"))
if err != nil {
t.Fatalf("creating server: %v", err)
}
@@ -251,6 +253,137 @@ func TestIntegrationSSHConnect(t *testing.T) {
}
})
// Test exec command capture.
t.Run("exec_command", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.Password("toor")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
// Run a command via exec (no PTY, no shell).
if err := session.Run("uname -a"); err != nil {
// Run returns an error because the server closes the channel,
// but that's expected.
_ = err
}
// Give the server a moment to store the command.
time.Sleep(200 * time.Millisecond)
// Verify the exec command was captured.
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
var foundExec bool
for _, s := range sessions {
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
foundExec = true
break
}
}
if !foundExec {
t.Error("expected a session with exec_command='uname -a'")
}
})
// Test username route: add username_routes so that "postgres" gets psql shell.
t.Run("username_route", func(t *testing.T) {
// Reconfigure with username routes.
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
// Need to get the "postgres" user in via static creds or threshold.
// Use static creds for simplicity.
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
config.Credential{Username: "postgres", Password: "postgres"},
)
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
defer func() {
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
}()
clientCfg := &ssh.ClientConfig{
User: "postgres",
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
t.Fatalf("request pty: %v", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("stdin pipe: %v", err)
}
var output bytes.Buffer
session.Stdout = &output
if err := session.Shell(); err != nil {
t.Fatalf("shell: %v", err)
}
// Wait for the psql banner.
time.Sleep(500 * time.Millisecond)
// Send \q to quit.
stdin.Write([]byte(`\q` + "\r"))
time.Sleep(200 * time.Millisecond)
session.Wait()
out := output.String()
if !strings.Contains(out, "psql") {
t.Errorf("output should contain psql banner, got: %s", out)
}
// Verify session was created with shell name "psql".
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
var foundPsql bool
for _, s := range sessions {
if s.ShellName == "psql" && s.Username == "postgres" {
foundPsql = true
break
}
}
if !foundPsql {
t.Error("expected a session with shell_name='psql' for user 'postgres'")
}
})
// Test threshold acceptance: after enough failed dials, a subsequent
// dial with the same credentials should succeed via threshold or
// remembered credential.

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
@@ -75,6 +75,9 @@ func (a *AdventureShell) Handle(ctx context.Context, sess *shell.SessionContext,
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("adventure")
}
if result.exit {
return nil

View File

@@ -8,8 +8,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
@@ -22,7 +22,7 @@ func (r *rwCloser) Close() error { return nil }
func runShell(t *testing.T, commands string) string {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
sess := &shell.SessionContext{
SessionID: sessID,
@@ -287,7 +287,7 @@ func TestEthernetCable(t *testing.T) {
func TestSessionLogs(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "adventure", "")
sess := &shell.SessionContext{
SessionID: sessID,

View File

@@ -9,7 +9,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute

View File

@@ -9,15 +9,15 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// newTestModel creates a model with a test session context.
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "banker", "banking")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "banker", "banking", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "banker",

View File

@@ -8,7 +8,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
type screen int
@@ -42,10 +42,8 @@ type model struct {
transfer transferModel
history historyModel
messages messagesModel
admin adminModel
pinInput string
pinStage int // 0=old, 1=new, 2=confirm, 3=done
pinMessage string
admin adminModel
changePin changePinModel
}
func newModel(sess *shell.SessionContext, bankName, terminalID, region string) *model {
@@ -130,7 +128,7 @@ func (m *model) View() string {
case screenMessages:
content = m.messages.View()
case screenChangePin:
content = m.viewChangePin()
content = m.changePin.View()
case screenAdmin:
content = m.admin.View()
}
@@ -182,9 +180,7 @@ func (m *model) updateMenu(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 5", "SECURE MESSAGES"))
case "6":
m.screen = screenChangePin
m.pinInput = ""
m.pinStage = 0
m.pinMessage = ""
m.changePin = newChangePinModel()
return m, tea.Batch(tea.ClearScreen, logAction(m.sess, "MENU 6", "CHANGE PIN"))
case "7":
m.quitting = true
@@ -278,95 +274,19 @@ func (m *model) updateMessages(msg tea.Msg) (tea.Model, tea.Cmd) {
}
func (m *model) updateChangePin(msg tea.Msg) (tea.Model, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
prevStage := m.changePin.stage
var cmd tea.Cmd
m.changePin, cmd = m.changePin.Update(msg)
// Log successful PIN change.
if m.changePin.stage == 3 && prevStage != 3 {
cmd = tea.Batch(cmd, logAction(m.sess, "CHANGE PIN", "PIN CHANGED SUCCESSFULLY"))
}
if m.pinStage == 3 {
return m, m.goToMenu()
if m.changePin.done {
return m, tea.Batch(cmd, m.goToMenu())
}
switch keyMsg.Type {
case tea.KeyEnter:
switch m.pinStage {
case 0:
if m.pinInput != "" {
m.pinStage = 1
m.pinInput = ""
}
case 1:
if len(m.pinInput) >= 4 {
m.pinMessage = m.pinInput
m.pinStage = 2
m.pinInput = ""
}
case 2:
if m.pinInput == m.pinMessage {
m.pinStage = 3
return m, logAction(m.sess, "CHANGE PIN", "PIN CHANGED SUCCESSFULLY")
}
m.pinInput = ""
m.pinMessage = ""
m.pinStage = 1
}
case tea.KeyEscape:
return m, m.goToMenu()
case tea.KeyBackspace:
if len(m.pinInput) > 0 {
m.pinInput = m.pinInput[:len(m.pinInput)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.pinInput) < 12 {
m.pinInput += ch
}
}
return m, nil
}
func (m *model) viewChangePin() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("CHANGE PIN"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
if m.pinStage == 3 {
b.WriteString(titleStyle.Render(" PIN CHANGED SUCCESSFULLY"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" YOUR NEW PIN IS NOW ACTIVE."))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" PLEASE USE YOUR NEW PIN FOR ALL FUTURE TRANSACTIONS."))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
} else {
prompts := []string{" CURRENT PIN: ", " NEW PIN: ", " CONFIRM PIN: "}
for i := 0; i < m.pinStage; i++ {
b.WriteString(baseStyle.Render(prompts[i]))
b.WriteString(baseStyle.Render(strings.Repeat("*", 4)))
b.WriteString("\n")
}
if m.pinStage < 3 {
b.WriteString(titleStyle.Render(prompts[m.pinStage]))
masked := strings.Repeat("*", len(m.pinInput))
b.WriteString(inputStyle.Render(masked))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
}
b.WriteString("\n")
if m.pinStage == 1 {
b.WriteString(dimStyle.Render(" PIN MUST BE AT LEAST 4 CHARACTERS"))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
}
b.WriteString("\n")
return b.String()
return m, cmd
}
func (m *model) updateAdmin(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -425,6 +345,9 @@ func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
defer cancel()
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
}
if sess.OnCommand != nil {
sess.OnCommand("banking")
}
return nil
}
}

View File

@@ -0,0 +1,111 @@
package banking
import (
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type changePinModel struct {
input string
stage int // 0=old, 1=new, 2=confirm, 3=done
newPin string
done bool
}
func newChangePinModel() changePinModel {
return changePinModel{}
}
func (m changePinModel) Update(msg tea.Msg) (changePinModel, tea.Cmd) {
keyMsg, ok := msg.(tea.KeyMsg)
if !ok {
return m, nil
}
if m.stage == 3 {
m.done = true
return m, nil
}
switch keyMsg.Type {
case tea.KeyEnter:
switch m.stage {
case 0:
if m.input != "" {
m.stage = 1
m.input = ""
}
case 1:
if len(m.input) >= 4 {
m.newPin = m.input
m.stage = 2
m.input = ""
}
case 2:
if m.input == m.newPin {
m.stage = 3
} else {
m.input = ""
m.newPin = ""
m.stage = 1
}
}
case tea.KeyEscape:
m.done = true
case tea.KeyBackspace:
if len(m.input) > 0 {
m.input = m.input[:len(m.input)-1]
}
default:
ch := keyMsg.String()
if len(ch) == 1 && ch[0] >= 32 && ch[0] < 127 && len(m.input) < 12 {
m.input += ch
}
}
return m, nil
}
func (m changePinModel) View() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(centerText("CHANGE PIN"))
b.WriteString("\n\n")
b.WriteString(thinDivider())
b.WriteString("\n\n")
if m.stage == 3 {
b.WriteString(titleStyle.Render(" PIN CHANGED SUCCESSFULLY"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" YOUR NEW PIN IS NOW ACTIVE."))
b.WriteString("\n")
b.WriteString(baseStyle.Render(" PLEASE USE YOUR NEW PIN FOR ALL FUTURE TRANSACTIONS."))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" PRESS ANY KEY TO RETURN TO MAIN MENU"))
} else {
prompts := []string{" CURRENT PIN: ", " NEW PIN: ", " CONFIRM PIN: "}
for i := 0; i < m.stage; i++ {
b.WriteString(baseStyle.Render(prompts[i]))
b.WriteString(baseStyle.Render(strings.Repeat("*", 4)))
b.WriteString("\n")
}
if m.stage < 3 {
b.WriteString(titleStyle.Render(prompts[m.stage]))
masked := strings.Repeat("*", len(m.input))
b.WriteString(inputStyle.Render(masked))
b.WriteString(inputStyle.Render("_"))
b.WriteString("\n")
}
b.WriteString("\n")
if m.stage == 1 {
b.WriteString(dimStyle.Render(" PIN MUST BE AT LEAST 4 CHARACTERS"))
b.WriteString("\n")
}
b.WriteString("\n")
b.WriteString(dimStyle.Render(" PRESS ESC TO RETURN TO MAIN MENU"))
}
b.WriteString("\n")
return b.String()
}

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
@@ -86,6 +86,9 @@ func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw i
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("bash")
}
if result.exit {
return nil

View File

@@ -9,8 +9,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
@@ -116,7 +116,7 @@ func TestReadLineCtrlD(t *testing.T) {
func TestBashShellHandle(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash", "")
sess := &shell.SessionContext{
SessionID: sessID,
@@ -166,7 +166,7 @@ func TestBashShellHandle(t *testing.T) {
func TestBashShellFakeUser(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash", "")
sess := &shell.SessionContext{
SessionID: sessID,

View File

@@ -0,0 +1,206 @@
package cisco
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// CiscoShell emulates a Cisco IOS CLI.
type CiscoShell struct{}
// NewCiscoShell returns a new CiscoShell instance.
func NewCiscoShell() *CiscoShell {
return &CiscoShell{}
}
func (c *CiscoShell) Name() string { return "cisco" }
func (c *CiscoShell) Description() string { return "Cisco IOS CLI emulator" }
func (c *CiscoShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
hostname := configString(sess.ShellConfig, "hostname", "Router")
model := configString(sess.ShellConfig, "model", "C2960")
iosVersion := configString(sess.ShellConfig, "ios_version", "15.0(2)SE11")
enablePass := configString(sess.ShellConfig, "enable_password", "")
state := newIOSState(hostname, model, iosVersion, enablePass)
// IOS just shows a blank line then the prompt after SSH auth.
fmt.Fprint(rw, "\r\n")
for {
prompt := state.prompt()
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
// Check for Ctrl+Z (^Z) — return to privileged exec.
if trimmed == "\x1a" || trimmed == "^Z" {
if state.mode == modeGlobalConfig || state.mode == modeInterfaceConfig {
state.mode = modePrivilegedExec
state.currentIf = ""
}
continue
}
// Handle "enable" specially — it needs password prompting.
if state.mode == modeUserExec && isEnableCommand(trimmed) {
output := handleEnable(ctx, state, rw)
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("cisco")
}
continue
}
result := state.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("cisco")
}
if result.exit {
return nil
}
}
}
// isEnableCommand checks if input resolves to "enable" in user exec mode.
func isEnableCommand(input string) bool {
words := strings.Fields(input)
if len(words) != 1 {
return false
}
w := strings.ToLower(words[0])
enable := "enable"
return len(w) >= 2 && len(w) <= len(enable) && enable[:len(w)] == w
}
// handleEnable manages the enable password prompt flow.
// Returns the output string (for logging).
func handleEnable(ctx context.Context, state *iosState, rw io.ReadWriter) string {
const maxAttempts = 3
hadFailure := false
for range maxAttempts {
fmt.Fprint(rw, "Password: ")
password, err := readPassword(ctx, rw)
if err != nil {
return ""
}
fmt.Fprint(rw, "\r\n")
if state.enablePass == "" {
// No password configured — accept after one failed attempt.
if hadFailure {
state.mode = modePrivilegedExec
return ""
}
hadFailure = true
} else if password == state.enablePass {
state.mode = modePrivilegedExec
return ""
}
}
output := "% Bad passwords"
fmt.Fprintf(rw, "%s\r\n", output)
return output
}
// readPassword reads a password without echoing characters.
func readPassword(ctx context.Context, rw io.ReadWriter) (string, error) {
var buf []byte
b := make([]byte, 1)
for {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
n, err := rw.Read(b)
if err != nil {
return "", err
}
if n == 0 {
continue
}
ch := b[0]
switch {
case ch == '\r' || ch == '\n':
return string(buf), nil
case ch == 4: // Ctrl+D
return string(buf), io.EOF
case ch == 3: // Ctrl+C
return "", io.EOF
case ch == 127 || ch == 8: // Backspace/DEL
if len(buf) > 0 {
buf = buf[:len(buf)-1]
}
case ch == 27: // ESC sequence
next := make([]byte, 1)
if n, _ := rw.Read(next); n > 0 && next[0] == '[' {
rw.Read(next)
}
case ch >= 32 && ch < 127:
buf = append(buf, ch)
// Don't echo.
}
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,531 @@
package cisco
import (
"testing"
)
// --- Abbreviation resolution tests ---
func TestResolveAbbreviationExact(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "shutdown"},
}
got, err := resolveAbbreviation("show", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
func TestResolveAbbreviationUnique(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
{name: "exit"},
}
got, err := resolveAbbreviation("sh", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
func TestResolveAbbreviationAmbiguous(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "shutdown"},
}
_, err := resolveAbbreviation("sh", entries)
if err == nil {
t.Fatal("expected ambiguous error, got nil")
}
if err.Error() != "ambiguous" {
t.Errorf("got error %q, want %q", err.Error(), "ambiguous")
}
}
func TestResolveAbbreviationUnknown(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
}
_, err := resolveAbbreviation("xyz", entries)
if err == nil {
t.Fatal("expected unknown error, got nil")
}
if err.Error() != "unknown" {
t.Errorf("got error %q, want %q", err.Error(), "unknown")
}
}
func TestResolveAbbreviationCaseInsensitive(t *testing.T) {
entries := []commandEntry{
{name: "show"},
{name: "enable"},
}
got, err := resolveAbbreviation("SH", entries)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "show" {
t.Errorf("got %q, want %q", got, "show")
}
}
// --- Multi-word command resolution tests ---
func TestResolveCommandShowRunningConfig(t *testing.T) {
resolved, args, err := resolveCommand([]string{"sh", "run"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(args) != 0 {
t.Errorf("unexpected args: %v", args)
}
want := []string{"show", "running-config"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandConfigureTerminal(t *testing.T) {
resolved, _, err := resolveCommand([]string{"conf", "t"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []string{"configure", "terminal"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandShowIPInterfaceBrief(t *testing.T) {
resolved, _, err := resolveCommand([]string{"sh", "ip", "int", "br"}, privilegedExecCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []string{"show", "ip", "interface", "brief"}
if len(resolved) != len(want) {
t.Fatalf("resolved = %v, want %v", resolved, want)
}
for i := range want {
if resolved[i] != want[i] {
t.Errorf("resolved[%d] = %q, want %q", i, resolved[i], want[i])
}
}
}
func TestResolveCommandWithArgs(t *testing.T) {
// "hostname MyRouter" → resolved=["hostname"], args=["MyRouter"]
resolved, args, err := resolveCommand([]string{"hostname", "MyRouter"}, globalConfigCommands)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(resolved) != 1 || resolved[0] != "hostname" {
t.Errorf("resolved = %v, want [hostname]", resolved)
}
if len(args) != 1 || args[0] != "MyRouter" {
t.Errorf("args = %v, want [MyRouter]", args)
}
}
func TestResolveCommandAmbiguous(t *testing.T) {
// In user exec, "e" matches "enable" and "exit" — ambiguous
_, _, err := resolveCommand([]string{"e"}, userExecCommands)
if err == nil {
t.Fatal("expected ambiguous error")
}
}
// --- Mode state machine tests ---
func TestPromptGeneration(t *testing.T) {
tests := []struct {
mode iosMode
want string
}{
{modeUserExec, "Router>"},
{modePrivilegedExec, "Router#"},
{modeGlobalConfig, "Router(config)#"},
{modeInterfaceConfig, "Router(config-if)#"},
}
for _, tt := range tests {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = tt.mode
if got := s.prompt(); got != tt.want {
t.Errorf("prompt(%d) = %q, want %q", tt.mode, got, tt.want)
}
}
}
func TestPromptAfterHostnameChange(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeGlobalConfig
s.dispatch("hostname Switch1")
if s.hostname != "Switch1" {
t.Fatalf("hostname = %q, want %q", s.hostname, "Switch1")
}
if got := s.prompt(); got != "Switch1(config)#" {
t.Errorf("prompt = %q, want %q", got, "Switch1(config)#")
}
}
func TestModeTransitions(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
// Start in user exec.
if s.mode != modeUserExec {
t.Fatalf("initial mode = %d, want %d", s.mode, modeUserExec)
}
// Can't skip to config mode directly from user exec.
result := s.dispatch("configure terminal")
if result.output == "" {
t.Error("expected error for conf t in user exec mode")
}
// Manually set privileged mode (enable tested separately).
s.mode = modePrivilegedExec
// conf t → global config
s.dispatch("configure terminal")
if s.mode != modeGlobalConfig {
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
}
// interface Gi0/0 → interface config
s.dispatch("interface GigabitEthernet0/0")
if s.mode != modeInterfaceConfig {
t.Errorf("mode after interface = %d, want %d", s.mode, modeInterfaceConfig)
}
// exit → back to global config
s.dispatch("exit")
if s.mode != modeGlobalConfig {
t.Errorf("mode after exit from if-config = %d, want %d", s.mode, modeGlobalConfig)
}
// end → back to privileged exec
s.dispatch("end")
if s.mode != modePrivilegedExec {
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
}
// disable → back to user exec
s.dispatch("disable")
if s.mode != modeUserExec {
t.Errorf("mode after disable = %d, want %d", s.mode, modeUserExec)
}
}
func TestEndFromInterfaceConfig(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("end")
if s.mode != modePrivilegedExec {
t.Errorf("mode after end = %d, want %d", s.mode, modePrivilegedExec)
}
if s.currentIf != "" {
t.Errorf("currentIf = %q, want empty", s.currentIf)
}
}
func TestExitFromPrivilegedExec(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
result := s.dispatch("exit")
if !result.exit {
t.Error("expected exit=true from privileged exec exit")
}
}
// --- Show command output tests ---
func TestShowVersionContainsModel(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showVersion(s)
if !contains(output, "C2960") {
t.Error("show version missing model")
}
if !contains(output, "15.0(2)SE11") {
t.Error("show version missing IOS version")
}
if !contains(output, "Router") {
t.Error("show version missing hostname")
}
}
func TestShowRunningConfigContainsInterfaces(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showRunningConfig(s)
if !contains(output, "hostname Router") {
t.Error("running-config missing hostname")
}
if !contains(output, "interface GigabitEthernet0/0") {
t.Error("running-config missing interface")
}
if !contains(output, "ip address 192.168.1.1") {
t.Error("running-config missing IP address")
}
if !contains(output, "line vty") {
t.Error("running-config missing VTY config")
}
}
func TestShowRunningConfigWithEnableSecret(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "secret123")
output := showRunningConfig(s)
if !contains(output, "enable secret") {
t.Error("running-config missing enable secret when password is set")
}
}
func TestShowRunningConfigWithoutEnableSecret(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showRunningConfig(s)
if contains(output, "enable secret") {
t.Error("running-config should not have enable secret when password is empty")
}
}
func TestShowIPInterfaceBrief(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showIPInterfaceBrief(s)
if !contains(output, "GigabitEthernet0/0") {
t.Error("ip interface brief missing GigabitEthernet0/0")
}
if !contains(output, "192.168.1.1") {
t.Error("ip interface brief missing 192.168.1.1")
}
}
func TestShowIPRoute(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
output := showIPRoute(s)
if !contains(output, "directly connected") {
t.Error("ip route missing connected routes")
}
if !contains(output, "0.0.0.0/0") {
t.Error("ip route missing default route")
}
}
func TestShowVLANBrief(t *testing.T) {
output := showVLANBrief()
if !contains(output, "default") {
t.Error("vlan brief missing default vlan")
}
if !contains(output, "MGMT") {
t.Error("vlan brief missing MGMT vlan")
}
}
// --- Interface config tests ---
func TestInterfaceShutdownNoShutdown(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("shutdown")
iface := s.findInterface("GigabitEthernet0/0")
if iface == nil {
t.Fatal("interface not found")
}
if !iface.shutdown {
t.Error("interface should be shutdown")
}
if iface.status != "administratively down" {
t.Errorf("status = %q, want %q", iface.status, "administratively down")
}
s.dispatch("no shutdown")
if iface.shutdown {
t.Error("interface should not be shutdown after no shutdown")
}
if iface.status != "up" {
t.Errorf("status = %q, want %q", iface.status, "up")
}
}
func TestInterfaceIPAddress(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modeInterfaceConfig
s.currentIf = "GigabitEthernet0/0"
s.dispatch("ip address 10.10.10.1 255.255.255.0")
iface := s.findInterface("GigabitEthernet0/0")
if iface == nil {
t.Fatal("interface not found")
}
if iface.ip != "10.10.10.1" {
t.Errorf("ip = %q, want %q", iface.ip, "10.10.10.1")
}
if iface.mask != "255.255.255.0" {
t.Errorf("mask = %q, want %q", iface.mask, "255.255.255.0")
}
}
// --- Dispatch / invalid command tests ---
func TestInvalidCommandInUserExec(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
result := s.dispatch("foobar")
if !contains(result.output, "Invalid input") {
t.Errorf("expected invalid input error, got %q", result.output)
}
}
func TestAmbiguousCommandOutput(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
// "e" in user exec is ambiguous (enable, exit)
result := s.dispatch("e")
if !contains(result.output, "Ambiguous") {
t.Errorf("expected ambiguous error, got %q", result.output)
}
}
func TestHelpCommand(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
result := s.dispatch("?")
if !contains(result.output, "show") {
t.Error("help missing 'show'")
}
if !contains(result.output, "enable") {
t.Error("help missing 'enable'")
}
}
// --- Abbreviation integration tests ---
func TestShowAbbreviationInDispatch(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
result := s.dispatch("sh ver")
if !contains(result.output, "Cisco IOS Software") {
t.Error("'sh ver' should produce version output")
}
}
func TestConfTAbbreviation(t *testing.T) {
s := newIOSState("Router", "C2960", "15.0(2)SE11", "")
s.mode = modePrivilegedExec
s.dispatch("conf t")
if s.mode != modeGlobalConfig {
t.Errorf("mode after conf t = %d, want %d", s.mode, modeGlobalConfig)
}
}
// --- Enable command detection ---
func TestIsEnableCommand(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"enable", true},
{"en", true},
{"ena", true},
{"e", false}, // too short (single char could be other commands)
{"enab", true},
{"ENABLE", true},
{"exit", false},
{"enable 15", false}, // has extra argument
}
for _, tt := range tests {
if got := isEnableCommand(tt.input); got != tt.want {
t.Errorf("isEnableCommand(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
// --- configString tests ---
func TestConfigString(t *testing.T) {
cfg := map[string]any{"hostname": "MySwitch"}
if got := configString(cfg, "hostname", "Router"); got != "MySwitch" {
t.Errorf("configString() = %q, want %q", got, "MySwitch")
}
if got := configString(cfg, "missing", "Default"); got != "Default" {
t.Errorf("configString() for missing = %q, want %q", got, "Default")
}
if got := configString(nil, "key", "Default"); got != "Default" {
t.Errorf("configString(nil) = %q, want %q", got, "Default")
}
}
// --- Helper ---
func TestMaskBits(t *testing.T) {
tests := []struct {
mask string
want int
}{
{"255.255.255.0", 24},
{"255.255.255.252", 30},
{"255.255.0.0", 16},
{"255.0.0.0", 8},
}
for _, tt := range tests {
if got := maskBits(tt.mask); got != tt.want {
t.Errorf("maskBits(%q) = %d, want %d", tt.mask, got, tt.want)
}
}
}
func TestNetworkFromIP(t *testing.T) {
tests := []struct {
ip, mask, want string
}{
{"192.168.1.1", "255.255.255.0", "192.168.1.0"},
{"10.0.0.1", "255.255.255.252", "10.0.0.0"},
{"172.16.5.100", "255.255.0.0", "172.16.0.0"},
}
for _, tt := range tests {
if got := networkFromIP(tt.ip, tt.mask); got != tt.want {
t.Errorf("networkFromIP(%q, %q) = %q, want %q", tt.ip, tt.mask, got, tt.want)
}
}
}
// --- Shell metadata ---
func TestShellNameAndDescription(t *testing.T) {
s := NewCiscoShell()
if s.Name() != "cisco" {
t.Errorf("Name() = %q, want %q", s.Name(), "cisco")
}
if s.Description() == "" {
t.Error("Description() should not be empty")
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && containsHelper(s, substr)
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,414 @@
package cisco
import (
"fmt"
"strings"
)
// commandResult holds the output of a command and whether the session should end.
type commandResult struct {
output string
exit bool
}
// commandEntry defines a single command with its name and optional sub-commands.
type commandEntry struct {
name string
subs []commandEntry // nil for leaf commands
}
// userExecCommands defines the command tree for user EXEC mode.
var userExecCommands = []commandEntry{
{name: "show", subs: []commandEntry{
{name: "version"},
{name: "clock"},
{name: "ip", subs: []commandEntry{
{name: "route"},
{name: "interface", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "interfaces"},
{name: "vlan", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "enable"},
{name: "exit"},
{name: "?"},
}
// privilegedExecCommands extends user commands for privileged mode.
var privilegedExecCommands = []commandEntry{
{name: "show", subs: []commandEntry{
{name: "version"},
{name: "clock"},
{name: "ip", subs: []commandEntry{
{name: "route"},
{name: "interface", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "interfaces"},
{name: "running-config"},
{name: "startup-config"},
{name: "vlan", subs: []commandEntry{
{name: "brief"},
}},
}},
{name: "configure", subs: []commandEntry{
{name: "terminal"},
}},
{name: "write", subs: []commandEntry{
{name: "memory"},
}},
{name: "copy"},
{name: "reload"},
{name: "disable"},
{name: "terminal", subs: []commandEntry{
{name: "length"},
}},
{name: "exit"},
{name: "?"},
}
// globalConfigCommands defines the command tree for global config mode.
var globalConfigCommands = []commandEntry{
{name: "hostname"},
{name: "interface"},
{name: "ip", subs: []commandEntry{
{name: "route"},
}},
{name: "no"},
{name: "end"},
{name: "exit"},
{name: "?"},
}
// interfaceConfigCommands defines the command tree for interface config mode.
var interfaceConfigCommands = []commandEntry{
{name: "ip", subs: []commandEntry{
{name: "address"},
}},
{name: "description"},
{name: "shutdown"},
{name: "no", subs: []commandEntry{
{name: "shutdown"},
}},
{name: "switchport", subs: []commandEntry{
{name: "mode"},
}},
{name: "end"},
{name: "exit"},
{name: "?"},
}
// commandsForMode returns the command tree for the given IOS mode.
func commandsForMode(mode iosMode) []commandEntry {
switch mode {
case modeUserExec:
return userExecCommands
case modePrivilegedExec:
return privilegedExecCommands
case modeGlobalConfig:
return globalConfigCommands
case modeInterfaceConfig:
return interfaceConfigCommands
default:
return userExecCommands
}
}
// resolveAbbreviation attempts to match an abbreviated word against a list of
// command entries. It returns the matched entry name, or an error string if
// ambiguous or unknown.
func resolveAbbreviation(word string, entries []commandEntry) (string, error) {
word = strings.ToLower(word)
var matches []string
for _, e := range entries {
if strings.ToLower(e.name) == word {
return e.name, nil // exact match
}
if strings.HasPrefix(strings.ToLower(e.name), word) {
matches = append(matches, e.name)
}
}
switch len(matches) {
case 0:
return "", fmt.Errorf("unknown")
case 1:
return matches[0], nil
default:
return "", fmt.Errorf("ambiguous")
}
}
// resolveCommand resolves a sequence of abbreviated words into the canonical
// command path (e.g., ["sh", "run"] → ["show", "running-config"]).
// It returns the resolved path, any remaining arguments, and an error if
// resolution fails.
func resolveCommand(words []string, entries []commandEntry) ([]string, []string, error) {
var resolved []string
current := entries
for i, w := range words {
name, err := resolveAbbreviation(w, current)
if err != nil {
if err.Error() == "unknown" && len(resolved) > 0 {
// Remaining words are arguments to the resolved command.
return resolved, words[i:], nil
}
return resolved, words[i:], err
}
resolved = append(resolved, name)
// Find sub-commands for the matched entry.
var nextLevel []commandEntry
for _, e := range current {
if e.name == name {
nextLevel = e.subs
break
}
}
if nextLevel == nil {
// Leaf command — rest are arguments.
return resolved, words[i+1:], nil
}
current = nextLevel
}
return resolved, nil, nil
}
// dispatch processes a command line in the context of the current IOS state.
func (s *iosState) dispatch(input string) commandResult {
words := strings.Fields(input)
if len(words) == 0 {
return commandResult{}
}
// Handle "?" as a help request.
if words[0] == "?" {
return s.cmdHelp()
}
cmds := commandsForMode(s.mode)
resolved, args, err := resolveCommand(words, cmds)
if err != nil {
if err.Error() == "ambiguous" {
return commandResult{output: fmt.Sprintf("%% Ambiguous command: \"%s\"", input)}
}
return commandResult{output: invalidInput(input)}
}
if len(resolved) == 0 {
return commandResult{output: invalidInput(input)}
}
cmd := strings.Join(resolved, " ")
switch s.mode {
case modeUserExec:
return s.dispatchUserExec(cmd, args)
case modePrivilegedExec:
return s.dispatchPrivilegedExec(cmd, args)
case modeGlobalConfig:
return s.dispatchGlobalConfig(cmd, args)
case modeInterfaceConfig:
return s.dispatchInterfaceConfig(cmd, args)
}
return commandResult{output: invalidInput(input)}
}
func (s *iosState) dispatchUserExec(cmd string, args []string) commandResult {
switch cmd {
case "show version":
return commandResult{output: showVersion(s)}
case "show clock":
return commandResult{output: showClock()}
case "show ip route":
return commandResult{output: showIPRoute(s)}
case "show ip interface brief":
return commandResult{output: showIPInterfaceBrief(s)}
case "show interfaces":
return commandResult{output: showInterfaces(s)}
case "show vlan brief":
return commandResult{output: showVLANBrief()}
case "enable":
return commandResult{} // handled in Handle() loop
case "exit":
return commandResult{exit: true}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchPrivilegedExec(cmd string, args []string) commandResult {
switch cmd {
case "show version":
return commandResult{output: showVersion(s)}
case "show clock":
return commandResult{output: showClock()}
case "show ip route":
return commandResult{output: showIPRoute(s)}
case "show ip interface brief":
return commandResult{output: showIPInterfaceBrief(s)}
case "show interfaces":
return commandResult{output: showInterfaces(s)}
case "show running-config":
return commandResult{output: showRunningConfig(s)}
case "show startup-config":
return commandResult{output: showRunningConfig(s)} // same as running
case "show vlan brief":
return commandResult{output: showVLANBrief()}
case "configure terminal":
s.mode = modeGlobalConfig
return commandResult{output: "Enter configuration commands, one per line. End with CNTL/Z."}
case "write memory":
return commandResult{output: "[OK]"}
case "copy":
return commandResult{output: "[OK]"}
case "reload":
return commandResult{output: "System configuration has been modified. Save? [yes/no]: ", exit: true}
case "disable":
s.mode = modeUserExec
return commandResult{}
case "terminal length":
return commandResult{} // accept silently
case "exit":
return commandResult{exit: true}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchGlobalConfig(cmd string, args []string) commandResult {
switch cmd {
case "hostname":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
s.hostname = args[0]
return commandResult{}
case "interface":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
ifName := strings.Join(args, "")
s.currentIf = ifName
s.mode = modeInterfaceConfig
return commandResult{}
case "ip route":
return commandResult{} // accept silently
case "no":
return commandResult{} // accept silently
case "end":
s.mode = modePrivilegedExec
return commandResult{}
case "exit":
s.mode = modePrivilegedExec
return commandResult{}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) dispatchInterfaceConfig(cmd string, args []string) commandResult {
switch cmd {
case "ip address":
if len(args) < 2 {
return commandResult{output: "% Incomplete command."}
}
if iface := s.findInterface(s.currentIf); iface != nil {
iface.ip = args[0]
iface.mask = args[1]
}
return commandResult{}
case "description":
if len(args) < 1 {
return commandResult{output: "% Incomplete command."}
}
if iface := s.findInterface(s.currentIf); iface != nil {
iface.desc = strings.Join(args, " ")
}
return commandResult{}
case "shutdown":
if iface := s.findInterface(s.currentIf); iface != nil {
iface.shutdown = true
iface.status = "administratively down"
iface.protocol = "down"
}
return commandResult{}
case "no shutdown":
if iface := s.findInterface(s.currentIf); iface != nil {
iface.shutdown = false
iface.status = "up"
iface.protocol = "up"
}
return commandResult{}
case "switchport mode":
return commandResult{} // accept silently
case "end":
s.mode = modePrivilegedExec
s.currentIf = ""
return commandResult{}
case "exit":
s.mode = modeGlobalConfig
s.currentIf = ""
return commandResult{}
}
return commandResult{output: invalidInput(cmd)}
}
func (s *iosState) cmdHelp() commandResult {
cmds := commandsForMode(s.mode)
var b strings.Builder
for _, e := range cmds {
if e.name == "?" {
continue
}
b.WriteString(fmt.Sprintf(" %-20s %s\n", e.name, helpText(e.name)))
}
return commandResult{output: b.String()}
}
func helpText(name string) string {
switch name {
case "show":
return "Show running system information"
case "enable":
return "Turn on privileged commands"
case "disable":
return "Turn off privileged commands"
case "exit":
return "Exit from the EXEC"
case "configure":
return "Enter configuration mode"
case "write":
return "Write running configuration to memory"
case "copy":
return "Copy from one file to another"
case "reload":
return "Halt and perform a cold restart"
case "terminal":
return "Set terminal line parameters"
case "hostname":
return "Set system's network name"
case "interface":
return "Select an interface to configure"
case "ip":
return "Global IP configuration subcommands"
case "no":
return "Negate a command or set its defaults"
case "end":
return "Exit from configure mode"
case "description":
return "Interface specific description"
case "shutdown":
return "Shutdown the selected interface"
case "switchport":
return "Set switching mode characteristics"
default:
return ""
}
}
func invalidInput(input string) string {
return fmt.Sprintf("%% Invalid input detected at '^' marker.\n\n%s\n^", input)
}

View File

@@ -0,0 +1,234 @@
package cisco
import (
"fmt"
"math/rand"
"strings"
"time"
)
func showVersion(s *iosState) string {
days := 14 + rand.Intn(350)
hours := rand.Intn(24)
mins := rand.Intn(60)
return fmt.Sprintf(`Cisco IOS Software, %s Software (%s-UNIVERSALK9-M), Version %s, RELEASE SOFTWARE (fc3)
Technical Support: http://www.cisco.com/techsupport
Copyright (c) 1986-2019 by Cisco Systems, Inc.
Compiled Thu 30-Jan-19 10:08 by prod_rel_team
ROM: Bootstrap program is %s boot loader
BOOTLDR: %s Boot Loader (C2960-HBOOT-M) Version 15.0(2r)SE, RELEASE SOFTWARE (fc1)
%s uptime is %d days, %d hours, %d minutes
System returned to ROM by power-on
System image file is "flash:/%s-universalk9-mz.SPA.%s.bin"
This product contains cryptographic features and is subject to United States
and local country laws governing import, export, transfer and use.
cisco %s (%s) processor (revision K0) with 524288K bytes of memory.
Processor board ID %s
Last reset from power-on
2 Gigabit Ethernet interfaces
1 Virtual Ethernet interface
64K bytes of flash-simulated non-volatile configuration memory.
Total of 65536K bytes of APC System Flash (Read/Write)
Configuration register is 0x2102`,
s.model, s.model, s.iosVersion,
s.model, s.model,
s.hostname, days, hours, mins,
s.model, s.iosVersion,
s.model, processorForModel(s.model),
s.serial,
)
}
func processorForModel(model string) string {
if strings.HasPrefix(model, "C29") {
return "PowerPC405"
}
return "MIPS"
}
func showClock() string {
now := time.Now().UTC()
return fmt.Sprintf("*%s UTC", now.Format("15:04:05.000 Mon Jan 2 2006"))
}
func showIPRoute(s *iosState) string {
var b strings.Builder
b.WriteString("Codes: C - connected, S - static, R - RIP, M - mobile, B - BGP\n")
b.WriteString(" D - EIGRP, EX - EIGRP external, O - OSPF, IA - OSPF inter area\n")
b.WriteString(" N1 - OSPF NSSA external type 1, N2 - OSPF NSSA external type 2\n")
b.WriteString(" E1 - OSPF external type 1, E2 - OSPF external type 2\n")
b.WriteString(" i - IS-IS, su - IS-IS summary, L1 - IS-IS level-1, L2 - IS-IS level-2\n")
b.WriteString(" ia - IS-IS inter area, * - candidate default, U - per-user static route\n")
b.WriteString(" o - ODR, P - periodic downloaded static route\n\n")
b.WriteString("Gateway of last resort is 10.0.0.2 to network 0.0.0.0\n\n")
for _, iface := range s.interfaces {
if iface.ip == "unassigned" || iface.status != "up" {
continue
}
network := networkFromIP(iface.ip, iface.mask)
maskBits := maskBits(iface.mask)
fmt.Fprintf(&b, "C %s/%d is directly connected, %s\n", network, maskBits, iface.name)
}
b.WriteString("S* 0.0.0.0/0 [1/0] via 10.0.0.2")
return b.String()
}
func showIPInterfaceBrief(s *iosState) string {
var b strings.Builder
fmt.Fprintf(&b, "%-25s %-15s %-4s %-7s %-22s %s\n",
"Interface", "IP-Address", "OK?", "Method", "Status", "Protocol")
for _, iface := range s.interfaces {
ip := iface.ip
if ip == "" {
ip = "unassigned"
}
fmt.Fprintf(&b, "%-25s %-15s YES manual %-22s %s\n",
iface.name, ip, iface.status, iface.protocol)
}
return b.String()
}
func showInterfaces(s *iosState) string {
var b strings.Builder
for i, iface := range s.interfaces {
if i > 0 {
b.WriteString("\n")
}
upDown := "up"
if iface.shutdown {
upDown = "administratively down"
}
fmt.Fprintf(&b, "%s is %s, line protocol is %s\n", iface.name, upDown, iface.protocol)
fmt.Fprintf(&b, " Hardware is Gigabit Ethernet, address is %s (bia %s)\n", iface.mac, iface.mac)
if iface.ip != "unassigned" && iface.ip != "" {
fmt.Fprintf(&b, " Internet address is %s/%d\n", iface.ip, maskBits(iface.mask))
}
fmt.Fprintf(&b, " MTU %d bytes, BW %s sec, DLY 10 usec,\n", iface.mtu, iface.bandwidth)
b.WriteString(" reliability 255/255, txload 1/255, rxload 1/255\n")
b.WriteString(" Encapsulation ARPA, loopback not set\n")
fmt.Fprintf(&b, " %d packets input, %d bytes, 0 no buffer\n", iface.rxPackets, iface.rxBytes)
fmt.Fprintf(&b, " %d packets output, %d bytes, 0 underruns", iface.txPackets, iface.txBytes)
}
return b.String()
}
func showRunningConfig(s *iosState) string {
var b strings.Builder
b.WriteString("Building configuration...\n\n")
b.WriteString("Current configuration : 1482 bytes\n")
b.WriteString("!\n")
b.WriteString("! Last configuration change at 14:32:22 UTC Mon Feb 10 2025\n")
b.WriteString("!\n")
b.WriteString("version 15.0\n")
b.WriteString("service timestamps debug datetime msec\n")
b.WriteString("service timestamps log datetime msec\n")
b.WriteString("no service password-encryption\n")
b.WriteString("!\n")
fmt.Fprintf(&b, "hostname %s\n", s.hostname)
b.WriteString("!\n")
b.WriteString("boot-start-marker\n")
b.WriteString("boot-end-marker\n")
b.WriteString("!\n")
if s.enablePass != "" {
b.WriteString("enable secret 5 $1$mERr$hx5rVt7rPNoS4wqbXKX7m0\n")
}
b.WriteString("!\n")
b.WriteString("no aaa new-model\n")
b.WriteString("!\n")
for _, iface := range s.interfaces {
b.WriteString("!\n")
fmt.Fprintf(&b, "interface %s\n", iface.name)
if iface.desc != "" {
fmt.Fprintf(&b, " description %s\n", iface.desc)
}
if iface.ip != "unassigned" && iface.ip != "" {
fmt.Fprintf(&b, " ip address %s %s\n", iface.ip, iface.mask)
} else {
b.WriteString(" no ip address\n")
}
if iface.shutdown {
b.WriteString(" shutdown\n")
}
}
b.WriteString("!\n")
b.WriteString("ip forward-protocol nd\n")
b.WriteString("!\n")
b.WriteString("ip route 0.0.0.0 0.0.0.0 10.0.0.2\n")
b.WriteString("!\n")
b.WriteString("access-list 10 permit 192.168.1.0 0.0.0.255\n")
b.WriteString("access-list 10 deny any\n")
b.WriteString("!\n")
b.WriteString("line con 0\n")
b.WriteString(" logging synchronous\n")
b.WriteString("line vty 0 4\n")
b.WriteString(" login local\n")
b.WriteString(" transport input ssh\n")
b.WriteString("!\n")
b.WriteString("end")
return b.String()
}
func showVLANBrief() string {
var b strings.Builder
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "VLAN", "Name", "Status", "Ports")
b.WriteString("---- -------------------------------- --------- -------------------------------\n")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1", "default", "active", "Gi0/0, Gi0/1, Gi0/2")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "10", "MGMT", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "20", "USERS", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "99", "NATIVE", "active", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1002", "fddi-default", "act/unsup", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s\n", "1003", "token-ring-default", "act/unsup", "")
fmt.Fprintf(&b, "%-6s %-32s %-10s %s", "1004", "fddinet-default", "act/unsup", "")
return b.String()
}
// networkFromIP derives the network address from an IP and mask.
func networkFromIP(ip, mask string) string {
ipParts := parseIPv4(ip)
maskParts := parseIPv4(mask)
if ipParts == nil || maskParts == nil {
return ip
}
return fmt.Sprintf("%d.%d.%d.%d",
ipParts[0]&maskParts[0],
ipParts[1]&maskParts[1],
ipParts[2]&maskParts[2],
ipParts[3]&maskParts[3],
)
}
func maskBits(mask string) int {
parts := parseIPv4(mask)
if parts == nil {
return 24
}
bits := 0
for _, p := range parts {
for i := 7; i >= 0; i-- {
if p&(1<<uint(i)) != 0 {
bits++
} else {
return bits
}
}
}
return bits
}
func parseIPv4(s string) []int {
var a, b, c, d int
n, _ := fmt.Sscanf(s, "%d.%d.%d.%d", &a, &b, &c, &d)
if n != 4 {
return nil
}
return []int{a, b, c, d}
}

View File

@@ -0,0 +1,109 @@
package cisco
import "fmt"
// iosMode represents the current CLI mode of the IOS state machine.
type iosMode int
const (
modeUserExec iosMode = iota // Router>
modePrivilegedExec // Router#
modeGlobalConfig // Router(config)#
modeInterfaceConfig // Router(config-if)#
)
// ifaceInfo holds interface metadata for show commands.
type ifaceInfo struct {
name string
ip string
mask string
status string
protocol string
mac string
bandwidth string
mtu int
rxPackets int
txPackets int
rxBytes int
txBytes int
shutdown bool
desc string
}
// iosState holds all mutable state for the Cisco IOS shell session.
type iosState struct {
mode iosMode
hostname string
model string
iosVersion string
serial string
enablePass string
interfaces []ifaceInfo
currentIf string
}
func newIOSState(hostname, model, iosVersion, enablePass string) *iosState {
return &iosState{
mode: modeUserExec,
hostname: hostname,
model: model,
iosVersion: iosVersion,
serial: "FTX1524Z0P3",
enablePass: enablePass,
interfaces: defaultInterfaces(),
}
}
func defaultInterfaces() []ifaceInfo {
return []ifaceInfo{
{
name: "GigabitEthernet0/0", ip: "192.168.1.1", mask: "255.255.255.0",
status: "up", protocol: "up", mac: "0050.7966.6800",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 148253, txPackets: 93127, rxBytes: 19284732, txBytes: 8291043,
},
{
name: "GigabitEthernet0/1", ip: "10.0.0.1", mask: "255.255.255.252",
status: "up", protocol: "up", mac: "0050.7966.6801",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 52104, txPackets: 48891, rxBytes: 4182934, txBytes: 3901284,
},
{
name: "GigabitEthernet0/2", ip: "unassigned", mask: "",
status: "administratively down", protocol: "down", mac: "0050.7966.6802",
bandwidth: "1000000 Kbit", mtu: 1500, shutdown: true,
},
{
name: "Vlan1", ip: "172.16.0.1", mask: "255.255.0.0",
status: "up", protocol: "up", mac: "0050.7966.6810",
bandwidth: "1000000 Kbit", mtu: 1500,
rxPackets: 8421, txPackets: 7103, rxBytes: 512384, txBytes: 423901,
},
}
}
// prompt returns the IOS prompt string for the current mode.
func (s *iosState) prompt() string {
switch s.mode {
case modeUserExec:
return fmt.Sprintf("%s>", s.hostname)
case modePrivilegedExec:
return fmt.Sprintf("%s#", s.hostname)
case modeGlobalConfig:
return fmt.Sprintf("%s(config)#", s.hostname)
case modeInterfaceConfig:
return fmt.Sprintf("%s(config-if)#", s.hostname)
default:
return fmt.Sprintf("%s>", s.hostname)
}
}
// findInterface returns a pointer to the interface with the given name, or nil.
func (s *iosState) findInterface(name string) *ifaceInfo {
for i := range s.interfaces {
if s.interfaces[i].name == name {
return &s.interfaces[i]
}
}
return nil
}

View File

@@ -6,7 +6,7 @@ import (
"sync"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// EventRecorder buffers I/O events in memory and periodically flushes them to

View File

@@ -6,7 +6,7 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func TestEventRecorderFlush(t *testing.T) {
@@ -14,7 +14,7 @@ func TestEventRecorderFlush(t *testing.T) {
ctx := context.Background()
// Create a session so events have a valid session ID.
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -55,7 +55,7 @@ func TestEventRecorderPeriodicFlush(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}

View File

@@ -8,7 +8,7 @@ import (
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
@@ -69,6 +69,9 @@ func (f *FridgeShell) Handle(ctx context.Context, sess *shell.SessionContext, rw
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("fridge")
}
if result.exit {
return nil

View File

@@ -8,8 +8,8 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
@@ -22,7 +22,7 @@ func (r *rwCloser) Close() error { return nil }
func runShell(t *testing.T, commands string) string {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
sess := &shell.SessionContext{
SessionID: sessID,
@@ -205,7 +205,7 @@ func TestLogoutCommand(t *testing.T) {
func TestSessionLogs(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge")
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "fridge", "")
sess := &shell.SessionContext{
SessionID: sessID,

View File

@@ -0,0 +1,123 @@
package psql
import (
"fmt"
"strings"
"time"
)
// commandResult holds the output of a command and whether the session should end.
type commandResult struct {
output string
exit bool
}
// dispatchBackslash handles psql backslash meta-commands.
func dispatchBackslash(cmd, dbName string) commandResult {
// Normalize: trim spaces after the backslash command word.
parts := strings.Fields(cmd)
if len(parts) == 0 {
return commandResult{output: "Invalid command \\. Try \\? for help."}
}
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
args := parts[1:]
switch verb {
case `\q`:
return commandResult{exit: true}
case `\dt`:
return commandResult{output: listTables()}
case `\d`:
if len(args) == 0 {
return commandResult{output: listTables()}
}
return commandResult{output: describeTable(args[0])}
case `\l`:
return commandResult{output: listDatabases()}
case `\du`:
return commandResult{output: listRoles()}
case `\conninfo`:
return commandResult{output: connInfo(dbName)}
case `\?`:
return commandResult{output: backslashHelp()}
case `\h`:
return commandResult{output: sqlHelp()}
default:
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
}
}
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
// Strip trailing semicolon and whitespace for matching.
trimmed := strings.TrimRight(sql, "; \t")
trimmed = strings.TrimSpace(trimmed)
upper := strings.ToUpper(trimmed)
switch {
case upper == "SELECT VERSION()":
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
return commandResult{output: formatSingleValue("version", ver)}
case upper == "SELECT CURRENT_DATABASE()":
return commandResult{output: formatSingleValue("current_database", dbName)}
case upper == "SELECT CURRENT_USER":
return commandResult{output: formatSingleValue("current_user", "postgres")}
case upper == "SELECT NOW()":
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
return commandResult{output: formatSingleValue("now", now)}
case upper == "SELECT 1":
return commandResult{output: formatSingleValue("?column?", "1")}
case strings.HasPrefix(upper, "INSERT"):
return commandResult{output: "INSERT 0 1"}
case strings.HasPrefix(upper, "UPDATE"):
return commandResult{output: "UPDATE 1"}
case strings.HasPrefix(upper, "DELETE"):
return commandResult{output: "DELETE 1"}
case strings.HasPrefix(upper, "CREATE TABLE"):
return commandResult{output: "CREATE TABLE"}
case strings.HasPrefix(upper, "CREATE DATABASE"):
return commandResult{output: "CREATE DATABASE"}
case strings.HasPrefix(upper, "DROP TABLE"):
return commandResult{output: "DROP TABLE"}
case strings.HasPrefix(upper, "ALTER TABLE"):
return commandResult{output: "ALTER TABLE"}
case upper == "BEGIN":
return commandResult{output: "BEGIN"}
case upper == "COMMIT":
return commandResult{output: "COMMIT"}
case upper == "ROLLBACK":
return commandResult{output: "ROLLBACK"}
case upper == "SHOW SERVER_VERSION":
return commandResult{output: formatSingleValue("server_version", pgVersion)}
case upper == "SHOW SEARCH_PATH":
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
case strings.HasPrefix(upper, "SET "):
return commandResult{output: "SET"}
default:
// Extract the first token for the error message.
firstToken := strings.Fields(trimmed)
token := trimmed
if len(firstToken) > 0 {
token = firstToken[0]
}
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
}
}
// formatSingleValue formats a single-row, single-column psql result.
func formatSingleValue(colName, value string) string {
width := max(len(colName), len(value))
var b strings.Builder
// Header
fmt.Fprintf(&b, " %-*s \n", width, colName)
// Separator
b.WriteString(strings.Repeat("-", width+2))
b.WriteString("\n")
// Value
fmt.Fprintf(&b, " %-*s\n", width, value)
// Row count
b.WriteString("(1 row)")
return b.String()
}

View File

@@ -0,0 +1,155 @@
package psql
import "fmt"
func startupBanner(version string) string {
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
}
func listTables() string {
return ` List of relations
Schema | Name | Type | Owner
--------+---------------+-------+----------
public | audit_log | table | postgres
public | credentials | table | postgres
public | sessions | table | postgres
public | users | table | postgres
(4 rows)`
}
func listDatabases() string {
return ` List of databases
Name | Owner | Encoding | Collate | Ctype | Access privileges
-----------+----------+----------+-------------+-------------+-----------------------
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
| | | | | postgres=CTc/postgres
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
| | | | | postgres=CTc/postgres
(4 rows)`
}
func listRoles() string {
return ` List of roles
Role name | Attributes | Member of
-----------+------------------------------------------------------------+-----------
app_user | | {}
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
readonly | Cannot login | {}`
}
func describeTable(name string) string {
switch name {
case "users":
return ` Table "public.users"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+-----------------------------------
id | integer | | not null | nextval('users_id_seq'::regclass)
username | character varying(255) | | not null |
email | character varying(255) | | not null |
password | character varying(255) | | not null |
created_at | timestamp without time zone | | | now()
updated_at | timestamp without time zone | | | now()
Indexes:
"users_pkey" PRIMARY KEY, btree (id)
"users_email_key" UNIQUE, btree (email)
"users_username_key" UNIQUE, btree (username)`
case "sessions":
return ` Table "public.sessions"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+--------------------------------------
id | integer | | not null | nextval('sessions_id_seq'::regclass)
user_id | integer | | |
token | character varying(255) | | not null |
ip_address | inet | | |
created_at | timestamp without time zone | | | now()
expires_at | timestamp without time zone | | not null |
Indexes:
"sessions_pkey" PRIMARY KEY, btree (id)
"sessions_token_key" UNIQUE, btree (token)
Foreign-key constraints:
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
case "credentials":
return ` Table "public.credentials"
Column | Type | Collation | Nullable | Default
-----------+-----------------------------+-----------+----------+-----------------------------------------
id | integer | | not null | nextval('credentials_id_seq'::regclass)
user_id | integer | | |
type | character varying(50) | | not null |
value | text | | not null |
created_at| timestamp without time zone | | | now()
Indexes:
"credentials_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
case "audit_log":
return ` Table "public.audit_log"
Column | Type | Collation | Nullable | Default
------------+-----------------------------+-----------+----------+---------------------------------------
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
user_id | integer | | |
action | character varying(100) | | not null |
details | text | | |
ip_address | inet | | |
created_at | timestamp without time zone | | | now()
Indexes:
"audit_log_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
default:
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
}
}
func connInfo(dbName string) string {
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
}
func backslashHelp() string {
return `General
\copyright show PostgreSQL usage and distribution terms
\crosstabview [COLUMNS] execute query and display result in crosstab
\errverbose show most recent error message at maximum verbosity
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
\gdesc describe result of query, without executing it
\gexec execute query, then execute each value in its result
\gset [PREFIX] execute query and store result in psql variables
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
\q quit psql
\watch [SEC] execute query every SEC seconds
Informational
(options: S = show system objects, + = additional detail)
\d[S+] list tables, views, and sequences
\d[S+] NAME describe table, view, sequence, or index
\da[S] [PATTERN] list aggregates
\dA[+] [PATTERN] list access methods
\dt[S+] [PATTERN] list tables
\du[S+] [PATTERN] list roles
\l[+] [PATTERN] list databases`
}
func sqlHelp() string {
return `Available help:
ABORT CREATE LANGUAGE
ALTER AGGREGATE CREATE MATERIALIZED VIEW
ALTER COLLATION CREATE OPERATOR
ALTER CONVERSION CREATE POLICY
ALTER DATABASE CREATE PROCEDURE
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
ALTER DOMAIN CREATE ROLE
ALTER EVENT TRIGGER CREATE RULE
ALTER EXTENSION CREATE SCHEMA
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
ALTER FOREIGN TABLE CREATE SERVER
ALTER FUNCTION CREATE STATISTICS
ALTER GROUP CREATE SUBSCRIPTION
ALTER INDEX CREATE TABLE
ALTER LANGUAGE CREATE TABLESPACE
BEGIN DELETE
COMMIT DROP TABLE
CREATE DATABASE INSERT
CREATE INDEX ROLLBACK
SELECT UPDATE`
}

137
internal/shell/psql/psql.go Normal file
View File

@@ -0,0 +1,137 @@
package psql
import (
"context"
"errors"
"fmt"
"io"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// PsqlShell emulates a PostgreSQL psql interactive terminal.
type PsqlShell struct{}
// NewPsqlShell returns a new PsqlShell instance.
func NewPsqlShell() *PsqlShell {
return &PsqlShell{}
}
func (p *PsqlShell) Name() string { return "psql" }
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
dbName := configString(sess.ShellConfig, "db_name", "postgres")
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
// Print startup banner.
fmt.Fprint(rw, startupBanner(pgVersion))
var sqlBuf []string // accumulates multi-line SQL
for {
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
// Empty line in non-buffering state: just re-prompt.
if trimmed == "" && len(sqlBuf) == 0 {
continue
}
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
if strings.HasPrefix(trimmed, `\`) {
sqlBuf = nil // discard any partial SQL
result := dispatchBackslash(trimmed, dbName)
if result.output != "" {
output := strings.ReplaceAll(result.output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("psql")
}
if result.exit {
return nil
}
continue
}
// Accumulate SQL lines.
sqlBuf = append(sqlBuf, line)
// Check if the statement is terminated by a semicolon.
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
continue
}
// Full statement ready — join and dispatch.
fullSQL := strings.Join(sqlBuf, " ")
sqlBuf = nil
result := dispatchSQL(fullSQL, dbName, pgVersion)
if result.output != "" {
output := strings.ReplaceAll(result.output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("psql")
}
if result.exit {
return nil
}
}
}
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
func buildPrompt(dbName string, continuation bool) string {
if continuation {
return dbName + "-# "
}
return dbName + "=# "
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,330 @@
package psql
import (
"strings"
"testing"
)
// --- Prompt tests ---
func TestBuildPromptNormal(t *testing.T) {
got := buildPrompt("postgres", false)
if got != "postgres=# " {
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
}
}
func TestBuildPromptContinuation(t *testing.T) {
got := buildPrompt("postgres", true)
if got != "postgres-# " {
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
}
}
func TestBuildPromptCustomDB(t *testing.T) {
got := buildPrompt("mydb", false)
if got != "mydb=# " {
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
}
}
// --- Backslash command dispatch tests ---
func TestBackslashQuit(t *testing.T) {
result := dispatchBackslash(`\q`, "postgres")
if !result.exit {
t.Error("\\q should set exit=true")
}
}
func TestBackslashListTables(t *testing.T) {
result := dispatchBackslash(`\dt`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\dt should list tables including 'users'")
}
if !strings.Contains(result.output, "sessions") {
t.Error("\\dt should list tables including 'sessions'")
}
}
func TestBackslashDescribeTable(t *testing.T) {
result := dispatchBackslash(`\d users`, "postgres")
if !strings.Contains(result.output, "username") {
t.Error("\\d users should describe users table with 'username' column")
}
if !strings.Contains(result.output, "PRIMARY KEY") {
t.Error("\\d users should include index info")
}
}
func TestBackslashDescribeUnknownTable(t *testing.T) {
result := dispatchBackslash(`\d nonexistent`, "postgres")
if !strings.Contains(result.output, "Did not find") {
t.Error("\\d nonexistent should return not found message")
}
}
func TestBackslashListDatabases(t *testing.T) {
result := dispatchBackslash(`\l`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\l should list databases including 'postgres'")
}
if !strings.Contains(result.output, "template0") {
t.Error("\\l should list databases including 'template0'")
}
}
func TestBackslashListRoles(t *testing.T) {
result := dispatchBackslash(`\du`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\du should list roles including 'postgres'")
}
if !strings.Contains(result.output, "Superuser") {
t.Error("\\du should show Superuser attribute for postgres")
}
}
func TestBackslashConnInfo(t *testing.T) {
result := dispatchBackslash(`\conninfo`, "mydb")
if !strings.Contains(result.output, "mydb") {
t.Error("\\conninfo should include database name")
}
if !strings.Contains(result.output, "5432") {
t.Error("\\conninfo should include port")
}
}
func TestBackslashHelp(t *testing.T) {
result := dispatchBackslash(`\?`, "postgres")
if !strings.Contains(result.output, `\q`) {
t.Error("\\? should include \\q in help output")
}
}
func TestBackslashSQLHelp(t *testing.T) {
result := dispatchBackslash(`\h`, "postgres")
if !strings.Contains(result.output, "SELECT") {
t.Error("\\h should include SQL commands like SELECT")
}
}
func TestBackslashUnknown(t *testing.T) {
result := dispatchBackslash(`\xyz`, "postgres")
if !strings.Contains(result.output, "Invalid command") {
t.Error("unknown backslash command should return error")
}
}
// --- SQL dispatch tests ---
func TestSQLSelectVersion(t *testing.T) {
result := dispatchSQL("SELECT version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SELECT version() should contain pg version")
}
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT version() should show row count")
}
}
func TestSQLSelectCurrentDatabase(t *testing.T) {
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("SELECT current_database() should return db name")
}
}
func TestSQLSelectCurrentUser(t *testing.T) {
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
if !strings.Contains(result.output, "postgres") {
t.Error("SELECT current_user should return postgres")
}
}
func TestSQLSelectNow(t *testing.T) {
result := dispatchSQL("SELECT now();", "postgres", "15.4")
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT now() should show row count")
}
}
func TestSQLSelectOne(t *testing.T) {
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
if !strings.Contains(result.output, "1") {
t.Error("SELECT 1 should return 1")
}
}
func TestSQLInsert(t *testing.T) {
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
if result.output != "INSERT 0 1" {
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
}
}
func TestSQLUpdate(t *testing.T) {
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
if result.output != "UPDATE 1" {
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
}
}
func TestSQLDelete(t *testing.T) {
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
if result.output != "DELETE 1" {
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
}
}
func TestSQLCreateTable(t *testing.T) {
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
if result.output != "CREATE TABLE" {
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
}
}
func TestSQLCreateDatabase(t *testing.T) {
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
if result.output != "CREATE DATABASE" {
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
}
}
func TestSQLDropTable(t *testing.T) {
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
if result.output != "DROP TABLE" {
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
}
}
func TestSQLAlterTable(t *testing.T) {
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
if result.output != "ALTER TABLE" {
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
}
}
func TestSQLBeginCommitRollback(t *testing.T) {
tests := []struct {
sql string
want string
}{
{"BEGIN;", "BEGIN"},
{"COMMIT;", "COMMIT"},
{"ROLLBACK;", "ROLLBACK"},
}
for _, tt := range tests {
result := dispatchSQL(tt.sql, "postgres", "15.4")
if result.output != tt.want {
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
}
}
}
func TestSQLShowServerVersion(t *testing.T) {
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SHOW server_version should contain version")
}
}
func TestSQLShowSearchPath(t *testing.T) {
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
if !strings.Contains(result.output, "public") {
t.Error("SHOW search_path should contain public")
}
}
func TestSQLSet(t *testing.T) {
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
if result.output != "SET" {
t.Errorf("SET output = %q, want %q", result.output, "SET")
}
}
func TestSQLUnrecognized(t *testing.T) {
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
if !strings.Contains(result.output, "ERROR") {
t.Error("unrecognized SQL should return error")
}
if !strings.Contains(result.output, "FOOBAR") {
t.Error("error should reference the offending token")
}
}
// --- Case insensitivity ---
func TestSQLCaseInsensitive(t *testing.T) {
result := dispatchSQL("select version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("select version() (lowercase) should work")
}
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("mixed case SELECT should work")
}
}
// --- Startup banner ---
func TestStartupBanner(t *testing.T) {
banner := startupBanner("15.4")
if !strings.Contains(banner, "psql (15.4)") {
t.Errorf("banner should contain version, got: %s", banner)
}
if !strings.Contains(banner, "help") {
t.Error("banner should mention help")
}
}
// --- configString ---
func TestConfigString(t *testing.T) {
cfg := map[string]any{"db_name": "mydb"}
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
t.Errorf("configString() = %q, want %q", got, "mydb")
}
if got := configString(cfg, "missing", "default"); got != "default" {
t.Errorf("configString() for missing = %q, want %q", got, "default")
}
if got := configString(nil, "key", "default"); got != "default" {
t.Errorf("configString(nil) = %q, want %q", got, "default")
}
}
// --- Shell metadata ---
func TestShellNameAndDescription(t *testing.T) {
s := NewPsqlShell()
if s.Name() != "psql" {
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
}
if s.Description() == "" {
t.Error("Description() should not be empty")
}
}
// --- formatSingleValue ---
func TestFormatSingleValue(t *testing.T) {
out := formatSingleValue("?column?", "1")
if !strings.Contains(out, "?column?") {
t.Error("should contain column name")
}
if !strings.Contains(out, "1") {
t.Error("should contain value")
}
if !strings.Contains(out, "(1 row)") {
t.Error("should contain row count")
}
}
// --- \d with no args ---
func TestBackslashDescribeNoArgs(t *testing.T) {
result := dispatchBackslash(`\d`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\d with no args should list tables")
}
}

View File

@@ -0,0 +1,463 @@
package roomba
import (
"context"
"errors"
"fmt"
"io"
"slices"
"strings"
"time"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
type RoombaShell struct{}
// NewRoombaShell returns a new RoombaShell instance.
func NewRoombaShell() *RoombaShell {
return &RoombaShell{}
}
func (r *RoombaShell) Name() string { return "roomba" }
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
state := newRoombaState()
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
fmt.Fprint(rw, banner)
for {
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
return nil
}
line, err := shell.ReadLine(ctx, rw)
if errors.Is(err, io.EOF) {
fmt.Fprint(rw, "logout\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := state.dispatch(trimmed)
var output string
if result.output != "" {
output = result.output
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
if sess.Store != nil {
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
return fmt.Errorf("append session log: %w", err)
}
}
if sess.OnCommand != nil {
sess.OnCommand("roomba")
}
if result.exit {
return nil
}
}
}
func bootBanner() string {
return `
____ _ ___ ____
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
iRobot Roomba j7+ | RoombaOS v4.3.7
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
Battery: 73% | WiFi: Connected (SmartHome-5G)
Type 'help' for available commands.
`
}
type room struct {
name string
areaSqFt int
lastCleaned time.Time
}
type scheduleEntry struct {
day string
time string
}
type historyEntry struct {
timestamp time.Time
room string
duration string
note string
}
type roombaState struct {
battery int
dustbin int
status string
rooms []room
schedule []scheduleEntry
cleanHistory []historyEntry
}
type commandResult struct {
output string
exit bool
}
func newRoombaState() *roombaState {
now := time.Now()
return &roombaState{
battery: 73,
dustbin: 61,
status: "Docked",
rooms: []room{
{"Kitchen", 180, now.Add(-2 * time.Hour)},
{"Living Room", 320, now.Add(-5 * time.Hour)},
{"Bedroom", 200, now.Add(-26 * time.Hour)},
{"Hallway", 60, now.Add(-5 * time.Hour)},
{"Bathroom", 75, now.Add(-50 * time.Hour)},
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
},
schedule: []scheduleEntry{
{"Monday", "09:00"},
{"Wednesday", "09:00"},
{"Friday", "09:00"},
{"Saturday", "14:00"},
},
cleanHistory: []historyEntry{
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
},
}
}
func (s *roombaState) dispatch(input string) commandResult {
parts := strings.Fields(input)
if len(parts) == 0 {
return commandResult{}
}
cmd := strings.ToLower(parts[0])
args := parts[1:]
switch cmd {
case "help":
return s.cmdHelp()
case "status":
return s.cmdStatus()
case "clean":
return s.cmdClean(args)
case "dock":
return s.cmdDock()
case "map":
return s.cmdMap()
case "schedule":
return s.cmdSchedule(args)
case "history":
return s.cmdHistory()
case "diagnostics":
return s.cmdDiagnostics()
case "alerts":
return s.cmdAlerts()
case "reboot":
return s.cmdReboot()
case "exit", "logout":
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
default:
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
}
}
func (s *roombaState) cmdHelp() commandResult {
help := `Available commands:
help - Show this help message
status - Show robot status
clean - Start full cleaning job
clean room <name> - Clean a specific room
dock - Return to dock
map - Show floor plan and room list
schedule - List cleaning schedule
schedule add <day> <time> - Add scheduled cleaning
schedule remove <day> - Remove scheduled cleaning
history - Show recent cleaning history
diagnostics - Run system diagnostics
alerts - Show active alerts
reboot - Reboot RoombaOS
exit / logout - Disconnect`
return commandResult{output: help}
}
func (s *roombaState) cmdStatus() commandResult {
var b strings.Builder
b.WriteString("=== RoombaOS System Status ===\n")
b.WriteString("Model: iRobot Roomba j7+\n")
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
b.WriteString("Side brush: OK (142 hrs)\n")
b.WriteString("Main brush: OK (98 hrs)\n")
b.WriteString("\n")
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
b.WriteString("Signal: -38 dBm\n")
b.WriteString("Alexa: Linked\n")
b.WriteString("Google Home: Linked\n")
b.WriteString("iRobot Home App: Connected\n")
b.WriteString("\n")
b.WriteString("Firmware: v4.3.7-stable\n")
b.WriteString("LIDAR: Operational\n")
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
return commandResult{output: b.String()}
}
func (s *roombaState) cmdClean(args []string) commandResult {
if s.status == "Cleaning" {
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
}
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
roomName := strings.Join(args[1:], " ")
for _, r := range s.rooms {
if strings.EqualFold(r.name, roomName) {
s.status = "Cleaning"
return commandResult{output: fmt.Sprintf(
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
)}
}
}
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
}
if len(args) > 0 {
return commandResult{output: "Usage: clean [room <name>]"}
}
s.status = "Cleaning"
var totalArea int
for _, r := range s.rooms {
totalArea += r.areaSqFt
}
return commandResult{output: fmt.Sprintf(
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
totalArea, len(s.rooms), totalArea/8,
)}
}
func (s *roombaState) cmdDock() commandResult {
if s.status == "Docked" {
return commandResult{output: "Already docked."}
}
if s.status == "Returning to dock" {
return commandResult{output: "Already returning to dock."}
}
s.status = "Returning to dock"
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
}
func (s *roombaState) cmdMap() commandResult {
var b strings.Builder
b.WriteString("=== Floor Plan ===\n\n")
b.WriteString(" +------------+----------+\n")
b.WriteString(" | | |\n")
b.WriteString(" | Kitchen | Bathroom |\n")
b.WriteString(" | 180sqft | 75sqft |\n")
b.WriteString(" | | |\n")
b.WriteString(" +------+-----+----+-----+\n")
b.WriteString(" | | | |\n")
b.WriteString(" | Hall | Living | Cat |\n")
b.WriteString(" | 60sf | Room | Rm |\n")
b.WriteString(" | | 320sqft |110sf|\n")
b.WriteString(" +------+ +-----+\n")
b.WriteString(" | | |\n")
b.WriteString(" | Bed +----------+\n")
b.WriteString(" | room | [DOCK]\n")
b.WriteString(" |200sf |\n")
b.WriteString(" +------+\n")
b.WriteString("\nRoom Details:\n")
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
for _, r := range s.rooms {
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
}
return commandResult{output: b.String()}
}
func (s *roombaState) cmdSchedule(args []string) commandResult {
if len(args) == 0 {
return s.scheduleList()
}
sub := strings.ToLower(args[0])
switch sub {
case "add":
if len(args) < 3 {
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
}
return s.scheduleAdd(args[1], args[2])
case "remove":
if len(args) < 2 {
return commandResult{output: "Usage: schedule remove <day>"}
}
return s.scheduleRemove(args[1])
default:
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
}
}
func (s *roombaState) scheduleList() commandResult {
if len(s.schedule) == 0 {
return commandResult{output: "No cleaning schedule configured."}
}
var b strings.Builder
b.WriteString("=== Cleaning Schedule ===\n")
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
for _, e := range s.schedule {
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
}
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
return commandResult{output: b.String()}
}
func (s *roombaState) scheduleAdd(day, t string) commandResult {
day = capitalizeFirst(strings.ToLower(day))
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
if !slices.Contains(validDays, day) {
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
}
for _, e := range s.schedule {
if strings.EqualFold(e.day, day) {
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
}
}
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
}
func (s *roombaState) scheduleRemove(day string) commandResult {
day = capitalizeFirst(strings.ToLower(day))
for i, e := range s.schedule {
if strings.EqualFold(e.day, day) {
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
}
}
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
}
func (s *roombaState) cmdHistory() commandResult {
if len(s.cleanHistory) == 0 {
return commandResult{output: "No cleaning history."}
}
var b strings.Builder
b.WriteString("=== Cleaning History ===\n")
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
for _, h := range s.cleanHistory {
ts := h.timestamp.Format("2006-01-02 15:04")
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
}
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
return commandResult{output: b.String()}
}
func (s *roombaState) cmdDiagnostics() commandResult {
diag := `Running RoombaOS diagnostics...
[1/8] Cliff sensors........... OK
[2/8] Bumper sensor........... OK
[3/8] Side brush motor........ OK (142 hrs until replacement)
[4/8] Main brush motor........ OK (98 hrs until replacement)
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
[7/8] Dustbin sensor.......... OK
[8/8] WiFi module............. OK (signal: -38 dBm)
ALL SYSTEMS NOMINAL`
return commandResult{output: diag}
}
func (s *roombaState) cmdAlerts() commandResult {
var alerts []string
if s.dustbin >= 60 {
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
}
alerts = append(alerts,
"WARNING: Side brush replacement due in 12 hours",
"INFO: Unidentified sticky substance detected in Kitchen",
"INFO: Cat frequently blocking cleaning path in Cat's Room",
"INFO: Firmware update available: v4.4.0-beta",
"INFO: Filter replacement recommended in 14 days",
)
var b strings.Builder
b.WriteString("=== Active Alerts ===\n")
for _, a := range alerts {
b.WriteString(a + "\n")
}
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
return commandResult{output: b.String()}
}
func (s *roombaState) cmdReboot() commandResult {
reboot := `RoombaOS is rebooting...
Stopping navigation engine..... done
Saving room map data........... done
Flushing cleaning logs......... done
Disconnecting from WiFi........ done
Rebooting now. Goodbye!`
return commandResult{output: reboot, exit: true}
}
func capitalizeFirst(s string) string {
if s == "" {
return s
}
return strings.ToUpper(s[:1]) + s[1:]
}
func formatDuration(d time.Duration) string {
hours := int(d.Hours())
minutes := int(d.Minutes()) % 60
if hours >= 24 {
days := hours / 24
hours %= 24
return fmt.Sprintf("%dd %dh", days, hours)
}
if hours > 0 {
return fmt.Sprintf("%dh %dm", hours, minutes)
}
return fmt.Sprintf("%dm", minutes)
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"io"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// Shell is the interface that all honeypot shell implementations must satisfy.
@@ -24,6 +24,7 @@ type SessionContext struct {
Store storage.Store
ShellConfig map[string]any
CommonConfig ShellCommonConfig
OnCommand func(shell string) // called when a command is executed; may be nil
}
// ShellCommonConfig holds settings shared across all shell types.

View File

@@ -0,0 +1,101 @@
package tetris
import "github.com/charmbracelet/lipgloss"
// pieceType identifies a tetromino (06).
type pieceType int
const (
pieceI pieceType = iota
pieceO
pieceT
pieceS
pieceZ
pieceJ
pieceL
)
const numPieceTypes = 7
// Standard Tetris colors.
var pieceColors = [numPieceTypes]lipgloss.Color{
lipgloss.Color("#00FFFF"), // I — cyan
lipgloss.Color("#FFFF00"), // O — yellow
lipgloss.Color("#AA00FF"), // T — purple
lipgloss.Color("#00FF00"), // S — green
lipgloss.Color("#FF0000"), // Z — red
lipgloss.Color("#0000FF"), // J — blue
lipgloss.Color("#FF8800"), // L — orange
}
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
// relative to the piece origin.
type rotation [4][2]int
var pieces = [numPieceTypes][4]rotation{
// I
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
},
// O
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
},
// T
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
},
// S
{
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
},
// Z
{
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
},
// J
{
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
},
// L
{
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
},
}
// spawnCol returns the starting column for a piece, centering it on the board.
func spawnCol(pt pieceType, rot int) int {
shape := pieces[pt][rot]
minC, maxC := shape[0][1], shape[0][1]
for _, off := range shape {
if off[1] < minC {
minC = off[1]
}
if off[1] > maxC {
maxC = off[1]
}
}
width := maxC - minC + 1
return (boardCols - width) / 2
}

View File

@@ -0,0 +1,210 @@
package tetris
import "math/rand/v2"
const (
boardRows = 20
boardCols = 10
)
// cell represents a single board cell. Zero value is empty.
type cell struct {
filled bool
piece pieceType // which piece type filled this cell (for color)
}
// gameState holds all mutable state for a Tetris game.
type gameState struct {
board [boardRows][boardCols]cell
current pieceType
currentRot int
currentRow int
currentCol int
next pieceType
score int
level int
lines int
gameOver bool
}
// newGame creates a new game state, optionally starting at a given level.
func newGame(startLevel int) *gameState {
g := &gameState{
level: startLevel,
next: pieceType(rand.IntN(numPieceTypes)),
}
g.spawnPiece()
return g
}
// spawnPiece pulls the next piece and generates a new next.
func (g *gameState) spawnPiece() {
g.current = g.next
g.next = pieceType(rand.IntN(numPieceTypes))
g.currentRot = 0
g.currentRow = 0
g.currentCol = spawnCol(g.current, 0)
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
g.gameOver = true
}
}
// canPlace checks whether the piece fits at the given position.
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
shape := pieces[pt][rot]
for _, off := range shape {
r, c := row+off[0], col+off[1]
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
return false
}
if g.board[r][c].filled {
return false
}
}
return true
}
// moveLeft moves the current piece left if possible.
func (g *gameState) moveLeft() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
g.currentCol--
return true
}
return false
}
// moveRight moves the current piece right if possible.
func (g *gameState) moveRight() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
g.currentCol++
return true
}
return false
}
// moveDown moves the current piece down one row. Returns false if it cannot.
func (g *gameState) moveDown() bool {
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
g.currentRow++
return true
}
return false
}
// rotate rotates the current piece clockwise with wall kick attempts.
func (g *gameState) rotate() bool {
newRot := (g.currentRot + 1) % 4
// Try in-place first.
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
g.currentRot = newRot
return true
}
// Wall kick: try +-1 column offset.
for _, offset := range []int{-1, 1} {
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
g.currentRot = newRot
g.currentCol += offset
return true
}
}
// I piece: try +-2.
if g.current == pieceI {
for _, offset := range []int{-2, 2} {
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
g.currentRot = newRot
g.currentCol += offset
return true
}
}
}
return false
}
// ghostRow returns the row where the piece would land.
func (g *gameState) ghostRow() int {
row := g.currentRow
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
row++
}
return row
}
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
func (g *gameState) hardDrop() int {
ghost := g.ghostRow()
dropped := ghost - g.currentRow
g.currentRow = ghost
return dropped
}
// lockPiece writes the current piece into the board.
func (g *gameState) lockPiece() {
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := g.currentRow+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
g.board[r][c] = cell{filled: true, piece: g.current}
}
}
}
// clearLines removes completed rows and returns how many were cleared.
func (g *gameState) clearLines() int {
cleared := 0
for r := boardRows - 1; r >= 0; r-- {
full := true
for c := range boardCols {
if !g.board[r][c].filled {
full = false
break
}
}
if full {
cleared++
// Shift everything above down.
for rr := r; rr > 0; rr-- {
g.board[rr] = g.board[rr-1]
}
g.board[0] = [boardCols]cell{}
r++ // re-check this row since we shifted
}
}
return cleared
}
// NES-style scoring multipliers per lines cleared.
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
// addScore updates score, lines, and level after clearing rows.
func (g *gameState) addScore(linesCleared int) {
if linesCleared > 0 && linesCleared <= 4 {
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
}
g.lines += linesCleared
// Level up every 10 lines.
newLevel := g.lines / 10
if newLevel > g.level {
g.level = newLevel
}
}
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
// Returns the number of lines cleared.
func (g *gameState) afterLock() int {
g.lockPiece()
cleared := g.clearLines()
g.addScore(cleared)
g.spawnPiece()
return cleared
}
// tickInterval returns the gravity interval in milliseconds for the current level.
func tickInterval(level int) int {
return max(800-level*60, 100)
}

View File

@@ -0,0 +1,331 @@
package tetris
import (
"context"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
type screen int
const (
screenTitle screen = iota
screenGame
screenGameOver
)
type tickMsg time.Time
type lockMsg time.Time
const lockDelay = 500 * time.Millisecond
type model struct {
sess *shell.SessionContext
difficulty string
screen screen
game *gameState
quitting bool
height int
keypresses int
locking bool // true when piece has landed and lock delay is active
}
func newModel(sess *shell.SessionContext, difficulty string) *model {
return &model{
sess: sess,
difficulty: difficulty,
screen: screenTitle,
}
}
func (m *model) Init() tea.Cmd {
return nil
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if m.quitting {
return m, tea.Quit
}
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.height = msg.Height
return m, nil
case tea.KeyMsg:
m.keypresses++
if msg.Type == tea.KeyCtrlC {
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
tea.Quit,
)
}
}
switch m.screen {
case screenTitle:
return m.updateTitle(msg)
case screenGame:
return m.updateGame(msg)
case screenGameOver:
return m.updateGameOver(msg)
}
return m, nil
}
func (m *model) View() string {
var content string
switch m.screen {
case screenTitle:
content = m.titleView()
case screenGame:
content = gameView(m.game)
case screenGameOver:
content = m.gameOverView()
}
return gameFrame(content, m.height)
}
// --- Title screen ---
func (m *model) titleView() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
b.WriteString("\n")
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(" Press any key to start"))
b.WriteString("\n")
return b.String()
}
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
if _, ok := msg.(tea.KeyMsg); ok {
m.screen = screenGame
var startLevel int
if m.difficulty == "hard" {
startLevel = 5
}
m.game = newGame(startLevel)
return m, tea.Batch(
tea.ClearScreen,
m.scheduleTick(),
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
)
}
return m, nil
}
// --- Game screen ---
func (m *model) scheduleTick() tea.Cmd {
ms := tickInterval(m.game.level)
if m.difficulty == "easy" {
ms = max(1000-m.game.level*60, 150)
}
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func (m *model) scheduleLock() tea.Cmd {
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
return lockMsg(t)
})
}
// performLock locks the piece, clears lines, and returns commands for logging
// and scheduling the next tick. Returns nil if game over (goToGameOver is
// included in the returned batch).
func (m *model) performLock() tea.Cmd {
m.locking = false
cleared := m.game.afterLock()
if m.game.gameOver {
return tea.Batch(
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
m.goToGameOver(),
)
}
var cmds []tea.Cmd
cmds = append(cmds, m.scheduleTick())
if cleared > 0 {
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
prevLevel := (m.game.lines - cleared) / 10
if m.game.level > prevLevel {
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
}
}
return tea.Batch(cmds...)
}
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case lockMsg:
if m.game.gameOver || !m.locking {
return m, nil
}
// Lock delay expired — lock the piece now.
return m, m.performLock()
case tickMsg:
if m.game.gameOver || m.locking {
return m, nil
}
if !m.game.moveDown() {
// Piece landed — start lock delay instead of locking immediately.
m.locking = true
return m, m.scheduleLock()
}
return m, m.scheduleTick()
case tea.KeyMsg:
if m.game.gameOver {
return m, nil
}
switch msg.String() {
case "left":
m.game.moveLeft()
// If piece can now drop further, cancel lock delay.
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case "right":
m.game.moveRight()
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case "down":
if m.game.moveDown() {
m.game.score++ // soft drop bonus
if m.locking {
m.locking = false
}
}
case "up", "z":
m.game.rotate()
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
m.locking = false
}
case " ":
m.locking = false
dropped := m.game.hardDrop()
m.game.score += dropped * 2
return m, m.performLock()
case "q":
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
tea.Quit,
)
}
return m, nil
}
return m, nil
}
// --- Game over screen ---
func (m *model) goToGameOver() tea.Cmd {
m.screen = screenGameOver
return tea.ClearScreen
}
func (m *model) gameOverView() string {
var b strings.Builder
b.WriteString("\n")
b.WriteString(titleStyle.Render(" GAME OVER"))
b.WriteString("\n\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
b.WriteString("\n")
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
b.WriteString("\n\n")
b.WriteString(dimStyle.Render(" R - Play again"))
b.WriteString("\n")
b.WriteString(dimStyle.Render(" Q - Quit"))
b.WriteString("\n")
return b.String()
}
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
if keyMsg, ok := msg.(tea.KeyMsg); ok {
switch keyMsg.String() {
case "r":
startLevel := 0
if m.difficulty == "hard" {
startLevel = 5
}
m.game = newGame(startLevel)
m.screen = screenGame
m.keypresses = 0
return m, tea.Batch(
tea.ClearScreen,
m.scheduleTick(),
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
)
case "q":
m.quitting = true
return m, tea.Batch(
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
tea.Quit,
)
}
}
return m, nil
}
// Helper methods for safe access when game may be nil.
func (m *model) gameScore() int {
if m.game != nil {
return m.game.score
}
return 0
}
func (m *model) gameLevel() int {
if m.game != nil {
return m.game.level
}
return 0
}
func (m *model) gameLines() int {
if m.game != nil {
return m.game.lines
}
return 0
}
// logAction returns a tea.Cmd that logs an action to the session store.
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
return func() tea.Msg {
if sess.Store != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
}
if sess.OnCommand != nil {
sess.OnCommand("tetris")
}
return nil
}
}

View File

@@ -0,0 +1,286 @@
package tetris
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
const termWidth = 80
var (
colorWhite = lipgloss.Color("#FFFFFF")
colorDim = lipgloss.Color("#555555")
colorBlack = lipgloss.Color("#000000")
colorGhost = lipgloss.Color("#333333")
)
var (
baseStyle = lipgloss.NewStyle().
Foreground(colorWhite).
Background(colorBlack)
dimStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
titleStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#00FFFF")).
Background(colorBlack).
Bold(true)
sidebarLabelStyle = lipgloss.NewStyle().
Foreground(colorDim).
Background(colorBlack)
sidebarValueStyle = lipgloss.NewStyle().
Foreground(colorWhite).
Background(colorBlack).
Bold(true)
)
// cellStyle returns a style for a filled cell of a given piece type.
func cellStyle(pt pieceType) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(pieceColors[pt]).
Background(colorBlack)
}
// ghostStyle returns a dimmed style for the ghost piece.
func ghostCellStyle() lipgloss.Style {
return lipgloss.NewStyle().
Foreground(colorGhost).
Background(colorBlack)
}
// renderBoard renders the board, current piece, and ghost piece as a string.
func renderBoard(g *gameState) string {
// Build a display grid that includes the current piece and ghost.
type displayCell struct {
filled bool
ghost bool
piece pieceType
}
var grid [boardRows][boardCols]displayCell
// Copy locked cells.
for r := range boardRows {
for c := range boardCols {
if g.board[r][c].filled {
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
}
}
}
// Ghost piece.
ghostR := g.ghostRow()
if ghostR != g.currentRow {
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := ghostR+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
grid[r][c] = displayCell{ghost: true, piece: g.current}
}
}
}
// Current piece.
shape := pieces[g.current][g.currentRot]
for _, off := range shape {
r, c := g.currentRow+off[0], g.currentCol+off[1]
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
grid[r][c] = displayCell{filled: true, piece: g.current}
}
}
// Render grid.
var b strings.Builder
borderStyle := dimStyle
for _, row := range grid {
b.WriteString(borderStyle.Render("|"))
for _, dc := range row {
switch {
case dc.filled:
b.WriteString(cellStyle(dc.piece).Render("[]"))
case dc.ghost:
b.WriteString(ghostCellStyle().Render("::"))
default:
b.WriteString(baseStyle.Render(" "))
}
}
b.WriteString(borderStyle.Render("|"))
b.WriteString("\n")
}
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
return b.String()
}
// renderNextPiece renders the "next piece" preview box.
func renderNextPiece(pt pieceType) string {
shape := pieces[pt][0]
// Determine bounding box.
minR, maxR := shape[0][0], shape[0][0]
minC, maxC := shape[0][1], shape[0][1]
for _, off := range shape {
if off[0] < minR {
minR = off[0]
}
if off[0] > maxR {
maxR = off[0]
}
if off[1] < minC {
minC = off[1]
}
if off[1] > maxC {
maxC = off[1]
}
}
rows := maxR - minR + 1
cols := maxC - minC + 1
// Build a small grid.
grid := make([][]bool, rows)
for i := range grid {
grid[i] = make([]bool, cols)
}
for _, off := range shape {
grid[off[0]-minR][off[1]-minC] = true
}
var b strings.Builder
boxWidth := 8 // chars for the box interior
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
b.WriteString("\n")
for r := range rows {
b.WriteString(dimStyle.Render("|"))
// Center the piece in the box.
pieceWidth := cols * 2
leftPad := (boxWidth - pieceWidth) / 2
rightPad := boxWidth - pieceWidth - leftPad
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
for c := range cols {
if grid[r][c] {
b.WriteString(cellStyle(pt).Render("[]"))
} else {
b.WriteString(baseStyle.Render(" "))
}
}
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
b.WriteString(dimStyle.Render("|"))
b.WriteString("\n")
}
// Fill remaining rows in the box (max 4 rows for I piece).
for r := rows; r < 2; r++ {
b.WriteString(dimStyle.Render("|"))
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
b.WriteString(dimStyle.Render("|"))
b.WriteString("\n")
}
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
return b.String()
}
// formatScore formats a score with comma separators.
func formatScore(n int) string {
s := fmt.Sprintf("%d", n)
if len(s) <= 3 {
return s
}
var parts []string
for len(s) > 3 {
parts = append([]string{s[len(s)-3:]}, parts...)
s = s[:len(s)-3]
}
parts = append([]string{s}, parts...)
return strings.Join(parts, ",")
}
// gameView combines the board and sidebar into the game screen.
func gameView(g *gameState) string {
boardStr := renderBoard(g)
boardLines := strings.Split(boardStr, "\n")
nextStr := renderNextPiece(g.next)
nextLines := strings.Split(nextStr, "\n")
// Build sidebar lines.
var sidebar []string
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
sidebar = append(sidebar, nextLines...)
sidebar = append(sidebar, "")
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
sidebar = append(sidebar, "")
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
// Combine board and sidebar side by side.
var b strings.Builder
maxLines := max(len(boardLines), len(sidebar))
for i := range maxLines {
boardLine := ""
if i < len(boardLines) {
boardLine = boardLines[i]
}
sidebarLine := ""
if i < len(sidebar) {
sidebarLine = sidebar[i]
}
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
b.WriteString(boardLine)
b.WriteString(sidebarLine)
b.WriteString("\n")
}
return b.String()
}
// padLine pads a single line to termWidth.
func padLine(line string) string {
w := lipgloss.Width(line)
if w >= termWidth {
return line
}
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
}
// padLines pads every line in a multi-line string to termWidth.
func padLines(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
lines[i] = padLine(line)
}
return strings.Join(lines, "\n")
}
// gameFrame wraps content with padding to fill the terminal.
func gameFrame(content string, height int) string {
var b strings.Builder
b.WriteString(content)
// Pad with blank lines to fill terminal height.
if height > 0 {
contentLines := strings.Count(content, "\n") + 1
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
for i := contentLines; i < height; i++ {
b.WriteString(blankLine)
b.WriteString("\n")
}
}
return padLines(b.String())
}

View File

@@ -0,0 +1,66 @@
package tetris
import (
"context"
"io"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 10 * time.Minute
// TetrisShell is a Tetris game TUI for the honeypot.
type TetrisShell struct{}
// NewTetrisShell returns a new TetrisShell instance.
func NewTetrisShell() *TetrisShell {
return &TetrisShell{}
}
func (t *TetrisShell) Name() string { return "tetris" }
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
m := newModel(sess, difficulty)
p := tea.NewProgram(m,
tea.WithInput(rw),
tea.WithOutput(rw),
tea.WithAltScreen(),
)
done := make(chan error, 1)
go func() {
_, err := p.Run()
done <- err
}()
select {
case err := <-done:
return err
case <-ctx.Done():
p.Quit()
<-done
return nil
}
}
// configString reads a string from the shell config map with a default.
func configString(cfg map[string]any, key, defaultVal string) string {
if cfg == nil {
return defaultVal
}
if v, ok := cfg[key]; ok {
if s, ok := v.(string); ok && s != "" {
return s
}
}
return defaultVal
}

View File

@@ -0,0 +1,582 @@
package tetris
import (
"context"
"strings"
"testing"
"time"
tea "github.com/charmbracelet/bubbletea"
"code.t-juice.club/torjus/oubliette/internal/shell"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// newTestModel creates a model with a test session context.
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
t.Helper()
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "player",
Store: store,
}
m := newModel(sess, "normal")
return m, store
}
// sendKey sends a single key message to the model and returns the command.
func sendKey(m *model, key string) tea.Cmd {
var msg tea.KeyMsg
switch key {
case "enter":
msg = tea.KeyMsg{Type: tea.KeyEnter}
case "up":
msg = tea.KeyMsg{Type: tea.KeyUp}
case "down":
msg = tea.KeyMsg{Type: tea.KeyDown}
case "left":
msg = tea.KeyMsg{Type: tea.KeyLeft}
case "right":
msg = tea.KeyMsg{Type: tea.KeyRight}
case "space":
msg = tea.KeyMsg{Type: tea.KeySpace}
case "ctrl+c":
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
default:
if len(key) == 1 {
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
}
}
_, cmd := m.Update(msg)
return cmd
}
// sendTick sends a tick message to the model.
func sendTick(m *model) tea.Cmd {
_, cmd := m.Update(tickMsg(time.Now()))
return cmd
}
// execCmds recursively executes tea.Cmd functions (including batches).
func execCmds(cmd tea.Cmd) {
if cmd == nil {
return
}
msg := cmd()
if batch, ok := msg.(tea.BatchMsg); ok {
for _, c := range batch {
execCmds(c)
}
}
}
func TestTetrisShellName(t *testing.T) {
sh := NewTetrisShell()
if sh.Name() != "tetris" {
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
}
if sh.Description() == "" {
t.Error("Description() should not be empty")
}
}
func TestConfigString(t *testing.T) {
cfg := map[string]any{
"difficulty": "hard",
}
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
t.Errorf("configString() = %q, want %q", got, "hard")
}
if got := configString(cfg, "missing", "normal"); got != "normal" {
t.Errorf("configString() = %q, want %q", got, "normal")
}
if got := configString(nil, "difficulty", "normal"); got != "normal" {
t.Errorf("configString(nil) = %q, want %q", got, "normal")
}
}
func TestTitleScreenRenders(t *testing.T) {
m, _ := newTestModel(t)
view := m.View()
if !strings.Contains(view, "████") {
t.Error("title screen should show TETRIS logo")
}
if !strings.Contains(view, "Press any key") {
t.Error("title screen should show 'Press any key'")
}
}
func TestTitleToGame(t *testing.T) {
m, _ := newTestModel(t)
if m.screen != screenTitle {
t.Fatalf("expected screenTitle, got %d", m.screen)
}
sendKey(m, "enter")
if m.screen != screenGame {
t.Errorf("expected screenGame after keypress, got %d", m.screen)
}
if m.game == nil {
t.Fatal("game should be initialized")
}
}
func TestGameRenders(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
view := m.View()
if !strings.Contains(view, "|") {
t.Error("game view should contain board borders")
}
if !strings.Contains(view, "SCORE") {
t.Error("game view should show SCORE")
}
if !strings.Contains(view, "LEVEL") {
t.Error("game view should show LEVEL")
}
if !strings.Contains(view, "LINES") {
t.Error("game view should show LINES")
}
if !strings.Contains(view, "NEXT") {
t.Error("game view should show NEXT")
}
}
// --- Pure game logic tests ---
func TestNewGame(t *testing.T) {
g := newGame(0)
if g.gameOver {
t.Error("new game should not be game over")
}
if g.score != 0 {
t.Errorf("initial score = %d, want 0", g.score)
}
if g.level != 0 {
t.Errorf("initial level = %d, want 0", g.level)
}
if g.lines != 0 {
t.Errorf("initial lines = %d, want 0", g.lines)
}
}
func TestNewGameHardLevel(t *testing.T) {
g := newGame(5)
if g.level != 5 {
t.Errorf("hard start level = %d, want 5", g.level)
}
}
func TestMoveLeft(t *testing.T) {
g := newGame(0)
startCol := g.currentCol
g.moveLeft()
if g.currentCol != startCol-1 {
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
}
}
func TestMoveRight(t *testing.T) {
g := newGame(0)
startCol := g.currentCol
g.moveRight()
if g.currentCol != startCol+1 {
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
}
}
func TestMoveDown(t *testing.T) {
g := newGame(0)
startRow := g.currentRow
moved := g.moveDown()
if !moved {
t.Error("moveDown should succeed from starting position")
}
if g.currentRow != startRow+1 {
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
}
}
func TestCannotMoveLeftBeyondWall(t *testing.T) {
g := newGame(0)
// Move all the way left.
for range boardCols {
g.moveLeft()
}
col := g.currentCol
g.moveLeft() // should not move further
if g.currentCol != col {
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
}
}
func TestCannotMoveRightBeyondWall(t *testing.T) {
g := newGame(0)
// Move all the way right.
for range boardCols {
g.moveRight()
}
col := g.currentCol
g.moveRight() // should not move further
if g.currentCol != col {
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
}
}
func TestRotate(t *testing.T) {
g := newGame(0)
startRot := g.currentRot
g.rotate()
// Rotation should change (possibly with wall kick).
if g.currentRot == startRot {
// Rotation might legitimately fail in some edge cases, so just check
// that the game state is valid.
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
t.Error("piece should be in a valid position after rotate attempt")
}
}
}
func TestHardDrop(t *testing.T) {
g := newGame(0)
startRow := g.currentRow
dropped := g.hardDrop()
if dropped == 0 {
t.Error("hard drop should move piece down at least some rows from top")
}
if g.currentRow <= startRow {
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
}
}
func TestGhostRow(t *testing.T) {
g := newGame(0)
ghost := g.ghostRow()
if ghost < g.currentRow {
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
}
// Ghost should be at a position where moving down one more is impossible.
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
t.Error("ghost row should be the lowest valid position")
}
}
func TestLockPiece(t *testing.T) {
g := newGame(0)
g.hardDrop()
pt := g.current
row, col, rot := g.currentRow, g.currentCol, g.currentRot
g.lockPiece()
// Verify that the piece's cells are now filled.
shape := pieces[pt][rot]
for _, off := range shape {
r, c := row+off[0], col+off[1]
if !g.board[r][c].filled {
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
}
}
}
func TestClearLines(t *testing.T) {
g := newGame(0)
// Fill the bottom row completely.
for c := range boardCols {
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
}
cleared := g.clearLines()
if cleared != 1 {
t.Errorf("clearLines() = %d, want 1", cleared)
}
// Bottom row should now be empty (shifted from above).
for c := range boardCols {
if g.board[boardRows-1][c].filled {
t.Errorf("bottom row col %d should be empty after clearing", c)
}
}
}
func TestClearMultipleLines(t *testing.T) {
g := newGame(0)
// Fill the bottom 4 rows.
for r := boardRows - 4; r < boardRows; r++ {
for c := range boardCols {
g.board[r][c] = cell{filled: true, piece: pieceI}
}
}
cleared := g.clearLines()
if cleared != 4 {
t.Errorf("clearLines() = %d, want 4", cleared)
}
}
func TestScoring(t *testing.T) {
tests := []struct {
lines int
level int
want int
}{
{1, 0, 40},
{2, 0, 100},
{3, 0, 300},
{4, 0, 1200},
{1, 1, 80},
{4, 2, 3600},
}
for _, tt := range tests {
g := newGame(tt.level)
g.addScore(tt.lines)
if g.score != tt.want {
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
}
}
}
func TestLevelUp(t *testing.T) {
g := newGame(0)
g.lines = 9
g.addScore(1) // This should push lines to 10, triggering level 1.
if g.level != 1 {
t.Errorf("level = %d, want 1 after 10 lines", g.level)
}
}
func TestTickInterval(t *testing.T) {
if got := tickInterval(0); got != 800 {
t.Errorf("tickInterval(0) = %d, want 800", got)
}
if got := tickInterval(5); got != 500 {
t.Errorf("tickInterval(5) = %d, want 500", got)
}
// Floor at 100ms.
if got := tickInterval(20); got != 100 {
t.Errorf("tickInterval(20) = %d, want 100", got)
}
}
func TestFormatScore(t *testing.T) {
tests := []struct {
n int
want string
}{
{0, "0"},
{100, "100"},
{1250, "1,250"},
{1000000, "1,000,000"},
}
for _, tt := range tests {
if got := formatScore(tt.n); got != tt.want {
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
}
}
}
func TestGameOverScreen(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Force game over.
m.game.gameOver = true
m.screen = screenGameOver
view := m.View()
if !strings.Contains(view, "GAME OVER") {
t.Error("game over screen should show GAME OVER")
}
if !strings.Contains(view, "Score") {
t.Error("game over screen should show score")
}
}
func TestRestartFromGameOver(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
m.game.gameOver = true
m.screen = screenGameOver
sendKey(m, "r")
if m.screen != screenGame {
t.Errorf("expected screenGame after restart, got %d", m.screen)
}
if m.game.gameOver {
t.Error("game should not be over after restart")
}
}
func TestQuitFromGame(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
sendKey(m, "q")
if !m.quitting {
t.Error("should be quitting after pressing q")
}
}
func TestQuitFromGameOver(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
m.game.gameOver = true
m.screen = screenGameOver
sendKey(m, "q")
if !m.quitting {
t.Error("should be quitting after pressing q in game over")
}
}
func TestSoftDropScoring(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
scoreBefore := m.game.score
sendKey(m, "down")
if m.game.score != scoreBefore+1 {
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
}
}
func TestHardDropScoring(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Hard drop gives 2 points per row dropped.
sendKey(m, "space")
if m.game.score < 2 {
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
}
}
func TestTickMovesDown(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
rowBefore := m.game.currentRow
sendTick(m)
// Piece should either move down by 1, or lock and spawn a new piece at top.
movedDown := m.game.currentRow == rowBefore+1
respawned := m.game.currentRow < rowBefore
if !movedDown && !respawned && !m.game.gameOver {
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
}
}
func TestSessionLogs(t *testing.T) {
m, store := newTestModel(t)
// Press key to start game — returns a logAction cmd.
cmd := sendKey(m, "enter")
if cmd != nil {
execCmds(cmd)
}
time.Sleep(50 * time.Millisecond)
found := false
for _, log := range store.SessionLogs {
if strings.Contains(log.Input, "GAME START") {
found = true
break
}
}
if !found {
t.Error("expected GAME START in session logs")
}
}
func TestKeypressCounter(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
sendKey(m, "left")
sendKey(m, "right")
sendKey(m, "down")
if m.keypresses != 4 { // enter + 3 game keys
t.Errorf("keypresses = %d, want 4", m.keypresses)
}
}
func TestLockDelay(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Drop piece to the bottom via ticks until it can't move down.
for range boardRows + 5 {
if m.locking {
break
}
sendTick(m)
}
if !m.locking {
t.Fatal("piece should be in locking state after hitting bottom")
}
// During lock delay, we should still be able to move left/right.
colBefore := m.game.currentCol
sendKey(m, "left")
if m.game.currentCol >= colBefore {
// Might not have moved if against wall, try right.
sendKey(m, "right")
}
// Sending a lockMsg should finalize the piece.
m.Update(lockMsg(time.Now()))
// After lock, a new piece should have spawned (row near top).
if m.game.currentRow > 1 && !m.game.gameOver {
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
}
}
func TestLockDelayCancelledByDrop(t *testing.T) {
m, _ := newTestModel(t)
sendKey(m, "enter") // start game
// Build a ledge: fill rows 18-19 but leave column 0 empty.
for r := boardRows - 2; r < boardRows; r++ {
for c := 1; c < boardCols; c++ {
m.game.board[r][c] = cell{filled: true, piece: pieceI}
}
}
// Move piece to column 0 area and drop it onto the ledge.
for range boardCols {
m.game.moveLeft()
}
// Tick down until locking.
for range boardRows + 5 {
if m.locking {
break
}
sendTick(m)
}
// If piece is on the ledge and we slide it to col 0 (open column),
// the lock delay should cancel since it can fall further.
// This test just validates the locking flag logic works.
if m.locking {
// Try moving — if piece can drop further, locking should cancel.
sendKey(m, "left")
// Whether locking cancels depends on the board state; just verify no crash.
}
}
func TestSpawnCol(t *testing.T) {
// All pieces should spawn roughly centered.
for pt := range pieceType(numPieceTypes) {
col := spawnCol(pt, 0)
if col < 0 || col > boardCols-1 {
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
}
// Verify piece fits at spawn position.
shape := pieces[pt][0]
for _, off := range shape {
c := col + off[1]
if c < 0 || c >= boardCols {
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
}
}
}
}

View File

@@ -0,0 +1,217 @@
package storage
import (
"context"
"time"
"github.com/prometheus/client_golang/prometheus"
)
// InstrumentedStore wraps a Store and records query duration and errors
// as Prometheus metrics for each method call.
type InstrumentedStore struct {
store Store
queryDuration *prometheus.HistogramVec
queryErrors *prometheus.CounterVec
}
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
return &InstrumentedStore{
store: store,
queryDuration: queryDuration,
queryErrors: queryErrors,
}
}
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
v, err := fn()
timer.ObserveDuration()
if err != nil {
s.queryErrors.WithLabelValues(method).Inc()
}
return v, err
}
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
err := fn()
timer.ObserveDuration()
if err != nil {
s.queryErrors.WithLabelValues(method).Inc()
}
return err
}
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
return observeErr(s, "RecordLoginAttempt", func() error {
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
})
}
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
return observe(s, "CreateSession", func() (string, error) {
return s.store.CreateSession(ctx, ip, username, shellName, country)
})
}
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
return observeErr(s, "EndSession", func() error {
return s.store.EndSession(ctx, sessionID, disconnectedAt)
})
}
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
return observeErr(s, "UpdateHumanScore", func() error {
return s.store.UpdateHumanScore(ctx, sessionID, score)
})
}
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
return observeErr(s, "SetExecCommand", func() error {
return s.store.SetExecCommand(ctx, sessionID, command)
})
}
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
return observeErr(s, "AppendSessionLog", func() error {
return s.store.AppendSessionLog(ctx, sessionID, input, output)
})
}
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
return s.store.DeleteRecordsBefore(ctx, cutoff)
})
}
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
return s.store.GetDashboardStats(ctx)
})
}
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
return s.store.GetTopUsernames(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
return s.store.GetTopPasswords(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
return s.store.GetTopIPs(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
return s.store.GetTopCountries(ctx, limit)
})
}
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
return s.store.GetTopExecCommands(ctx, limit)
})
}
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
return observe(s, "GetRecentSessions", func() ([]Session, error) {
return s.store.GetRecentSessions(ctx, limit, activeOnly)
})
}
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
})
}
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
return observe(s, "GetSession", func() (*Session, error) {
return s.store.GetSession(ctx, sessionID)
})
}
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
return s.store.GetSessionLogs(ctx, sessionID)
})
}
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
return observeErr(s, "AppendSessionEvents", func() error {
return s.store.AppendSessionEvents(ctx, events)
})
}
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
return s.store.GetSessionEvents(ctx, sessionID)
})
}
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
return observe(s, "CloseActiveSessions", func() (int64, error) {
return s.store.CloseActiveSessions(ctx, disconnectedAt)
})
}
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
return s.store.GetAttemptsOverTime(ctx, days, since, until)
})
}
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
return s.store.GetHourlyPattern(ctx, since, until)
})
}
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
return s.store.GetCountryStats(ctx)
})
}
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
return s.store.GetFilteredDashboardStats(ctx, f)
})
}
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
return s.store.GetFilteredTopUsernames(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
return s.store.GetFilteredTopPasswords(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
return s.store.GetFilteredTopIPs(ctx, limit, f)
})
}
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
return s.store.GetFilteredTopCountries(ctx, limit, f)
})
}
func (s *InstrumentedStore) Close() error {
return s.store.Close()
}

View File

@@ -0,0 +1,163 @@
package storage
import (
"context"
"errors"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
)
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_query_errors_total",
Help: "test",
}, []string{"method"})
store := NewMemoryStore()
return NewInstrumentedStore(store, dur, errs), dur, errs
}
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
m := &dto.Metric{}
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
return m.GetHistogram().GetSampleCount()
}
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
m := &dto.Metric{}
c.WithLabelValues(method).Write(m)
return m.GetCounter().GetValue()
}
func TestInstrumentedStoreDelegation(t *testing.T) {
s, dur, _ := newTestInstrumented()
ctx := context.Background()
// RecordLoginAttempt should delegate and record duration.
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
// CreateSession should delegate and return a valid ID.
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if id == "" {
t.Fatal("CreateSession returned empty ID")
}
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
// GetDashboardStats should delegate.
stats, err := s.GetDashboardStats(ctx)
if err != nil {
t.Fatalf("GetDashboardStats: %v", err)
}
if stats == nil {
t.Fatal("GetDashboardStats returned nil")
}
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
}
func TestInstrumentedStoreErrorCounting(t *testing.T) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test_ec_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_ec_query_errors_total",
Help: "test",
}, []string{"method"})
es := &errorStore{}
s := NewInstrumentedStore(es, dur, errs)
ctx := context.Background()
// Error should be counted.
err := s.EndSession(ctx, "nonexistent", time.Now())
if !errors.Is(err, errFake) {
t.Fatalf("expected errFake, got %v", err)
}
if c := getHistogramCount(dur, "EndSession"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
if c := getCounterValue(errs, "EndSession"); c != 1 {
t.Fatalf("expected error count 1, got %f", c)
}
// Successful call should not increment error counter.
s2, _, errs2 := newTestInstrumented()
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if err != nil {
t.Fatalf("RecordLoginAttempt: %v", err)
}
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
t.Fatalf("expected error count 0, got %f", c)
}
}
// errorStore is a Store that returns errors for all methods.
type errorStore struct {
MemoryStore
}
var errFake = errors.New("fake error")
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
return errFake
}
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
return errFake
}
func TestInstrumentedStoreObserveErr(t *testing.T) {
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "test2_query_duration_seconds",
Help: "test",
Buckets: []float64{0.001, 0.01, 0.1, 1},
}, []string{"method"})
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test2_query_errors_total",
Help: "test",
}, []string{"method"})
es := &errorStore{}
s := NewInstrumentedStore(es, dur, errs)
ctx := context.Background()
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
if !errors.Is(err, errFake) {
t.Fatalf("expected errFake, got %v", err)
}
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected error count 1, got %f", c)
}
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
t.Fatalf("expected 1 observation, got %d", c)
}
}
func TestInstrumentedStoreClose(t *testing.T) {
s, _, _ := newTestInstrumented()
if err := s.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
}

View File

@@ -25,7 +25,7 @@ func NewMemoryStore() *MemoryStore {
}
}
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip string) error {
func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password, ip, country string) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -35,6 +35,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
if a.Username == username && a.Password == password && a.IP == ip {
a.Count++
a.LastSeen = now
a.Country = country
return nil
}
}
@@ -44,6 +45,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
Username: username,
Password: password,
IP: ip,
Country: country,
Count: 1,
FirstSeen: now,
LastSeen: now,
@@ -51,7 +53,7 @@ func (m *MemoryStore) RecordLoginAttempt(_ context.Context, username, password,
return nil
}
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName string) (string, error) {
func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName, country string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -60,6 +62,7 @@ func (m *MemoryStore) CreateSession(_ context.Context, ip, username, shellName s
m.Sessions[id] = &Session{
ID: id,
IP: ip,
Country: country,
Username: username,
ShellName: shellName,
ConnectedAt: now,
@@ -88,6 +91,16 @@ func (m *MemoryStore) UpdateHumanScore(_ context.Context, sessionID string, scor
return nil
}
func (m *MemoryStore) SetExecCommand(_ context.Context, sessionID string, command string) error {
m.mu.Lock()
defer m.mu.Unlock()
if s, ok := m.Sessions[sessionID]; ok {
s.ExecCommand = &command
}
return nil
}
func (m *MemoryStore) AppendSessionLog(_ context.Context, sessionID, input, output string) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -234,7 +247,60 @@ func (m *MemoryStore) GetTopPasswords(_ context.Context, limit int) ([]TopEntry,
func (m *MemoryStore) GetTopIPs(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.topN("ip", limit), nil
type ipInfo struct {
count int64
country string
}
agg := make(map[string]*ipInfo)
for _, a := range m.LoginAttempts {
info, ok := agg[a.IP]
if !ok {
info = &ipInfo{}
agg[a.IP] = info
}
info.count += int64(a.Count)
if a.Country != "" {
info.country = a.Country
}
}
entries := make([]TopEntry, 0, len(agg))
for ip, info := range agg {
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) GetTopCountries(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.Country == "" {
continue
}
counts[a.Country] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
// topN aggregates login attempts by the given field and returns the top N. Must be called with m.mu held.
@@ -270,20 +336,105 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
m.mu.Lock()
defer m.mu.Unlock()
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
}
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.collectSessions(limit, activeOnly, f), nil
}
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
// Compute event counts and input bytes per session.
eventCounts := make(map[string]int)
inputBytes := make(map[string]int64)
for _, e := range m.SessionEvents {
eventCounts[e.SessionID]++
if e.Direction == 0 {
inputBytes[e.SessionID] += int64(len(e.Data))
}
}
var sessions []Session
for _, s := range m.Sessions {
if activeOnly && s.DisconnectedAt != nil {
continue
}
sessions = append(sessions, *s)
if !matchesSessionFilter(s, f) {
continue
}
sess := *s
sess.EventCount = eventCounts[s.ID]
sess.InputBytes = inputBytes[s.ID]
sessions = append(sessions, sess)
}
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
})
if f.SortBy == "input_bytes" {
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].InputBytes > sessions[j].InputBytes
})
} else {
sort.Slice(sessions, func(i, j int) bool {
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
})
}
if limit > 0 && len(sessions) > limit {
sessions = sessions[:limit]
}
return sessions, nil
return sessions
}
// matchesSessionFilter returns true if the session matches the given filter.
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
return false
}
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
return false
}
if f.IP != "" && s.IP != f.IP {
return false
}
if f.Country != "" && s.Country != f.Country {
return false
}
if f.Username != "" && s.Username != f.Username {
return false
}
if f.HumanScoreAboveZero {
if s.HumanScore == nil || *s.HumanScore <= 0 {
return false
}
}
return true
}
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, s := range m.Sessions {
if s.ExecCommand != nil {
counts[*s.ExecCommand]++
}
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) {
@@ -301,6 +452,258 @@ func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time
return count, nil
}
func (m *MemoryStore) GetAttemptsOverTime(_ context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
m.mu.Lock()
defer m.mu.Unlock()
var cutoff time.Time
if since != nil {
cutoff = *since
} else {
cutoff = time.Now().UTC().AddDate(0, 0, -days)
}
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.LastSeen.Before(cutoff) {
continue
}
if until != nil && a.LastSeen.After(*until) {
continue
}
day := a.LastSeen.Format("2006-01-02")
counts[day] += int64(a.Count)
}
points := make([]TimeSeriesPoint, 0, len(counts))
for day, count := range counts {
t, _ := time.Parse("2006-01-02", day)
points = append(points, TimeSeriesPoint{Timestamp: t, Count: count})
}
sort.Slice(points, func(i, j int) bool {
return points[i].Timestamp.Before(points[j].Timestamp)
})
return points, nil
}
func (m *MemoryStore) GetHourlyPattern(_ context.Context, since, until *time.Time) ([]HourlyCount, error) {
m.mu.Lock()
defer m.mu.Unlock()
hourCounts := make(map[int]int64)
for _, a := range m.LoginAttempts {
if since != nil && a.LastSeen.Before(*since) {
continue
}
if until != nil && a.LastSeen.After(*until) {
continue
}
hour := a.LastSeen.Hour()
hourCounts[hour] += int64(a.Count)
}
counts := make([]HourlyCount, 0, len(hourCounts))
for h, c := range hourCounts {
counts = append(counts, HourlyCount{Hour: h, Count: c})
}
sort.Slice(counts, func(i, j int) bool {
return counts[i].Hour < counts[j].Hour
})
return counts, nil
}
func (m *MemoryStore) GetCountryStats(_ context.Context) ([]CountryCount, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for _, a := range m.LoginAttempts {
if a.Country == "" {
continue
}
counts[a.Country] += int64(a.Count)
}
result := make([]CountryCount, 0, len(counts))
for country, count := range counts {
result = append(result, CountryCount{Country: country, Count: count})
}
sort.Slice(result, func(i, j int) bool {
return result[i].Count > result[j].Count
})
return result, nil
}
// matchesFilter returns true if the login attempt matches the given filter. Must be called with m.mu held.
func matchesFilter(a *LoginAttempt, f DashboardFilter) bool {
if f.Since != nil && a.LastSeen.Before(*f.Since) {
return false
}
if f.Until != nil && a.LastSeen.After(*f.Until) {
return false
}
if f.IP != "" && a.IP != f.IP {
return false
}
if f.Country != "" && a.Country != f.Country {
return false
}
if f.Username != "" && a.Username != f.Username {
return false
}
return true
}
func (m *MemoryStore) GetFilteredDashboardStats(_ context.Context, f DashboardFilter) (*DashboardStats, error) {
m.mu.Lock()
defer m.mu.Unlock()
stats := &DashboardStats{}
ips := make(map[string]struct{})
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
stats.TotalAttempts += int64(a.Count)
ips[a.IP] = struct{}{}
}
stats.UniqueIPs = int64(len(ips))
for _, s := range m.Sessions {
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
continue
}
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
continue
}
if f.IP != "" && s.IP != f.IP {
continue
}
if f.Country != "" && s.Country != f.Country {
continue
}
stats.TotalSessions++
if s.DisconnectedAt == nil {
stats.ActiveSessions++
}
}
return stats, nil
}
func (m *MemoryStore) GetFilteredTopUsernames(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.filteredTopN("username", limit, f), nil
}
func (m *MemoryStore) GetFilteredTopPasswords(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.filteredTopN("password", limit, f), nil
}
func (m *MemoryStore) GetFilteredTopIPs(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
type ipInfo struct {
count int64
country string
}
agg := make(map[string]*ipInfo)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
info, ok := agg[a.IP]
if !ok {
info = &ipInfo{}
agg[a.IP] = info
}
info.count += int64(a.Count)
if a.Country != "" {
info.country = a.Country
}
}
entries := make([]TopEntry, 0, len(agg))
for ip, info := range agg {
entries = append(entries, TopEntry{Value: ip, Country: info.country, Count: info.count})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
func (m *MemoryStore) GetFilteredTopCountries(_ context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
m.mu.Lock()
defer m.mu.Unlock()
counts := make(map[string]int64)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if a.Country == "" {
continue
}
if !matchesFilter(a, f) {
continue
}
counts[a.Country] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries, nil
}
// filteredTopN aggregates login attempts by the given field with filter applied and returns the top N. Must be called with m.mu held.
func (m *MemoryStore) filteredTopN(field string, limit int, f DashboardFilter) []TopEntry {
counts := make(map[string]int64)
for i := range m.LoginAttempts {
a := &m.LoginAttempts[i]
if !matchesFilter(a, f) {
continue
}
var key string
switch field {
case "username":
key = a.Username
case "password":
key = a.Password
case "ip":
key = a.IP
}
counts[key] += int64(a.Count)
}
entries := make([]TopEntry, 0, len(counts))
for k, v := range counts {
entries = append(entries, TopEntry{Value: k, Count: v})
}
sort.Slice(entries, func(i, j int) bool {
return entries[i].Count > entries[j].Count
})
if limit > 0 && len(entries) > limit {
entries = entries[:limit]
}
return entries
}
func (m *MemoryStore) Close() error {
return nil
}

View File

@@ -0,0 +1,3 @@
ALTER TABLE login_attempts ADD COLUMN country TEXT NOT NULL DEFAULT '';
ALTER TABLE sessions ADD COLUMN country TEXT NOT NULL DEFAULT '';
CREATE INDEX idx_login_attempts_country ON login_attempts(country);

View File

@@ -0,0 +1 @@
ALTER TABLE sessions ADD COLUMN exec_command TEXT;

View File

@@ -0,0 +1,3 @@
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);

View File

@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
t.Fatalf("query version: %v", err)
}
if version != 2 {
t.Errorf("version = %d, want 2", version)
if version != 5 {
t.Errorf("version = %d, want 5", version)
}
// Verify tables exist by inserting into them.
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
t.Fatalf("query version: %v", err)
}
if version != 2 {
t.Errorf("version = %d after double migrate, want 2", version)
if version != 5 {
t.Errorf("version = %d after double migrate, want 5", version)
}
}

View File

@@ -22,7 +22,7 @@ func TestRunRetentionDeletesOldRecords(t *testing.T) {
}
// Insert a recent login attempt.
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
t.Fatalf("insert recent attempt: %v", err)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
@@ -34,28 +35,29 @@ func NewSQLiteStore(dbPath string) (*SQLiteStore, error) {
return &SQLiteStore{db: db}, nil
}
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip string) error {
func (s *SQLiteStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
INSERT INTO login_attempts (username, password, ip, count, first_seen, last_seen)
VALUES (?, ?, ?, 1, ?, ?)
INSERT INTO login_attempts (username, password, ip, country, count, first_seen, last_seen)
VALUES (?, ?, ?, ?, 1, ?, ?)
ON CONFLICT(username, password, ip) DO UPDATE SET
count = count + 1,
last_seen = ?`,
username, password, ip, now, now, now)
last_seen = ?,
country = ?`,
username, password, ip, country, now, now, now, country)
if err != nil {
return fmt.Errorf("recording login attempt: %w", err)
}
return nil
}
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName string) (string, error) {
func (s *SQLiteStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
id := uuid.New().String()
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
INSERT INTO sessions (id, ip, username, shell_name, connected_at)
VALUES (?, ?, ?, ?, ?)`,
id, ip, username, shellName, now)
INSERT INTO sessions (id, ip, username, shell_name, country, connected_at)
VALUES (?, ?, ?, ?, ?, ?)`,
id, ip, username, shellName, country, now)
if err != nil {
return "", fmt.Errorf("creating session: %w", err)
}
@@ -82,6 +84,16 @@ func (s *SQLiteStore) UpdateHumanScore(ctx context.Context, sessionID string, sc
return nil
}
func (s *SQLiteStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
_, err := s.db.ExecContext(ctx, `
UPDATE sessions SET exec_command = ? WHERE id = ?`,
command, sessionID)
if err != nil {
return fmt.Errorf("setting exec command: %w", err)
}
return nil
}
func (s *SQLiteStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
now := time.Now().UTC().Format(time.RFC3339)
_, err := s.db.ExecContext(ctx, `
@@ -99,12 +111,13 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
var execCommand sql.NullString
err := s.db.QueryRowContext(ctx, `
SELECT id, ip, username, shell_name, connected_at, disconnected_at, human_score
SELECT id, ip, country, username, shell_name, connected_at, disconnected_at, human_score, exec_command
FROM sessions WHERE id = ?`, sessionID).Scan(
&sess.ID, &sess.IP, &sess.Username, &sess.ShellName,
&connectedAt, &disconnectedAt, &humanScore,
&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName,
&connectedAt, &disconnectedAt, &humanScore, &execCommand,
)
if err == sql.ErrNoRows {
return nil, nil
@@ -121,6 +134,9 @@ func (s *SQLiteStore) GetSession(ctx context.Context, sessionID string) (*Sessio
if humanScore.Valid {
sess.HumanScore = &humanScore.Float64
}
if execCommand.Valid {
sess.ExecCommand = &execCommand.String
}
return &sess, nil
}
@@ -288,10 +304,60 @@ func (s *SQLiteStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntr
}
func (s *SQLiteStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
return s.queryTopN(ctx, "ip", limit)
rows, err := s.db.QueryContext(ctx, `
SELECT ip, country, SUM(count) AS total
FROM login_attempts
GROUP BY ip
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top IPs: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top IPs: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT country, SUM(count) AS total
FROM login_attempts
WHERE country != ''
GROUP BY country
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top countries: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top countries: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ([]TopEntry, error) {
switch column {
case "username", "password", "ip":
// valid columns
default:
return nil, fmt.Errorf("invalid column: %s", column)
}
query := fmt.Sprintf(`
SELECT %s, SUM(count) AS total
FROM login_attempts
@@ -317,40 +383,132 @@ func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) (
}
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
query := `SELECT id, ip, username, shell_name, connected_at, disconnected_at, human_score FROM sessions`
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
if activeOnly {
query += ` WHERE disconnected_at IS NULL`
query += ` WHERE s.disconnected_at IS NULL`
}
query += ` ORDER BY connected_at DESC LIMIT ?`
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, limit)
return s.scanSessions(ctx, query, limit)
}
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
var clauses []string
var args []any
if activeOnly {
clauses = append(clauses, "s.disconnected_at IS NULL")
}
if f.Since != nil {
clauses = append(clauses, "s.connected_at >= ?")
args = append(args, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
clauses = append(clauses, "s.connected_at <= ?")
args = append(args, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
clauses = append(clauses, "s.ip = ?")
args = append(args, f.IP)
}
if f.Country != "" {
clauses = append(clauses, "s.country = ?")
args = append(args, f.Country)
}
if f.Username != "" {
clauses = append(clauses, "s.username = ?")
args = append(args, f.Username)
}
if f.HumanScoreAboveZero {
clauses = append(clauses, "s.human_score > 0")
}
if len(clauses) == 0 {
return "", nil
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
var validSessionSorts = map[string]string{
"connected_at": "s.connected_at DESC",
"input_bytes": "input_bytes DESC",
}
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
where, args := buildSessionWhereClause(f, activeOnly)
args = append(args, limit)
orderBy := validSessionSorts["connected_at"]
if mapped, ok := validSessionSorts[f.SortBy]; ok {
orderBy = mapped
}
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
return s.scanSessions(ctx, query, args...)
}
// scanSessions executes a session query and scans the results.
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying recent sessions: %w", err)
return nil, fmt.Errorf("querying sessions: %w", err)
}
defer func() { _ = rows.Close() }()
var sessions []Session
for rows.Next() {
var s Session
var sess Session
var connectedAt string
var disconnectedAt sql.NullString
var humanScore sql.NullFloat64
if err := rows.Scan(&s.ID, &s.IP, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore); err != nil {
var execCommand sql.NullString
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
return nil, fmt.Errorf("scanning session: %w", err)
}
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
if disconnectedAt.Valid {
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
s.DisconnectedAt = &t
sess.DisconnectedAt = &t
}
if humanScore.Valid {
s.HumanScore = &humanScore.Float64
sess.HumanScore = &humanScore.Float64
}
sessions = append(sessions, s)
if execCommand.Valid {
sess.ExecCommand = &execCommand.String
}
sessions = append(sessions, sess)
}
return sessions, rows.Err()
}
func (s *SQLiteStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT exec_command, COUNT(*) as total
FROM sessions
WHERE exec_command IS NOT NULL
GROUP BY exec_command
ORDER BY total DESC
LIMIT ?`, limit)
if err != nil {
return nil, fmt.Errorf("querying top exec commands: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning top exec commands: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
res, err := s.db.ExecContext(ctx, `
UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`,
@@ -361,6 +519,265 @@ func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt ti
return res.RowsAffected()
}
func (s *SQLiteStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
query := `SELECT DATE(last_seen) AS d, SUM(count) FROM login_attempts WHERE 1=1`
var args []any
if since != nil {
query += ` AND last_seen >= ?`
args = append(args, since.UTC().Format(time.RFC3339))
} else {
query += ` AND last_seen >= ?`
args = append(args, time.Now().UTC().AddDate(0, 0, -days).Format("2006-01-02"))
}
if until != nil {
query += ` AND last_seen <= ?`
args = append(args, until.UTC().Format(time.RFC3339))
}
query += ` GROUP BY d ORDER BY d`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying attempts over time: %w", err)
}
defer func() { _ = rows.Close() }()
var points []TimeSeriesPoint
for rows.Next() {
var dateStr string
var p TimeSeriesPoint
if err := rows.Scan(&dateStr, &p.Count); err != nil {
return nil, fmt.Errorf("scanning time series point: %w", err)
}
p.Timestamp, _ = time.Parse("2006-01-02", dateStr)
points = append(points, p)
}
return points, rows.Err()
}
func (s *SQLiteStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
query := `SELECT CAST(STRFTIME('%H', last_seen) AS INTEGER) AS h, SUM(count) FROM login_attempts WHERE 1=1`
var args []any
if since != nil {
query += ` AND last_seen >= ?`
args = append(args, since.UTC().Format(time.RFC3339))
}
if until != nil {
query += ` AND last_seen <= ?`
args = append(args, until.UTC().Format(time.RFC3339))
}
query += ` GROUP BY h ORDER BY h`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying hourly pattern: %w", err)
}
defer func() { _ = rows.Close() }()
var counts []HourlyCount
for rows.Next() {
var c HourlyCount
if err := rows.Scan(&c.Hour, &c.Count); err != nil {
return nil, fmt.Errorf("scanning hourly count: %w", err)
}
counts = append(counts, c)
}
return counts, rows.Err()
}
func (s *SQLiteStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT country, SUM(count) AS total
FROM login_attempts
WHERE country != ''
GROUP BY country
ORDER BY total DESC`)
if err != nil {
return nil, fmt.Errorf("querying country stats: %w", err)
}
defer func() { _ = rows.Close() }()
var counts []CountryCount
for rows.Next() {
var c CountryCount
if err := rows.Scan(&c.Country, &c.Count); err != nil {
return nil, fmt.Errorf("scanning country count: %w", err)
}
counts = append(counts, c)
}
return counts, rows.Err()
}
// buildAttemptWhereClause builds a dynamic WHERE clause for login_attempts filtering.
func buildAttemptWhereClause(f DashboardFilter) (string, []any) {
var clauses []string
var args []any
if f.Since != nil {
clauses = append(clauses, "last_seen >= ?")
args = append(args, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
clauses = append(clauses, "last_seen <= ?")
args = append(args, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
clauses = append(clauses, "ip = ?")
args = append(args, f.IP)
}
if f.Country != "" {
clauses = append(clauses, "country = ?")
args = append(args, f.Country)
}
if f.Username != "" {
clauses = append(clauses, "username = ?")
args = append(args, f.Username)
}
if len(clauses) == 0 {
return "", nil
}
return " WHERE " + strings.Join(clauses, " AND "), args
}
func (s *SQLiteStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
where, args := buildAttemptWhereClause(f)
stats := &DashboardStats{}
err := s.db.QueryRowContext(ctx,
`SELECT COALESCE(SUM(count), 0), COUNT(DISTINCT ip) FROM login_attempts`+where, args...).
Scan(&stats.TotalAttempts, &stats.UniqueIPs)
if err != nil {
return nil, fmt.Errorf("querying filtered attempt stats: %w", err)
}
// Sessions don't have username/password, so only filter by time, IP, country.
sessQuery := `SELECT COUNT(*) FROM sessions WHERE 1=1`
var sessArgs []any
if f.Since != nil {
sessQuery += ` AND connected_at >= ?`
sessArgs = append(sessArgs, f.Since.UTC().Format(time.RFC3339))
}
if f.Until != nil {
sessQuery += ` AND connected_at <= ?`
sessArgs = append(sessArgs, f.Until.UTC().Format(time.RFC3339))
}
if f.IP != "" {
sessQuery += ` AND ip = ?`
sessArgs = append(sessArgs, f.IP)
}
if f.Country != "" {
sessQuery += ` AND country = ?`
sessArgs = append(sessArgs, f.Country)
}
err = s.db.QueryRowContext(ctx, sessQuery, sessArgs...).Scan(&stats.TotalSessions)
if err != nil {
return nil, fmt.Errorf("querying filtered total sessions: %w", err)
}
err = s.db.QueryRowContext(ctx, sessQuery+` AND disconnected_at IS NULL`, sessArgs...).Scan(&stats.ActiveSessions)
if err != nil {
return nil, fmt.Errorf("querying filtered active sessions: %w", err)
}
return stats, nil
}
func (s *SQLiteStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return s.queryFilteredTopN(ctx, "username", limit, f)
}
func (s *SQLiteStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
return s.queryFilteredTopN(ctx, "password", limit, f)
}
func (s *SQLiteStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
where, args := buildAttemptWhereClause(f)
args = append(args, limit)
//nolint:gosec // where clause built from trusted constants, not user input
query := `SELECT ip, country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY ip ORDER BY total DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top IPs: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Country, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top IPs: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
where, args := buildAttemptWhereClause(f)
countryClause := "country != ''"
if where == "" {
where = " WHERE " + countryClause
} else {
where += " AND " + countryClause
}
args = append(args, limit)
//nolint:gosec // where clause built from trusted constants, not user input
query := `SELECT country, SUM(count) AS total FROM login_attempts` + where + ` GROUP BY country ORDER BY total DESC LIMIT ?`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top countries: %w", err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top countries: %w", err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) queryFilteredTopN(ctx context.Context, column string, limit int, f DashboardFilter) ([]TopEntry, error) {
switch column {
case "username", "password":
// valid columns
default:
return nil, fmt.Errorf("invalid column: %s", column)
}
where, args := buildAttemptWhereClause(f)
args = append(args, limit)
query := fmt.Sprintf(`
SELECT %s, SUM(count) AS total
FROM login_attempts%s
GROUP BY %s
ORDER BY total DESC
LIMIT ?`, column, where, column)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("querying filtered top %s: %w", column, err)
}
defer func() { _ = rows.Close() }()
var entries []TopEntry
for rows.Next() {
var e TopEntry
if err := rows.Scan(&e.Value, &e.Count); err != nil {
return nil, fmt.Errorf("scanning filtered top %s: %w", column, err)
}
entries = append(entries, e)
}
return entries, rows.Err()
}
func (s *SQLiteStore) Close() error {
return s.db.Close()
}

View File

@@ -23,17 +23,17 @@ func TestRecordLoginAttempt(t *testing.T) {
ctx := context.Background()
// First attempt creates a new record.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("first attempt: %v", err)
}
// Second attempt with same credentials increments count.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("second attempt: %v", err)
}
// Different IP is a separate record.
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
t.Fatalf("different IP: %v", err)
}
@@ -62,7 +62,7 @@ func TestCreateAndEndSession(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
@@ -100,7 +100,7 @@ func TestUpdateHumanScore(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
@@ -123,7 +123,7 @@ func TestAppendSessionLog(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
@@ -159,7 +159,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
}
// Insert a recent login attempt.
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "new", "new", "2.2.2.2", ""); err != nil {
t.Fatalf("insert recent attempt: %v", err)
}
@@ -178,7 +178,7 @@ func TestDeleteRecordsBefore(t *testing.T) {
}
// Insert a recent session.
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", ""); err != nil {
if _, err := store.CreateSession(ctx, "2.2.2.2", "new", "", ""); err != nil {
t.Fatalf("insert recent session: %v", err)
}
@@ -204,6 +204,79 @@ func TestDeleteRecordsBefore(t *testing.T) {
}
}
func TestGetTopExecCommands(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
// Create sessions with exec commands.
for range 3 {
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
t.Fatalf("setting exec command: %v", err)
}
}
for range 2 {
id, err := store.CreateSession(ctx, "10.0.0.2", "admin", "", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.SetExecCommand(ctx, id, "cat /etc/passwd"); err != nil {
t.Fatalf("setting exec command: %v", err)
}
}
// Session without exec command — should not appear.
if _, err := store.CreateSession(ctx, "10.0.0.3", "test", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
entries, err := store.GetTopExecCommands(ctx, 10)
if err != nil {
t.Fatalf("GetTopExecCommands: %v", err)
}
if len(entries) != 2 {
t.Fatalf("len = %d, want 2", len(entries))
}
if entries[0].Value != "uname -a" || entries[0].Count != 3 {
t.Errorf("entries[0] = %+v, want uname -a:3", entries[0])
}
if entries[1].Value != "cat /etc/passwd" || entries[1].Count != 2 {
t.Errorf("entries[1] = %+v, want cat /etc/passwd:2", entries[1])
}
}
func TestGetRecentSessionsEventCount(t *testing.T) {
store := newTestStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
// Add some events.
events := []SessionEvent{
{SessionID: id, Timestamp: time.Now(), Direction: 0, Data: []byte("ls\n")},
{SessionID: id, Timestamp: time.Now(), Direction: 1, Data: []byte("file1\n")},
}
if err := store.AppendSessionEvents(ctx, events); err != nil {
t.Fatalf("appending events: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].EventCount != 2 {
t.Errorf("EventCount = %d, want 2", sessions[0].EventCount)
}
}
func TestNewSQLiteStoreCreatesFile(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "test.db")
store, err := NewSQLiteStore(dbPath)
@@ -214,7 +287,7 @@ func TestNewSQLiteStoreCreatesFile(t *testing.T) {
// Verify we can use the store.
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "test", "test", "127.0.0.1", ""); err != nil {
t.Fatalf("recording attempt: %v", err)
}
}

View File

@@ -11,6 +11,7 @@ type LoginAttempt struct {
Username string
Password string
IP string
Country string
Count int
FirstSeen time.Time
LastSeen time.Time
@@ -20,11 +21,15 @@ type LoginAttempt struct {
type Session struct {
ID string
IP string
Country string
Username string
ShellName string
ConnectedAt time.Time
DisconnectedAt *time.Time
HumanScore *float64
ExecCommand *string
EventCount int
InputBytes int64
}
// SessionLog represents a single log entry for a session.
@@ -52,20 +57,50 @@ type DashboardStats struct {
ActiveSessions int64
}
// TimeSeriesPoint represents a single data point in a time series.
type TimeSeriesPoint struct {
Timestamp time.Time
Count int64
}
// HourlyCount represents the total attempts for a given hour of day.
type HourlyCount struct {
Hour int // 0-23
Count int64
}
// CountryCount represents the total attempts from a given country.
type CountryCount struct {
Country string
Count int64
}
// DashboardFilter contains optional filters for dashboard queries.
type DashboardFilter struct {
Since *time.Time
Until *time.Time
IP string
Country string
Username string
HumanScoreAboveZero bool
SortBy string
}
// TopEntry represents a value and its count for top-N queries.
type TopEntry struct {
Value string
Count int64
Value string
Country string // populated by GetTopIPs
Count int64
}
// Store is the interface for persistent storage of honeypot data.
type Store interface {
// RecordLoginAttempt upserts a login attempt, incrementing the count
// for existing (username, password, ip) combinations.
RecordLoginAttempt(ctx context.Context, username, password, ip string) error
RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error
// CreateSession creates a new session record and returns its UUID.
CreateSession(ctx context.Context, ip, username, shellName string) (string, error)
CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error)
// EndSession sets the disconnected_at timestamp for a session.
EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error
@@ -73,6 +108,9 @@ type Store interface {
// UpdateHumanScore sets the human detection score for a session.
UpdateHumanScore(ctx context.Context, sessionID string, score float64) error
// SetExecCommand sets the exec command for a session.
SetExecCommand(ctx context.Context, sessionID string, command string) error
// AppendSessionLog adds a log entry to a session.
AppendSessionLog(ctx context.Context, sessionID, input, output string) error
@@ -92,10 +130,20 @@ type Store interface {
// GetTopIPs returns the top N IPs by total attempt count.
GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error)
// GetTopCountries returns the top N countries by total attempt count.
GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error)
// GetTopExecCommands returns the top N exec commands by session count.
GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error)
// GetRecentSessions returns the most recent sessions ordered by connected_at DESC.
// If activeOnly is true, only sessions with no disconnected_at are returned.
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
// GetFilteredSessions returns sessions matching the given filter, ordered
// by the filter's SortBy field (default: connected_at DESC).
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
// GetSession returns a single session by ID.
GetSession(ctx context.Context, sessionID string) (*Session, error)
@@ -113,6 +161,30 @@ type Store interface {
// sessions left over from a previous unclean shutdown.
CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error)
// GetAttemptsOverTime returns daily attempt counts for the last N days.
GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error)
// GetHourlyPattern returns total attempts grouped by hour of day (0-23).
GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error)
// GetCountryStats returns total attempts per country, ordered by count DESC.
GetCountryStats(ctx context.Context) ([]CountryCount, error)
// GetFilteredDashboardStats returns aggregate counts with optional filters applied.
GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error)
// GetFilteredTopUsernames returns top usernames with optional filters applied.
GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopPasswords returns top passwords with optional filters applied.
GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopIPs returns top IPs with optional filters applied.
GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// GetFilteredTopCountries returns top countries with optional filters applied.
GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error)
// Close releases any resources held by the store.
Close() error
}

View File

@@ -38,23 +38,23 @@ func seedData(t *testing.T, store Store) {
// Login attempts: root/toor from two IPs, admin/admin from one IP.
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.2", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 2 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
// Sessions: one active, one ended.
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
@@ -62,7 +62,7 @@ func seedData(t *testing.T, store Store) {
t.Fatalf("ending session: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash"); err != nil {
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
}
@@ -210,7 +210,7 @@ func TestGetSession(t *testing.T) {
t.Run("found", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -233,7 +233,7 @@ func TestGetSessionLogs(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -277,7 +277,7 @@ func TestSessionEvents(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -336,9 +336,9 @@ func TestCloseActiveSessions(t *testing.T) {
ctx := context.Background()
// Create 3 sessions: end one, leave two active.
id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
store.CreateSession(ctx, "10.0.0.2", "admin", "bash")
store.CreateSession(ctx, "10.0.0.3", "test", "bash")
id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
store.CreateSession(ctx, "10.0.0.3", "test", "bash", "")
store.EndSession(ctx, id1, time.Now())
n, err := store.CloseActiveSessions(ctx, time.Now())
@@ -361,6 +361,289 @@ func TestCloseActiveSessions(t *testing.T) {
})
}
func TestSetExecCommand(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("set and retrieve", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
// Initially nil.
s, err := store.GetSession(ctx, id)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if s.ExecCommand != nil {
t.Errorf("expected nil ExecCommand, got %q", *s.ExecCommand)
}
// Set exec command.
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
t.Fatalf("SetExecCommand: %v", err)
}
s, err = store.GetSession(ctx, id)
if err != nil {
t.Fatalf("GetSession: %v", err)
}
if s.ExecCommand == nil {
t.Fatal("expected non-nil ExecCommand")
}
if *s.ExecCommand != "uname -a" {
t.Errorf("ExecCommand = %q, want %q", *s.ExecCommand, "uname -a")
}
})
t.Run("appears in recent sessions", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.SetExecCommand(ctx, id, "id"); err != nil {
t.Fatalf("SetExecCommand: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ExecCommand == nil || *sessions[0].ExecCommand != "id" {
t.Errorf("ExecCommand = %v, want \"id\"", sessions[0].ExecCommand)
}
})
})
}
func seedChartData(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
// Record attempts with country data from different IPs.
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
for range 2 {
if err := store.RecordLoginAttempt(ctx, "root", "123456", "10.0.0.3", "CN"); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
}
func TestGetAttemptsOverTime(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
if err != nil {
t.Fatalf("GetAttemptsOverTime: %v", err)
}
if len(points) != 0 {
t.Errorf("expected empty, got %v", points)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
points, err := store.GetAttemptsOverTime(context.Background(), 30, nil, nil)
if err != nil {
t.Fatalf("GetAttemptsOverTime: %v", err)
}
// All data was inserted today, so should be one point.
if len(points) != 1 {
t.Fatalf("len = %d, want 1", len(points))
}
// 5 + 3 + 2 = 10 total.
if points[0].Count != 10 {
t.Errorf("count = %d, want 10", points[0].Count)
}
})
})
}
func TestGetHourlyPattern(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
if err != nil {
t.Fatalf("GetHourlyPattern: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty, got %v", counts)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
counts, err := store.GetHourlyPattern(context.Background(), nil, nil)
if err != nil {
t.Fatalf("GetHourlyPattern: %v", err)
}
// All data was inserted at the same hour.
if len(counts) != 1 {
t.Fatalf("len = %d, want 1", len(counts))
}
if counts[0].Count != 10 {
t.Errorf("count = %d, want 10", counts[0].Count)
}
})
})
}
func TestGetCountryStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
store := newStore(t)
counts, err := store.GetCountryStats(context.Background())
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 0 {
t.Errorf("expected empty, got %v", counts)
}
})
t.Run("with data", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
counts, err := store.GetCountryStats(context.Background())
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 2 {
t.Fatalf("len = %d, want 2", len(counts))
}
// CN: 5 + 2 = 7, RU: 3 - ordered by count DESC.
if counts[0].Country != "CN" || counts[0].Count != 7 {
t.Errorf("counts[0] = %+v, want CN/7", counts[0])
}
if counts[1].Country != "RU" || counts[1].Count != 3 {
t.Errorf("counts[1] = %+v, want RU/3", counts[1])
}
})
t.Run("excludes empty country", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "test", "test", "10.0.0.2", "US"); err != nil {
t.Fatalf("seeding: %v", err)
}
counts, err := store.GetCountryStats(ctx)
if err != nil {
t.Fatalf("GetCountryStats: %v", err)
}
if len(counts) != 1 {
t.Fatalf("len = %d, want 1", len(counts))
}
if counts[0].Country != "US" {
t.Errorf("country = %q, want US", counts[0].Country)
}
})
})
}
func TestGetFilteredDashboardStats(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("no filter", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 10 {
t.Errorf("TotalAttempts = %d, want 10", stats.TotalAttempts)
}
})
t.Run("filter by country", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Country: "CN"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
// CN: 5 + 2 = 7
if stats.TotalAttempts != 7 {
t.Errorf("TotalAttempts = %d, want 7", stats.TotalAttempts)
}
})
t.Run("filter by IP", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{IP: "10.0.0.1"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 5 {
t.Errorf("TotalAttempts = %d, want 5", stats.TotalAttempts)
}
})
t.Run("filter by username", func(t *testing.T) {
store := newStore(t)
seedChartData(t, store)
stats, err := store.GetFilteredDashboardStats(context.Background(), DashboardFilter{Username: "admin"})
if err != nil {
t.Fatalf("GetFilteredDashboardStats: %v", err)
}
if stats.TotalAttempts != 3 {
t.Errorf("TotalAttempts = %d, want 3", stats.TotalAttempts)
}
})
})
}
func TestGetFilteredTopUsernames(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
store := newStore(t)
seedChartData(t, store)
// Filter by country CN should only show root.
entries, err := store.GetFilteredTopUsernames(context.Background(), 10, DashboardFilter{Country: "CN"})
if err != nil {
t.Fatalf("GetFilteredTopUsernames: %v", err)
}
if len(entries) != 1 {
t.Fatalf("len = %d, want 1", len(entries))
}
if entries[0].Value != "root" || entries[0].Count != 7 {
t.Errorf("entries[0] = %+v, want root/7", entries[0])
}
})
}
func TestGetRecentSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("empty", func(t *testing.T) {
@@ -417,3 +700,192 @@ func TestGetRecentSessions(t *testing.T) {
})
})
}
func TestInputBytes(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("counts only input direction", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
events := []SessionEvent{
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
}
if err := store.AppendSessionEvents(ctx, events); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
if sessions[0].InputBytes != 7 {
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
}
})
t.Run("zero when no events", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sessions, err := store.GetRecentSessions(ctx, 10, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].InputBytes != 0 {
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
}
})
})
}
func TestGetFilteredSessions(t *testing.T) {
testStores(t, func(t *testing.T, newStore storeFactory) {
t.Run("filter by human score", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Create two sessions, one with human score > 0.
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
}
})
t.Run("sort by input bytes", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
// Session with more input (created first).
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
now := time.Now().UTC()
if err := store.AppendSessionEvents(ctx, []SessionEvent{
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
}); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
// Session with less input (created after id1, so would be first by connected_at).
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
time.Sleep(1100 * time.Millisecond)
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.AppendSessionEvents(ctx, []SessionEvent{
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
}); err != nil {
t.Fatalf("AppendSessionEvents: %v", err)
}
// Default sort (connected_at DESC) should show id2 first.
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
if sessions[0].ID != id2 {
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
}
// Sort by input_bytes should show id1 first (more input).
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 2 {
t.Fatalf("len = %d, want 2", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
}
})
t.Run("combined filters", func(t *testing.T) {
store := newStore(t)
ctx := context.Background()
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
// Different country, also has score.
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
t.Fatalf("UpdateHumanScore: %v", err)
}
// Same country CN but no score.
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
// Filter: CN + human score > 0 -> only id1.
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
Country: "CN",
HumanScoreAboveZero: true,
})
if err != nil {
t.Fatalf("GetFilteredSessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("len = %d, want 1", len(sessions))
}
if sessions[0].ID != id1 {
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
}
})
})
}

View File

@@ -1,24 +1,37 @@
package web
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"strconv"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
// dbContext returns a context detached from the HTTP request lifecycle with a
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
// queries when the browser aborts the previous XHR.
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
}
type dashboardData struct {
Stats *storage.DashboardStats
TopUsernames []storage.TopEntry
TopPasswords []storage.TopEntry
TopIPs []storage.TopEntry
ActiveSessions []storage.Session
RecentSessions []storage.Session
Stats *storage.DashboardStats
TopUsernames []storage.TopEntry
TopPasswords []storage.TopEntry
TopIPs []storage.TopEntry
TopCountries []storage.TopEntry
TopExecCommands []storage.TopEntry
ActiveSessions []storage.Session
RecentSessions []storage.Session
}
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
stats, err := s.store.GetDashboardStats(ctx)
if err != nil {
@@ -48,6 +61,20 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
return
}
topCountries, err := s.store.GetTopCountries(ctx, 10)
if err != nil {
s.logger.Error("failed to get top countries", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topExecCommands, err := s.store.GetTopExecCommands(ctx, 10)
if err != nil {
s.logger.Error("failed to get top exec commands", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
activeSessions, err := s.store.GetRecentSessions(ctx, 50, true)
if err != nil {
s.logger.Error("failed to get active sessions", "err", err)
@@ -63,12 +90,14 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
}
data := dashboardData{
Stats: stats,
TopUsernames: topUsernames,
TopPasswords: topPasswords,
TopIPs: topIPs,
ActiveSessions: activeSessions,
RecentSessions: recentSessions,
Stats: stats,
TopUsernames: topUsernames,
TopPasswords: topPasswords,
TopIPs: topIPs,
TopCountries: topCountries,
TopExecCommands: topExecCommands,
ActiveSessions: activeSessions,
RecentSessions: recentSessions,
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
@@ -78,7 +107,10 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
stats, err := s.store.GetDashboardStats(r.Context())
ctx, cancel := dbContext(r)
defer cancel()
stats, err := s.store.GetDashboardStats(ctx)
if err != nil {
s.logger.Error("failed to get dashboard stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -92,7 +124,10 @@ func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
sessions, err := s.store.GetRecentSessions(r.Context(), 50, true)
ctx, cancel := dbContext(r)
defer cancel()
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
if err != nil {
s.logger.Error("failed to get active sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
@@ -105,6 +140,24 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req
}
}
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
f := parseDashboardFilter(r)
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
if err != nil {
s.logger.Error("failed to get filtered sessions", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
s.logger.Error("failed to render recent sessions fragment", "err", err)
}
}
type sessionDetailData struct {
Session *storage.Session
Logs []storage.SessionLog
@@ -112,7 +165,8 @@ type sessionDetailData struct {
}
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
sessionID := r.PathValue("id")
session, err := s.store.GetSession(ctx, sessionID)
@@ -162,8 +216,201 @@ type apiEventsResponse struct {
Events []apiEvent `json:"events"`
}
// parseDateParam parses a "YYYY-MM-DD" query parameter into a *time.Time.
func parseDateParam(r *http.Request, name string) *time.Time {
v := r.URL.Query().Get(name)
if v == "" {
return nil
}
t, err := time.Parse("2006-01-02", v)
if err != nil {
return nil
}
// For "until" dates, set to end of day.
if name == "until" {
t = t.Add(24*time.Hour - time.Second)
}
return &t
}
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
return storage.DashboardFilter{
Since: parseDateParam(r, "since"),
Until: parseDateParam(r, "until"),
IP: r.URL.Query().Get("ip"),
Country: r.URL.Query().Get("country"),
Username: r.URL.Query().Get("username"),
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
SortBy: r.URL.Query().Get("sort"),
}
}
type apiTimeSeriesPoint struct {
Date string `json:"date"`
Count int64 `json:"count"`
}
type apiAttemptsOverTimeResponse struct {
Points []apiTimeSeriesPoint `json:"points"`
}
func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Request) {
days := 30
if v := r.URL.Query().Get("days"); v != "" {
if d, err := strconv.Atoi(v); err == nil && d > 0 && d <= 365 {
days = d
}
}
since := parseDateParam(r, "since")
until := parseDateParam(r, "until")
ctx, cancel := dbContext(r)
defer cancel()
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
if err != nil {
s.logger.Error("failed to get attempts over time", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiAttemptsOverTimeResponse{Points: make([]apiTimeSeriesPoint, len(points))}
for i, p := range points {
resp.Points[i] = apiTimeSeriesPoint{
Date: p.Timestamp.Format("2006-01-02"),
Count: p.Count,
}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode attempts over time", "err", err)
}
}
type apiHourlyCount struct {
Hour int `json:"hour"`
Count int64 `json:"count"`
}
type apiHourlyPatternResponse struct {
Hours []apiHourlyCount `json:"hours"`
}
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
since := parseDateParam(r, "since")
until := parseDateParam(r, "until")
counts, err := s.store.GetHourlyPattern(ctx, since, until)
if err != nil {
s.logger.Error("failed to get hourly pattern", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiHourlyPatternResponse{Hours: make([]apiHourlyCount, len(counts))}
for i, c := range counts {
resp.Hours[i] = apiHourlyCount{Hour: c.Hour, Count: c.Count}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode hourly pattern", "err", err)
}
}
type apiCountryCount struct {
Country string `json:"country"`
Count int64 `json:"count"`
}
type apiCountryStatsResponse struct {
Countries []apiCountryCount `json:"countries"`
}
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
counts, err := s.store.GetCountryStats(ctx)
if err != nil {
s.logger.Error("failed to get country stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
resp := apiCountryStatsResponse{Countries: make([]apiCountryCount, len(counts))}
for i, c := range counts {
resp.Countries[i] = apiCountryCount{Country: c.Country, Count: c.Count}
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
s.logger.Error("failed to encode country stats", "err", err)
}
}
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
ctx, cancel := dbContext(r)
defer cancel()
f := parseDashboardFilter(r)
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
if err != nil {
s.logger.Error("failed to get filtered stats", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topUsernames, err := s.store.GetFilteredTopUsernames(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top usernames", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topPasswords, err := s.store.GetFilteredTopPasswords(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top passwords", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topIPs, err := s.store.GetFilteredTopIPs(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top IPs", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
topCountries, err := s.store.GetFilteredTopCountries(ctx, 10, f)
if err != nil {
s.logger.Error("failed to get filtered top countries", "err", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
data := dashboardData{
Stats: stats,
TopUsernames: topUsernames,
TopPasswords: topPasswords,
TopIPs: topIPs,
TopCountries: topCountries,
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := s.tmpl.dashboard.ExecuteTemplate(w, "dashboard_content", data); err != nil {
s.logger.Error("failed to render dashboard content fragment", "err", err)
}
}
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx, cancel := dbContext(r)
defer cancel()
sessionID := r.PathValue("id")
events, err := s.store.GetSessionEvents(ctx, sessionID)

14
internal/web/static/chart.min.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,275 @@
(function() {
'use strict';
// Chart.js theme for Pico dark mode
Chart.defaults.color = '#b0b0b8';
Chart.defaults.borderColor = '#3a3a4a';
var attemptsChart = null;
var hourlyChart = null;
function getFilterParams() {
var form = document.getElementById('filter-form');
if (!form) return '';
var params = new URLSearchParams();
var since = form.elements['since'].value;
var until = form.elements['until'].value;
if (since) params.set('since', since);
if (until) params.set('until', until);
var humanScore = form.elements['human_score'];
if (humanScore && humanScore.checked) params.set('human_score', '1');
var sortBy = form.elements['sort'];
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
return params.toString();
}
function initAttemptsChart() {
var canvas = document.getElementById('chart-attempts');
if (!canvas) return;
var ctx = canvas.getContext('2d');
var qs = getFilterParams();
var url = '/api/charts/attempts-over-time' + (qs ? '?' + qs : '');
fetch(url)
.then(function(r) { return r.json(); })
.then(function(data) {
var labels = data.points.map(function(p) { return p.date; });
var values = data.points.map(function(p) { return p.count; });
if (attemptsChart) {
attemptsChart.data.labels = labels;
attemptsChart.data.datasets[0].data = values;
attemptsChart.update();
return;
}
attemptsChart = new Chart(ctx, {
type: 'line',
data: {
labels: labels,
datasets: [{
label: 'Attempts',
data: values,
borderColor: '#6366f1',
backgroundColor: 'rgba(99, 102, 241, 0.1)',
fill: true,
tension: 0.3,
pointRadius: 2
}]
},
options: {
responsive: true,
maintainAspectRatio: true,
plugins: { legend: { display: false } },
scales: {
x: { grid: { display: false } },
y: { beginAtZero: true }
}
}
});
});
}
function initHourlyChart() {
var canvas = document.getElementById('chart-hourly');
if (!canvas) return;
var ctx = canvas.getContext('2d');
var qs = getFilterParams();
var url = '/api/charts/hourly-pattern' + (qs ? '?' + qs : '');
fetch(url)
.then(function(r) { return r.json(); })
.then(function(data) {
// Fill all 24 hours, defaulting to 0
var hourMap = {};
data.hours.forEach(function(h) { hourMap[h.hour] = h.count; });
var labels = [];
var values = [];
for (var i = 0; i < 24; i++) {
labels.push(i + ':00');
values.push(hourMap[i] || 0);
}
if (hourlyChart) {
hourlyChart.data.labels = labels;
hourlyChart.data.datasets[0].data = values;
hourlyChart.update();
return;
}
hourlyChart = new Chart(ctx, {
type: 'bar',
data: {
labels: labels,
datasets: [{
label: 'Attempts',
data: values,
backgroundColor: 'rgba(99, 102, 241, 0.6)',
borderColor: '#6366f1',
borderWidth: 1
}]
},
options: {
responsive: true,
maintainAspectRatio: true,
plugins: { legend: { display: false } },
scales: {
x: { grid: { display: false } },
y: { beginAtZero: true }
}
}
});
});
}
function initWorldMap() {
var container = document.getElementById('world-map');
if (!container) return;
fetch('/static/world.svg')
.then(function(r) { return r.text(); })
.then(function(svgText) {
container.innerHTML = svgText;
fetch('/api/charts/country-stats')
.then(function(r) { return r.json(); })
.then(function(data) {
colorMap(container, data.countries);
});
});
}
function colorMap(container, countries) {
if (!countries || countries.length === 0) return;
var maxCount = countries[0].count; // already sorted DESC
var logMax = Math.log(maxCount + 1);
// Build lookup
var lookup = {};
countries.forEach(function(c) {
lookup[c.country.toLowerCase()] = c.count;
});
// Create tooltip element
var tooltip = document.createElement('div');
tooltip.id = 'map-tooltip';
tooltip.style.cssText = 'position:fixed;display:none;background:#1a1a2e;color:#e0e0e8;padding:4px 8px;border-radius:4px;font-size:13px;pointer-events:none;z-index:1000;border:1px solid #3a3a4a;';
document.body.appendChild(tooltip);
var svg = container.querySelector('svg');
if (!svg) return;
// Remove SVG title to prevent browser native tooltip
var svgTitle = svg.querySelector('title');
if (svgTitle) svgTitle.remove();
// Select both <path id="xx"> and <g id="xx"> country elements
var elements = svg.querySelectorAll('path[id], g[id]');
elements.forEach(function(el) {
var id = el.id.toLowerCase();
if (id.charAt(0) === '_') return; // skip non-country paths
var count = lookup[id];
if (count) {
var intensity = Math.log(count + 1) / logMax;
var r = Math.round(30 + intensity * 69); // 30 -> 99
var g = Math.round(30 + intensity * 72); // 30 -> 102
var b = Math.round(62 + intensity * 179); // 62 -> 241
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
// For <g> elements, color child paths; for <path>, color directly
if (el.tagName.toLowerCase() === 'g') {
el.querySelectorAll('path').forEach(function(p) {
p.style.fill = color;
});
} else {
el.style.fill = color;
}
}
el.addEventListener('mouseenter', function(e) {
var cc = id.toUpperCase();
var n = lookup[id] || 0;
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
tooltip.style.display = 'block';
});
el.addEventListener('mousemove', function(e) {
tooltip.style.left = (e.clientX + 12) + 'px';
tooltip.style.top = (e.clientY - 10) + 'px';
});
el.addEventListener('mouseleave', function() {
tooltip.style.display = 'none';
});
el.addEventListener('click', function() {
var input = document.querySelector('#filter-form input[name="country"]');
if (input) {
input.value = id.toUpperCase();
applyFilters();
}
});
el.style.cursor = 'pointer';
});
}
function applyFilters() {
// Re-fetch charts with filter params
initAttemptsChart();
initHourlyChart();
// Re-fetch dashboard content via htmx
var form = document.getElementById('filter-form');
if (!form) return;
var params = new URLSearchParams();
['since', 'until', 'ip', 'country', 'username'].forEach(function(name) {
var val = form.elements[name].value;
if (val) params.set(name, val);
});
var humanScore = form.elements['human_score'];
if (humanScore && humanScore.checked) params.set('human_score', '1');
var sortBy = form.elements['sort'];
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
var target = document.getElementById('dashboard-content');
if (target) {
var url = '/fragments/dashboard-content?' + params.toString();
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
}
// Server-side filter for recent sessions table
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
}
window.clearFilters = function() {
var form = document.getElementById('filter-form');
if (form) {
form.reset();
applyFilters();
}
};
window.applyFilters = applyFilters;
// Initialize on DOM ready
document.addEventListener('DOMContentLoaded', function() {
initAttemptsChart();
initHourlyChart();
initWorldMap();
var form = document.getElementById('filter-form');
if (form) {
form.addEventListener('submit', function(e) {
e.preventDefault();
applyFilters();
});
}
});
})();

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 55 KiB

View File

@@ -44,6 +44,32 @@ func templateFuncMap() template.FuncMap {
}
return fmt.Sprintf("%.0f%%", *f*100)
},
"derefString": func(s *string) string {
if s == nil {
return ""
}
return *s
},
"truncateCommand": func(s string) string {
if len(s) > 50 {
return s[:50] + "..."
}
return s
},
"formatBytes": func(b int64) string {
const (
kb = 1024
mb = 1024 * kb
)
switch {
case b >= mb:
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
case b >= kb:
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
default:
return fmt.Sprintf("%d B", b)
}
},
}
}
@@ -55,6 +81,7 @@ func loadTemplates() (*templateSet, error) {
"templates/dashboard.html",
"templates/fragments/stats.html",
"templates/fragments/active_sessions.html",
"templates/fragments/recent_sessions.html",
)
if err != nil {
return nil, fmt.Errorf("parsing dashboard templates: %w", err)

View File

@@ -3,6 +3,86 @@
{{template "stats" .Stats}}
</section>
<details>
<summary>Filters</summary>
<form id="filter-form">
<div class="grid">
<label>Since <input type="date" name="since"></label>
<label>Until <input type="date" name="until"></label>
<label>IP <input type="text" name="ip" placeholder="10.0.0.1"></label>
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
<label>Username <input type="text" name="username" placeholder="root"></label>
</div>
<div class="grid">
<label><input type="checkbox" name="human_score" value="1"> Human score &gt; 0</label>
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
</div>
<button type="submit">Apply</button>
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
</form>
</details>
<section>
<h3>Attack Trends</h3>
<div class="grid">
<article>
<header>Attempts Over Time</header>
<canvas id="chart-attempts"></canvas>
</article>
<article>
<header>Hourly Pattern (UTC)</header>
<canvas id="chart-hourly"></canvas>
</article>
</div>
</section>
<section>
<h3>Attack Origins</h3>
<article>
<div id="world-map"></div>
</article>
</section>
<div id="dashboard-content">
{{template "dashboard_content" .}}
</div>
<section>
<h3>Active Sessions</h3>
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
{{template "active_sessions" .ActiveSessions}}
</div>
</section>
<section>
<h3>Recent Sessions</h3>
<table id="recent-sessions-table">
<thead>
<tr>
<th>ID</th>
<th>IP</th>
<th>Country</th>
<th>Username</th>
<th>Type</th>
<th>Score</th>
<th>Input</th>
<th>Connected</th>
<th>Disconnected</th>
</tr>
</thead>
<tbody>
{{template "recent_sessions" .RecentSessions}}
</tbody>
</table>
</section>
{{end}}
{{define "scripts"}}
<script src="/static/chart.min.js"></script>
<script src="/static/dashboard.js"></script>
{{end}}
{{define "dashboard_content"}}
<section>
<h3>Top Credentials & IPs</h3>
<div class="top-grid">
@@ -40,10 +120,25 @@
<header>Top IPs</header>
<table>
<thead>
<tr><th>IP</th><th>Attempts</th></tr>
<tr><th>IP</th><th>Country</th><th>Attempts</th></tr>
</thead>
<tbody>
{{range .TopIPs}}
<tr><td>{{.Value}}</td><td>{{.Country}}</td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="3">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
<article>
<header>Top Countries</header>
<table>
<thead>
<tr><th>Country</th><th>Attempts</th></tr>
</thead>
<tbody>
{{range .TopCountries}}
<tr><td>{{.Value}}</td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="2">No data</td></tr>
@@ -51,45 +146,21 @@
</tbody>
</table>
</article>
<article>
<header>Top Exec Commands</header>
<table>
<thead>
<tr><th>Command</th><th>Count</th></tr>
</thead>
<tbody>
{{range .TopExecCommands}}
<tr><td><code>{{truncateCommand .Value}}</code></td><td>{{.Count}}</td></tr>
{{else}}
<tr><td colspan="2">No data</td></tr>
{{end}}
</tbody>
</table>
</article>
</div>
</section>
<section>
<h3>Active Sessions</h3>
<div id="active-sessions" hx-get="/fragments/active-sessions" hx-trigger="every 10s" hx-swap="innerHTML">
{{template "active_sessions" .ActiveSessions}}
</div>
</section>
<section>
<h3>Recent Sessions</h3>
<table>
<thead>
<tr>
<th>ID</th>
<th>IP</th>
<th>Username</th>
<th>Shell</th>
<th>Score</th>
<th>Connected</th>
<th>Disconnected</th>
</tr>
</thead>
<tbody>
{{range .RecentSessions}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a></td>
<td>{{.IP}}</td>
<td>{{.Username}}</td>
<td>{{.ShellName}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatTime .ConnectedAt}}</td>
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
</tr>
{{else}}
<tr><td colspan="7">No sessions</td></tr>
{{end}}
</tbody>
</table>
</section>
{{end}}

View File

@@ -4,24 +4,28 @@
<tr>
<th>ID</th>
<th>IP</th>
<th>Country</th>
<th>Username</th>
<th>Shell</th>
<th>Type</th>
<th>Score</th>
<th>Input</th>
<th>Connected</th>
</tr>
</thead>
<tbody>
{{range .}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a></td>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
<td>{{.IP}}</td>
<td>{{.Country}}</td>
<td>{{.Username}}</td>
<td>{{.ShellName}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatBytes .InputBytes}}</td>
<td>{{formatTime .ConnectedAt}}</td>
</tr>
{{else}}
<tr><td colspan="6">No active sessions</td></tr>
<tr><td colspan="8">No active sessions</td></tr>
{{end}}
</tbody>
</table>

View File

@@ -0,0 +1,17 @@
{{define "recent_sessions"}}
{{range .}}
<tr>
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
<td>{{.IP}}</td>
<td>{{.Country}}</td>
<td>{{.Username}}</td>
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
<td>{{formatBytes .InputBytes}}</td>
<td>{{formatTime .ConnectedAt}}</td>
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
</tr>
{{else}}
<tr><td colspan="9">No sessions</td></tr>
{{end}}
{{end}}

View File

@@ -29,9 +29,16 @@
}
.top-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(280px, 1fr));
grid-template-columns: repeat(auto-fit, minmax(380px, 1fr));
gap: 1rem;
}
.top-grid article {
overflow: hidden;
min-width: 0;
}
#world-map svg { width: 100%; height: auto; }
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
nav h1 {
margin: 0;
}
@@ -52,5 +59,6 @@
<main class="container">
{{block "content" .}}{{end}}
</main>
{{block "scripts" .}}{{end}}
</body>
</html>

View File

@@ -7,8 +7,10 @@
<table>
<tbody>
<tr><td><strong>IP</strong></td><td>{{.Session.IP}}</td></tr>
<tr><td><strong>Country</strong></td><td>{{.Session.Country}}</td></tr>
<tr><td><strong>Username</strong></td><td>{{.Session.Username}}</td></tr>
<tr><td><strong>Shell</strong></td><td>{{.Session.ShellName}}</td></tr>
{{if .Session.ExecCommand}}<tr><td><strong>Exec Command</strong></td><td><code>{{derefString .Session.ExecCommand}}</code></td></tr>{{end}}
<tr><td><strong>Score</strong></td><td>{{formatScore .Session.HumanScore}}</td></tr>
<tr><td><strong>Connected</strong></td><td>{{formatTime .Session.ConnectedAt}}</td></tr>
<tr>

View File

@@ -1,11 +1,13 @@
package web
import (
"crypto/subtle"
"embed"
"log/slog"
"net/http"
"strings"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
//go:embed static/*
@@ -20,7 +22,9 @@ type Server struct {
}
// NewServer creates a new web Server with routes registered.
func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
// If metricsHandler is non-nil, a /metrics endpoint is registered.
// If metricsToken is non-empty, the metrics endpoint requires Bearer token auth.
func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Handler, metricsToken string) (*Server, error) {
tmpl, err := loadTemplates()
if err != nil {
return nil, err
@@ -36,9 +40,22 @@ func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
s.mux.Handle("GET /static/", http.FileServerFS(staticFS))
s.mux.HandleFunc("GET /sessions/{id}", s.handleSessionDetail)
s.mux.HandleFunc("GET /api/sessions/{id}/events", s.handleAPISessionEvents)
s.mux.HandleFunc("GET /api/charts/attempts-over-time", s.handleAPIAttemptsOverTime)
s.mux.HandleFunc("GET /api/charts/hourly-pattern", s.handleAPIHourlyPattern)
s.mux.HandleFunc("GET /api/charts/country-stats", s.handleAPICountryStats)
s.mux.HandleFunc("GET /", s.handleDashboard)
s.mux.HandleFunc("GET /fragments/stats", s.handleFragmentStats)
s.mux.HandleFunc("GET /fragments/active-sessions", s.handleFragmentActiveSessions)
s.mux.HandleFunc("GET /fragments/dashboard-content", s.handleFragmentDashboardContent)
s.mux.HandleFunc("GET /fragments/recent-sessions", s.handleFragmentRecentSessions)
if metricsHandler != nil {
h := metricsHandler
if metricsToken != "" {
h = requireBearerToken(metricsToken, h)
}
s.mux.Handle("GET /metrics", h)
}
return s, nil
}
@@ -47,3 +64,20 @@ func NewServer(store storage.Store, logger *slog.Logger) (*Server, error) {
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.mux.ServeHTTP(w, r)
}
// requireBearerToken wraps a handler to require a valid Bearer token.
func requireBearerToken(token string, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
provided := auth[len("Bearer "):]
if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -10,14 +10,15 @@ import (
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/storage"
"code.t-juice.club/torjus/oubliette/internal/metrics"
"code.t-juice.club/torjus/oubliette/internal/storage"
)
func newTestServer(t *testing.T) *Server {
t.Helper()
store := storage.NewMemoryStore()
logger := slog.Default()
srv, err := NewServer(store, logger)
srv, err := NewServer(store, logger, nil, "")
if err != nil {
t.Fatalf("creating server: %v", err)
}
@@ -30,29 +31,53 @@ func newSeededTestServer(t *testing.T) *Server {
ctx := context.Background()
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1"); err != nil {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
}
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2"); err != nil {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil {
t.Fatalf("seeding attempt: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash"); err != nil {
if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash"); err != nil {
if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil {
t.Fatalf("creating session: %v", err)
}
logger := slog.Default()
srv, err := NewServer(store, logger)
srv, err := NewServer(store, logger, nil, "")
if err != nil {
t.Fatalf("creating server: %v", err)
}
return srv
}
func TestDbContextNotCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
req := httptest.NewRequest(http.MethodGet, "/", nil)
req = req.WithContext(ctx)
dbCtx, dbCancel := dbContext(req)
defer dbCancel()
// Cancel the original request context.
cancel()
// The DB context should still be usable.
select {
case <-dbCtx.Done():
t.Fatal("dbContext should not be canceled when request context is canceled")
default:
}
// Verify the DB context has a deadline (from the timeout).
if _, ok := dbCtx.Deadline(); !ok {
t.Error("dbContext should have a deadline")
}
}
func TestDashboardHandler(t *testing.T) {
t.Run("empty store", func(t *testing.T) {
srv := newTestServer(t)
@@ -149,12 +174,12 @@ func TestSessionDetailHandler(t *testing.T) {
t.Run("found", func(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
srv, err := NewServer(store, slog.Default())
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
@@ -180,7 +205,7 @@ func TestSessionDetailHandler(t *testing.T) {
func TestAPISessionEvents(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash")
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("CreateSession: %v", err)
}
@@ -194,7 +219,7 @@ func TestAPISessionEvents(t *testing.T) {
t.Fatalf("AppendSessionEvents: %v", err)
}
srv, err := NewServer(store, slog.Default())
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
@@ -236,6 +261,293 @@ func TestAPISessionEvents(t *testing.T) {
}
}
func TestMetricsEndpoint(t *testing.T) {
t.Run("enabled", func(t *testing.T) {
m := metrics.New("test")
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
if !strings.Contains(body, `oubliette_build_info{version="test"} 1`) {
t.Errorf("response should contain build_info metric, got:\n%s", body)
}
})
t.Run("disabled", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
// Without a metrics handler, /metrics falls through to the dashboard.
body := w.Body.String()
if strings.Contains(body, "oubliette_build_info") {
t.Error("response should not contain prometheus metrics when disabled")
}
})
}
func TestMetricsBearerToken(t *testing.T) {
m := metrics.New("test")
t.Run("valid token", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
req.Header.Set("Authorization", "Bearer secret")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
t.Run("wrong token", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
req.Header.Set("Authorization", "Bearer wrong")
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", w.Code)
}
})
t.Run("missing header", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "secret")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", w.Code)
}
})
t.Run("no token configured", func(t *testing.T) {
store := storage.NewMemoryStore()
srv, err := NewServer(store, slog.Default(), m.Handler(), "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
})
}
func TestTruncateCommand(t *testing.T) {
funcMap := templateFuncMap()
fn := funcMap["truncateCommand"].(func(string) string)
tests := []struct {
input string
want string
}{
{"short", "short"},
{"exactly fifty characters long! that is what it i.", "exactly fifty characters long! that is what it i."},
{"this string is definitely longer than fifty characters and should be truncated", "this string is definitely longer than fifty charac..."},
{"", ""},
}
for _, tt := range tests {
got := fn(tt.input)
if got != tt.want {
t.Errorf("truncateCommand(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestDashboardExecCommands(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
if err != nil {
t.Fatalf("creating session: %v", err)
}
if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil {
t.Fatalf("setting exec command: %v", err)
}
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
if !strings.Contains(body, "Top Exec Commands") {
t.Error("response should contain 'Top Exec Commands'")
}
if !strings.Contains(body, "uname -a") {
t.Error("response should contain exec command 'uname -a'")
}
}
func TestAPIAttemptsOverTime(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/api/charts/attempts-over-time", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
ct := w.Header().Get("Content-Type")
if !strings.Contains(ct, "application/json") {
t.Errorf("Content-Type = %q, want application/json", ct)
}
var resp apiAttemptsOverTimeResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
// Seeded data inserted today -> at least 1 point.
if len(resp.Points) == 0 {
t.Error("expected at least one data point")
}
}
func TestAPIHourlyPattern(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/api/charts/hourly-pattern", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
var resp apiHourlyPatternResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if len(resp.Hours) == 0 {
t.Error("expected at least one hourly data point")
}
}
func TestAPICountryStats(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding: %v", err)
}
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding: %v", err)
}
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/api/charts/country-stats", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
var resp apiCountryStatsResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if len(resp.Countries) != 2 {
t.Fatalf("len = %d, want 2", len(resp.Countries))
}
}
func TestFragmentDashboardContent(t *testing.T) {
srv := newSeededTestServer(t)
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
if strings.Contains(body, "<!DOCTYPE html>") {
t.Error("dashboard content fragment should not contain full HTML document")
}
if !strings.Contains(body, "Top Usernames") {
t.Error("dashboard content fragment should contain 'Top Usernames'")
}
}
func TestFragmentDashboardContentWithFilter(t *testing.T) {
store := storage.NewMemoryStore()
ctx := context.Background()
for range 5 {
if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil {
t.Fatalf("seeding: %v", err)
}
}
for range 3 {
if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil {
t.Fatalf("seeding: %v", err)
}
}
srv, err := NewServer(store, slog.Default(), nil, "")
if err != nil {
t.Fatalf("NewServer: %v", err)
}
req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content?country=CN", nil)
w := httptest.NewRecorder()
srv.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want 200", w.Code)
}
body := w.Body.String()
// When filtered by CN, should show root but not admin.
if !strings.Contains(body, "root") {
t.Error("response should contain 'root' when filtered by CN")
}
}
func TestStaticAssets(t *testing.T) {
srv := newTestServer(t)
@@ -245,6 +557,9 @@ func TestStaticAssets(t *testing.T) {
}{
{"/static/pico.min.css", "text/css"},
{"/static/htmx.min.js", "text/javascript"},
{"/static/chart.min.js", "text/javascript"},
{"/static/dashboard.js", "text/javascript"},
{"/static/world.svg", "image/svg+xml"},
}
for _, tt := range tests {

View File

@@ -29,6 +29,21 @@ password = "admin"
# password = "banking"
# shell = "banking"
# [[auth.static_credentials]]
# username = "admin"
# password = "cisco"
# shell = "cisco"
# [[auth.static_credentials]]
# username = "irobot"
# password = "roomba"
# shell = "roomba"
# [[auth.static_credentials]]
# username = "player"
# password = "tetris"
# shell = "tetris"
[storage]
db_path = "oubliette.db"
retention_days = 90
@@ -37,12 +52,20 @@ retention_interval = "1h"
# [web]
# enabled = true
# listen_addr = ":8080"
# metrics_enabled = true
# metrics_token = "" # bearer token for /metrics; empty = no auth
[shell]
hostname = "ubuntu-server"
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
# fake_user = "" # override username in prompt; empty = use authenticated user
# Map usernames to specific shells (regardless of how auth succeeded).
# Credential-specific shell overrides take priority over username routes.
# [shell.username_routes]
# postgres = "psql"
# admin = "bash"
# Per-shell configuration (optional).
# [shell.banking]
# bank_name = "SECUREBANK"
@@ -52,6 +75,22 @@ hostname = "ubuntu-server"
# [shell.adventure]
# dungeon_name = "THE OUBLIETTE"
# [shell.cisco]
# hostname = "Router"
# model = "C2960"
# ios_version = "15.0(2)SE11"
# enable_password = "" # empty = accept after 1 failed attempt
# [shell.psql]
# db_name = "postgres"
# pg_version = "15.4"
# [shell.roomba]
# No configuration options currently.
# [shell.tetris]
# difficulty = "normal" # "easy" (slower start), "normal" (standard), "hard" (start at level 5)
# [detection]
# enabled = true
# threshold = 0.6 # 0.01.0, sessions above this trigger notifications

18
scripts/fetch-geoip.sh Executable file
View File

@@ -0,0 +1,18 @@
#!/usr/bin/env bash
# Downloads the DB-IP Lite country MMDB database for development.
# The Nix build fetches this automatically; this script is for local dev only.
set -euo pipefail
URL="https://download.db-ip.com/free/dbip-country-lite-2026-02.mmdb.gz"
DEST="internal/geoip/dbip-country-lite.mmdb"
cd "$(git rev-parse --show-toplevel)"
if [ -f "$DEST" ]; then
echo "GeoIP database already exists at $DEST"
exit 0
fi
echo "Downloading DB-IP Lite country database..."
curl -fSL "$URL" | gunzip > "$DEST"
echo "Saved to $DEST"