package psql import ( "strings" "testing" ) // --- Prompt tests --- func TestBuildPromptNormal(t *testing.T) { got := buildPrompt("postgres", false) if got != "postgres=# " { t.Errorf("buildPrompt(postgres, false) = %q, want %q", got, "postgres=# ") } } func TestBuildPromptContinuation(t *testing.T) { got := buildPrompt("postgres", true) if got != "postgres-# " { t.Errorf("buildPrompt(postgres, true) = %q, want %q", got, "postgres-# ") } } func TestBuildPromptCustomDB(t *testing.T) { got := buildPrompt("mydb", false) if got != "mydb=# " { t.Errorf("buildPrompt(mydb, false) = %q, want %q", got, "mydb=# ") } } // --- Backslash command dispatch tests --- func TestBackslashQuit(t *testing.T) { result := dispatchBackslash(`\q`, "postgres") if !result.exit { t.Error("\\q should set exit=true") } } func TestBackslashListTables(t *testing.T) { result := dispatchBackslash(`\dt`, "postgres") if !strings.Contains(result.output, "users") { t.Error("\\dt should list tables including 'users'") } if !strings.Contains(result.output, "sessions") { t.Error("\\dt should list tables including 'sessions'") } } func TestBackslashDescribeTable(t *testing.T) { result := dispatchBackslash(`\d users`, "postgres") if !strings.Contains(result.output, "username") { t.Error("\\d users should describe users table with 'username' column") } if !strings.Contains(result.output, "PRIMARY KEY") { t.Error("\\d users should include index info") } } func TestBackslashDescribeUnknownTable(t *testing.T) { result := dispatchBackslash(`\d nonexistent`, "postgres") if !strings.Contains(result.output, "Did not find") { t.Error("\\d nonexistent should return not found message") } } func TestBackslashListDatabases(t *testing.T) { result := dispatchBackslash(`\l`, "postgres") if !strings.Contains(result.output, "postgres") { t.Error("\\l should list databases including 'postgres'") } if !strings.Contains(result.output, "template0") { t.Error("\\l should list databases including 'template0'") } } func TestBackslashListRoles(t *testing.T) { result := dispatchBackslash(`\du`, "postgres") if !strings.Contains(result.output, "postgres") { t.Error("\\du should list roles including 'postgres'") } if !strings.Contains(result.output, "Superuser") { t.Error("\\du should show Superuser attribute for postgres") } } func TestBackslashConnInfo(t *testing.T) { result := dispatchBackslash(`\conninfo`, "mydb") if !strings.Contains(result.output, "mydb") { t.Error("\\conninfo should include database name") } if !strings.Contains(result.output, "5432") { t.Error("\\conninfo should include port") } } func TestBackslashHelp(t *testing.T) { result := dispatchBackslash(`\?`, "postgres") if !strings.Contains(result.output, `\q`) { t.Error("\\? should include \\q in help output") } } func TestBackslashSQLHelp(t *testing.T) { result := dispatchBackslash(`\h`, "postgres") if !strings.Contains(result.output, "SELECT") { t.Error("\\h should include SQL commands like SELECT") } } func TestBackslashUnknown(t *testing.T) { result := dispatchBackslash(`\xyz`, "postgres") if !strings.Contains(result.output, "Invalid command") { t.Error("unknown backslash command should return error") } } // --- SQL dispatch tests --- func TestSQLSelectVersion(t *testing.T) { result := dispatchSQL("SELECT version();", "postgres", "15.4") if !strings.Contains(result.output, "15.4") { t.Error("SELECT version() should contain pg version") } if !strings.Contains(result.output, "(1 row)") { t.Error("SELECT version() should show row count") } } func TestSQLSelectCurrentDatabase(t *testing.T) { result := dispatchSQL("SELECT current_database();", "mydb", "15.4") if !strings.Contains(result.output, "mydb") { t.Error("SELECT current_database() should return db name") } } func TestSQLSelectCurrentUser(t *testing.T) { result := dispatchSQL("SELECT current_user;", "postgres", "15.4") if !strings.Contains(result.output, "postgres") { t.Error("SELECT current_user should return postgres") } } func TestSQLSelectNow(t *testing.T) { result := dispatchSQL("SELECT now();", "postgres", "15.4") if !strings.Contains(result.output, "(1 row)") { t.Error("SELECT now() should show row count") } } func TestSQLSelectOne(t *testing.T) { result := dispatchSQL("SELECT 1;", "postgres", "15.4") if !strings.Contains(result.output, "1") { t.Error("SELECT 1 should return 1") } } func TestSQLInsert(t *testing.T) { result := dispatchSQL("INSERT INTO users (name) VALUES ('test');", "postgres", "15.4") if result.output != "INSERT 0 1" { t.Errorf("INSERT output = %q, want %q", result.output, "INSERT 0 1") } } func TestSQLUpdate(t *testing.T) { result := dispatchSQL("UPDATE users SET name = 'foo';", "postgres", "15.4") if result.output != "UPDATE 1" { t.Errorf("UPDATE output = %q, want %q", result.output, "UPDATE 1") } } func TestSQLDelete(t *testing.T) { result := dispatchSQL("DELETE FROM users WHERE id = 1;", "postgres", "15.4") if result.output != "DELETE 1" { t.Errorf("DELETE output = %q, want %q", result.output, "DELETE 1") } } func TestSQLCreateTable(t *testing.T) { result := dispatchSQL("CREATE TABLE test (id int);", "postgres", "15.4") if result.output != "CREATE TABLE" { t.Errorf("CREATE TABLE output = %q, want %q", result.output, "CREATE TABLE") } } func TestSQLCreateDatabase(t *testing.T) { result := dispatchSQL("CREATE DATABASE testdb;", "postgres", "15.4") if result.output != "CREATE DATABASE" { t.Errorf("CREATE DATABASE output = %q, want %q", result.output, "CREATE DATABASE") } } func TestSQLDropTable(t *testing.T) { result := dispatchSQL("DROP TABLE test;", "postgres", "15.4") if result.output != "DROP TABLE" { t.Errorf("DROP TABLE output = %q, want %q", result.output, "DROP TABLE") } } func TestSQLAlterTable(t *testing.T) { result := dispatchSQL("ALTER TABLE users ADD COLUMN age int;", "postgres", "15.4") if result.output != "ALTER TABLE" { t.Errorf("ALTER TABLE output = %q, want %q", result.output, "ALTER TABLE") } } func TestSQLBeginCommitRollback(t *testing.T) { tests := []struct { sql string want string }{ {"BEGIN;", "BEGIN"}, {"COMMIT;", "COMMIT"}, {"ROLLBACK;", "ROLLBACK"}, } for _, tt := range tests { result := dispatchSQL(tt.sql, "postgres", "15.4") if result.output != tt.want { t.Errorf("dispatchSQL(%q) = %q, want %q", tt.sql, result.output, tt.want) } } } func TestSQLShowServerVersion(t *testing.T) { result := dispatchSQL("SHOW server_version;", "postgres", "15.4") if !strings.Contains(result.output, "15.4") { t.Error("SHOW server_version should contain version") } } func TestSQLShowSearchPath(t *testing.T) { result := dispatchSQL("SHOW search_path;", "postgres", "15.4") if !strings.Contains(result.output, "public") { t.Error("SHOW search_path should contain public") } } func TestSQLSet(t *testing.T) { result := dispatchSQL("SET client_encoding = 'UTF8';", "postgres", "15.4") if result.output != "SET" { t.Errorf("SET output = %q, want %q", result.output, "SET") } } func TestSQLUnrecognized(t *testing.T) { result := dispatchSQL("FOOBAR baz;", "postgres", "15.4") if !strings.Contains(result.output, "ERROR") { t.Error("unrecognized SQL should return error") } if !strings.Contains(result.output, "FOOBAR") { t.Error("error should reference the offending token") } } // --- Case insensitivity --- func TestSQLCaseInsensitive(t *testing.T) { result := dispatchSQL("select version();", "postgres", "15.4") if !strings.Contains(result.output, "15.4") { t.Error("select version() (lowercase) should work") } result = dispatchSQL("Select Current_Database();", "mydb", "15.4") if !strings.Contains(result.output, "mydb") { t.Error("mixed case SELECT should work") } } // --- Startup banner --- func TestStartupBanner(t *testing.T) { banner := startupBanner("15.4") if !strings.Contains(banner, "psql (15.4)") { t.Errorf("banner should contain version, got: %s", banner) } if !strings.Contains(banner, "help") { t.Error("banner should mention help") } } // --- configString --- func TestConfigString(t *testing.T) { cfg := map[string]any{"db_name": "mydb"} if got := configString(cfg, "db_name", "postgres"); got != "mydb" { t.Errorf("configString() = %q, want %q", got, "mydb") } if got := configString(cfg, "missing", "default"); got != "default" { t.Errorf("configString() for missing = %q, want %q", got, "default") } if got := configString(nil, "key", "default"); got != "default" { t.Errorf("configString(nil) = %q, want %q", got, "default") } } // --- Shell metadata --- func TestShellNameAndDescription(t *testing.T) { s := NewPsqlShell() if s.Name() != "psql" { t.Errorf("Name() = %q, want %q", s.Name(), "psql") } if s.Description() == "" { t.Error("Description() should not be empty") } } // --- formatSingleValue --- func TestFormatSingleValue(t *testing.T) { out := formatSingleValue("?column?", "1") if !strings.Contains(out, "?column?") { t.Error("should contain column name") } if !strings.Contains(out, "1") { t.Error("should contain value") } if !strings.Contains(out, "(1 row)") { t.Error("should contain row count") } } // --- \d with no args --- func TestBackslashDescribeNoArgs(t *testing.T) { result := dispatchBackslash(`\d`, "postgres") if !strings.Contains(result.output, "users") { t.Error("\\d with no args should list tables") } }