This repository has been archived on 2026-03-09. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
oubliette/internal/server/server.go
Torjus Håkestad 0133d956a5 feat: capture SSH exec commands (PLAN.md 4.4)
Bots often send commands via `ssh user@host <command>` (exec request)
rather than requesting an interactive shell. These were previously
rejected silently. Now exec commands are captured, stored on the session
record, and displayed in the web UI session detail page.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:43:11 +01:00

477 lines
13 KiB
Go

package server
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"errors"
"fmt"
"log/slog"
"net"
"os"
"time"
"git.t-juice.club/torjus/oubliette/internal/auth"
"git.t-juice.club/torjus/oubliette/internal/config"
"git.t-juice.club/torjus/oubliette/internal/detection"
"git.t-juice.club/torjus/oubliette/internal/geoip"
"git.t-juice.club/torjus/oubliette/internal/metrics"
"git.t-juice.club/torjus/oubliette/internal/notify"
"git.t-juice.club/torjus/oubliette/internal/shell"
"git.t-juice.club/torjus/oubliette/internal/shell/adventure"
"git.t-juice.club/torjus/oubliette/internal/shell/banking"
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
"git.t-juice.club/torjus/oubliette/internal/shell/cisco"
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
"git.t-juice.club/torjus/oubliette/internal/storage"
"golang.org/x/crypto/ssh"
)
type Server struct {
cfg config.Config
store storage.Store
authenticator *auth.Authenticator
sshConfig *ssh.ServerConfig
logger *slog.Logger
connSem chan struct{} // semaphore limiting concurrent connections
shellRegistry *shell.Registry
notifier notify.Sender
metrics *metrics.Metrics
geoip *geoip.Reader
}
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
registry := shell.NewRegistry()
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
return nil, fmt.Errorf("registering bash shell: %w", err)
}
if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil {
return nil, fmt.Errorf("registering fridge shell: %w", err)
}
if err := registry.Register(banking.NewBankingShell(), 1); err != nil {
return nil, fmt.Errorf("registering banking shell: %w", err)
}
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
return nil, fmt.Errorf("registering adventure shell: %w", err)
}
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
return nil, fmt.Errorf("registering cisco shell: %w", err)
}
geo, err := geoip.New()
if err != nil {
return nil, fmt.Errorf("opening geoip database: %w", err)
}
s := &Server{
cfg: cfg,
store: store,
authenticator: auth.NewAuthenticator(cfg.Auth),
logger: logger,
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
shellRegistry: registry,
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
metrics: m,
geoip: geo,
}
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
if err != nil {
return nil, fmt.Errorf("host key: %w", err)
}
s.sshConfig = &ssh.ServerConfig{
PasswordCallback: s.passwordCallback,
ServerVersion: "SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.6",
}
s.sshConfig.AddHostKey(hostKey)
return s, nil
}
func (s *Server) ListenAndServe(ctx context.Context) error {
defer s.geoip.Close()
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
if err != nil {
return fmt.Errorf("listen: %w", err)
}
defer listener.Close()
s.logger.Info("SSH server listening", "addr", s.cfg.SSH.ListenAddr)
go func() {
<-ctx.Done()
listener.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
if ctx.Err() != nil {
return nil
}
s.logger.Error("accept error", "err", err)
continue
}
// Enforce max concurrent connections.
select {
case s.connSem <- struct{}{}:
s.metrics.SSHConnectionsActive.Inc()
go func() {
defer func() {
<-s.connSem
s.metrics.SSHConnectionsActive.Dec()
}()
s.handleConn(conn)
}()
default:
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
conn.Close()
}
}
}
func (s *Server) handleConn(conn net.Conn) {
defer conn.Close()
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
return
}
defer sshConn.Close()
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
s.logger.Info("SSH connection established",
"remote_addr", sshConn.RemoteAddr(),
"user", sshConn.User(),
)
go ssh.DiscardRequests(reqs)
for newChan := range chans {
if newChan.ChannelType() != "session" {
newChan.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChan.Accept()
if err != nil {
s.logger.Error("channel accept error", "err", err)
return
}
go s.handleSession(channel, requests, sshConn)
}
}
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
defer channel.Close()
// Select a shell from the registry.
// If the auth layer specified a shell preference, use it; otherwise random.
var selectedShell shell.Shell
if conn.Permissions != nil && conn.Permissions.Extensions["shell"] != "" {
shellName := conn.Permissions.Extensions["shell"]
sh, ok := s.shellRegistry.Get(shellName)
if ok {
selectedShell = sh
} else {
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
}
}
if selectedShell == nil {
var err error
selectedShell, err = s.shellRegistry.Select()
if err != nil {
s.logger.Error("failed to select shell", "err", err)
return
}
}
ip := extractIP(conn.RemoteAddr())
country := s.geoip.Lookup(ip)
sessionStart := time.Now()
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
if err != nil {
s.logger.Error("failed to create session", "err", err)
} else {
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
s.metrics.SessionsActive.Inc()
defer func() {
s.metrics.SessionsActive.Dec()
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
s.logger.Error("failed to end session", "err", err)
}
}()
}
s.logger.Info("session started",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"shell", selectedShell.Name(),
"session_id", sessionID,
)
// Send session_started notification.
connectedAt := time.Now()
sessionInfo := notify.SessionInfo{
ID: sessionID,
IP: ip,
Username: conn.User(),
ShellName: selectedShell.Name(),
ConnectedAt: notify.FormatConnectedAt(connectedAt),
}
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
defer s.notifier.CleanupSession(sessionID)
// Handle session requests (pty-req, shell, exec, etc.)
execCh := make(chan string, 1)
go func() {
defer close(execCh)
for req := range requests {
switch req.Type {
case "pty-req", "shell":
if req.WantReply {
req.Reply(true, nil)
}
case "exec":
if req.WantReply {
req.Reply(true, nil)
}
var payload struct{ Command string }
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
execCh <- payload.Command
}
default:
if req.WantReply {
req.Reply(false, nil)
}
}
}
}()
// Check for exec request before proceeding to interactive shell.
select {
case cmd, ok := <-execCh:
if ok && cmd != "" {
s.logger.Info("exec command received",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
"command", cmd,
)
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
}
s.metrics.ExecCommandsTotal.Inc()
// Send exit-status 0 and close channel.
exitPayload := make([]byte, 4) // uint32(0)
_, _ = channel.SendRequest("exit-status", false, exitPayload)
return
}
case <-time.After(500 * time.Millisecond):
// No exec request within timeout — proceed with interactive shell.
}
// Build session context.
var shellCfg map[string]any
if s.cfg.Shell.Shells != nil {
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
}
sessCtx := &shell.SessionContext{
SessionID: sessionID,
Username: conn.User(),
RemoteAddr: ip,
ClientVersion: string(conn.ClientVersion()),
Store: s.store,
ShellConfig: shellCfg,
CommonConfig: shell.ShellCommonConfig{
Hostname: s.cfg.Shell.Hostname,
Banner: s.cfg.Shell.Banner,
FakeUser: s.cfg.Shell.FakeUser,
},
OnCommand: func(sh string) {
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
},
}
// Wrap channel in RecordingChannel.
recorder := shell.NewRecordingChannel(channel)
// Always record session events for replay.
eventRec := shell.NewEventRecorder(sessionID, s.store, s.logger)
eventRec.Start(context.Background())
defer eventRec.Close()
recorder.AddCallback(eventRec.RecordEvent)
// Set up detection scorer if enabled.
var scorer *detection.Scorer
var scoreCancel context.CancelFunc
if s.cfg.Detection.Enabled {
scorer = detection.NewScorer()
recorder.AddCallback(func(ts time.Time, direction int, data []byte) {
scorer.RecordEvent(ts, direction, data)
})
var scoreCtx context.Context
scoreCtx, scoreCancel = context.WithCancel(context.Background())
go s.runScoreUpdater(scoreCtx, sessionID, scorer, sessionInfo)
}
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
s.logger.Error("shell error", "err", err, "session_id", sessionID)
}
// Stop score updater and write final score.
if scoreCancel != nil {
scoreCancel()
}
if scorer != nil {
finalScore := scorer.Score()
s.metrics.HumanScore.Observe(finalScore)
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
}
s.logger.Info("session ended",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
"human_score", finalScore,
)
} else {
s.logger.Info("session ended",
"remote_addr", conn.RemoteAddr(),
"user", conn.User(),
"session_id", sessionID,
)
}
}
// runScoreUpdater periodically computes the human score, writes it to the DB,
// and triggers a notification if the threshold is crossed.
func (s *Server) runScoreUpdater(ctx context.Context, sessionID string, scorer *detection.Scorer, sessionInfo notify.SessionInfo) {
ticker := time.NewTicker(s.cfg.Detection.UpdateIntervalDuration)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
score := scorer.Score()
if err := s.store.UpdateHumanScore(ctx, sessionID, score); err != nil {
s.logger.Error("failed to update human score", "err", err, "session_id", sessionID)
continue
}
s.logger.Debug("human score updated", "session_id", sessionID, "score", score)
if score >= s.cfg.Detection.Threshold {
info := sessionInfo
info.HumanScore = score
s.notifier.Notify(ctx, notify.EventHumanDetected, info)
}
}
}
}
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
ip := extractIP(conn.RemoteAddr())
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
if d.Accepted {
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
} else {
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
}
s.logger.Info("auth attempt",
"remote_addr", conn.RemoteAddr(),
"username", conn.User(),
"accepted", d.Accepted,
"reason", d.Reason,
)
country := s.geoip.Lookup(ip)
if country != "" {
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
}
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
s.logger.Error("failed to record login attempt", "err", err)
}
if d.Accepted {
var perms *ssh.Permissions
if d.Shell != "" {
perms = &ssh.Permissions{
Extensions: map[string]string{"shell": d.Shell},
}
}
return perms, nil
}
return nil, fmt.Errorf("rejected")
}
func extractIP(addr net.Addr) string {
host, _, err := net.SplitHostPort(addr.String())
if err != nil {
// Might not have a port, try using the string directly.
return addr.String()
}
// Normalize IPv4-mapped IPv6 addresses.
ip := net.ParseIP(host)
if ip == nil {
return host
}
if v4 := ip.To4(); v4 != nil {
return v4.String()
}
return ip.String()
}
func loadOrGenerateHostKey(path string) (ssh.Signer, error) {
data, err := os.ReadFile(path)
if err == nil {
signer, err := ssh.ParsePrivateKey(data)
if err != nil {
return nil, fmt.Errorf("parsing host key: %w", err)
}
return signer, nil
}
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("reading host key: %w", err)
}
// Generate new Ed25519 key.
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generating key: %w", err)
}
privBytes, err := ssh.MarshalPrivateKey(priv, "")
if err != nil {
return nil, fmt.Errorf("marshaling key: %w", err)
}
pemData := pem.EncodeToMemory(privBytes)
if err := os.WriteFile(path, pemData, 0600); err != nil {
return nil, fmt.Errorf("writing host key: %w", err)
}
signer, err := ssh.ParsePrivateKey(pemData)
if err != nil {
return nil, fmt.Errorf("parsing generated key: %w", err)
}
slog.Info("generated new host key", "path", path)
return signer, nil
}