feat: add psql shell and username-to-shell routing
Add a PostgreSQL psql interactive terminal shell with backslash meta-commands, SQL statement handling with multi-line buffering, and canned responses for common queries. Add username-based shell routing via [shell.username_routes] config (second priority after credential- specific shell, before random selection). Bump version to 0.13.0. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
330
internal/shell/psql/psql_test.go
Normal file
330
internal/shell/psql/psql_test.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package psql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- Prompt tests ---
|
||||
|
||||
func TestBuildPromptNormal(t *testing.T) {
|
||||
got := buildPrompt("postgres", false)
|
||||
if got != "postgres=# " {
|
||||
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptContinuation(t *testing.T) {
|
||||
got := buildPrompt("postgres", true)
|
||||
if got != "postgres-# " {
|
||||
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPromptCustomDB(t *testing.T) {
|
||||
got := buildPrompt("mydb", false)
|
||||
if got != "mydb=# " {
|
||||
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Backslash command dispatch tests ---
|
||||
|
||||
func TestBackslashQuit(t *testing.T) {
|
||||
result := dispatchBackslash(`\q`, "postgres")
|
||||
if !result.exit {
|
||||
t.Error("\\q should set exit=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListTables(t *testing.T) {
|
||||
result := dispatchBackslash(`\dt`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\dt should list tables including 'users'")
|
||||
}
|
||||
if !strings.Contains(result.output, "sessions") {
|
||||
t.Error("\\dt should list tables including 'sessions'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d users`, "postgres")
|
||||
if !strings.Contains(result.output, "username") {
|
||||
t.Error("\\d users should describe users table with 'username' column")
|
||||
}
|
||||
if !strings.Contains(result.output, "PRIMARY KEY") {
|
||||
t.Error("\\d users should include index info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashDescribeUnknownTable(t *testing.T) {
|
||||
result := dispatchBackslash(`\d nonexistent`, "postgres")
|
||||
if !strings.Contains(result.output, "Did not find") {
|
||||
t.Error("\\d nonexistent should return not found message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListDatabases(t *testing.T) {
|
||||
result := dispatchBackslash(`\l`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\l should list databases including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "template0") {
|
||||
t.Error("\\l should list databases including 'template0'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashListRoles(t *testing.T) {
|
||||
result := dispatchBackslash(`\du`, "postgres")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("\\du should list roles including 'postgres'")
|
||||
}
|
||||
if !strings.Contains(result.output, "Superuser") {
|
||||
t.Error("\\du should show Superuser attribute for postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashConnInfo(t *testing.T) {
|
||||
result := dispatchBackslash(`\conninfo`, "mydb")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("\\conninfo should include database name")
|
||||
}
|
||||
if !strings.Contains(result.output, "5432") {
|
||||
t.Error("\\conninfo should include port")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\?`, "postgres")
|
||||
if !strings.Contains(result.output, `\q`) {
|
||||
t.Error("\\? should include \\q in help output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashSQLHelp(t *testing.T) {
|
||||
result := dispatchBackslash(`\h`, "postgres")
|
||||
if !strings.Contains(result.output, "SELECT") {
|
||||
t.Error("\\h should include SQL commands like SELECT")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackslashUnknown(t *testing.T) {
|
||||
result := dispatchBackslash(`\xyz`, "postgres")
|
||||
if !strings.Contains(result.output, "Invalid command") {
|
||||
t.Error("unknown backslash command should return error")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SQL dispatch tests ---
|
||||
|
||||
func TestSQLSelectVersion(t *testing.T) {
|
||||
result := dispatchSQL("SELECT version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SELECT version() should contain pg version")
|
||||
}
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT version() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentDatabase(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("SELECT current_database() should return db name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectCurrentUser(t *testing.T) {
|
||||
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "postgres") {
|
||||
t.Error("SELECT current_user should return postgres")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectNow(t *testing.T) {
|
||||
result := dispatchSQL("SELECT now();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "(1 row)") {
|
||||
t.Error("SELECT now() should show row count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSelectOne(t *testing.T) {
|
||||
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "1") {
|
||||
t.Error("SELECT 1 should return 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLInsert(t *testing.T) {
|
||||
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
|
||||
if result.output != "INSERT 0 1" {
|
||||
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUpdate(t *testing.T) {
|
||||
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
|
||||
if result.output != "UPDATE 1" {
|
||||
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDelete(t *testing.T) {
|
||||
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
|
||||
if result.output != "DELETE 1" {
|
||||
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateTable(t *testing.T) {
|
||||
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
|
||||
if result.output != "CREATE TABLE" {
|
||||
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCreateDatabase(t *testing.T) {
|
||||
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
|
||||
if result.output != "CREATE DATABASE" {
|
||||
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLDropTable(t *testing.T) {
|
||||
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
|
||||
if result.output != "DROP TABLE" {
|
||||
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLAlterTable(t *testing.T) {
|
||||
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
|
||||
if result.output != "ALTER TABLE" {
|
||||
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLBeginCommitRollback(t *testing.T) {
|
||||
tests := []struct {
|
||||
sql string
|
||||
want string
|
||||
}{
|
||||
{"BEGIN;", "BEGIN"},
|
||||
{"COMMIT;", "COMMIT"},
|
||||
{"ROLLBACK;", "ROLLBACK"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := dispatchSQL(tt.sql, "postgres", "15.4")
|
||||
if result.output != tt.want {
|
||||
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowServerVersion(t *testing.T) {
|
||||
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("SHOW server_version should contain version")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLShowSearchPath(t *testing.T) {
|
||||
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "public") {
|
||||
t.Error("SHOW search_path should contain public")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLSet(t *testing.T) {
|
||||
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
|
||||
if result.output != "SET" {
|
||||
t.Errorf("SET output = %q, want %q", result.output, "SET")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLUnrecognized(t *testing.T) {
|
||||
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "ERROR") {
|
||||
t.Error("unrecognized SQL should return error")
|
||||
}
|
||||
if !strings.Contains(result.output, "FOOBAR") {
|
||||
t.Error("error should reference the offending token")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Case insensitivity ---
|
||||
|
||||
func TestSQLCaseInsensitive(t *testing.T) {
|
||||
result := dispatchSQL("select version();", "postgres", "15.4")
|
||||
if !strings.Contains(result.output, "15.4") {
|
||||
t.Error("select version() (lowercase) should work")
|
||||
}
|
||||
|
||||
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
|
||||
if !strings.Contains(result.output, "mydb") {
|
||||
t.Error("mixed case SELECT should work")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Startup banner ---
|
||||
|
||||
func TestStartupBanner(t *testing.T) {
|
||||
banner := startupBanner("15.4")
|
||||
if !strings.Contains(banner, "psql (15.4)") {
|
||||
t.Errorf("banner should contain version, got: %s", banner)
|
||||
}
|
||||
if !strings.Contains(banner, "help") {
|
||||
t.Error("banner should mention help")
|
||||
}
|
||||
}
|
||||
|
||||
// --- configString ---
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{"db_name": "mydb"}
|
||||
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
|
||||
t.Errorf("configString() = %q, want %q", got, "mydb")
|
||||
}
|
||||
if got := configString(cfg, "missing", "default"); got != "default" {
|
||||
t.Errorf("configString() for missing = %q, want %q", got, "default")
|
||||
}
|
||||
if got := configString(nil, "key", "default"); got != "default" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "default")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Shell metadata ---
|
||||
|
||||
func TestShellNameAndDescription(t *testing.T) {
|
||||
s := NewPsqlShell()
|
||||
if s.Name() != "psql" {
|
||||
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
|
||||
}
|
||||
if s.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// --- formatSingleValue ---
|
||||
|
||||
func TestFormatSingleValue(t *testing.T) {
|
||||
out := formatSingleValue("?column?", "1")
|
||||
if !strings.Contains(out, "?column?") {
|
||||
t.Error("should contain column name")
|
||||
}
|
||||
if !strings.Contains(out, "1") {
|
||||
t.Error("should contain value")
|
||||
}
|
||||
if !strings.Contains(out, "(1 row)") {
|
||||
t.Error("should contain row count")
|
||||
}
|
||||
}
|
||||
|
||||
// --- \d with no args ---
|
||||
|
||||
func TestBackslashDescribeNoArgs(t *testing.T) {
|
||||
result := dispatchBackslash(`\d`, "postgres")
|
||||
if !strings.Contains(result.output, "users") {
|
||||
t.Error("\\d with no args should list tables")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user