diff --git a/PLAN.md b/PLAN.md
index 9607755..8ffe7b2 100644
--- a/PLAN.md
+++ b/PLAN.md
@@ -171,7 +171,15 @@ Goal: Add the entertaining shell implementations.
### 3.5 Banking TUI Shell ✅
- 80s-style green-on-black bank terminal
-### 3.6 Other Shell Ideas (Future)
+### 3.6 PostgreSQL psql Shell ✅
+- Simulates psql interactive terminal with `db_name` and `pg_version` config
+- Backslash meta-commands: `\q`, `\dt`, `\d
`, `\l`, `\du`, `\conninfo`, `\?`, `\h`
+- SQL statement handling with multi-line buffering (semicolon-terminated)
+- Canned responses for common queries (SELECT version(), current_database(), etc.)
+- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
+- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
+
+### 3.7 Other Shell Ideas (Future)
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
- **ELIZA therapist:** every response is a therapy question
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
diff --git a/README.md b/README.md
index 181a319..1ed5d4d 100644
--- a/README.md
+++ b/README.md
@@ -34,7 +34,8 @@ Key settings:
- `auth.accept_after` — accept login after N failures per IP (default `10`)
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
-- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation)
+- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal)
+- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
- `storage.db_path` — SQLite database path (default `oubliette.db`)
- `storage.retention_days` — auto-prune records older than N days (default `90`)
- `storage.retention_interval` — how often to run retention (default `1h`)
diff --git a/cmd/oubliette/main.go b/cmd/oubliette/main.go
index a0b7ba6..9784561 100644
--- a/cmd/oubliette/main.go
+++ b/cmd/oubliette/main.go
@@ -20,7 +20,7 @@ import (
"git.t-juice.club/torjus/oubliette/internal/web"
)
-const Version = "0.12.0"
+const Version = "0.13.0"
func main() {
if err := run(); err != nil {
diff --git a/internal/config/config.go b/internal/config/config.go
index 751e53f..3d40213 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -28,10 +28,11 @@ type WebConfig struct {
}
type ShellConfig struct {
- Hostname string `toml:"hostname"`
- Banner string `toml:"banner"`
- FakeUser string `toml:"fake_user"`
- Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
+ Hostname string `toml:"hostname"`
+ Banner string `toml:"banner"`
+ FakeUser string `toml:"fake_user"`
+ UsernameRoutes map[string]string `toml:"username_routes"`
+ Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
}
type StorageConfig struct {
@@ -165,9 +166,10 @@ func applyDefaults(cfg *Config) {
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
var knownShellKeys = map[string]bool{
- "hostname": true,
- "banner": true,
- "fake_user": true,
+ "hostname": true,
+ "banner": true,
+ "fake_user": true,
+ "username_routes": true,
}
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
index 4abeae2..d6bfb5a 100644
--- a/internal/config/config_test.go
+++ b/internal/config/config_test.go
@@ -313,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
}
}
+func TestLoadUsernameRoutes(t *testing.T) {
+ content := `
+[shell]
+hostname = "myhost"
+
+[shell.username_routes]
+postgres = "psql"
+admin = "bash"
+
+[shell.bash]
+custom_key = "value"
+`
+ path := writeTemp(t, content)
+ cfg, err := Load(path)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if cfg.Shell.UsernameRoutes == nil {
+ t.Fatal("UsernameRoutes should not be nil")
+ }
+ if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
+ t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
+ }
+ if cfg.Shell.UsernameRoutes["admin"] != "bash" {
+ t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
+ }
+ // username_routes should NOT appear in the Shells map.
+ if _, ok := cfg.Shell.Shells["username_routes"]; ok {
+ t.Error("username_routes should not appear in Shells map")
+ }
+ // bash should still appear in Shells map.
+ if _, ok := cfg.Shell.Shells["bash"]; !ok {
+ t.Error("Shells[\"bash\"] should still be present")
+ }
+}
+
func writeTemp(t *testing.T, content string) string {
t.Helper()
path := filepath.Join(t.TempDir(), "config.toml")
diff --git a/internal/server/server.go b/internal/server/server.go
index c4dc38a..e4229c3 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -24,6 +24,7 @@ import (
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
"git.t-juice.club/torjus/oubliette/internal/shell/cisco"
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
+ psqlshell "git.t-juice.club/torjus/oubliette/internal/shell/psql"
"git.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
@@ -58,6 +59,9 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
return nil, fmt.Errorf("registering cisco shell: %w", err)
}
+ if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
+ return nil, fmt.Errorf("registering psql shell: %w", err)
+ }
geo, err := geoip.New()
if err != nil {
@@ -185,6 +189,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
}
}
+ // Second priority: username-based route.
+ if selectedShell == nil {
+ if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
+ sh, found := s.shellRegistry.Get(shellName)
+ if found {
+ selectedShell = sh
+ } else {
+ s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
+ }
+ }
+ }
+ // Lowest priority: random selection.
if selectedShell == nil {
var err error
selectedShell, err = s.shellRegistry.Select()
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index a6b45cd..9795513 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -11,6 +11,7 @@ import (
"testing"
"time"
+ "git.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/metrics"
"git.t-juice.club/torjus/oubliette/internal/storage"
@@ -300,6 +301,89 @@ func TestIntegrationSSHConnect(t *testing.T) {
}
})
+ // Test username route: add username_routes so that "postgres" gets psql shell.
+ t.Run("username_route", func(t *testing.T) {
+ // Reconfigure with username routes.
+ srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
+ defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
+
+ // Need to get the "postgres" user in via static creds or threshold.
+ // Use static creds for simplicity.
+ srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
+ config.Credential{Username: "postgres", Password: "postgres"},
+ )
+ srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
+ defer func() {
+ srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
+ srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
+ }()
+
+ clientCfg := &ssh.ClientConfig{
+ User: "postgres",
+ Auth: []ssh.AuthMethod{ssh.Password("postgres")},
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ Timeout: 5 * time.Second,
+ }
+
+ client, err := ssh.Dial("tcp", addr, clientCfg)
+ if err != nil {
+ t.Fatalf("SSH dial: %v", err)
+ }
+ defer client.Close()
+
+ session, err := client.NewSession()
+ if err != nil {
+ t.Fatalf("new session: %v", err)
+ }
+ defer session.Close()
+
+ if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
+ t.Fatalf("request pty: %v", err)
+ }
+
+ stdin, err := session.StdinPipe()
+ if err != nil {
+ t.Fatalf("stdin pipe: %v", err)
+ }
+
+ var output bytes.Buffer
+ session.Stdout = &output
+
+ if err := session.Shell(); err != nil {
+ t.Fatalf("shell: %v", err)
+ }
+
+ // Wait for the psql banner.
+ time.Sleep(500 * time.Millisecond)
+
+ // Send \q to quit.
+ stdin.Write([]byte(`\q` + "\r"))
+ time.Sleep(200 * time.Millisecond)
+
+ session.Wait()
+
+ out := output.String()
+ if !strings.Contains(out, "psql") {
+ t.Errorf("output should contain psql banner, got: %s", out)
+ }
+
+ // Verify session was created with shell name "psql".
+ sessions, err := store.GetRecentSessions(context.Background(), 50, false)
+ if err != nil {
+ t.Fatalf("GetRecentSessions: %v", err)
+ }
+ var foundPsql bool
+ for _, s := range sessions {
+ if s.ShellName == "psql" && s.Username == "postgres" {
+ foundPsql = true
+ break
+ }
+ }
+ if !foundPsql {
+ t.Error("expected a session with shell_name='psql' for user 'postgres'")
+ }
+ })
+
// Test threshold acceptance: after enough failed dials, a subsequent
// dial with the same credentials should succeed via threshold or
// remembered credential.
diff --git a/internal/shell/psql/commands.go b/internal/shell/psql/commands.go
new file mode 100644
index 0000000..c75b523
--- /dev/null
+++ b/internal/shell/psql/commands.go
@@ -0,0 +1,123 @@
+package psql
+
+import (
+ "fmt"
+ "strings"
+ "time"
+)
+
+// commandResult holds the output of a command and whether the session should end.
+type commandResult struct {
+ output string
+ exit bool
+}
+
+// dispatchBackslash handles psql backslash meta-commands.
+func dispatchBackslash(cmd, dbName string) commandResult {
+ // Normalize: trim spaces after the backslash command word.
+ parts := strings.Fields(cmd)
+ if len(parts) == 0 {
+ return commandResult{output: "Invalid command \\. Try \\? for help."}
+ }
+
+ verb := parts[0] // e.g. `\q`, `\dt`, `\d`
+ args := parts[1:]
+
+ switch verb {
+ case `\q`:
+ return commandResult{exit: true}
+ case `\dt`:
+ return commandResult{output: listTables()}
+ case `\d`:
+ if len(args) == 0 {
+ return commandResult{output: listTables()}
+ }
+ return commandResult{output: describeTable(args[0])}
+ case `\l`:
+ return commandResult{output: listDatabases()}
+ case `\du`:
+ return commandResult{output: listRoles()}
+ case `\conninfo`:
+ return commandResult{output: connInfo(dbName)}
+ case `\?`:
+ return commandResult{output: backslashHelp()}
+ case `\h`:
+ return commandResult{output: sqlHelp()}
+ default:
+ return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
+ }
+}
+
+// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
+func dispatchSQL(sql, dbName, pgVersion string) commandResult {
+ // Strip trailing semicolon and whitespace for matching.
+ trimmed := strings.TrimRight(sql, "; \t")
+ trimmed = strings.TrimSpace(trimmed)
+ upper := strings.ToUpper(trimmed)
+
+ switch {
+ case upper == "SELECT VERSION()":
+ ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
+ return commandResult{output: formatSingleValue("version", ver)}
+ case upper == "SELECT CURRENT_DATABASE()":
+ return commandResult{output: formatSingleValue("current_database", dbName)}
+ case upper == "SELECT CURRENT_USER":
+ return commandResult{output: formatSingleValue("current_user", "postgres")}
+ case upper == "SELECT NOW()":
+ now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
+ return commandResult{output: formatSingleValue("now", now)}
+ case upper == "SELECT 1":
+ return commandResult{output: formatSingleValue("?column?", "1")}
+ case strings.HasPrefix(upper, "INSERT"):
+ return commandResult{output: "INSERT 0 1"}
+ case strings.HasPrefix(upper, "UPDATE"):
+ return commandResult{output: "UPDATE 1"}
+ case strings.HasPrefix(upper, "DELETE"):
+ return commandResult{output: "DELETE 1"}
+ case strings.HasPrefix(upper, "CREATE TABLE"):
+ return commandResult{output: "CREATE TABLE"}
+ case strings.HasPrefix(upper, "CREATE DATABASE"):
+ return commandResult{output: "CREATE DATABASE"}
+ case strings.HasPrefix(upper, "DROP TABLE"):
+ return commandResult{output: "DROP TABLE"}
+ case strings.HasPrefix(upper, "ALTER TABLE"):
+ return commandResult{output: "ALTER TABLE"}
+ case upper == "BEGIN":
+ return commandResult{output: "BEGIN"}
+ case upper == "COMMIT":
+ return commandResult{output: "COMMIT"}
+ case upper == "ROLLBACK":
+ return commandResult{output: "ROLLBACK"}
+ case upper == "SHOW SERVER_VERSION":
+ return commandResult{output: formatSingleValue("server_version", pgVersion)}
+ case upper == "SHOW SEARCH_PATH":
+ return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
+ case strings.HasPrefix(upper, "SET "):
+ return commandResult{output: "SET"}
+ default:
+ // Extract the first token for the error message.
+ firstToken := strings.Fields(trimmed)
+ token := trimmed
+ if len(firstToken) > 0 {
+ token = firstToken[0]
+ }
+ return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
+ }
+}
+
+// formatSingleValue formats a single-row, single-column psql result.
+func formatSingleValue(colName, value string) string {
+ width := max(len(colName), len(value))
+
+ var b strings.Builder
+ // Header
+ fmt.Fprintf(&b, " %-*s \n", width, colName)
+ // Separator
+ b.WriteString(strings.Repeat("-", width+2))
+ b.WriteString("\n")
+ // Value
+ fmt.Fprintf(&b, " %-*s\n", width, value)
+ // Row count
+ b.WriteString("(1 row)")
+ return b.String()
+}
diff --git a/internal/shell/psql/output.go b/internal/shell/psql/output.go
new file mode 100644
index 0000000..be87a6f
--- /dev/null
+++ b/internal/shell/psql/output.go
@@ -0,0 +1,155 @@
+package psql
+
+import "fmt"
+
+func startupBanner(version string) string {
+ return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
+}
+
+func listTables() string {
+ return ` List of relations
+ Schema | Name | Type | Owner
+--------+---------------+-------+----------
+ public | audit_log | table | postgres
+ public | credentials | table | postgres
+ public | sessions | table | postgres
+ public | users | table | postgres
+(4 rows)`
+}
+
+func listDatabases() string {
+ return ` List of databases
+ Name | Owner | Encoding | Collate | Ctype | Access privileges
+-----------+----------+----------+-------------+-------------+-----------------------
+ app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
+ postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
+ template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
+ | | | | | postgres=CTc/postgres
+ template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
+ | | | | | postgres=CTc/postgres
+(4 rows)`
+}
+
+func listRoles() string {
+ return ` List of roles
+ Role name | Attributes | Member of
+-----------+------------------------------------------------------------+-----------
+ app_user | | {}
+ postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
+ readonly | Cannot login | {}`
+}
+
+func describeTable(name string) string {
+ switch name {
+ case "users":
+ return ` Table "public.users"
+ Column | Type | Collation | Nullable | Default
+------------+-----------------------------+-----------+----------+-----------------------------------
+ id | integer | | not null | nextval('users_id_seq'::regclass)
+ username | character varying(255) | | not null |
+ email | character varying(255) | | not null |
+ password | character varying(255) | | not null |
+ created_at | timestamp without time zone | | | now()
+ updated_at | timestamp without time zone | | | now()
+Indexes:
+ "users_pkey" PRIMARY KEY, btree (id)
+ "users_email_key" UNIQUE, btree (email)
+ "users_username_key" UNIQUE, btree (username)`
+ case "sessions":
+ return ` Table "public.sessions"
+ Column | Type | Collation | Nullable | Default
+------------+-----------------------------+-----------+----------+--------------------------------------
+ id | integer | | not null | nextval('sessions_id_seq'::regclass)
+ user_id | integer | | |
+ token | character varying(255) | | not null |
+ ip_address | inet | | |
+ created_at | timestamp without time zone | | | now()
+ expires_at | timestamp without time zone | | not null |
+Indexes:
+ "sessions_pkey" PRIMARY KEY, btree (id)
+ "sessions_token_key" UNIQUE, btree (token)
+Foreign-key constraints:
+ "sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
+ case "credentials":
+ return ` Table "public.credentials"
+ Column | Type | Collation | Nullable | Default
+-----------+-----------------------------+-----------+----------+-----------------------------------------
+ id | integer | | not null | nextval('credentials_id_seq'::regclass)
+ user_id | integer | | |
+ type | character varying(50) | | not null |
+ value | text | | not null |
+ created_at| timestamp without time zone | | | now()
+Indexes:
+ "credentials_pkey" PRIMARY KEY, btree (id)
+Foreign-key constraints:
+ "credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
+ case "audit_log":
+ return ` Table "public.audit_log"
+ Column | Type | Collation | Nullable | Default
+------------+-----------------------------+-----------+----------+---------------------------------------
+ id | integer | | not null | nextval('audit_log_id_seq'::regclass)
+ user_id | integer | | |
+ action | character varying(100) | | not null |
+ details | text | | |
+ ip_address | inet | | |
+ created_at | timestamp without time zone | | | now()
+Indexes:
+ "audit_log_pkey" PRIMARY KEY, btree (id)
+Foreign-key constraints:
+ "audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
+ default:
+ return fmt.Sprintf("Did not find any relation named \"%s\".", name)
+ }
+}
+
+func connInfo(dbName string) string {
+ return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
+}
+
+func backslashHelp() string {
+ return `General
+ \copyright show PostgreSQL usage and distribution terms
+ \crosstabview [COLUMNS] execute query and display result in crosstab
+ \errverbose show most recent error message at maximum verbosity
+ \g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
+ \gdesc describe result of query, without executing it
+ \gexec execute query, then execute each value in its result
+ \gset [PREFIX] execute query and store result in psql variables
+ \gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
+ \q quit psql
+ \watch [SEC] execute query every SEC seconds
+
+Informational
+ (options: S = show system objects, + = additional detail)
+ \d[S+] list tables, views, and sequences
+ \d[S+] NAME describe table, view, sequence, or index
+ \da[S] [PATTERN] list aggregates
+ \dA[+] [PATTERN] list access methods
+ \dt[S+] [PATTERN] list tables
+ \du[S+] [PATTERN] list roles
+ \l[+] [PATTERN] list databases`
+}
+
+func sqlHelp() string {
+ return `Available help:
+ ABORT CREATE LANGUAGE
+ ALTER AGGREGATE CREATE MATERIALIZED VIEW
+ ALTER COLLATION CREATE OPERATOR
+ ALTER CONVERSION CREATE POLICY
+ ALTER DATABASE CREATE PROCEDURE
+ ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
+ ALTER DOMAIN CREATE ROLE
+ ALTER EVENT TRIGGER CREATE RULE
+ ALTER EXTENSION CREATE SCHEMA
+ ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
+ ALTER FOREIGN TABLE CREATE SERVER
+ ALTER FUNCTION CREATE STATISTICS
+ ALTER GROUP CREATE SUBSCRIPTION
+ ALTER INDEX CREATE TABLE
+ ALTER LANGUAGE CREATE TABLESPACE
+ BEGIN DELETE
+ COMMIT DROP TABLE
+ CREATE DATABASE INSERT
+ CREATE INDEX ROLLBACK
+ SELECT UPDATE`
+}
diff --git a/internal/shell/psql/psql.go b/internal/shell/psql/psql.go
new file mode 100644
index 0000000..2aa8d07
--- /dev/null
+++ b/internal/shell/psql/psql.go
@@ -0,0 +1,137 @@
+package psql
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "strings"
+ "time"
+
+ "git.t-juice.club/torjus/oubliette/internal/shell"
+)
+
+const sessionTimeout = 5 * time.Minute
+
+// PsqlShell emulates a PostgreSQL psql interactive terminal.
+type PsqlShell struct{}
+
+// NewPsqlShell returns a new PsqlShell instance.
+func NewPsqlShell() *PsqlShell {
+ return &PsqlShell{}
+}
+
+func (p *PsqlShell) Name() string { return "psql" }
+func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
+
+func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
+ ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
+ defer cancel()
+
+ dbName := configString(sess.ShellConfig, "db_name", "postgres")
+ pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
+
+ // Print startup banner.
+ fmt.Fprint(rw, startupBanner(pgVersion))
+
+ var sqlBuf []string // accumulates multi-line SQL
+
+ for {
+ prompt := buildPrompt(dbName, len(sqlBuf) > 0)
+ if _, err := fmt.Fprint(rw, prompt); err != nil {
+ return nil
+ }
+
+ line, err := shell.ReadLine(ctx, rw)
+ if errors.Is(err, io.EOF) {
+ return nil
+ }
+ if err != nil {
+ return nil
+ }
+
+ trimmed := strings.TrimSpace(line)
+
+ // Empty line in non-buffering state: just re-prompt.
+ if trimmed == "" && len(sqlBuf) == 0 {
+ continue
+ }
+
+ // Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
+ if strings.HasPrefix(trimmed, `\`) {
+ sqlBuf = nil // discard any partial SQL
+
+ result := dispatchBackslash(trimmed, dbName)
+ if result.output != "" {
+ output := strings.ReplaceAll(result.output, "\n", "\r\n")
+ fmt.Fprintf(rw, "%s\r\n", output)
+ }
+
+ if sess.Store != nil {
+ if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
+ return fmt.Errorf("append session log: %w", err)
+ }
+ }
+ if sess.OnCommand != nil {
+ sess.OnCommand("psql")
+ }
+
+ if result.exit {
+ return nil
+ }
+ continue
+ }
+
+ // Accumulate SQL lines.
+ sqlBuf = append(sqlBuf, line)
+
+ // Check if the statement is terminated by a semicolon.
+ if !strings.HasSuffix(strings.TrimSpace(line), ";") {
+ continue
+ }
+
+ // Full statement ready — join and dispatch.
+ fullSQL := strings.Join(sqlBuf, " ")
+ sqlBuf = nil
+
+ result := dispatchSQL(fullSQL, dbName, pgVersion)
+ if result.output != "" {
+ output := strings.ReplaceAll(result.output, "\n", "\r\n")
+ fmt.Fprintf(rw, "%s\r\n", output)
+ }
+
+ if sess.Store != nil {
+ if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
+ return fmt.Errorf("append session log: %w", err)
+ }
+ }
+ if sess.OnCommand != nil {
+ sess.OnCommand("psql")
+ }
+
+ if result.exit {
+ return nil
+ }
+ }
+}
+
+// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
+func buildPrompt(dbName string, continuation bool) string {
+ if continuation {
+ return dbName + "-# "
+ }
+ return dbName + "=# "
+}
+
+// configString reads a string from the shell config map with a default.
+func configString(cfg map[string]any, key, defaultVal string) string {
+ if cfg == nil {
+ return defaultVal
+ }
+ if v, ok := cfg[key]; ok {
+ if s, ok := v.(string); ok && s != "" {
+ return s
+ }
+ }
+ return defaultVal
+}
diff --git a/internal/shell/psql/psql_test.go b/internal/shell/psql/psql_test.go
new file mode 100644
index 0000000..01c7baf
--- /dev/null
+++ b/internal/shell/psql/psql_test.go
@@ -0,0 +1,330 @@
+package psql
+
+import (
+ "strings"
+ "testing"
+)
+
+// --- Prompt tests ---
+
+func TestBuildPromptNormal(t *testing.T) {
+ got := buildPrompt("postgres", false)
+ if got != "postgres=# " {
+ t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
+ }
+}
+
+func TestBuildPromptContinuation(t *testing.T) {
+ got := buildPrompt("postgres", true)
+ if got != "postgres-# " {
+ t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
+ }
+}
+
+func TestBuildPromptCustomDB(t *testing.T) {
+ got := buildPrompt("mydb", false)
+ if got != "mydb=# " {
+ t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
+ }
+}
+
+// --- Backslash command dispatch tests ---
+
+func TestBackslashQuit(t *testing.T) {
+ result := dispatchBackslash(`\q`, "postgres")
+ if !result.exit {
+ t.Error("\\q should set exit=true")
+ }
+}
+
+func TestBackslashListTables(t *testing.T) {
+ result := dispatchBackslash(`\dt`, "postgres")
+ if !strings.Contains(result.output, "users") {
+ t.Error("\\dt should list tables including 'users'")
+ }
+ if !strings.Contains(result.output, "sessions") {
+ t.Error("\\dt should list tables including 'sessions'")
+ }
+}
+
+func TestBackslashDescribeTable(t *testing.T) {
+ result := dispatchBackslash(`\d users`, "postgres")
+ if !strings.Contains(result.output, "username") {
+ t.Error("\\d users should describe users table with 'username' column")
+ }
+ if !strings.Contains(result.output, "PRIMARY KEY") {
+ t.Error("\\d users should include index info")
+ }
+}
+
+func TestBackslashDescribeUnknownTable(t *testing.T) {
+ result := dispatchBackslash(`\d nonexistent`, "postgres")
+ if !strings.Contains(result.output, "Did not find") {
+ t.Error("\\d nonexistent should return not found message")
+ }
+}
+
+func TestBackslashListDatabases(t *testing.T) {
+ result := dispatchBackslash(`\l`, "postgres")
+ if !strings.Contains(result.output, "postgres") {
+ t.Error("\\l should list databases including 'postgres'")
+ }
+ if !strings.Contains(result.output, "template0") {
+ t.Error("\\l should list databases including 'template0'")
+ }
+}
+
+func TestBackslashListRoles(t *testing.T) {
+ result := dispatchBackslash(`\du`, "postgres")
+ if !strings.Contains(result.output, "postgres") {
+ t.Error("\\du should list roles including 'postgres'")
+ }
+ if !strings.Contains(result.output, "Superuser") {
+ t.Error("\\du should show Superuser attribute for postgres")
+ }
+}
+
+func TestBackslashConnInfo(t *testing.T) {
+ result := dispatchBackslash(`\conninfo`, "mydb")
+ if !strings.Contains(result.output, "mydb") {
+ t.Error("\\conninfo should include database name")
+ }
+ if !strings.Contains(result.output, "5432") {
+ t.Error("\\conninfo should include port")
+ }
+}
+
+func TestBackslashHelp(t *testing.T) {
+ result := dispatchBackslash(`\?`, "postgres")
+ if !strings.Contains(result.output, `\q`) {
+ t.Error("\\? should include \\q in help output")
+ }
+}
+
+func TestBackslashSQLHelp(t *testing.T) {
+ result := dispatchBackslash(`\h`, "postgres")
+ if !strings.Contains(result.output, "SELECT") {
+ t.Error("\\h should include SQL commands like SELECT")
+ }
+}
+
+func TestBackslashUnknown(t *testing.T) {
+ result := dispatchBackslash(`\xyz`, "postgres")
+ if !strings.Contains(result.output, "Invalid command") {
+ t.Error("unknown backslash command should return error")
+ }
+}
+
+// --- SQL dispatch tests ---
+
+func TestSQLSelectVersion(t *testing.T) {
+ result := dispatchSQL("SELECT version();", "postgres", "15.4")
+ if !strings.Contains(result.output, "15.4") {
+ t.Error("SELECT version() should contain pg version")
+ }
+ if !strings.Contains(result.output, "(1 row)") {
+ t.Error("SELECT version() should show row count")
+ }
+}
+
+func TestSQLSelectCurrentDatabase(t *testing.T) {
+ result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
+ if !strings.Contains(result.output, "mydb") {
+ t.Error("SELECT current_database() should return db name")
+ }
+}
+
+func TestSQLSelectCurrentUser(t *testing.T) {
+ result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
+ if !strings.Contains(result.output, "postgres") {
+ t.Error("SELECT current_user should return postgres")
+ }
+}
+
+func TestSQLSelectNow(t *testing.T) {
+ result := dispatchSQL("SELECT now();", "postgres", "15.4")
+ if !strings.Contains(result.output, "(1 row)") {
+ t.Error("SELECT now() should show row count")
+ }
+}
+
+func TestSQLSelectOne(t *testing.T) {
+ result := dispatchSQL("SELECT 1;", "postgres", "15.4")
+ if !strings.Contains(result.output, "1") {
+ t.Error("SELECT 1 should return 1")
+ }
+}
+
+func TestSQLInsert(t *testing.T) {
+ result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
+ if result.output != "INSERT 0 1" {
+ t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
+ }
+}
+
+func TestSQLUpdate(t *testing.T) {
+ result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
+ if result.output != "UPDATE 1" {
+ t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
+ }
+}
+
+func TestSQLDelete(t *testing.T) {
+ result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
+ if result.output != "DELETE 1" {
+ t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
+ }
+}
+
+func TestSQLCreateTable(t *testing.T) {
+ result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
+ if result.output != "CREATE TABLE" {
+ t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
+ }
+}
+
+func TestSQLCreateDatabase(t *testing.T) {
+ result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
+ if result.output != "CREATE DATABASE" {
+ t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
+ }
+}
+
+func TestSQLDropTable(t *testing.T) {
+ result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
+ if result.output != "DROP TABLE" {
+ t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
+ }
+}
+
+func TestSQLAlterTable(t *testing.T) {
+ result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
+ if result.output != "ALTER TABLE" {
+ t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
+ }
+}
+
+func TestSQLBeginCommitRollback(t *testing.T) {
+ tests := []struct {
+ sql string
+ want string
+ }{
+ {"BEGIN;", "BEGIN"},
+ {"COMMIT;", "COMMIT"},
+ {"ROLLBACK;", "ROLLBACK"},
+ }
+ for _, tt := range tests {
+ result := dispatchSQL(tt.sql, "postgres", "15.4")
+ if result.output != tt.want {
+ t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
+ }
+ }
+}
+
+func TestSQLShowServerVersion(t *testing.T) {
+ result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
+ if !strings.Contains(result.output, "15.4") {
+ t.Error("SHOW server_version should contain version")
+ }
+}
+
+func TestSQLShowSearchPath(t *testing.T) {
+ result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
+ if !strings.Contains(result.output, "public") {
+ t.Error("SHOW search_path should contain public")
+ }
+}
+
+func TestSQLSet(t *testing.T) {
+ result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
+ if result.output != "SET" {
+ t.Errorf("SET output = %q, want %q", result.output, "SET")
+ }
+}
+
+func TestSQLUnrecognized(t *testing.T) {
+ result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
+ if !strings.Contains(result.output, "ERROR") {
+ t.Error("unrecognized SQL should return error")
+ }
+ if !strings.Contains(result.output, "FOOBAR") {
+ t.Error("error should reference the offending token")
+ }
+}
+
+// --- Case insensitivity ---
+
+func TestSQLCaseInsensitive(t *testing.T) {
+ result := dispatchSQL("select version();", "postgres", "15.4")
+ if !strings.Contains(result.output, "15.4") {
+ t.Error("select version() (lowercase) should work")
+ }
+
+ result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
+ if !strings.Contains(result.output, "mydb") {
+ t.Error("mixed case SELECT should work")
+ }
+}
+
+// --- Startup banner ---
+
+func TestStartupBanner(t *testing.T) {
+ banner := startupBanner("15.4")
+ if !strings.Contains(banner, "psql (15.4)") {
+ t.Errorf("banner should contain version, got: %s", banner)
+ }
+ if !strings.Contains(banner, "help") {
+ t.Error("banner should mention help")
+ }
+}
+
+// --- configString ---
+
+func TestConfigString(t *testing.T) {
+ cfg := map[string]any{"db_name": "mydb"}
+ if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
+ t.Errorf("configString() = %q, want %q", got, "mydb")
+ }
+ if got := configString(cfg, "missing", "default"); got != "default" {
+ t.Errorf("configString() for missing = %q, want %q", got, "default")
+ }
+ if got := configString(nil, "key", "default"); got != "default" {
+ t.Errorf("configString(nil) = %q, want %q", got, "default")
+ }
+}
+
+// --- Shell metadata ---
+
+func TestShellNameAndDescription(t *testing.T) {
+ s := NewPsqlShell()
+ if s.Name() != "psql" {
+ t.Errorf("Name() = %q, want %q", s.Name(), "psql")
+ }
+ if s.Description() == "" {
+ t.Error("Description() should not be empty")
+ }
+}
+
+// --- formatSingleValue ---
+
+func TestFormatSingleValue(t *testing.T) {
+ out := formatSingleValue("?column?", "1")
+ if !strings.Contains(out, "?column?") {
+ t.Error("should contain column name")
+ }
+ if !strings.Contains(out, "1") {
+ t.Error("should contain value")
+ }
+ if !strings.Contains(out, "(1 row)") {
+ t.Error("should contain row count")
+ }
+}
+
+// --- \d with no args ---
+
+func TestBackslashDescribeNoArgs(t *testing.T) {
+ result := dispatchBackslash(`\d`, "postgres")
+ if !strings.Contains(result.output, "users") {
+ t.Error("\\d with no args should list tables")
+ }
+}
diff --git a/oubliette.toml.example b/oubliette.toml.example
index 3960dd4..c4b0254 100644
--- a/oubliette.toml.example
+++ b/oubliette.toml.example
@@ -50,6 +50,12 @@ hostname = "ubuntu-server"
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
# fake_user = "" # override username in prompt; empty = use authenticated user
+# Map usernames to specific shells (regardless of how auth succeeded).
+# Credential-specific shell overrides take priority over username routes.
+# [shell.username_routes]
+# postgres = "psql"
+# admin = "bash"
+
# Per-shell configuration (optional).
# [shell.banking]
# bank_name = "SECUREBANK"
@@ -65,6 +71,10 @@ hostname = "ubuntu-server"
# ios_version = "15.0(2)SE11"
# enable_password = "" # empty = accept after 1 failed attempt
+# [shell.psql]
+# db_name = "postgres"
+# pg_version = "15.4"
+
# [detection]
# enabled = true
# threshold = 0.6 # 0.0–1.0, sessions above this trigger notifications