package psql import ( "fmt" "strings" "time" ) // commandResult holds the output of a command and whether the session should end. type commandResult struct { output string exit bool } // dispatchBackslash handles psql backslash meta-commands. func dispatchBackslash(cmd, dbName string) commandResult { // Normalize: trim spaces after the backslash command word. parts := strings.Fields(cmd) if len(parts) == 0 { return commandResult{output: "Invalid command \\. Try \\? for help."} } verb := parts[0] // e.g. `\q`, `\dt`, `\d` args := parts[1:] switch verb { case `\q`: return commandResult{exit: true} case `\dt`: return commandResult{output: listTables()} case `\d`: if len(args) == 0 { return commandResult{output: listTables()} } return commandResult{output: describeTable(args[0])} case `\l`: return commandResult{output: listDatabases()} case `\du`: return commandResult{output: listRoles()} case `\conninfo`: return commandResult{output: connInfo(dbName)} case `\?`: return commandResult{output: backslashHelp()} case `\h`: return commandResult{output: sqlHelp()} default: return commandResult{output: fmt.Sprintf("Invalid command %s. Try \\? for help.", verb)} } } // dispatchSQL handles SQL statements (already accumulated and semicolon-terminated). func dispatchSQL(sql, dbName, pgVersion string) commandResult { // Strip trailing semicolon and whitespace for matching. trimmed := strings.TrimRight(sql, "; \t") trimmed = strings.TrimSpace(trimmed) upper := strings.ToUpper(trimmed) switch { case upper == "SELECT VERSION()": ver := fmt.Sprintf("PostgreSQL %s on x86_64-pc-linux-gnu, compiled by gcc (GCC) 13.2.0, 64-bit", pgVersion) return commandResult{output: formatSingleValue("version", ver)} case upper == "SELECT CURRENT_DATABASE()": return commandResult{output: formatSingleValue("current_database", dbName)} case upper == "SELECT CURRENT_USER": return commandResult{output: formatSingleValue("current_user", "postgres")} case upper == "SELECT NOW()": now := time.Now().UTC().Format("2006-01-02 15:04:05.000000+00") return commandResult{output: formatSingleValue("now", now)} case upper == "SELECT 1": return commandResult{output: formatSingleValue("?column?", "1")} case strings.HasPrefix(upper, "INSERT"): return commandResult{output: "INSERT 0 1"} case strings.HasPrefix(upper, "UPDATE"): return commandResult{output: "UPDATE 1"} case strings.HasPrefix(upper, "DELETE"): return commandResult{output: "DELETE 1"} case strings.HasPrefix(upper, "CREATE TABLE"): return commandResult{output: "CREATE TABLE"} case strings.HasPrefix(upper, "CREATE DATABASE"): return commandResult{output: "CREATE DATABASE"} case strings.HasPrefix(upper, "DROP TABLE"): return commandResult{output: "DROP TABLE"} case strings.HasPrefix(upper, "ALTER TABLE"): return commandResult{output: "ALTER TABLE"} case upper == "BEGIN": return commandResult{output: "BEGIN"} case upper == "COMMIT": return commandResult{output: "COMMIT"} case upper == "ROLLBACK": return commandResult{output: "ROLLBACK"} case upper == "SHOW SERVER_VERSION": return commandResult{output: formatSingleValue("server_version", pgVersion)} case upper == "SHOW SEARCH_PATH": return commandResult{output: formatSingleValue("search_path", "\"$user\", public")} case strings.HasPrefix(upper, "SET "): return commandResult{output: "SET"} default: // Extract the first token for the error message. firstToken := strings.Fields(trimmed) token := trimmed if len(firstToken) > 0 { token = firstToken[0] } return commandResult{output: fmt.Sprintf("ERROR: syntax error at or near \"%s\"\nLINE 1: %s\n ^", token, trimmed)} } } // formatSingleValue formats a single-row, single-column psql result. func formatSingleValue(colName, value string) string { width := max(len(colName), len(value)) var b strings.Builder // Header fmt.Fprintf(&b, " %-*s \n", width, colName) // Separator b.WriteString(strings.Repeat("-", width+2)) b.WriteString("\n") // Value fmt.Fprintf(&b, " %-*s\n", width, value) // Row count b.WriteString("(1 row)") return b.String() }