feat: add psql shell and username-to-shell routing
Add a PostgreSQL psql interactive terminal shell with backslash meta-commands, SQL statement handling with multi-line buffering, and canned responses for common queries. Add username-based shell routing via [shell.username_routes] config (second priority after credential- specific shell, before random selection). Bump version to 0.13.0. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -28,10 +28,11 @@ type WebConfig struct {
|
||||
}
|
||||
|
||||
type ShellConfig struct {
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
UsernameRoutes map[string]string `toml:"username_routes"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
@@ -165,9 +166,10 @@ func applyDefaults(cfg *Config) {
|
||||
|
||||
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
|
||||
var knownShellKeys = map[string]bool{
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
"username_routes": true,
|
||||
}
|
||||
|
||||
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
|
||||
|
||||
@@ -313,6 +313,42 @@ func TestLoadInvalidTOML(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadUsernameRoutes(t *testing.T) {
|
||||
content := `
|
||||
[shell]
|
||||
hostname = "myhost"
|
||||
|
||||
[shell.username_routes]
|
||||
postgres = "psql"
|
||||
admin = "bash"
|
||||
|
||||
[shell.bash]
|
||||
custom_key = "value"
|
||||
`
|
||||
path := writeTemp(t, content)
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes == nil {
|
||||
t.Fatal("UsernameRoutes should not be nil")
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes["postgres"] != "psql" {
|
||||
t.Errorf("UsernameRoutes[\"postgres\"] = %q, want %q", cfg.Shell.UsernameRoutes["postgres"], "psql")
|
||||
}
|
||||
if cfg.Shell.UsernameRoutes["admin"] != "bash" {
|
||||
t.Errorf("UsernameRoutes[\"admin\"] = %q, want %q", cfg.Shell.UsernameRoutes["admin"], "bash")
|
||||
}
|
||||
// username_routes should NOT appear in the Shells map.
|
||||
if _, ok := cfg.Shell.Shells["username_routes"]; ok {
|
||||
t.Error("username_routes should not appear in Shells map")
|
||||
}
|
||||
// bash should still appear in Shells map.
|
||||
if _, ok := cfg.Shell.Shells["bash"]; !ok {
|
||||
t.Error("Shells[\"bash\"] should still be present")
|
||||
}
|
||||
}
|
||||
|
||||
func writeTemp(t *testing.T, content string) string {
|
||||
t.Helper()
|
||||
path := filepath.Join(t.TempDir(), "config.toml")
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
psqlshell "git.t-juice.club/torjus/oubliette/internal/shell/psql"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
@@ -58,6 +59,9 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics
|
||||
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering cisco shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering psql shell: %w", err)
|
||||
}
|
||||
|
||||
geo, err := geoip.New()
|
||||
if err != nil {
|
||||
@@ -185,6 +189,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
|
||||
}
|
||||
}
|
||||
// Second priority: username-based route.
|
||||
if selectedShell == nil {
|
||||
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
|
||||
sh, found := s.shellRegistry.Get(shellName)
|
||||
if found {
|
||||
selectedShell = sh
|
||||
} else {
|
||||
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
|
||||
}
|
||||
}
|
||||
}
|
||||
// Lowest priority: random selection.
|
||||
if selectedShell == nil {
|
||||
var err error
|
||||
selectedShell, err = s.shellRegistry.Select()
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
@@ -300,6 +301,89 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Test username route: add username_routes so that "postgres" gets psql shell.
|
||||
t.Run("username_route", func(t *testing.T) {
|
||||
// Reconfigure with username routes.
|
||||
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
|
||||
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
|
||||
|
||||
// Need to get the "postgres" user in via static creds or threshold.
|
||||
// Use static creds for simplicity.
|
||||
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
|
||||
config.Credential{Username: "postgres", Password: "postgres"},
|
||||
)
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
defer func() {
|
||||
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
}()
|
||||
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "postgres",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SSH dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||
t.Fatalf("request pty: %v", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("stdin pipe: %v", err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
session.Stdout = &output
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("shell: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the psql banner.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Send \q to quit.
|
||||
stdin.Write([]byte(`\q` + "\r"))
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
session.Wait()
|
||||
|
||||
out := output.String()
|
||||
if !strings.Contains(out, "psql") {
|
||||
t.Errorf("output should contain psql banner, got: %s", out)
|
||||
}
|
||||
|
||||
// Verify session was created with shell name "psql".
|
||||
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
var foundPsql bool
|
||||
for _, s := range sessions {
|
||||
if s.ShellName == "psql" && s.Username == "postgres" {
|
||||
foundPsql = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundPsql {
|
||||
t.Error("expected a session with shell_name='psql' for user 'postgres'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test threshold acceptance: after enough failed dials, a subsequent
|
||||
// dial with the same credentials should succeed via threshold or
|
||||
// remembered credential.
|
||||
|
||||
123
internal/shell/psql/commands.go
Normal file
123
internal/shell/psql/commands.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// commandResult holds the output of a command and whether the session should end.
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
// dispatchBackslash handles psql backslash meta-commands.
|
||||
func dispatchBackslash(cmd, dbName string) commandResult {
|
||||
// Normalize: trim spaces after the backslash command word.
|
||||
parts := strings.Fields(cmd)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{output: "Invalid command \\. Try \\? for help."}
|
||||
}
|
||||
|
||||
verb := parts[0] // e.g. `\q`, `\dt`, `\d`
|
||||
args := parts[1:]
|
||||
|
||||
switch verb {
|
||||
case `\q`:
|
||||
return commandResult{exit: true}
|
||||
case `\dt`:
|
||||
return commandResult{output: listTables()}
|
||||
case `\d`:
|
||||
if len(args) == 0 {
|
||||
return commandResult{output: listTables()}
|
||||
}
|
||||
return commandResult{output: describeTable(args[0])}
|
||||
case `\l`:
|
||||
return commandResult{output: listDatabases()}
|
||||
case `\du`:
|
||||
return commandResult{output: listRoles()}
|
||||
case `\conninfo`:
|
||||
return commandResult{output: connInfo(dbName)}
|
||||
case `\?`:
|
||||
return commandResult{output: backslashHelp()}
|
||||
case `\h`:
|
||||
return commandResult{output: sqlHelp()}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)}
|
||||
}
|
||||
}
|
||||
|
||||
// dispatchSQL handles SQL statements (already accumulated and semicolon-terminated).
|
||||
func dispatchSQL(sql, dbName, pgVersion string) commandResult {
|
||||
// Strip trailing semicolon and whitespace for matching.
|
||||
trimmed := strings.TrimRight(sql, "; \t")
|
||||
trimmed = strings.TrimSpace(trimmed)
|
||||
upper := strings.ToUpper(trimmed)
|
||||
|
||||
switch {
|
||||
case upper == "SELECT VERSION()":
|
||||
ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion)
|
||||
return commandResult{output: formatSingleValue("version", ver)}
|
||||
case upper == "SELECT CURRENT_DATABASE()":
|
||||
return commandResult{output: formatSingleValue("current_database", dbName)}
|
||||
case upper == "SELECT CURRENT_USER":
|
||||
return commandResult{output: formatSingleValue("current_user", "postgres")}
|
||||
case upper == "SELECT NOW()":
|
||||
now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00")
|
||||
return commandResult{output: formatSingleValue("now", now)}
|
||||
case upper == "SELECT 1":
|
||||
return commandResult{output: formatSingleValue("?column?", "1")}
|
||||
case strings.HasPrefix(upper, "INSERT"):
|
||||
return commandResult{output: "INSERT 0 1"}
|
||||
case strings.HasPrefix(upper, "UPDATE"):
|
||||
return commandResult{output: "UPDATE 1"}
|
||||
case strings.HasPrefix(upper, "DELETE"):
|
||||
return commandResult{output: "DELETE 1"}
|
||||
case strings.HasPrefix(upper, "CREATE TABLE"):
|
||||
return commandResult{output: "CREATE TABLE"}
|
||||
case strings.HasPrefix(upper, "CREATE DATABASE"):
|
||||
return commandResult{output: "CREATE DATABASE"}
|
||||
case strings.HasPrefix(upper, "DROP TABLE"):
|
||||
return commandResult{output: "DROP TABLE"}
|
||||
case strings.HasPrefix(upper, "ALTER TABLE"):
|
||||
return commandResult{output: "ALTER TABLE"}
|
||||
case upper == "BEGIN":
|
||||
return commandResult{output: "BEGIN"}
|
||||
case upper == "COMMIT":
|
||||
return commandResult{output: "COMMIT"}
|
||||
case upper == "ROLLBACK":
|
||||
return commandResult{output: "ROLLBACK"}
|
||||
case upper == "SHOW SERVER_VERSION":
|
||||
return commandResult{output: formatSingleValue("server_version", pgVersion)}
|
||||
case upper == "SHOW SEARCH_PATH":
|
||||
return commandResult{output: formatSingleValue("search_path", "\"$user\", public")}
|
||||
case strings.HasPrefix(upper, "SET "):
|
||||
return commandResult{output: "SET"}
|
||||
default:
|
||||
// Extract the first token for the error message.
|
||||
firstToken := strings.Fields(trimmed)
|
||||
token := trimmed
|
||||
if len(firstToken) > 0 {
|
||||
token = firstToken[0]
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)}
|
||||
}
|
||||
}
|
||||
|
||||
// formatSingleValue formats a single-row, single-column psql result.
|
||||
func formatSingleValue(colName, value string) string {
|
||||
width := max(len(colName), len(value))
|
||||
|
||||
var b strings.Builder
|
||||
// Header
|
||||
fmt.Fprintf(&b, " %-*s \n", width, colName)
|
||||
// Separator
|
||||
b.WriteString(strings.Repeat("-", width+2))
|
||||
b.WriteString("\n")
|
||||
// Value
|
||||
fmt.Fprintf(&b, " %-*s\n", width, value)
|
||||
// Row count
|
||||
b.WriteString("(1 row)")
|
||||
return b.String()
|
||||
}
|
||||
155
internal/shell/psql/output.go
Normal file
155
internal/shell/psql/output.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package psql
|
||||
|
||||
import "fmt"
|
||||
|
||||
func startupBanner(version string) string {
|
||||
return fmt.Sprintf("psql (%s)\nType \"help\" for help.\n", version)
|
||||
}
|
||||
|
||||
func listTables() string {
|
||||
return ` List of relations
|
||||
Schema | Name | Type | Owner
|
||||
--------+---------------+-------+----------
|
||||
public | audit_log | table | postgres
|
||||
public | credentials | table | postgres
|
||||
public | sessions | table | postgres
|
||||
public | users | table | postgres
|
||||
(4 rows)`
|
||||
}
|
||||
|
||||
func listDatabases() string {
|
||||
return ` List of databases
|
||||
Name | Owner | Encoding | Collate | Ctype | Access privileges
|
||||
-----------+----------+----------+-------------+-------------+-----------------------
|
||||
app_db | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||
postgres | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 |
|
||||
template0 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||
| | | | | postgres=CTc/postgres
|
||||
template1 | postgres | UTF8 | en_US.UTF-8 | en_US.UTF-8 | =c/postgres +
|
||||
| | | | | postgres=CTc/postgres
|
||||
(4 rows)`
|
||||
}
|
||||
|
||||
func listRoles() string {
|
||||
return ` List of roles
|
||||
Role name | Attributes | Member of
|
||||
-----------+------------------------------------------------------------+-----------
|
||||
app_user | | {}
|
||||
postgres | Superuser, Create role, Create DB, Replication, Bypass RLS | {}
|
||||
readonly | Cannot login | {}`
|
||||
}
|
||||
|
||||
func describeTable(name string) string {
|
||||
switch name {
|
||||
case "users":
|
||||
return ` Table "public.users"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+-----------------------------------
|
||||
id | integer | | not null | nextval('users_id_seq'::regclass)
|
||||
username | character varying(255) | | not null |
|
||||
email | character varying(255) | | not null |
|
||||
password | character varying(255) | | not null |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
updated_at | timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"users_pkey" PRIMARY KEY, btree (id)
|
||||
"users_email_key" UNIQUE, btree (email)
|
||||
"users_username_key" UNIQUE, btree (username)`
|
||||
case "sessions":
|
||||
return ` Table "public.sessions"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+--------------------------------------
|
||||
id | integer | | not null | nextval('sessions_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
token | character varying(255) | | not null |
|
||||
ip_address | inet | | |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
expires_at | timestamp without time zone | | not null |
|
||||
Indexes:
|
||||
"sessions_pkey" PRIMARY KEY, btree (id)
|
||||
"sessions_token_key" UNIQUE, btree (token)
|
||||
Foreign-key constraints:
|
||||
"sessions_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
case "credentials":
|
||||
return ` Table "public.credentials"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
-----------+-----------------------------+-----------+----------+-----------------------------------------
|
||||
id | integer | | not null | nextval('credentials_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
type | character varying(50) | | not null |
|
||||
value | text | | not null |
|
||||
created_at| timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"credentials_pkey" PRIMARY KEY, btree (id)
|
||||
Foreign-key constraints:
|
||||
"credentials_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
case "audit_log":
|
||||
return ` Table "public.audit_log"
|
||||
Column | Type | Collation | Nullable | Default
|
||||
------------+-----------------------------+-----------+----------+---------------------------------------
|
||||
id | integer | | not null | nextval('audit_log_id_seq'::regclass)
|
||||
user_id | integer | | |
|
||||
action | character varying(100) | | not null |
|
||||
details | text | | |
|
||||
ip_address | inet | | |
|
||||
created_at | timestamp without time zone | | | now()
|
||||
Indexes:
|
||||
"audit_log_pkey" PRIMARY KEY, btree (id)
|
||||
Foreign-key constraints:
|
||||
"audit_log_user_id_fkey" FOREIGN KEY (user_id) REFERENCES users(id)`
|
||||
default:
|
||||
return fmt.Sprintf("Did not find any relation named \"%s\".", name)
|
||||
}
|
||||
}
|
||||
|
||||
func connInfo(dbName string) string {
|
||||
return fmt.Sprintf("You are connected to database \"%s\" as user \"postgres\" via socket in \"/var/run/postgresql\" at port \"5432\".", dbName)
|
||||
}
|
||||
|
||||
func backslashHelp() string {
|
||||
return `General
|
||||
\copyright show PostgreSQL usage and distribution terms
|
||||
\crosstabview [COLUMNS] execute query and display result in crosstab
|
||||
\errverbose show most recent error message at maximum verbosity
|
||||
\g [(OPTIONS)] [FILE] execute query (and send result to file or |pipe)
|
||||
\gdesc describe result of query, without executing it
|
||||
\gexec execute query, then execute each value in its result
|
||||
\gset [PREFIX] execute query and store result in psql variables
|
||||
\gx [(OPTIONS)] [FILE] as \g, but forces expanded output mode
|
||||
\q quit psql
|
||||
\watch [SEC] execute query every SEC seconds
|
||||
|
||||
Informational
|
||||
(options: S = show system objects, + = additional detail)
|
||||
\d[S+] list tables, views, and sequences
|
||||
\d[S+] NAME describe table, view, sequence, or index
|
||||
\da[S] [PATTERN] list aggregates
|
||||
\dA[+] [PATTERN] list access methods
|
||||
\dt[S+] [PATTERN] list tables
|
||||
\du[S+] [PATTERN] list roles
|
||||
\l[+] [PATTERN] list databases`
|
||||
}
|
||||
|
||||
func sqlHelp() string {
|
||||
return `Available help:
|
||||
ABORT CREATE LANGUAGE
|
||||
ALTER AGGREGATE CREATE MATERIALIZED VIEW
|
||||
ALTER COLLATION CREATE OPERATOR
|
||||
ALTER CONVERSION CREATE POLICY
|
||||
ALTER DATABASE CREATE PROCEDURE
|
||||
ALTER DEFAULT PRIVILEGES CREATE PUBLICATION
|
||||
ALTER DOMAIN CREATE ROLE
|
||||
ALTER EVENT TRIGGER CREATE RULE
|
||||
ALTER EXTENSION CREATE SCHEMA
|
||||
ALTER FOREIGN DATA WRAPPER CREATE SEQUENCE
|
||||
ALTER FOREIGN TABLE CREATE SERVER
|
||||
ALTER FUNCTION CREATE STATISTICS
|
||||
ALTER GROUP CREATE SUBSCRIPTION
|
||||
ALTER INDEX CREATE TABLE
|
||||
ALTER LANGUAGE CREATE TABLESPACE
|
||||
BEGIN DELETE
|
||||
COMMIT DROP TABLE
|
||||
CREATE DATABASE INSERT
|
||||
CREATE INDEX ROLLBACK
|
||||
SELECT UPDATE`
|
||||
}
|
||||
137
internal/shell/psql/psql.go
Normal file
137
internal/shell/psql/psql.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// PsqlShell emulates a PostgreSQL psql interactive terminal.
|
||||
type PsqlShell struct{}
|
||||
|
||||
// NewPsqlShell returns a new PsqlShell instance.
|
||||
func NewPsqlShell() *PsqlShell {
|
||||
return &PsqlShell{}
|
||||
}
|
||||
|
||||
func (p *PsqlShell) Name() string { return "psql" }
|
||||
func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" }
|
||||
|
||||
func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dbName := configString(sess.ShellConfig, "db_name", "postgres")
|
||||
pgVersion := configString(sess.ShellConfig, "pg_version", "15.4")
|
||||
|
||||
// Print startup banner.
|
||||
fmt.Fprint(rw, startupBanner(pgVersion))
|
||||
|
||||
var sqlBuf []string // accumulates multi-line SQL
|
||||
|
||||
for {
|
||||
prompt := buildPrompt(dbName, len(sqlBuf) > 0)
|
||||
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
|
||||
// Empty line in non-buffering state: just re-prompt.
|
||||
if trimmed == "" && len(sqlBuf) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Backslash commands dispatch immediately (even mid-buffer they cancel the buffer).
|
||||
if strings.HasPrefix(trimmed, `\`) {
|
||||
sqlBuf = nil // discard any partial SQL
|
||||
|
||||
result := dispatchBackslash(trimmed, dbName)
|
||||
if result.output != "" {
|
||||
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("psql")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate SQL lines.
|
||||
sqlBuf = append(sqlBuf, line)
|
||||
|
||||
// Check if the statement is terminated by a semicolon.
|
||||
if !strings.HasSuffix(strings.TrimSpace(line), ";") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Full statement ready — join and dispatch.
|
||||
fullSQL := strings.Join(sqlBuf, " ")
|
||||
sqlBuf = nil
|
||||
|
||||
result := dispatchSQL(fullSQL, dbName, pgVersion)
|
||||
if result.output != "" {
|
||||
output := strings.ReplaceAll(result.output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("psql")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL.
|
||||
func buildPrompt(dbName string, continuation bool) string {
|
||||
if continuation {
|
||||
return dbName + "-# "
|
||||
}
|
||||
return dbName + "=# "
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
330
internal/shell/psql/psql_test.go
Normal file
330
internal/shell/psql/psql_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Prompt tests ---
|
||||
|
||||
func TestBuildPromptNormal(t *testing.T) {
|
||||
got := buildPrompt("postgres", false)
|
||||
if got != "postgres=# " {
|
||||
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptContinuation(t *testing.T) {
|
||||
got := buildPrompt("postgres", true)
|
||||
if got != "postgres-# " {
|
||||
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptCustomDB(t *testing.T) {
|
||||
got := buildPrompt("mydb", false)
|
||||
if got != "mydb=# " {
|
||||
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Backslash command dispatch tests ---
|
||||
|
||||
func TestBackslashQuit(t *testing.T) {
|
||||
result := dispatchBackslash(`\q`, "postgres")
|
||||
if !result.exit {
|
||||
t.Error("\\q should set exit=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListTables(t *testing.T) {
|
||||
result := dispatchBackslash(`\dt`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\dt should list tables including 'users'")
|
||||
}
|
||||
if !strings.Contains(result.output, "sessions") {
|
||||
t.Error("\\dt should list tables including 'sessions'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d users`, "postgres")
|
||||
if !strings.Contains(result.output, "username") {
|
||||
t.Error("\\d users should describe users table with 'username' column")
|
||||
}
|
||||
if !strings.Contains(result.output, "PRIMARY KEY") {
|
||||
t.Error("\\d users should include index info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeUnknownTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d nonexistent`, "postgres")
|
||||
if !strings.Contains(result.output, "Did not find") {
|
||||
t.Error("\\d nonexistent should return not found message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListDatabases(t *testing.T) {
|
||||
result := dispatchBackslash(`\l`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\l should list databases including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "template0") {
|
||||
t.Error("\\l should list databases including 'template0'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListRoles(t *testing.T) {
|
||||
result := dispatchBackslash(`\du`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\du should list roles including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "Superuser") {
|
||||
t.Error("\\du should show Superuser attribute for postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashConnInfo(t *testing.T) {
|
||||
result := dispatchBackslash(`\conninfo`, "mydb")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("\\conninfo should include database name")
|
||||
}
|
||||
if !strings.Contains(result.output, "5432") {
|
||||
t.Error("\\conninfo should include port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\?`, "postgres")
|
||||
if !strings.Contains(result.output, `\q`) {
|
||||
t.Error("\\? should include \\q in help output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashSQLHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\h`, "postgres")
|
||||
if !strings.Contains(result.output, "SELECT") {
|
||||
t.Error("\\h should include SQL commands like SELECT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashUnknown(t *testing.T) {
|
||||
result := dispatchBackslash(`\xyz`, "postgres")
|
||||
if !strings.Contains(result.output, "Invalid command") {
|
||||
t.Error("unknown backslash command should return error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SQL dispatch tests ---
|
||||
|
||||
func TestSQLSelectVersion(t *testing.T) {
|
||||
result := dispatchSQL("SELECT version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SELECT version() should contain pg version")
|
||||
}
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT version() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentDatabase(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("SELECT current_database() should return db name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentUser(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("SELECT current_user should return postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectNow(t *testing.T) {
|
||||
result := dispatchSQL("SELECT now();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT now() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectOne(t *testing.T) {
|
||||
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "1") {
|
||||
t.Error("SELECT 1 should return 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLInsert(t *testing.T) {
|
||||
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
|
||||
if result.output != "INSERT 0 1" {
|
||||
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUpdate(t *testing.T) {
|
||||
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
|
||||
if result.output != "UPDATE 1" {
|
||||
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDelete(t *testing.T) {
|
||||
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
|
||||
if result.output != "DELETE 1" {
|
||||
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateTable(t *testing.T) {
|
||||
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
|
||||
if result.output != "CREATE TABLE" {
|
||||
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateDatabase(t *testing.T) {
|
||||
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
|
||||
if result.output != "CREATE DATABASE" {
|
||||
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDropTable(t *testing.T) {
|
||||
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
|
||||
if result.output != "DROP TABLE" {
|
||||
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLAlterTable(t *testing.T) {
|
||||
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
|
||||
if result.output != "ALTER TABLE" {
|
||||
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLBeginCommitRollback(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
want string
|
||||
}{
|
||||
{"BEGIN;", "BEGIN"},
|
||||
{"COMMIT;", "COMMIT"},
|
||||
{"ROLLBACK;", "ROLLBACK"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := dispatchSQL(tt.sql, "postgres", "15.4")
|
||||
if result.output != tt.want {
|
||||
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowServerVersion(t *testing.T) {
|
||||
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SHOW server_version should contain version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowSearchPath(t *testing.T) {
|
||||
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "public") {
|
||||
t.Error("SHOW search_path should contain public")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSet(t *testing.T) {
|
||||
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
|
||||
if result.output != "SET" {
|
||||
t.Errorf("SET output = %q, want %q", result.output, "SET")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUnrecognized(t *testing.T) {
|
||||
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "ERROR") {
|
||||
t.Error("unrecognized SQL should return error")
|
||||
}
|
||||
if !strings.Contains(result.output, "FOOBAR") {
|
||||
t.Error("error should reference the offending token")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Case insensitivity ---
|
||||
|
||||
func TestSQLCaseInsensitive(t *testing.T) {
|
||||
result := dispatchSQL("select version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("select version() (lowercase) should work")
|
||||
}
|
||||
|
||||
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("mixed case SELECT should work")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Startup banner ---
|
||||
|
||||
func TestStartupBanner(t *testing.T) {
|
||||
banner := startupBanner("15.4")
|
||||
if !strings.Contains(banner, "psql (15.4)") {
|
||||
t.Errorf("banner should contain version, got: %s", banner)
|
||||
}
|
||||
if !strings.Contains(banner, "help") {
|
||||
t.Error("banner should mention help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- configString ---
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"db_name": "mydb"}
|
||||
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
|
||||
t.Errorf("configString() = %q, want %q", got, "mydb")
|
||||
}
|
||||
if got := configString(cfg, "missing", "default"); got != "default" {
|
||||
t.Errorf("configString() for missing = %q, want %q", got, "default")
|
||||
}
|
||||
if got := configString(nil, "key", "default"); got != "default" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shell metadata ---
|
||||
|
||||
func TestShellNameAndDescription(t *testing.T) {
|
||||
s := NewPsqlShell()
|
||||
if s.Name() != "psql" {
|
||||
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
|
||||
}
|
||||
if s.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// --- formatSingleValue ---
|
||||
|
||||
func TestFormatSingleValue(t *testing.T) {
|
||||
out := formatSingleValue("?column?", "1")
|
||||
if !strings.Contains(out, "?column?") {
|
||||
t.Error("should contain column name")
|
||||
}
|
||||
if !strings.Contains(out, "1") {
|
||||
t.Error("should contain value")
|
||||
}
|
||||
if !strings.Contains(out, "(1 row)") {
|
||||
t.Error("should contain row count")
|
||||
}
|
||||
}
|
||||
|
||||
// --- \d with no args ---
|
||||
|
||||
func TestBackslashDescribeNoArgs(t *testing.T) {
|
||||
result := dispatchBackslash(`\d`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\d with no args should list tables")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user