feat: add shell interface, registry, and bash shell emulator

Implement Phase 1.4: replaces the hardcoded banner/timeout stub with a
proper shell system. Adds a Shell interface with weighted registry for
shell selection, a RecordingChannel wrapper (pass-through for now, prep
for Phase 2.3 replay), and a bash-like shell with fake filesystem,
terminal line reader, and command handling (pwd, ls, cd, cat, whoami,
hostname, id, uname, exit). Sessions now log command/output pairs to
the store and record the shell name.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 20:24:48 +01:00
parent ae9924ffbb
commit 8189a108d1
17 changed files with 1503 additions and 41 deletions

View File

@@ -14,28 +14,35 @@ import (
"git.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
"git.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
const sessionTimeout = 30 * time.Second
type Server struct {
cfg config.Config
store storage.Store
authenticator *auth.Authenticator
sshConfig *ssh.ServerConfig
logger *slog.Logger
connSem chan struct{} // semaphore limiting concurrent connections
cfg config.Config
store storage.Store
authenticator *auth.Authenticator
sshConfig *ssh.ServerConfig
logger *slog.Logger
connSem chan struct{} // semaphore limiting concurrent connections
shellRegistry *shell.Registry
}
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
registry := shell.NewRegistry()
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
return nil, fmt.Errorf("registering bash shell: %w", err)
}
s := &Server{
cfg: cfg,
store: store,
authenticator: auth.NewAuthenticator(cfg.Auth),
logger: logger,
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
shellRegistry: registry,
}
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
@@ -126,8 +133,15 @@ func (s *Server) handleConn(conn net.Conn) {
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
defer channel.Close()
// Select a shell from the registry.
selectedShell, err := s.shellRegistry.Select()
if err != nil {
s.logger.Error("failed to select shell", "err", err)
return
}
ip := extractIP(conn.RemoteAddr())
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "")
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
if err != nil {
s.logger.Error("failed to create session", "err", err)
} else {
@@ -138,6 +152,13 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}()
}
s.logger.Info("session started",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"shell", selectedShell.Name(),
"session_id", sessionID,
)
// Handle session requests (pty-req, shell, etc.)
go func() {
for req := range requests {
@@ -154,33 +175,37 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request
}
}()
// Write a fake banner.
fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n")
fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User())
// Hold connection open until timeout or client disconnect.
timer := time.NewTimer(sessionTimeout)
defer timer.Stop()
done := make(chan struct{})
go func() {
buf := make([]byte, 256)
for {
_, err := channel.Read(buf)
if err != nil {
close(done)
return
}
}
}()
select {
case <-timer.C:
s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User())
case <-done:
s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User())
// Build session context.
var shellCfg map[string]any
if s.cfg.Shell.Shells != nil {
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
}
sessCtx := &shell.SessionContext{
SessionID: sessionID,
Username: conn.User(),
RemoteAddr: ip,
ClientVersion: string(conn.ClientVersion()),
Store: s.store,
ShellConfig: shellCfg,
CommonConfig: shell.ShellCommonConfig{
Hostname: s.cfg.Shell.Hostname,
Banner: s.cfg.Shell.Banner,
FakeUser: s.cfg.Shell.FakeUser,
},
}
// Wrap channel in RecordingChannel for future byte-level recording.
recorder := shell.NewRecordingChannel(channel)
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
s.logger.Error("shell error", "err", err, "session_id", sessionID)
}
s.logger.Info("session ended",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
)
}
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {

View File

@@ -1,11 +1,13 @@
package server
import (
"bytes"
"context"
"log/slog"
"net"
"os"
"path/filepath"
"strings"
"testing"
"time"
@@ -109,6 +111,10 @@ func TestIntegrationSSHConnect(t *testing.T) {
{Username: "root", Password: "toor"},
},
},
Shell: config.ShellConfig{
Hostname: "ubuntu-server",
Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n",
},
LogLevel: "debug",
}
@@ -152,7 +158,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
time.Sleep(50 * time.Millisecond)
}
// Test static credential login.
// Test static credential login with shell interaction.
t.Run("static_cred", func(t *testing.T) {
clientCfg := &ssh.ClientConfig{
User: "root",
@@ -172,6 +178,62 @@ func TestIntegrationSSHConnect(t *testing.T) {
t.Fatalf("new session: %v", err)
}
defer session.Close()
// Request PTY and shell.
if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil {
t.Fatalf("request pty: %v", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("stdin pipe: %v", err)
}
var output bytes.Buffer
session.Stdout = &output
if err := session.Shell(); err != nil {
t.Fatalf("shell: %v", err)
}
// Wait for the prompt, then send commands.
time.Sleep(500 * time.Millisecond)
stdin.Write([]byte("pwd\r"))
time.Sleep(200 * time.Millisecond)
stdin.Write([]byte("whoami\r"))
time.Sleep(200 * time.Millisecond)
stdin.Write([]byte("exit\r"))
// Wait for session to end.
session.Wait()
out := output.String()
if !strings.Contains(out, "Welcome to Ubuntu") {
t.Errorf("output should contain banner, got: %s", out)
}
if !strings.Contains(out, "/root") {
t.Errorf("output should contain /root from pwd, got: %s", out)
}
if !strings.Contains(out, "root") {
t.Errorf("output should contain 'root' from whoami, got: %s", out)
}
// Verify session logs were recorded.
if len(store.SessionLogs) < 2 {
t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs))
}
// Verify session was created with shell name.
var foundBash bool
for _, s := range store.Sessions {
if s.ShellName == "bash" {
foundBash = true
break
}
}
if !foundBash {
t.Error("expected a session with shell_name='bash'")
}
})
// Test wrong password is rejected.