From 40fda3420c0f1df7e843b41d70369f6e498a822f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torjus=20H=C3=A5kestad?= Date: Sun, 15 Feb 2026 19:58:34 +0100 Subject: [PATCH] feat: add psql shell and username-to-shell routing Add a PostgreSQL psql interactive terminal shell with backslash meta-commands, SQL statement handling with multi-line buffering, and canned responses for common queries. Add username-based shell routing via [shell.username_routes] config (second priority after credential- specific shell, before random selection). Bump version to 0.13.0. Co-Authored-By: Claude Opus 4.6 --- PLAN.md | 10 +- README.md | 3 +- cmd/oubliette/main.go | 2 +- internal/config/config.go | 16 +- internal/config/config_test.go | 36 ++++ internal/server/server.go | 16 ++ internal/server/server_test.go | 84 ++++++++ internal/shell/psql/commands.go | 123 ++++++++++++ internal/shell/psql/output.go | 155 +++++++++++++++ internal/shell/psql/psql.go | 137 +++++++++++++ internal/shell/psql/psql_test.go | 330 +++++++++++++++++++++++++++++++ oubliette.toml.example | 10 + 12 files changed, 912 insertions(+), 10 deletions(-) create mode 100644 internal/shell/psql/commands.go create mode 100644 internal/shell/psql/output.go create mode 100644 internal/shell/psql/psql.go create mode 100644 internal/shell/psql/psql_test.go diff --git a/PLAN.md b/PLAN.md index 9607755..8ffe7b2 100644 --- a/PLAN.md +++ b/PLAN.md @@ -171,7 +171,15 @@ Goal: Add the entertaining shell implementations. ### 3.5 Banking TUI Shell ✅ - 80s-style green-on-black bank terminal -### 3.6 Other Shell Ideas (Future) +### 3.6 PostgreSQL psql Shell ✅ +- Simulates psql interactive terminal with `db_name` and `pg_version` config +- Backslash meta-commands: `\q`, `\dt`, `\d `, `\l`, `\du`, `\conninfo`, `\?`, `\h` +- SQL statement handling with multi-line buffering (semicolon-terminated) +- Canned responses for common queries (SELECT version(), current_database(), etc.) +- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.) +- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells + +### 3.7 Other Shell Ideas (Future) - **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE" - **ELIZA therapist:** every response is a therapy question - **Pizza ordering terminal:** "Welcome to PizzaNet v2.3" diff --git a/README.md b/README.md index 181a319..1ed5d4d 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,8 @@ Key settings: - `auth.accept_after` — accept login after N failures per IP (default `10`) - `auth.credential_ttl` — how long to remember accepted credentials (default `24h`) - `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell) -- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation) +- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal) +- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority - `storage.db_path` — SQLite database path (default `oubliette.db`) - `storage.retention_days` — auto-prune records older than N days (default `90`) - `storage.retention_interval` — how often to run retention (default `1h`) diff --git a/cmd/oubliette/main.go b/cmd/oubliette/main.go index a0b7ba6..9784561 100644 --- a/cmd/oubliette/main.go +++ b/cmd/oubliette/main.go @@ -20,7 +20,7 @@ import ( "git.t-juice.club/torjus/oubliette/internal/web" ) -const Version = "0.12.0" +const Version = "0.13.0" func main() { if err := run(); err != nil { diff --git a/internal/config/config.go b/internal/config/config.go index 751e53f..3d40213 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -28,10 +28,11 @@ type WebConfig struct { } type ShellConfig struct { - Hostname string `toml:"hostname"` - Banner string `toml:"banner"` - FakeUser string `toml:"fake_user"` - Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually + Hostname string `toml:"hostname"` + Banner string `toml:"banner"` + FakeUser string `toml:"fake_user"` + UsernameRoutes map[string]string `toml:"username_routes"` + Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually } type StorageConfig struct { @@ -165,9 +166,10 @@ func applyDefaults(cfg *Config) { // knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables. var knownShellKeys = map[string]bool{ - "hostname": true, - "banner": true, - "fake_user": true, + "hostname": true, + "banner": true, + "fake_user": true, + "username_routes": true, } // extractShellTables pulls per-shell config sub-tables from the raw [shell] section. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4abeae2..d6bfb5a 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -313,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) { } } +func TestLoadUsernameRoutes(t *testing.T) { + content := ` +[shell] +hostname = "myhost" + +[shell.username_routes] +postgres = "psql" +admin = "bash" + +[shell.bash] +custom_key = "value" +` + path := writeTemp(t, content) + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Shell.UsernameRoutes == nil { + t.Fatal("UsernameRoutes should not be nil") + } + if cfg.Shell.UsernameRoutes["postgres"] != "psql" { + t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql") + } + if cfg.Shell.UsernameRoutes["admin"] != "bash" { + t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash") + } + // username_routes should NOT appear in the Shells map. + if _, ok := cfg.Shell.Shells["username_routes"]; ok { + t.Error("username_routes should not appear in Shells map") + } + // bash should still appear in Shells map. + if _, ok := cfg.Shell.Shells["bash"]; !ok { + t.Error("Shells[\"bash\"] should still be present") + } +} + func writeTemp(t *testing.T, content string) string { t.Helper() path := filepath.Join(t.TempDir(), "config.toml") diff --git a/internal/server/server.go b/internal/server/server.go index c4dc38a..e4229c3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -24,6 +24,7 @@ import ( "git.t-juice.club/torjus/oubliette/internal/shell/bash" "git.t-juice.club/torjus/oubliette/internal/shell/cisco" "git.t-juice.club/torjus/oubliette/internal/shell/fridge" + psqlshell "git.t-juice.club/torjus/oubliette/internal/shell/psql" "git.t-juice.club/torjus/oubliette/internal/storage" "golang.org/x/crypto/ssh" ) @@ -58,6 +59,9 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil { return nil, fmt.Errorf("registering cisco shell: %w", err) } + if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil { + return nil, fmt.Errorf("registering psql shell: %w", err) + } geo, err := geoip.New() if err != nil { @@ -185,6 +189,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request s.logger.Warn("configured shell not found, falling back to random", "shell", shellName) } } + // Second priority: username-based route. + if selectedShell == nil { + if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok { + sh, found := s.shellRegistry.Get(shellName) + if found { + selectedShell = sh + } else { + s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User()) + } + } + } + // Lowest priority: random selection. if selectedShell == nil { var err error selectedShell, err = s.shellRegistry.Select() diff --git a/internal/server/server_test.go b/internal/server/server_test.go index a6b45cd..9795513 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "git.t-juice.club/torjus/oubliette/internal/auth" "git.t-juice.club/torjus/oubliette/internal/config" "git.t-juice.club/torjus/oubliette/internal/metrics" "git.t-juice.club/torjus/oubliette/internal/storage" @@ -300,6 +301,89 @@ func TestIntegrationSSHConnect(t *testing.T) { } }) + // Test username route: add username_routes so that "postgres" gets psql shell. + t.Run("username_route", func(t *testing.T) { + // Reconfigure with username routes. + srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"} + defer func() { srv.cfg.Shell.UsernameRoutes = nil }() + + // Need to get the "postgres" user in via static creds or threshold. + // Use static creds for simplicity. + srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials, + config.Credential{Username: "postgres", Password: "postgres"}, + ) + srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth) + defer func() { + srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1] + srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth) + }() + + clientCfg := &ssh.ClientConfig{ + User: "postgres", + Auth: []ssh.AuthMethod{ssh.Password("postgres")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + client, err := ssh.Dial("tcp", addr, clientCfg) + if err != nil { + t.Fatalf("SSH dial: %v", err) + } + defer client.Close() + + session, err := client.NewSession() + if err != nil { + t.Fatalf("new session: %v", err) + } + defer session.Close() + + if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil { + t.Fatalf("request pty: %v", err) + } + + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("stdin pipe: %v", err) + } + + var output bytes.Buffer + session.Stdout = &output + + if err := session.Shell(); err != nil { + t.Fatalf("shell: %v", err) + } + + // Wait for the psql banner. + time.Sleep(500 * time.Millisecond) + + // Send \q to quit. + stdin.Write([]byte(`\q` + "\r")) + time.Sleep(200 * time.Millisecond) + + session.Wait() + + out := output.String() + if !strings.Contains(out, "psql") { + t.Errorf("output should contain psql banner, got: %s", out) + } + + // Verify session was created with shell name "psql". + sessions, err := store.GetRecentSessions(context.Background(), 50, false) + if err != nil { + t.Fatalf("GetRecentSessions: %v", err) + } + var foundPsql bool + for _, s := range sessions { + if s.ShellName == "psql" && s.Username == "postgres" { + foundPsql = true + break + } + } + if !foundPsql { + t.Error("expected a session with shell_name='psql' for user 'postgres'") + } + }) + // Test threshold acceptance: after enough failed dials, a subsequent // dial with the same credentials should succeed via threshold or // remembered credential. diff --git a/internal/shell/psql/commands.go b/internal/shell/psql/commands.go new file mode 100644 index 0000000..c75b523 --- /dev/null +++ b/internal/shell/psql/commands.go @@ -0,0 +1,123 @@ +package psql + +import ( + "fmt" + "strings" + "time" +) + +// commandResult holds the output of a command and whether the session should end. +type commandResult struct { + output string + exit bool +} + +// dispatchBackslash handles psql backslash meta-commands. +func dispatchBackslash(cmd, dbName string) commandResult { + // Normalize: trim spaces after the backslash command word. + parts := strings.Fields(cmd) + if len(parts) == 0 { + return commandResult{output: "Invalid command \\. Try \\? for help."} + } + + verb := parts[0] // e.g. `\q`, `\dt`, `\d` + args := parts[1:] + + switch verb { + case `\q`: + return commandResult{exit: true} + case `\dt`: + return commandResult{output: listTables()} + case `\d`: + if len(args) == 0 { + return commandResult{output: listTables()} + } + return commandResult{output: describeTable(args[0])} + case `\l`: + return commandResult{output: listDatabases()} + case `\du`: + return commandResult{output: listRoles()} + case `\conninfo`: + return commandResult{output: connInfo(dbName)} + case `\?`: + return commandResult{output: backslashHelp()} + case `\h`: + return commandResult{output: sqlHelp()} + default: + return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)} + } +} + +// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated). +func dispatchSQL(sql, dbName, pgVersion string) commandResult { + // Strip trailing semicolon and whitespace for matching. + trimmed := strings.TrimRight(sql, "; \t") + trimmed = strings.TrimSpace(trimmed) + upper := strings.ToUpper(trimmed) + + switch { + case upper == "SELECT VERSION()": + ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion) + return commandResult{output: formatSingleValue("version", ver)} + case upper == "SELECT CURRENT_DATABASE()": + return commandResult{output: formatSingleValue("current_database", dbName)} + case upper == "SELECT CURRENT_USER": + return commandResult{output: formatSingleValue("current_user", "postgres")} + case upper == "SELECT NOW()": + now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00") + return commandResult{output: formatSingleValue("now", now)} + case upper == "SELECT 1": + return commandResult{output: formatSingleValue("?column?", "1")} + case strings.HasPrefix(upper, "INSERT"): + return commandResult{output: "INSERT 0 1"} + case strings.HasPrefix(upper, "UPDATE"): + return commandResult{output: "UPDATE 1"} + case strings.HasPrefix(upper, "DELETE"): + return commandResult{output: "DELETE 1"} + case strings.HasPrefix(upper, "CREATE TABLE"): + return commandResult{output: "CREATE TABLE"} + case strings.HasPrefix(upper, "CREATE DATABASE"): + return commandResult{output: "CREATE DATABASE"} + case strings.HasPrefix(upper, "DROP TABLE"): + return commandResult{output: "DROP TABLE"} + case strings.HasPrefix(upper, "ALTER TABLE"): + return commandResult{output: "ALTER TABLE"} + case upper == "BEGIN": + return commandResult{output: "BEGIN"} + case upper == "COMMIT": + return commandResult{output: "COMMIT"} + case upper == "ROLLBACK": + return commandResult{output: "ROLLBACK"} + case upper == "SHOW SERVER_VERSION": + return commandResult{output: formatSingleValue("server_version", pgVersion)} + case upper == "SHOW SEARCH_PATH": + return commandResult{output: formatSingleValue("search_path", "\"$user\", public")} + case strings.HasPrefix(upper, "SET "): + return commandResult{output: "SET"} + default: + // Extract the first token for the error message. + firstToken := strings.Fields(trimmed) + token := trimmed + if len(firstToken) > 0 { + token = firstToken[0] + } + return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)} + } +} + +// formatSingleValue formats a single-row, single-column psql result. +func formatSingleValue(colName, value string) string { + width := max(len(colName), len(value)) + + var b strings.Builder + // Header + fmt.Fprintf(&b, " %-*s \n", width, colName) + // Separator + b.WriteString(strings.Repeat("-", width+2)) + b.WriteString("\n") + // Value + fmt.Fprintf(&b, " %-*s\n", width, value) + // Row count + b.WriteString("(1 row)") + return b.String() +} diff --git a/internal/shell/psql/output.go b/internal/shell/psql/output.go new file mode 100644 index 0000000..be87a6f --- /dev/null +++ b/internal/shell/psql/output.go @@ -0,0 +1,155 @@ +package psql + +import "fmt" + +func startupBanner(version string) string { + return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version) +} + +func listTables() string { + return ` List of relations + Schema | Name | Type | Owner +--------+---------------+-------+---------- + public | audit_log | table | postgres + public | credentials | table | postgres + public | sessions | table | postgres + public | users | table | postgres +(4 rows)` +} + +func listDatabases() string { + return ` List of databases + Name | Owner | Encoding | Collate | Ctype | Access privileges +-----------+----------+----------+-------------+-------------+----------------------- + app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | + postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | + template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres + + | | | | | postgres=CTc/postgres + template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres + + | | | | | postgres=CTc/postgres +(4 rows)` +} + +func listRoles() string { + return ` List of roles + Role name | Attributes | Member of +-----------+------------------------------------------------------------+----------- + app_user | | {} + postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {} + readonly | Cannot login | {}` +} + +func describeTable(name string) string { + switch name { + case "users": + return ` Table "public.users" + Column | Type | Collation | Nullable | Default +------------+-----------------------------+-----------+----------+----------------------------------- + id | integer | | not null | nextval('users_id_seq'::regclass) + username | character varying(255) | | not null | + email | character varying(255) | | not null | + password | character varying(255) | | not null | + created_at | timestamp without time zone | | | now() + updated_at | timestamp without time zone | | | now() +Indexes: + "users_pkey" PRIMARY KEY, btree (id) + "users_email_key" UNIQUE, btree (email) + "users_username_key" UNIQUE, btree (username)` + case "sessions": + return ` Table "public.sessions" + Column | Type | Collation | Nullable | Default +------------+-----------------------------+-----------+----------+-------------------------------------- + id | integer | | not null | nextval('sessions_id_seq'::regclass) + user_id | integer | | | + token | character varying(255) | | not null | + ip_address | inet | | | + created_at | timestamp without time zone | | | now() + expires_at | timestamp without time zone | | not null | +Indexes: + "sessions_pkey" PRIMARY KEY, btree (id) + "sessions_token_key" UNIQUE, btree (token) +Foreign-key constraints: + "sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)` + case "credentials": + return ` Table "public.credentials" + Column | Type | Collation | Nullable | Default +-----------+-----------------------------+-----------+----------+----------------------------------------- + id | integer | | not null | nextval('credentials_id_seq'::regclass) + user_id | integer | | | + type | character varying(50) | | not null | + value | text | | not null | + created_at| timestamp without time zone | | | now() +Indexes: + "credentials_pkey" PRIMARY KEY, btree (id) +Foreign-key constraints: + "credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)` + case "audit_log": + return ` Table "public.audit_log" + Column | Type | Collation | Nullable | Default +------------+-----------------------------+-----------+----------+--------------------------------------- + id | integer | | not null | nextval('audit_log_id_seq'::regclass) + user_id | integer | | | + action | character varying(100) | | not null | + details | text | | | + ip_address | inet | | | + created_at | timestamp without time zone | | | now() +Indexes: + "audit_log_pkey" PRIMARY KEY, btree (id) +Foreign-key constraints: + "audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)` + default: + return fmt.Sprintf("Did not find any relation named \"%s\".", name) + } +} + +func connInfo(dbName string) string { + return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName) +} + +func backslashHelp() string { + return `General + \copyright show PostgreSQL usage and distribution terms + \crosstabview [COLUMNS] execute query and display result in crosstab + \errverbose show most recent error message at maximum verbosity + \g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe) + \gdesc describe result of query, without executing it + \gexec execute query, then execute each value in its result + \gset [PREFIX] execute query and store result in psql variables + \gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode + \q quit psql + \watch [SEC] execute query every SEC seconds + +Informational + (options: S = show system objects, + = additional detail) + \d[S+] list tables, views, and sequences + \d[S+] NAME describe table, view, sequence, or index + \da[S] [PATTERN] list aggregates + \dA[+] [PATTERN] list access methods + \dt[S+] [PATTERN] list tables + \du[S+] [PATTERN] list roles + \l[+] [PATTERN] list databases` +} + +func sqlHelp() string { + return `Available help: + ABORT CREATE LANGUAGE + ALTER AGGREGATE CREATE MATERIALIZED VIEW + ALTER COLLATION CREATE OPERATOR + ALTER CONVERSION CREATE POLICY + ALTER DATABASE CREATE PROCEDURE + ALTER DEFAULT PRIVILEGES CREATE PUBLICATION + ALTER DOMAIN CREATE ROLE + ALTER EVENT TRIGGER CREATE RULE + ALTER EXTENSION CREATE SCHEMA + ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE + ALTER FOREIGN TABLE CREATE SERVER + ALTER FUNCTION CREATE STATISTICS + ALTER GROUP CREATE SUBSCRIPTION + ALTER INDEX CREATE TABLE + ALTER LANGUAGE CREATE TABLESPACE + BEGIN DELETE + COMMIT DROP TABLE + CREATE DATABASE INSERT + CREATE INDEX ROLLBACK + SELECT UPDATE` +} diff --git a/internal/shell/psql/psql.go b/internal/shell/psql/psql.go new file mode 100644 index 0000000..2aa8d07 --- /dev/null +++ b/internal/shell/psql/psql.go @@ -0,0 +1,137 @@ +package psql + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "time" + + "git.t-juice.club/torjus/oubliette/internal/shell" +) + +const sessionTimeout = 5 * time.Minute + +// PsqlShell emulates a PostgreSQL psql interactive terminal. +type PsqlShell struct{} + +// NewPsqlShell returns a new PsqlShell instance. +func NewPsqlShell() *PsqlShell { + return &PsqlShell{} +} + +func (p *PsqlShell) Name() string { return "psql" } +func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" } + +func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error { + ctx, cancel := context.WithTimeout(ctx, sessionTimeout) + defer cancel() + + dbName := configString(sess.ShellConfig, "db_name", "postgres") + pgVersion := configString(sess.ShellConfig, "pg_version", "15.4") + + // Print startup banner. + fmt.Fprint(rw, startupBanner(pgVersion)) + + var sqlBuf []string // accumulates multi-line SQL + + for { + prompt := buildPrompt(dbName, len(sqlBuf) > 0) + if _, err := fmt.Fprint(rw, prompt); err != nil { + return nil + } + + line, err := shell.ReadLine(ctx, rw) + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return nil + } + + trimmed := strings.TrimSpace(line) + + // Empty line in non-buffering state: just re-prompt. + if trimmed == "" && len(sqlBuf) == 0 { + continue + } + + // Backslash commands dispatch immediately (even mid-buffer they cancel the buffer). + if strings.HasPrefix(trimmed, `\`) { + sqlBuf = nil // discard any partial SQL + + result := dispatchBackslash(trimmed, dbName) + if result.output != "" { + output := strings.ReplaceAll(result.output, "\n", "\r\n") + fmt.Fprintf(rw, "%s\r\n", output) + } + + if sess.Store != nil { + if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil { + return fmt.Errorf("append session log: %w", err) + } + } + if sess.OnCommand != nil { + sess.OnCommand("psql") + } + + if result.exit { + return nil + } + continue + } + + // Accumulate SQL lines. + sqlBuf = append(sqlBuf, line) + + // Check if the statement is terminated by a semicolon. + if !strings.HasSuffix(strings.TrimSpace(line), ";") { + continue + } + + // Full statement ready — join and dispatch. + fullSQL := strings.Join(sqlBuf, " ") + sqlBuf = nil + + result := dispatchSQL(fullSQL, dbName, pgVersion) + if result.output != "" { + output := strings.ReplaceAll(result.output, "\n", "\r\n") + fmt.Fprintf(rw, "%s\r\n", output) + } + + if sess.Store != nil { + if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil { + return fmt.Errorf("append session log: %w", err) + } + } + if sess.OnCommand != nil { + sess.OnCommand("psql") + } + + if result.exit { + return nil + } + } +} + +// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL. +func buildPrompt(dbName string, continuation bool) string { + if continuation { + return dbName + "-# " + } + return dbName + "=# " +} + +// configString reads a string from the shell config map with a default. +func configString(cfg map[string]any, key, defaultVal string) string { + if cfg == nil { + return defaultVal + } + if v, ok := cfg[key]; ok { + if s, ok := v.(string); ok && s != "" { + return s + } + } + return defaultVal +} diff --git a/internal/shell/psql/psql_test.go b/internal/shell/psql/psql_test.go new file mode 100644 index 0000000..01c7baf --- /dev/null +++ b/internal/shell/psql/psql_test.go @@ -0,0 +1,330 @@ +package psql + +import ( + "strings" + "testing" +) + +// --- Prompt tests --- + +func TestBuildPromptNormal(t *testing.T) { + got := buildPrompt("postgres", false) + if got != "postgres=# " { + t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ") + } +} + +func TestBuildPromptContinuation(t *testing.T) { + got := buildPrompt("postgres", true) + if got != "postgres-# " { + t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ") + } +} + +func TestBuildPromptCustomDB(t *testing.T) { + got := buildPrompt("mydb", false) + if got != "mydb=# " { + t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ") + } +} + +// --- Backslash command dispatch tests --- + +func TestBackslashQuit(t *testing.T) { + result := dispatchBackslash(`\q`, "postgres") + if !result.exit { + t.Error("\\q should set exit=true") + } +} + +func TestBackslashListTables(t *testing.T) { + result := dispatchBackslash(`\dt`, "postgres") + if !strings.Contains(result.output, "users") { + t.Error("\\dt should list tables including 'users'") + } + if !strings.Contains(result.output, "sessions") { + t.Error("\\dt should list tables including 'sessions'") + } +} + +func TestBackslashDescribeTable(t *testing.T) { + result := dispatchBackslash(`\d users`, "postgres") + if !strings.Contains(result.output, "username") { + t.Error("\\d users should describe users table with 'username' column") + } + if !strings.Contains(result.output, "PRIMARY KEY") { + t.Error("\\d users should include index info") + } +} + +func TestBackslashDescribeUnknownTable(t *testing.T) { + result := dispatchBackslash(`\d nonexistent`, "postgres") + if !strings.Contains(result.output, "Did not find") { + t.Error("\\d nonexistent should return not found message") + } +} + +func TestBackslashListDatabases(t *testing.T) { + result := dispatchBackslash(`\l`, "postgres") + if !strings.Contains(result.output, "postgres") { + t.Error("\\l should list databases including 'postgres'") + } + if !strings.Contains(result.output, "template0") { + t.Error("\\l should list databases including 'template0'") + } +} + +func TestBackslashListRoles(t *testing.T) { + result := dispatchBackslash(`\du`, "postgres") + if !strings.Contains(result.output, "postgres") { + t.Error("\\du should list roles including 'postgres'") + } + if !strings.Contains(result.output, "Superuser") { + t.Error("\\du should show Superuser attribute for postgres") + } +} + +func TestBackslashConnInfo(t *testing.T) { + result := dispatchBackslash(`\conninfo`, "mydb") + if !strings.Contains(result.output, "mydb") { + t.Error("\\conninfo should include database name") + } + if !strings.Contains(result.output, "5432") { + t.Error("\\conninfo should include port") + } +} + +func TestBackslashHelp(t *testing.T) { + result := dispatchBackslash(`\?`, "postgres") + if !strings.Contains(result.output, `\q`) { + t.Error("\\? should include \\q in help output") + } +} + +func TestBackslashSQLHelp(t *testing.T) { + result := dispatchBackslash(`\h`, "postgres") + if !strings.Contains(result.output, "SELECT") { + t.Error("\\h should include SQL commands like SELECT") + } +} + +func TestBackslashUnknown(t *testing.T) { + result := dispatchBackslash(`\xyz`, "postgres") + if !strings.Contains(result.output, "Invalid command") { + t.Error("unknown backslash command should return error") + } +} + +// --- SQL dispatch tests --- + +func TestSQLSelectVersion(t *testing.T) { + result := dispatchSQL("SELECT version();", "postgres", "15.4") + if !strings.Contains(result.output, "15.4") { + t.Error("SELECT version() should contain pg version") + } + if !strings.Contains(result.output, "(1 row)") { + t.Error("SELECT version() should show row count") + } +} + +func TestSQLSelectCurrentDatabase(t *testing.T) { + result := dispatchSQL("SELECT current_database();", "mydb", "15.4") + if !strings.Contains(result.output, "mydb") { + t.Error("SELECT current_database() should return db name") + } +} + +func TestSQLSelectCurrentUser(t *testing.T) { + result := dispatchSQL("SELECT current_user;", "postgres", "15.4") + if !strings.Contains(result.output, "postgres") { + t.Error("SELECT current_user should return postgres") + } +} + +func TestSQLSelectNow(t *testing.T) { + result := dispatchSQL("SELECT now();", "postgres", "15.4") + if !strings.Contains(result.output, "(1 row)") { + t.Error("SELECT now() should show row count") + } +} + +func TestSQLSelectOne(t *testing.T) { + result := dispatchSQL("SELECT 1;", "postgres", "15.4") + if !strings.Contains(result.output, "1") { + t.Error("SELECT 1 should return 1") + } +} + +func TestSQLInsert(t *testing.T) { + result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4") + if result.output != "INSERT 0 1" { + t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1") + } +} + +func TestSQLUpdate(t *testing.T) { + result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4") + if result.output != "UPDATE 1" { + t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1") + } +} + +func TestSQLDelete(t *testing.T) { + result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4") + if result.output != "DELETE 1" { + t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1") + } +} + +func TestSQLCreateTable(t *testing.T) { + result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4") + if result.output != "CREATE TABLE" { + t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE") + } +} + +func TestSQLCreateDatabase(t *testing.T) { + result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4") + if result.output != "CREATE DATABASE" { + t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE") + } +} + +func TestSQLDropTable(t *testing.T) { + result := dispatchSQL("DROP TABLE test;", "postgres", "15.4") + if result.output != "DROP TABLE" { + t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE") + } +} + +func TestSQLAlterTable(t *testing.T) { + result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4") + if result.output != "ALTER TABLE" { + t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE") + } +} + +func TestSQLBeginCommitRollback(t *testing.T) { + tests := []struct { + sql string + want string + }{ + {"BEGIN;", "BEGIN"}, + {"COMMIT;", "COMMIT"}, + {"ROLLBACK;", "ROLLBACK"}, + } + for _, tt := range tests { + result := dispatchSQL(tt.sql, "postgres", "15.4") + if result.output != tt.want { + t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want) + } + } +} + +func TestSQLShowServerVersion(t *testing.T) { + result := dispatchSQL("SHOW server_version;", "postgres", "15.4") + if !strings.Contains(result.output, "15.4") { + t.Error("SHOW server_version should contain version") + } +} + +func TestSQLShowSearchPath(t *testing.T) { + result := dispatchSQL("SHOW search_path;", "postgres", "15.4") + if !strings.Contains(result.output, "public") { + t.Error("SHOW search_path should contain public") + } +} + +func TestSQLSet(t *testing.T) { + result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4") + if result.output != "SET" { + t.Errorf("SET output = %q, want %q", result.output, "SET") + } +} + +func TestSQLUnrecognized(t *testing.T) { + result := dispatchSQL("FOOBAR baz;", "postgres", "15.4") + if !strings.Contains(result.output, "ERROR") { + t.Error("unrecognized SQL should return error") + } + if !strings.Contains(result.output, "FOOBAR") { + t.Error("error should reference the offending token") + } +} + +// --- Case insensitivity --- + +func TestSQLCaseInsensitive(t *testing.T) { + result := dispatchSQL("select version();", "postgres", "15.4") + if !strings.Contains(result.output, "15.4") { + t.Error("select version() (lowercase) should work") + } + + result = dispatchSQL("Select Current_Database();", "mydb", "15.4") + if !strings.Contains(result.output, "mydb") { + t.Error("mixed case SELECT should work") + } +} + +// --- Startup banner --- + +func TestStartupBanner(t *testing.T) { + banner := startupBanner("15.4") + if !strings.Contains(banner, "psql (15.4)") { + t.Errorf("banner should contain version, got: %s", banner) + } + if !strings.Contains(banner, "help") { + t.Error("banner should mention help") + } +} + +// --- configString --- + +func TestConfigString(t *testing.T) { + cfg := map[string]any{"db_name": "mydb"} + if got := configString(cfg, "db_name", "postgres"); got != "mydb" { + t.Errorf("configString() = %q, want %q", got, "mydb") + } + if got := configString(cfg, "missing", "default"); got != "default" { + t.Errorf("configString() for missing = %q, want %q", got, "default") + } + if got := configString(nil, "key", "default"); got != "default" { + t.Errorf("configString(nil) = %q, want %q", got, "default") + } +} + +// --- Shell metadata --- + +func TestShellNameAndDescription(t *testing.T) { + s := NewPsqlShell() + if s.Name() != "psql" { + t.Errorf("Name() = %q, want %q", s.Name(), "psql") + } + if s.Description() == "" { + t.Error("Description() should not be empty") + } +} + +// --- formatSingleValue --- + +func TestFormatSingleValue(t *testing.T) { + out := formatSingleValue("?column?", "1") + if !strings.Contains(out, "?column?") { + t.Error("should contain column name") + } + if !strings.Contains(out, "1") { + t.Error("should contain value") + } + if !strings.Contains(out, "(1 row)") { + t.Error("should contain row count") + } +} + +// --- \d with no args --- + +func TestBackslashDescribeNoArgs(t *testing.T) { + result := dispatchBackslash(`\d`, "postgres") + if !strings.Contains(result.output, "users") { + t.Error("\\d with no args should list tables") + } +} diff --git a/oubliette.toml.example b/oubliette.toml.example index 3960dd4..c4b0254 100644 --- a/oubliette.toml.example +++ b/oubliette.toml.example @@ -50,6 +50,12 @@ hostname = "ubuntu-server" # banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n" # fake_user = "" # override username in prompt; empty = use authenticated user +# Map usernames to specific shells (regardless of how auth succeeded). +# Credential-specific shell overrides take priority over username routes. +# [shell.username_routes] +# postgres = "psql" +# admin = "bash" + # Per-shell configuration (optional). # [shell.banking] # bank_name = "SECUREBANK" @@ -65,6 +71,10 @@ hostname = "ubuntu-server" # ios_version = "15.0(2)SE11" # enable_password = "" # empty = accept after 1 failed attempt +# [shell.psql] +# db_name = "postgres" +# pg_version = "15.4" + # [detection] # enabled = true # threshold = 0.6 # 0.0–1.0, sessions above this trigger notifications