Add a PostgreSQL psql interactive terminal shell with backslash meta-commands, SQL statement handling with multi-line buffering, and canned responses for common queries. Add username-based shell routing via [shell.username_routes] config (second priority after credential- specific shell, before random selection). Bump version to 0.13.0. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
418 lines
10 KiB
Go
418 lines
10 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.t-juice.club/torjus/oubliette/internal/auth"
|
|
"git.t-juice.club/torjus/oubliette/internal/config"
|
|
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
|
"git.t-juice.club/torjus/oubliette/internal/storage"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type testAddr struct {
|
|
str string
|
|
network string
|
|
}
|
|
|
|
func (a testAddr) Network() string { return a.network }
|
|
func (a testAddr) String() string { return a.str }
|
|
|
|
func newAddr(s, network string) net.Addr {
|
|
return testAddr{str: s, network: network}
|
|
}
|
|
|
|
func TestHostKey_Generate(t *testing.T) {
|
|
path := filepath.Join(t.TempDir(), "host_key")
|
|
|
|
signer, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if signer == nil {
|
|
t.Fatal("signer is nil")
|
|
}
|
|
|
|
// File should exist with correct permissions.
|
|
info, err := os.Stat(path)
|
|
if err != nil {
|
|
t.Fatalf("stat host key: %v", err)
|
|
}
|
|
if perm := info.Mode().Perm(); perm != 0600 {
|
|
t.Errorf("permissions = %o, want 0600", perm)
|
|
}
|
|
}
|
|
|
|
func TestHostKey_Load(t *testing.T) {
|
|
path := filepath.Join(t.TempDir(), "host_key")
|
|
|
|
// Generate first.
|
|
signer1, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("generate: %v", err)
|
|
}
|
|
|
|
// Load existing.
|
|
signer2, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("load: %v", err)
|
|
}
|
|
|
|
// Keys should be the same.
|
|
if string(signer1.PublicKey().Marshal()) != string(signer2.PublicKey().Marshal()) {
|
|
t.Error("loaded key differs from generated key")
|
|
}
|
|
}
|
|
|
|
func TestExtractIP(t *testing.T) {
|
|
tests := []struct {
|
|
addr string
|
|
want string
|
|
}{
|
|
{"192.168.1.1:22", "192.168.1.1"},
|
|
{"[::1]:22", "::1"},
|
|
{"[::ffff:192.168.1.1]:22", "192.168.1.1"},
|
|
{"10.0.0.1:12345", "10.0.0.1"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.addr, func(t *testing.T) {
|
|
addr := newAddr(tt.addr, "tcp")
|
|
got := extractIP(addr)
|
|
if got != tt.want {
|
|
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIntegrationSSHConnect(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping integration test")
|
|
}
|
|
|
|
tmpDir := t.TempDir()
|
|
cfg := config.Config{
|
|
SSH: config.SSHConfig{
|
|
ListenAddr: "127.0.0.1:0",
|
|
HostKeyPath: filepath.Join(tmpDir, "host_key"),
|
|
MaxConnections: 100,
|
|
},
|
|
Auth: config.AuthConfig{
|
|
AcceptAfter: 2,
|
|
CredentialTTLDuration: time.Hour,
|
|
StaticCredentials: []config.Credential{
|
|
{Username: "root", Password: "toor", Shell: "bash"},
|
|
},
|
|
},
|
|
Shell: config.ShellConfig{
|
|
Hostname: "ubuntu-server",
|
|
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
|
|
},
|
|
LogLevel: "debug",
|
|
}
|
|
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
store := storage.NewMemoryStore()
|
|
srv, err := New(cfg, store, logger, metrics.New("test"))
|
|
if err != nil {
|
|
t.Fatalf("creating server: %v", err)
|
|
}
|
|
|
|
// Use a listener to get the actual port.
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
addr := listener.Addr().String()
|
|
listener.Close()
|
|
|
|
cfg.SSH.ListenAddr = addr
|
|
srv.cfg = cfg
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- srv.ListenAndServe(ctx)
|
|
}()
|
|
|
|
// Wait for server to be ready.
|
|
var conn net.Conn
|
|
for i := range 50 {
|
|
conn, err = net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
|
if err == nil {
|
|
conn.Close()
|
|
break
|
|
}
|
|
if i == 49 {
|
|
t.Fatalf("server not ready after retries: %v", err)
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
|
|
// Test static credential login with shell interaction.
|
|
t.Run("static_cred", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("SSH dial: %v", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
t.Fatalf("new session: %v", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
// Request PTY and shell.
|
|
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
|
t.Fatalf("request pty: %v", err)
|
|
}
|
|
|
|
stdin, err := session.StdinPipe()
|
|
if err != nil {
|
|
t.Fatalf("stdin pipe: %v", err)
|
|
}
|
|
|
|
var output bytes.Buffer
|
|
session.Stdout = &output
|
|
|
|
if err := session.Shell(); err != nil {
|
|
t.Fatalf("shell: %v", err)
|
|
}
|
|
|
|
// Wait for the prompt, then send commands.
|
|
time.Sleep(500 * time.Millisecond)
|
|
stdin.Write([]byte("pwd\r"))
|
|
time.Sleep(200 * time.Millisecond)
|
|
stdin.Write([]byte("whoami\r"))
|
|
time.Sleep(200 * time.Millisecond)
|
|
stdin.Write([]byte("exit\r"))
|
|
|
|
// Wait for session to end.
|
|
session.Wait()
|
|
|
|
out := output.String()
|
|
if !strings.Contains(out, "Welcome to Ubuntu") {
|
|
t.Errorf("output should contain banner, got: %s", out)
|
|
}
|
|
if !strings.Contains(out, "/root") {
|
|
t.Errorf("output should contain /root from pwd, got: %s", out)
|
|
}
|
|
if !strings.Contains(out, "root") {
|
|
t.Errorf("output should contain 'root' from whoami, got: %s", out)
|
|
}
|
|
|
|
// Verify session logs were recorded.
|
|
if len(store.SessionLogs) < 2 {
|
|
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
|
}
|
|
|
|
// Verify session was created with shell name.
|
|
var foundBash bool
|
|
for _, s := range store.Sessions {
|
|
if s.ShellName == "bash" {
|
|
foundBash = true
|
|
break
|
|
}
|
|
}
|
|
if !foundBash {
|
|
t.Error("expected a session with shell_name='bash'")
|
|
}
|
|
})
|
|
|
|
// Test wrong password is rejected.
|
|
t.Run("wrong_password", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.Password("wrong")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
_, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err == nil {
|
|
t.Fatal("expected error for wrong password")
|
|
}
|
|
})
|
|
|
|
// Test exec command capture.
|
|
t.Run("exec_command", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("SSH dial: %v", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
t.Fatalf("new session: %v", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
// Run a command via exec (no PTY, no shell).
|
|
if err := session.Run("uname -a"); err != nil {
|
|
// Run returns an error because the server closes the channel,
|
|
// but that's expected.
|
|
_ = err
|
|
}
|
|
|
|
// Give the server a moment to store the command.
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
// Verify the exec command was captured.
|
|
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
|
if err != nil {
|
|
t.Fatalf("GetRecentSessions: %v", err)
|
|
}
|
|
var foundExec bool
|
|
for _, s := range sessions {
|
|
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
|
|
foundExec = true
|
|
break
|
|
}
|
|
}
|
|
if !foundExec {
|
|
t.Error("expected a session with exec_command='uname -a'")
|
|
}
|
|
})
|
|
|
|
// Test username route: add username_routes so that "postgres" gets psql shell.
|
|
t.Run("username_route", func(t *testing.T) {
|
|
// Reconfigure with username routes.
|
|
srv.cfg.Shell.UsernameRoutes = map[string]string{"postgres": "psql"}
|
|
defer func() { srv.cfg.Shell.UsernameRoutes = nil }()
|
|
|
|
// Need to get the "postgres" user in via static creds or threshold.
|
|
// Use static creds for simplicity.
|
|
srv.cfg.Auth.StaticCredentials = append(srv.cfg.Auth.StaticCredentials,
|
|
config.Credential{Username: "postgres", Password: "postgres"},
|
|
)
|
|
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
|
defer func() {
|
|
srv.cfg.Auth.StaticCredentials = srv.cfg.Auth.StaticCredentials[:1]
|
|
srv.authenticator = auth.NewAuthenticator(srv.cfg.Auth)
|
|
}()
|
|
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "postgres",
|
|
Auth: []ssh.AuthMethod{ssh.Password("postgres")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("SSH dial: %v", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
t.Fatalf("new session: %v", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
|
t.Fatalf("request pty: %v", err)
|
|
}
|
|
|
|
stdin, err := session.StdinPipe()
|
|
if err != nil {
|
|
t.Fatalf("stdin pipe: %v", err)
|
|
}
|
|
|
|
var output bytes.Buffer
|
|
session.Stdout = &output
|
|
|
|
if err := session.Shell(); err != nil {
|
|
t.Fatalf("shell: %v", err)
|
|
}
|
|
|
|
// Wait for the psql banner.
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
// Send \q to quit.
|
|
stdin.Write([]byte(`\q` + "\r"))
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
session.Wait()
|
|
|
|
out := output.String()
|
|
if !strings.Contains(out, "psql") {
|
|
t.Errorf("output should contain psql banner, got: %s", out)
|
|
}
|
|
|
|
// Verify session was created with shell name "psql".
|
|
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
|
if err != nil {
|
|
t.Fatalf("GetRecentSessions: %v", err)
|
|
}
|
|
var foundPsql bool
|
|
for _, s := range sessions {
|
|
if s.ShellName == "psql" && s.Username == "postgres" {
|
|
foundPsql = true
|
|
break
|
|
}
|
|
}
|
|
if !foundPsql {
|
|
t.Error("expected a session with shell_name='psql' for user 'postgres'")
|
|
}
|
|
})
|
|
|
|
// Test threshold acceptance: after enough failed dials, a subsequent
|
|
// dial with the same credentials should succeed via threshold or
|
|
// remembered credential.
|
|
t.Run("threshold", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "threshuser",
|
|
Auth: []ssh.AuthMethod{ssh.Password("threshpass")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
// Make several dials to accumulate failures past the threshold.
|
|
for range 5 {
|
|
c, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err == nil {
|
|
// Threshold reached, success!
|
|
c.Close()
|
|
return
|
|
}
|
|
}
|
|
|
|
// After enough failures the credential should be remembered.
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("expected threshold/remembered acceptance after many attempts: %v", err)
|
|
}
|
|
client.Close()
|
|
})
|
|
|
|
cancel()
|
|
}
|