Bots often send commands via `ssh user@host <command>` (exec request) rather than requesting an interactive shell. These were previously rejected silently. Now exec commands are captured, stored on the session record, and displayed in the web UI session detail page. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
477 lines
13 KiB
Go
477 lines
13 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
|
|
"git.t-juice.club/torjus/oubliette/internal/auth"
|
|
"git.t-juice.club/torjus/oubliette/internal/config"
|
|
"git.t-juice.club/torjus/oubliette/internal/detection"
|
|
"git.t-juice.club/torjus/oubliette/internal/geoip"
|
|
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
|
"git.t-juice.club/torjus/oubliette/internal/notify"
|
|
"git.t-juice.club/torjus/oubliette/internal/shell"
|
|
"git.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
|
"git.t-juice.club/torjus/oubliette/internal/shell/banking"
|
|
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
|
"git.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
|
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
|
"git.t-juice.club/torjus/oubliette/internal/storage"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type Server struct {
|
|
cfg config.Config
|
|
store storage.Store
|
|
authenticator *auth.Authenticator
|
|
sshConfig *ssh.ServerConfig
|
|
logger *slog.Logger
|
|
connSem chan struct{} // semaphore limiting concurrent connections
|
|
shellRegistry *shell.Registry
|
|
notifier notify.Sender
|
|
metrics *metrics.Metrics
|
|
geoip *geoip.Reader
|
|
}
|
|
|
|
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
|
|
registry := shell.NewRegistry()
|
|
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering bash shell: %w", err)
|
|
}
|
|
if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering fridge shell: %w", err)
|
|
}
|
|
if err := registry.Register(banking.NewBankingShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering banking shell: %w", err)
|
|
}
|
|
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering adventure shell: %w", err)
|
|
}
|
|
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering cisco shell: %w", err)
|
|
}
|
|
|
|
geo, err := geoip.New()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("opening geoip database: %w", err)
|
|
}
|
|
|
|
s := &Server{
|
|
cfg: cfg,
|
|
store: store,
|
|
authenticator: auth.NewAuthenticator(cfg.Auth),
|
|
logger: logger,
|
|
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
|
shellRegistry: registry,
|
|
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
|
|
metrics: m,
|
|
geoip: geo,
|
|
}
|
|
|
|
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("host key: %w", err)
|
|
}
|
|
|
|
s.sshConfig = &ssh.ServerConfig{
|
|
PasswordCallback: s.passwordCallback,
|
|
ServerVersion: "SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.6",
|
|
}
|
|
s.sshConfig.AddHostKey(hostKey)
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Server) ListenAndServe(ctx context.Context) error {
|
|
defer s.geoip.Close()
|
|
|
|
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen: %w", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
s.logger.Info("SSH server listening", "addr", s.cfg.SSH.ListenAddr)
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
listener.Close()
|
|
}()
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil
|
|
}
|
|
s.logger.Error("accept error", "err", err)
|
|
continue
|
|
}
|
|
|
|
// Enforce max concurrent connections.
|
|
select {
|
|
case s.connSem <- struct{}{}:
|
|
s.metrics.SSHConnectionsActive.Inc()
|
|
go func() {
|
|
defer func() {
|
|
<-s.connSem
|
|
s.metrics.SSHConnectionsActive.Dec()
|
|
}()
|
|
s.handleConn(conn)
|
|
}()
|
|
default:
|
|
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
|
|
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
|
conn.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleConn(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
|
if err != nil {
|
|
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
|
|
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
|
|
return
|
|
}
|
|
defer sshConn.Close()
|
|
|
|
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
|
|
s.logger.Info("SSH connection established",
|
|
"remote_addr", sshConn.RemoteAddr(),
|
|
"user", sshConn.User(),
|
|
)
|
|
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
for newChan := range chans {
|
|
if newChan.ChannelType() != "session" {
|
|
newChan.Reject(ssh.UnknownChannelType, "unknown channel type")
|
|
continue
|
|
}
|
|
|
|
channel, requests, err := newChan.Accept()
|
|
if err != nil {
|
|
s.logger.Error("channel accept error", "err", err)
|
|
return
|
|
}
|
|
|
|
go s.handleSession(channel, requests, sshConn)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
|
defer channel.Close()
|
|
|
|
// Select a shell from the registry.
|
|
// If the auth layer specified a shell preference, use it; otherwise random.
|
|
var selectedShell shell.Shell
|
|
if conn.Permissions != nil && conn.Permissions.Extensions["shell"] != "" {
|
|
shellName := conn.Permissions.Extensions["shell"]
|
|
sh, ok := s.shellRegistry.Get(shellName)
|
|
if ok {
|
|
selectedShell = sh
|
|
} else {
|
|
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
|
|
}
|
|
}
|
|
if selectedShell == nil {
|
|
var err error
|
|
selectedShell, err = s.shellRegistry.Select()
|
|
if err != nil {
|
|
s.logger.Error("failed to select shell", "err", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
ip := extractIP(conn.RemoteAddr())
|
|
country := s.geoip.Lookup(ip)
|
|
sessionStart := time.Now()
|
|
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
|
|
if err != nil {
|
|
s.logger.Error("failed to create session", "err", err)
|
|
} else {
|
|
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
|
|
s.metrics.SessionsActive.Inc()
|
|
defer func() {
|
|
s.metrics.SessionsActive.Dec()
|
|
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
|
|
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
|
|
s.logger.Error("failed to end session", "err", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
s.logger.Info("session started",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"user", conn.User(),
|
|
"shell", selectedShell.Name(),
|
|
"session_id", sessionID,
|
|
)
|
|
|
|
// Send session_started notification.
|
|
connectedAt := time.Now()
|
|
sessionInfo := notify.SessionInfo{
|
|
ID: sessionID,
|
|
IP: ip,
|
|
Username: conn.User(),
|
|
ShellName: selectedShell.Name(),
|
|
ConnectedAt: notify.FormatConnectedAt(connectedAt),
|
|
}
|
|
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
|
|
defer s.notifier.CleanupSession(sessionID)
|
|
|
|
// Handle session requests (pty-req, shell, exec, etc.)
|
|
execCh := make(chan string, 1)
|
|
go func() {
|
|
defer close(execCh)
|
|
for req := range requests {
|
|
switch req.Type {
|
|
case "pty-req", "shell":
|
|
if req.WantReply {
|
|
req.Reply(true, nil)
|
|
}
|
|
case "exec":
|
|
if req.WantReply {
|
|
req.Reply(true, nil)
|
|
}
|
|
var payload struct{ Command string }
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
|
|
execCh <- payload.Command
|
|
}
|
|
default:
|
|
if req.WantReply {
|
|
req.Reply(false, nil)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Check for exec request before proceeding to interactive shell.
|
|
select {
|
|
case cmd, ok := <-execCh:
|
|
if ok && cmd != "" {
|
|
s.logger.Info("exec command received",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"user", conn.User(),
|
|
"session_id", sessionID,
|
|
"command", cmd,
|
|
)
|
|
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
|
|
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
|
|
}
|
|
s.metrics.ExecCommandsTotal.Inc()
|
|
// Send exit-status 0 and close channel.
|
|
exitPayload := make([]byte, 4) // uint32(0)
|
|
_, _ = channel.SendRequest("exit-status", false, exitPayload)
|
|
return
|
|
}
|
|
case <-time.After(500 * time.Millisecond):
|
|
// No exec request within timeout — proceed with interactive shell.
|
|
}
|
|
|
|
// Build session context.
|
|
var shellCfg map[string]any
|
|
if s.cfg.Shell.Shells != nil {
|
|
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
|
|
}
|
|
sessCtx := &shell.SessionContext{
|
|
SessionID: sessionID,
|
|
Username: conn.User(),
|
|
RemoteAddr: ip,
|
|
ClientVersion: string(conn.ClientVersion()),
|
|
Store: s.store,
|
|
ShellConfig: shellCfg,
|
|
CommonConfig: shell.ShellCommonConfig{
|
|
Hostname: s.cfg.Shell.Hostname,
|
|
Banner: s.cfg.Shell.Banner,
|
|
FakeUser: s.cfg.Shell.FakeUser,
|
|
},
|
|
OnCommand: func(sh string) {
|
|
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
|
|
},
|
|
}
|
|
|
|
// Wrap channel in RecordingChannel.
|
|
recorder := shell.NewRecordingChannel(channel)
|
|
|
|
// Always record session events for replay.
|
|
eventRec := shell.NewEventRecorder(sessionID, s.store, s.logger)
|
|
eventRec.Start(context.Background())
|
|
defer eventRec.Close()
|
|
recorder.AddCallback(eventRec.RecordEvent)
|
|
|
|
// Set up detection scorer if enabled.
|
|
var scorer *detection.Scorer
|
|
var scoreCancel context.CancelFunc
|
|
if s.cfg.Detection.Enabled {
|
|
scorer = detection.NewScorer()
|
|
recorder.AddCallback(func(ts time.Time, direction int, data []byte) {
|
|
scorer.RecordEvent(ts, direction, data)
|
|
})
|
|
|
|
var scoreCtx context.Context
|
|
scoreCtx, scoreCancel = context.WithCancel(context.Background())
|
|
go s.runScoreUpdater(scoreCtx, sessionID, scorer, sessionInfo)
|
|
}
|
|
|
|
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
|
|
s.logger.Error("shell error", "err", err, "session_id", sessionID)
|
|
}
|
|
|
|
// Stop score updater and write final score.
|
|
if scoreCancel != nil {
|
|
scoreCancel()
|
|
}
|
|
if scorer != nil {
|
|
finalScore := scorer.Score()
|
|
s.metrics.HumanScore.Observe(finalScore)
|
|
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
|
|
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
|
|
}
|
|
s.logger.Info("session ended",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"user", conn.User(),
|
|
"session_id", sessionID,
|
|
"human_score", finalScore,
|
|
)
|
|
} else {
|
|
s.logger.Info("session ended",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"user", conn.User(),
|
|
"session_id", sessionID,
|
|
)
|
|
}
|
|
}
|
|
|
|
// runScoreUpdater periodically computes the human score, writes it to the DB,
|
|
// and triggers a notification if the threshold is crossed.
|
|
func (s *Server) runScoreUpdater(ctx context.Context, sessionID string, scorer *detection.Scorer, sessionInfo notify.SessionInfo) {
|
|
ticker := time.NewTicker(s.cfg.Detection.UpdateIntervalDuration)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
score := scorer.Score()
|
|
if err := s.store.UpdateHumanScore(ctx, sessionID, score); err != nil {
|
|
s.logger.Error("failed to update human score", "err", err, "session_id", sessionID)
|
|
continue
|
|
}
|
|
s.logger.Debug("human score updated", "session_id", sessionID, "score", score)
|
|
|
|
if score >= s.cfg.Detection.Threshold {
|
|
info := sessionInfo
|
|
info.HumanScore = score
|
|
s.notifier.Notify(ctx, notify.EventHumanDetected, info)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
ip := extractIP(conn.RemoteAddr())
|
|
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
|
|
|
|
if d.Accepted {
|
|
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
|
|
} else {
|
|
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
|
|
}
|
|
|
|
s.logger.Info("auth attempt",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"username", conn.User(),
|
|
"accepted", d.Accepted,
|
|
"reason", d.Reason,
|
|
)
|
|
|
|
country := s.geoip.Lookup(ip)
|
|
if country != "" {
|
|
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
|
|
}
|
|
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
|
|
s.logger.Error("failed to record login attempt", "err", err)
|
|
}
|
|
|
|
if d.Accepted {
|
|
var perms *ssh.Permissions
|
|
if d.Shell != "" {
|
|
perms = &ssh.Permissions{
|
|
Extensions: map[string]string{"shell": d.Shell},
|
|
}
|
|
}
|
|
return perms, nil
|
|
}
|
|
return nil, fmt.Errorf("rejected")
|
|
}
|
|
|
|
func extractIP(addr net.Addr) string {
|
|
host, _, err := net.SplitHostPort(addr.String())
|
|
if err != nil {
|
|
// Might not have a port, try using the string directly.
|
|
return addr.String()
|
|
}
|
|
// Normalize IPv4-mapped IPv6 addresses.
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
return host
|
|
}
|
|
if v4 := ip.To4(); v4 != nil {
|
|
return v4.String()
|
|
}
|
|
return ip.String()
|
|
}
|
|
|
|
func loadOrGenerateHostKey(path string) (ssh.Signer, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err == nil {
|
|
signer, err := ssh.ParsePrivateKey(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing host key: %w", err)
|
|
}
|
|
return signer, nil
|
|
}
|
|
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
return nil, fmt.Errorf("reading host key: %w", err)
|
|
}
|
|
|
|
// Generate new Ed25519 key.
|
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating key: %w", err)
|
|
}
|
|
|
|
privBytes, err := ssh.MarshalPrivateKey(priv, "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshaling key: %w", err)
|
|
}
|
|
|
|
pemData := pem.EncodeToMemory(privBytes)
|
|
if err := os.WriteFile(path, pemData, 0600); err != nil {
|
|
return nil, fmt.Errorf("writing host key: %w", err)
|
|
}
|
|
|
|
signer, err := ssh.ParsePrivateKey(pemData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing generated key: %w", err)
|
|
}
|
|
|
|
slog.Info("generated new host key", "path", path)
|
|
return signer, nil
|
|
}
|
|
|