Bots often send commands via `ssh user@host <command>` (exec request) rather than requesting an interactive shell. These were previously rejected silently. Now exec commands are captured, stored on the session record, and displayed in the web UI session detail page. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
334 lines
8.0 KiB
Go
334 lines
8.0 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.t-juice.club/torjus/oubliette/internal/config"
|
|
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
|
"git.t-juice.club/torjus/oubliette/internal/storage"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type testAddr struct {
|
|
str string
|
|
network string
|
|
}
|
|
|
|
func (a testAddr) Network() string { return a.network }
|
|
func (a testAddr) String() string { return a.str }
|
|
|
|
func newAddr(s, network string) net.Addr {
|
|
return testAddr{str: s, network: network}
|
|
}
|
|
|
|
func TestHostKey_Generate(t *testing.T) {
|
|
path := filepath.Join(t.TempDir(), "host_key")
|
|
|
|
signer, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if signer == nil {
|
|
t.Fatal("signer is nil")
|
|
}
|
|
|
|
// File should exist with correct permissions.
|
|
info, err := os.Stat(path)
|
|
if err != nil {
|
|
t.Fatalf("stat host key: %v", err)
|
|
}
|
|
if perm := info.Mode().Perm(); perm != 0600 {
|
|
t.Errorf("permissions = %o, want 0600", perm)
|
|
}
|
|
}
|
|
|
|
func TestHostKey_Load(t *testing.T) {
|
|
path := filepath.Join(t.TempDir(), "host_key")
|
|
|
|
// Generate first.
|
|
signer1, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("generate: %v", err)
|
|
}
|
|
|
|
// Load existing.
|
|
signer2, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("load: %v", err)
|
|
}
|
|
|
|
// Keys should be the same.
|
|
if string(signer1.PublicKey().Marshal()) != string(signer2.PublicKey().Marshal()) {
|
|
t.Error("loaded key differs from generated key")
|
|
}
|
|
}
|
|
|
|
func TestExtractIP(t *testing.T) {
|
|
tests := []struct {
|
|
addr string
|
|
want string
|
|
}{
|
|
{"192.168.1.1:22", "192.168.1.1"},
|
|
{"[::1]:22", "::1"},
|
|
{"[::ffff:192.168.1.1]:22", "192.168.1.1"},
|
|
{"10.0.0.1:12345", "10.0.0.1"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.addr, func(t *testing.T) {
|
|
addr := newAddr(tt.addr, "tcp")
|
|
got := extractIP(addr)
|
|
if got != tt.want {
|
|
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIntegrationSSHConnect(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping integration test")
|
|
}
|
|
|
|
tmpDir := t.TempDir()
|
|
cfg := config.Config{
|
|
SSH: config.SSHConfig{
|
|
ListenAddr: "127.0.0.1:0",
|
|
HostKeyPath: filepath.Join(tmpDir, "host_key"),
|
|
MaxConnections: 100,
|
|
},
|
|
Auth: config.AuthConfig{
|
|
AcceptAfter: 2,
|
|
CredentialTTLDuration: time.Hour,
|
|
StaticCredentials: []config.Credential{
|
|
{Username: "root", Password: "toor", Shell: "bash"},
|
|
},
|
|
},
|
|
Shell: config.ShellConfig{
|
|
Hostname: "ubuntu-server",
|
|
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
|
|
},
|
|
LogLevel: "debug",
|
|
}
|
|
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
store := storage.NewMemoryStore()
|
|
srv, err := New(cfg, store, logger, metrics.New("test"))
|
|
if err != nil {
|
|
t.Fatalf("creating server: %v", err)
|
|
}
|
|
|
|
// Use a listener to get the actual port.
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
addr := listener.Addr().String()
|
|
listener.Close()
|
|
|
|
cfg.SSH.ListenAddr = addr
|
|
srv.cfg = cfg
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- srv.ListenAndServe(ctx)
|
|
}()
|
|
|
|
// Wait for server to be ready.
|
|
var conn net.Conn
|
|
for i := range 50 {
|
|
conn, err = net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
|
if err == nil {
|
|
conn.Close()
|
|
break
|
|
}
|
|
if i == 49 {
|
|
t.Fatalf("server not ready after retries: %v", err)
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
|
|
// Test static credential login with shell interaction.
|
|
t.Run("static_cred", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("SSH dial: %v", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
t.Fatalf("new session: %v", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
// Request PTY and shell.
|
|
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
|
t.Fatalf("request pty: %v", err)
|
|
}
|
|
|
|
stdin, err := session.StdinPipe()
|
|
if err != nil {
|
|
t.Fatalf("stdin pipe: %v", err)
|
|
}
|
|
|
|
var output bytes.Buffer
|
|
session.Stdout = &output
|
|
|
|
if err := session.Shell(); err != nil {
|
|
t.Fatalf("shell: %v", err)
|
|
}
|
|
|
|
// Wait for the prompt, then send commands.
|
|
time.Sleep(500 * time.Millisecond)
|
|
stdin.Write([]byte("pwd\r"))
|
|
time.Sleep(200 * time.Millisecond)
|
|
stdin.Write([]byte("whoami\r"))
|
|
time.Sleep(200 * time.Millisecond)
|
|
stdin.Write([]byte("exit\r"))
|
|
|
|
// Wait for session to end.
|
|
session.Wait()
|
|
|
|
out := output.String()
|
|
if !strings.Contains(out, "Welcome to Ubuntu") {
|
|
t.Errorf("output should contain banner, got: %s", out)
|
|
}
|
|
if !strings.Contains(out, "/root") {
|
|
t.Errorf("output should contain /root from pwd, got: %s", out)
|
|
}
|
|
if !strings.Contains(out, "root") {
|
|
t.Errorf("output should contain 'root' from whoami, got: %s", out)
|
|
}
|
|
|
|
// Verify session logs were recorded.
|
|
if len(store.SessionLogs) < 2 {
|
|
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
|
}
|
|
|
|
// Verify session was created with shell name.
|
|
var foundBash bool
|
|
for _, s := range store.Sessions {
|
|
if s.ShellName == "bash" {
|
|
foundBash = true
|
|
break
|
|
}
|
|
}
|
|
if !foundBash {
|
|
t.Error("expected a session with shell_name='bash'")
|
|
}
|
|
})
|
|
|
|
// Test wrong password is rejected.
|
|
t.Run("wrong_password", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.Password("wrong")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
_, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err == nil {
|
|
t.Fatal("expected error for wrong password")
|
|
}
|
|
})
|
|
|
|
// Test exec command capture.
|
|
t.Run("exec_command", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("SSH dial: %v", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
t.Fatalf("new session: %v", err)
|
|
}
|
|
defer session.Close()
|
|
|
|
// Run a command via exec (no PTY, no shell).
|
|
if err := session.Run("uname -a"); err != nil {
|
|
// Run returns an error because the server closes the channel,
|
|
// but that's expected.
|
|
_ = err
|
|
}
|
|
|
|
// Give the server a moment to store the command.
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
// Verify the exec command was captured.
|
|
sessions, err := store.GetRecentSessions(context.Background(), 50, false)
|
|
if err != nil {
|
|
t.Fatalf("GetRecentSessions: %v", err)
|
|
}
|
|
var foundExec bool
|
|
for _, s := range sessions {
|
|
if s.ExecCommand != nil && *s.ExecCommand == "uname -a" {
|
|
foundExec = true
|
|
break
|
|
}
|
|
}
|
|
if !foundExec {
|
|
t.Error("expected a session with exec_command='uname -a'")
|
|
}
|
|
})
|
|
|
|
// Test threshold acceptance: after enough failed dials, a subsequent
|
|
// dial with the same credentials should succeed via threshold or
|
|
// remembered credential.
|
|
t.Run("threshold", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "threshuser",
|
|
Auth: []ssh.AuthMethod{ssh.Password("threshpass")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
// Make several dials to accumulate failures past the threshold.
|
|
for range 5 {
|
|
c, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err == nil {
|
|
// Threshold reached, success!
|
|
c.Close()
|
|
return
|
|
}
|
|
}
|
|
|
|
// After enough failures the credential should be remembered.
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("expected threshold/remembered acceptance after many attempts: %v", err)
|
|
}
|
|
client.Close()
|
|
})
|
|
|
|
cancel()
|
|
}
|