feat: add psql shell and username-to-shell routing
Add a PostgreSQL psql interactive terminal shell with backslash meta-commands, SQL statement handling with multi-line buffering, and canned responses for common queries. Add username-based shell routing via [shell.username_routes] config (second priority after credential- specific shell, before random selection). Bump version to 0.13.0. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,7 @@ import (
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
psqlshell "git.t-juice.club/torjus/oubliette/internal/shell/psql"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
@@ -58,6 +59,9 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics
|
||||
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering cisco shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering psql shell: %w", err)
|
||||
}
|
||||
|
||||
geo, err := geoip.New()
|
||||
if err != nil {
|
||||
@@ -185,6 +189,18 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
|
||||
}
|
||||
}
|
||||
// Second priority: username-based route.
|
||||
if selectedShell == nil {
|
||||
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
|
||||
sh, found := s.shellRegistry.Get(shellName)
|
||||
if found {
|
||||
selectedShell = sh
|
||||
} else {
|
||||
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
|
||||
}
|
||||
}
|
||||
}
|
||||
// Lowest priority: random selection.
|
||||
if selectedShell == nil {
|
||||
var err error
|
||||
selectedShell, err = s.shellRegistry.Select()
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
@@ -300,6 +301,89 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
// Test username route: add username_routes so that "postgres" gets psql shell.
|
||||
t.Run("username_route", func(t *testing.T) {
|
||||
// Reconfigure with username routes.
|
||||
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
|
||||
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
|
||||
|
||||
// Need to get the "postgres" user in via static creds or threshold.
|
||||
// Use static creds for simplicity.
|
||||
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
|
||||
config.Credential{Username: "postgres", Password: "postgres"},
|
||||
)
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
defer func() {
|
||||
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
|
||||
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
||||
}()
|
||||
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "postgres",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", addr, clientCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SSH dial: %v", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||
t.Fatalf("request pty: %v", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("stdin pipe: %v", err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
session.Stdout = &output
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("shell: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the psql banner.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Send \q to quit.
|
||||
stdin.Write([]byte(`\q` + "\r"))
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
session.Wait()
|
||||
|
||||
out := output.String()
|
||||
if !strings.Contains(out, "psql") {
|
||||
t.Errorf("output should contain psql banner, got: %s", out)
|
||||
}
|
||||
|
||||
// Verify session was created with shell name "psql".
|
||||
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
var foundPsql bool
|
||||
for _, s := range sessions {
|
||||
if s.ShellName == "psql" && s.Username == "postgres" {
|
||||
foundPsql = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundPsql {
|
||||
t.Error("expected a session with shell_name='psql' for user 'postgres'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test threshold acceptance: after enough failed dials, a subsequent
|
||||
// dial with the same credentials should succeed via threshold or
|
||||
// remembered credential.
|
||||
|
||||
Reference in New Issue
Block a user