This repository has been archived on 2026-03-09. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
oubliette/internal/server/server_test.go
Torjus Håkestad 0133d956a5 feat: capture SSH exec commands (PLAN.md 4.4)
Bots often send commands via `ssh user@host <command>` (exec request)
rather than requesting an interactive shell. These were previously
rejected silently. Now exec commands are captured, stored on the session
record, and displayed in the web UI session detail page.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:43:11 +01:00

334 lines
8.0 KiB
Go

package server
import (
"bytes"
"context"
"log/slog"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/metrics"
"git.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
type testAddr struct {
str string
network string
}
func (a testAddr) Network() string { return a.network }
func (a testAddr) String() string { return a.str }
func newAddr(s, network string) net.Addr {
return testAddr{str: s, network: network}
}
func TestHostKey_Generate(t *testing.T) {
path := filepath.Join(t.TempDir(), "host_key")
signer, err := loadOrGenerateHostKey(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if signer == nil {
t.Fatal("signer is nil")
}
// File should exist with correct permissions.
info, err := os.Stat(path)
if err != nil {
t.Fatalf("stat host key: %v", err)
}
if perm := info.Mode().Perm(); perm != 0600 {
t.Errorf("permissions = %o, want 0600", perm)
}
}
func TestHostKey_Load(t *testing.T) {
path := filepath.Join(t.TempDir(), "host_key")
// Generate first.
signer1, err := loadOrGenerateHostKey(path)
if err != nil {
t.Fatalf("generate: %v", err)
}
// Load existing.
signer2, err := loadOrGenerateHostKey(path)
if err != nil {
t.Fatalf("load: %v", err)
}
// Keys should be the same.
if string(signer1.PublicKey().Marshal()) != string(signer2.PublicKey().Marshal()) {
t.Error("loaded key differs from generated key")
}
}
func TestExtractIP(t *testing.T) {
tests := []struct {
addr string
want string
}{
{"192.168.1.1:22", "192.168.1.1"},
{"[::1]:22", "::1"},
{"[::ffff:192.168.1.1]:22", "192.168.1.1"},
{"10.0.0.1:12345", "10.0.0.1"},
}
for _, tt := range tests {
t.Run(tt.addr, func(t *testing.T) {
addr := newAddr(tt.addr, "tcp")
got := extractIP(addr)
if got != tt.want {
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, got, tt.want)
}
})
}
}
func TestIntegrationSSHConnect(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
tmpDir := t.TempDir()
cfg := config.Config{
SSH: config.SSHConfig{
ListenAddr: "127.0.0.1:0",
HostKeyPath: filepath.Join(tmpDir, "host_key"),
MaxConnections: 100,
},
Auth: config.AuthConfig{
AcceptAfter: 2,
CredentialTTLDuration: time.Hour,
StaticCredentials: []config.Credential{
{Username: "root", Password: "toor", Shell: "bash"},
},
},
Shell: config.ShellConfig{
Hostname: "ubuntu-server",
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
},
LogLevel: "debug",
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
store := storage.NewMemoryStore()
srv, err := New(cfg, store, logger, metrics.New("test"))
if err != nil {
t.Fatalf("creating server: %v", err)
}
// Use a listener to get the actual port.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
addr := listener.Addr().String()
listener.Close()
cfg.SSH.ListenAddr = addr
srv.cfg = cfg
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- srv.ListenAndServe(ctx)
}()
// Wait for server to be ready.
var conn net.Conn
for i := range 50 {
conn, err = net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
conn.Close()
break
}
if i == 49 {
t.Fatalf("server not ready after retries: %v", err)
}
time.Sleep(50 * time.Millisecond)
}
// Test static credential login with shell interaction.
t.Run("static_cred", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.Password("toor")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
// Request PTY and shell.
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
t.Fatalf("request pty: %v", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("stdin pipe: %v", err)
}
var output bytes.Buffer
session.Stdout = &output
if err := session.Shell(); err != nil {
t.Fatalf("shell: %v", err)
}
// Wait for the prompt, then send commands.
time.Sleep(500 * time.Millisecond)
stdin.Write([]byte("pwd\r"))
time.Sleep(200 * time.Millisecond)
stdin.Write([]byte("whoami\r"))
time.Sleep(200 * time.Millisecond)
stdin.Write([]byte("exit\r"))
// Wait for session to end.
session.Wait()
out := output.String()
if !strings.Contains(out, "Welcome to Ubuntu") {
t.Errorf("output should contain banner, got: %s", out)
}
if !strings.Contains(out, "/root") {
t.Errorf("output should contain /root from pwd, got: %s", out)
}
if !strings.Contains(out, "root") {
t.Errorf("output should contain 'root' from whoami, got: %s", out)
}
// Verify session logs were recorded.
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
// Verify session was created with shell name.
var foundBash bool
for _, s := range store.Sessions {
if s.ShellName == "bash" {
foundBash = true
break
}
}
if !foundBash {
t.Error("expected a session with shell_name='bash'")
}
})
// Test wrong password is rejected.
t.Run("wrong_password", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.Password("wrong")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
_, err := ssh.Dial("tcp", addr, clientCfg)
if err == nil {
t.Fatal("expected error for wrong password")
}
})
// Test exec command capture.
t.Run("exec_command", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{ssh.Password("toor")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("SSH dial: %v", err)
}
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Fatalf("new session: %v", err)
}
defer session.Close()
// Run a command via exec (no PTY, no shell).
if err := session.Run("uname -a"); err != nil {
// Run returns an error because the server closes the channel,
// but that's expected.
_ = err
}
// Give the server a moment to store the command.
time.Sleep(200 * time.Millisecond)
// Verify the exec command was captured.
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
if err != nil {
t.Fatalf("GetRecentSessions: %v", err)
}
var foundExec bool
for _, s := range sessions {
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
foundExec = true
break
}
}
if !foundExec {
t.Error("expected a session with exec_command='uname -a'")
}
})
// Test threshold acceptance: after enough failed dials, a subsequent
// dial with the same credentials should succeed via threshold or
// remembered credential.
t.Run("threshold", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{
User: "threshuser",
Auth: []ssh.AuthMethod{ssh.Password("threshpass")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
// Make several dials to accumulate failures past the threshold.
for range 5 {
c, err := ssh.Dial("tcp", addr, clientCfg)
if err == nil {
// Threshold reached, success!
c.Close()
return
}
}
// After enough failures the credential should be remembered.
client, err := ssh.Dial("tcp", addr, clientCfg)
if err != nil {
t.Fatalf("expected threshold/remembered acceptance after many attempts: %v", err)
}
client.Close()
})
cancel()
}