This repository has been archived on 2026-03-09. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
oubliette/internal/shell/psql/psql_test.go
Torjus Håkestad 40fda3420c feat: add psql shell and username-to-shell routing
Add a PostgreSQL psql interactive terminal shell with backslash
meta-commands, SQL statement handling with multi-line buffering, and
canned responses for common queries. Add username-based shell routing
via [shell.username_routes] config (second priority after credential-
specific shell, before random selection). Bump version to 0.13.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 19:58:34 +01:00

331 lines
9.3 KiB
Go

package psql
import (
"strings"
"testing"
)
// --- Prompt tests ---
func TestBuildPromptNormal(t *testing.T) {
got := buildPrompt("postgres", false)
if got != "postgres=# " {
t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ")
}
}
func TestBuildPromptContinuation(t *testing.T) {
got := buildPrompt("postgres", true)
if got != "postgres-# " {
t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ")
}
}
func TestBuildPromptCustomDB(t *testing.T) {
got := buildPrompt("mydb", false)
if got != "mydb=# " {
t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ")
}
}
// --- Backslash command dispatch tests ---
func TestBackslashQuit(t *testing.T) {
result := dispatchBackslash(`\q`, "postgres")
if !result.exit {
t.Error("\\q should set exit=true")
}
}
func TestBackslashListTables(t *testing.T) {
result := dispatchBackslash(`\dt`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\dt should list tables including 'users'")
}
if !strings.Contains(result.output, "sessions") {
t.Error("\\dt should list tables including 'sessions'")
}
}
func TestBackslashDescribeTable(t *testing.T) {
result := dispatchBackslash(`\d users`, "postgres")
if !strings.Contains(result.output, "username") {
t.Error("\\d users should describe users table with 'username' column")
}
if !strings.Contains(result.output, "PRIMARY KEY") {
t.Error("\\d users should include index info")
}
}
func TestBackslashDescribeUnknownTable(t *testing.T) {
result := dispatchBackslash(`\d nonexistent`, "postgres")
if !strings.Contains(result.output, "Did not find") {
t.Error("\\d nonexistent should return not found message")
}
}
func TestBackslashListDatabases(t *testing.T) {
result := dispatchBackslash(`\l`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\l should list databases including 'postgres'")
}
if !strings.Contains(result.output, "template0") {
t.Error("\\l should list databases including 'template0'")
}
}
func TestBackslashListRoles(t *testing.T) {
result := dispatchBackslash(`\du`, "postgres")
if !strings.Contains(result.output, "postgres") {
t.Error("\\du should list roles including 'postgres'")
}
if !strings.Contains(result.output, "Superuser") {
t.Error("\\du should show Superuser attribute for postgres")
}
}
func TestBackslashConnInfo(t *testing.T) {
result := dispatchBackslash(`\conninfo`, "mydb")
if !strings.Contains(result.output, "mydb") {
t.Error("\\conninfo should include database name")
}
if !strings.Contains(result.output, "5432") {
t.Error("\\conninfo should include port")
}
}
func TestBackslashHelp(t *testing.T) {
result := dispatchBackslash(`\?`, "postgres")
if !strings.Contains(result.output, `\q`) {
t.Error("\\? should include \\q in help output")
}
}
func TestBackslashSQLHelp(t *testing.T) {
result := dispatchBackslash(`\h`, "postgres")
if !strings.Contains(result.output, "SELECT") {
t.Error("\\h should include SQL commands like SELECT")
}
}
func TestBackslashUnknown(t *testing.T) {
result := dispatchBackslash(`\xyz`, "postgres")
if !strings.Contains(result.output, "Invalid command") {
t.Error("unknown backslash command should return error")
}
}
// --- SQL dispatch tests ---
func TestSQLSelectVersion(t *testing.T) {
result := dispatchSQL("SELECT version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SELECT version() should contain pg version")
}
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT version() should show row count")
}
}
func TestSQLSelectCurrentDatabase(t *testing.T) {
result := dispatchSQL("SELECT current_database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("SELECT current_database() should return db name")
}
}
func TestSQLSelectCurrentUser(t *testing.T) {
result := dispatchSQL("SELECT current_user;", "postgres", "15.4")
if !strings.Contains(result.output, "postgres") {
t.Error("SELECT current_user should return postgres")
}
}
func TestSQLSelectNow(t *testing.T) {
result := dispatchSQL("SELECT now();", "postgres", "15.4")
if !strings.Contains(result.output, "(1 row)") {
t.Error("SELECT now() should show row count")
}
}
func TestSQLSelectOne(t *testing.T) {
result := dispatchSQL("SELECT 1;", "postgres", "15.4")
if !strings.Contains(result.output, "1") {
t.Error("SELECT 1 should return 1")
}
}
func TestSQLInsert(t *testing.T) {
result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4")
if result.output != "INSERT 0 1" {
t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1")
}
}
func TestSQLUpdate(t *testing.T) {
result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4")
if result.output != "UPDATE 1" {
t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1")
}
}
func TestSQLDelete(t *testing.T) {
result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4")
if result.output != "DELETE 1" {
t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1")
}
}
func TestSQLCreateTable(t *testing.T) {
result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4")
if result.output != "CREATE TABLE" {
t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE")
}
}
func TestSQLCreateDatabase(t *testing.T) {
result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4")
if result.output != "CREATE DATABASE" {
t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE")
}
}
func TestSQLDropTable(t *testing.T) {
result := dispatchSQL("DROP TABLE test;", "postgres", "15.4")
if result.output != "DROP TABLE" {
t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE")
}
}
func TestSQLAlterTable(t *testing.T) {
result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4")
if result.output != "ALTER TABLE" {
t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE")
}
}
func TestSQLBeginCommitRollback(t *testing.T) {
tests := []struct {
sql string
want string
}{
{"BEGIN;", "BEGIN"},
{"COMMIT;", "COMMIT"},
{"ROLLBACK;", "ROLLBACK"},
}
for _, tt := range tests {
result := dispatchSQL(tt.sql, "postgres", "15.4")
if result.output != tt.want {
t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want)
}
}
}
func TestSQLShowServerVersion(t *testing.T) {
result := dispatchSQL("SHOW server_version;", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("SHOW server_version should contain version")
}
}
func TestSQLShowSearchPath(t *testing.T) {
result := dispatchSQL("SHOW search_path;", "postgres", "15.4")
if !strings.Contains(result.output, "public") {
t.Error("SHOW search_path should contain public")
}
}
func TestSQLSet(t *testing.T) {
result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4")
if result.output != "SET" {
t.Errorf("SET output = %q, want %q", result.output, "SET")
}
}
func TestSQLUnrecognized(t *testing.T) {
result := dispatchSQL("FOOBAR baz;", "postgres", "15.4")
if !strings.Contains(result.output, "ERROR") {
t.Error("unrecognized SQL should return error")
}
if !strings.Contains(result.output, "FOOBAR") {
t.Error("error should reference the offending token")
}
}
// --- Case insensitivity ---
func TestSQLCaseInsensitive(t *testing.T) {
result := dispatchSQL("select version();", "postgres", "15.4")
if !strings.Contains(result.output, "15.4") {
t.Error("select version() (lowercase) should work")
}
result = dispatchSQL("Select Current_Database();", "mydb", "15.4")
if !strings.Contains(result.output, "mydb") {
t.Error("mixed case SELECT should work")
}
}
// --- Startup banner ---
func TestStartupBanner(t *testing.T) {
banner := startupBanner("15.4")
if !strings.Contains(banner, "psql (15.4)") {
t.Errorf("banner should contain version, got: %s", banner)
}
if !strings.Contains(banner, "help") {
t.Error("banner should mention help")
}
}
// --- configString ---
func TestConfigString(t *testing.T) {
cfg := map[string]any{"db_name": "mydb"}
if got := configString(cfg, "db_name", "postgres"); got != "mydb" {
t.Errorf("configString() = %q, want %q", got, "mydb")
}
if got := configString(cfg, "missing", "default"); got != "default" {
t.Errorf("configString() for missing = %q, want %q", got, "default")
}
if got := configString(nil, "key", "default"); got != "default" {
t.Errorf("configString(nil) = %q, want %q", got, "default")
}
}
// --- Shell metadata ---
func TestShellNameAndDescription(t *testing.T) {
s := NewPsqlShell()
if s.Name() != "psql" {
t.Errorf("Name() = %q, want %q", s.Name(), "psql")
}
if s.Description() == "" {
t.Error("Description() should not be empty")
}
}
// --- formatSingleValue ---
func TestFormatSingleValue(t *testing.T) {
out := formatSingleValue("?column?", "1")
if !strings.Contains(out, "?column?") {
t.Error("should contain column name")
}
if !strings.Contains(out, "1") {
t.Error("should contain value")
}
if !strings.Contains(out, "(1 row)") {
t.Error("should contain row count")
}
}
// --- \d with no args ---
func TestBackslashDescribeNoArgs(t *testing.T) {
result := dispatchBackslash(`\d`, "postgres")
if !strings.Contains(result.output, "users") {
t.Error("\\d with no args should list tables")
}
}