This repository has been archived on 2026-03-09. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
oubliette/internal/server/server.go
Torjus Håkestad 8189a108d1 feat: add shell interface, registry, and bash shell emulator
Implement Phase 1.4: replaces the hardcoded banner/timeout stub with a
proper shell system. Adds a Shell interface with weighted registry for
shell selection, a RecordingChannel wrapper (pass-through for now, prep
for Phase 2.3 replay), and a bash-like shell with fake filesystem,
terminal line reader, and command handling (pwd, ls, cd, cat, whoami,
hostname, id, uname, exit). Sessions now log command/output pairs to
the store and record the shell name.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 20:24:48 +01:00

288 lines
6.9 KiB
Go

package server
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"net"
"os"
"time"
"git.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
"git.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
type Server struct {
cfg config.Config
store storage.Store
authenticator *auth.Authenticator
sshConfig *ssh.ServerConfig
logger *slog.Logger
connSem chan struct{} // semaphore limiting concurrent connections
shellRegistry *shell.Registry
}
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
registry := shell.NewRegistry()
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
return nil, fmt.Errorf("registering bash shell: %w", err)
}
s := &Server{
cfg: cfg,
store: store,
authenticator: auth.NewAuthenticator(cfg.Auth),
logger: logger,
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
shellRegistry: registry,
}
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
if err != nil {
return nil, fmt.Errorf("host key: %w", err)
}
s.sshConfig = &ssh.ServerConfig{
PasswordCallback: s.passwordCallback,
ServerVersion: "SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.6",
}
s.sshConfig.AddHostKey(hostKey)
return s, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
if err != nil {
return fmt.Errorf("listen: %w", err)
}
defer listener.Close()
s.logger.Info("SSH server listening", "addr", s.cfg.SSH.ListenAddr)
go func() {
<-ctx.Done()
listener.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
if ctx.Err() != nil {
return nil
}
s.logger.Error("accept error", "err", err)
continue
}
// Enforce max concurrent connections.
select {
case s.connSem <- struct{}{}:
go func() {
defer func() { <-s.connSem }()
s.handleConn(conn)
}()
default:
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
conn.Close()
}
}
}
func (s *Server) handleConn(conn net.Conn) {
defer conn.Close()
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
return
}
defer sshConn.Close()
s.logger.Info("SSH connection established",
"remote_addr", sshConn.RemoteAddr(),
"user", sshConn.User(),
)
go ssh.DiscardRequests(reqs)
for newChan := range chans {
if newChan.ChannelType() != "session" {
newChan.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChan.Accept()
if err != nil {
s.logger.Error("channel accept error", "err", err)
return
}
go s.handleSession(channel, requests, sshConn)
}
}
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
defer channel.Close()
// Select a shell from the registry.
selectedShell, err := s.shellRegistry.Select()
if err != nil {
s.logger.Error("failed to select shell", "err", err)
return
}
ip := extractIP(conn.RemoteAddr())
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name())
if err != nil {
s.logger.Error("failed to create session", "err", err)
} else {
defer func() {
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
s.logger.Error("failed to end session", "err", err)
}
}()
}
s.logger.Info("session started",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"shell", selectedShell.Name(),
"session_id", sessionID,
)
// Handle session requests (pty-req, shell, etc.)
go func() {
for req := range requests {
switch req.Type {
case "pty-req", "shell":
if req.WantReply {
req.Reply(true, nil)
}
default:
if req.WantReply {
req.Reply(false, nil)
}
}
}
}()
// Build session context.
var shellCfg map[string]any
if s.cfg.Shell.Shells != nil {
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
}
sessCtx := &shell.SessionContext{
SessionID: sessionID,
Username: conn.User(),
RemoteAddr: ip,
ClientVersion: string(conn.ClientVersion()),
Store: s.store,
ShellConfig: shellCfg,
CommonConfig: shell.ShellCommonConfig{
Hostname: s.cfg.Shell.Hostname,
Banner: s.cfg.Shell.Banner,
FakeUser: s.cfg.Shell.FakeUser,
},
}
// Wrap channel in RecordingChannel for future byte-level recording.
recorder := shell.NewRecordingChannel(channel)
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
s.logger.Error("shell error", "err", err, "session_id", sessionID)
}
s.logger.Info("session ended",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
)
}
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
ip := extractIP(conn.RemoteAddr())
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
s.logger.Info("auth attempt",
"remote_addr", conn.RemoteAddr(),
"username", conn.User(),
"accepted", d.Accepted,
"reason", d.Reason,
)
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil {
s.logger.Error("failed to record login attempt", "err", err)
}
if d.Accepted {
return nil, nil
}
return nil, fmt.Errorf("rejected")
}
func extractIP(addr net.Addr) string {
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
// Might not have a port, try using the string directly.
return addr.String()
}
// Normalize IPv4-mapped IPv6 addresses.
ip := net.ParseIP(host)
if ip == nil {
return host
}
if v4 := ip.To4(); v4 != nil {
return v4.String()
}
return ip.String()
}
func loadOrGenerateHostKey(path string) (ssh.Signer, error) {
data, err := os.ReadFile(path)
if err == nil {
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
return nil, fmt.Errorf("parsing host key: %w", err)
}
return signer, nil
}
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("reading host key: %w", err)
}
// Generate new Ed25519 key.
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generating key: %w", err)
}
privBytes, err := ssh.MarshalPrivateKey(priv, "")
if err != nil {
return nil, fmt.Errorf("marshaling key: %w", err)
}
pemData := pem.EncodeToMemory(privBytes)
if err := os.WriteFile(path, pemData, 0600); err != nil {
return nil, fmt.Errorf("writing host key: %w", err)
}
signer, err := ssh.ParsePrivateKey(pemData)
if err != nil {
return nil, fmt.Errorf("parsing generated key: %w", err)
}
slog.Info("generated new host key", "path", path)
return signer, nil
}