package psql import ( "context" "errors" "fmt" "io" "strings" "time" "git.t-juice.club/torjus/oubliette/internal/shell" ) const sessionTimeout = 5 * time.Minute // PsqlShell emulates a PostgreSQL psql interactive terminal. type PsqlShell struct{} // NewPsqlShell returns a new PsqlShell instance. func NewPsqlShell() *PsqlShell { return &PsqlShell{} } func (p *PsqlShell) Name() string { return "psql" } func (p *PsqlShell) Description() string { return "PostgreSQL psql interactive terminal" } func (p *PsqlShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error { ctx, cancel := context.WithTimeout(ctx, sessionTimeout) defer cancel() dbName := configString(sess.ShellConfig, "db_name", "postgres") pgVersion := configString(sess.ShellConfig, "pg_version", "15.4") // Print startup banner. fmt.Fprint(rw, startupBanner(pgVersion)) var sqlBuf []string // accumulates multi-line SQL for { prompt := buildPrompt(dbName, len(sqlBuf) > 0) if _, err := fmt.Fprint(rw, prompt); err != nil { return nil } line, err := shell.ReadLine(ctx, rw) if errors.Is(err, io.EOF) { return nil } if err != nil { return nil } trimmed := strings.TrimSpace(line) // Empty line in non-buffering state: just re-prompt. if trimmed == "" && len(sqlBuf) == 0 { continue } // Backslash commands dispatch immediately (even mid-buffer they cancel the buffer). if strings.HasPrefix(trimmed, `\`) { sqlBuf = nil // discard any partial SQL result := dispatchBackslash(trimmed, dbName) if result.output != "" { output := strings.ReplaceAll(result.output, "\n", "\r\n") fmt.Fprintf(rw, "%s\r\n", output) } if sess.Store != nil { if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, result.output); err != nil { return fmt.Errorf("append session log: %w", err) } } if sess.OnCommand != nil { sess.OnCommand("psql") } if result.exit { return nil } continue } // Accumulate SQL lines. sqlBuf = append(sqlBuf, line) // Check if the statement is terminated by a semicolon. if !strings.HasSuffix(strings.TrimSpace(line), ";") { continue } // Full statement ready — join and dispatch. fullSQL := strings.Join(sqlBuf, " ") sqlBuf = nil result := dispatchSQL(fullSQL, dbName, pgVersion) if result.output != "" { output := strings.ReplaceAll(result.output, "\n", "\r\n") fmt.Fprintf(rw, "%s\r\n", output) } if sess.Store != nil { if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, fullSQL, result.output); err != nil { return fmt.Errorf("append session log: %w", err) } } if sess.OnCommand != nil { sess.OnCommand("psql") } if result.exit { return nil } } } // buildPrompt returns the psql prompt. continuation is true when buffering multi-line SQL. func buildPrompt(dbName string, continuation bool) string { if continuation { return dbName + "-# " } return dbName + "=# " } // configString reads a string from the shell config map with a default. func configString(cfg map[string]any, key, defaultVal string) string { if cfg == nil { return defaultVal } if v, ok := cfg[key]; ok { if s, ok := v.(string); ok && s != "" { return s } } return defaultVal }