package server import ( "context" "crypto/ed25519" "crypto/rand" "encoding/pem" "errors" "fmt" "log/slog" "net" "os" "time" "git.t-juice.club/torjus/oubliette/internal/auth" "git.t-juice.club/torjus/oubliette/internal/config" "git.t-juice.club/torjus/oubliette/internal/detection" "git.t-juice.club/torjus/oubliette/internal/metrics" "git.t-juice.club/torjus/oubliette/internal/notify" "git.t-juice.club/torjus/oubliette/internal/shell" "git.t-juice.club/torjus/oubliette/internal/shell/adventure" "git.t-juice.club/torjus/oubliette/internal/shell/banking" "git.t-juice.club/torjus/oubliette/internal/shell/bash" "git.t-juice.club/torjus/oubliette/internal/shell/cisco" "git.t-juice.club/torjus/oubliette/internal/shell/fridge" "git.t-juice.club/torjus/oubliette/internal/storage" "golang.org/x/crypto/ssh" ) type Server struct { cfg config.Config store storage.Store authenticator *auth.Authenticator sshConfig *ssh.ServerConfig logger *slog.Logger connSem chan struct{} // semaphore limiting concurrent connections shellRegistry *shell.Registry notifier notify.Sender metrics *metrics.Metrics } func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) { registry := shell.NewRegistry() if err := registry.Register(bash.NewBashShell(), 1); err != nil { return nil, fmt.Errorf("registering bash shell: %w", err) } if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil { return nil, fmt.Errorf("registering fridge shell: %w", err) } if err := registry.Register(banking.NewBankingShell(), 1); err != nil { return nil, fmt.Errorf("registering banking shell: %w", err) } if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil { return nil, fmt.Errorf("registering adventure shell: %w", err) } if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil { return nil, fmt.Errorf("registering cisco shell: %w", err) } s := &Server{ cfg: cfg, store: store, authenticator: auth.NewAuthenticator(cfg.Auth), logger: logger, connSem: make(chan struct{}, cfg.SSH.MaxConnections), shellRegistry: registry, notifier: notify.NewSender(cfg.Notify.Webhooks, logger), metrics: m, } hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath) if err != nil { return nil, fmt.Errorf("host key: %w", err) } s.sshConfig = &ssh.ServerConfig{ PasswordCallback: s.passwordCallback, ServerVersion: "SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.6", } s.sshConfig.AddHostKey(hostKey) return s, nil } func (s *Server) ListenAndServe(ctx context.Context) error { listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr) if err != nil { return fmt.Errorf("listen: %w", err) } defer listener.Close() s.logger.Info("SSH server listening", "addr", s.cfg.SSH.ListenAddr) go func() { <-ctx.Done() listener.Close() }() for { conn, err := listener.Accept() if err != nil { if ctx.Err() != nil { return nil } s.logger.Error("accept error", "err", err) continue } // Enforce max concurrent connections. select { case s.connSem <- struct{}{}: s.metrics.SSHConnectionsActive.Inc() go func() { defer func() { <-s.connSem s.metrics.SSHConnectionsActive.Dec() }() s.handleConn(conn) }() default: s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc() s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr()) conn.Close() } } } func (s *Server) handleConn(conn net.Conn) { defer conn.Close() sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig) if err != nil { s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc() s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err) return } defer sshConn.Close() s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc() s.logger.Info("SSH connection established", "remote_addr", sshConn.RemoteAddr(), "user", sshConn.User(), ) go ssh.DiscardRequests(reqs) for newChan := range chans { if newChan.ChannelType() != "session" { newChan.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChan.Accept() if err != nil { s.logger.Error("channel accept error", "err", err) return } go s.handleSession(channel, requests, sshConn) } } func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) { defer channel.Close() // Select a shell from the registry. // If the auth layer specified a shell preference, use it; otherwise random. var selectedShell shell.Shell if conn.Permissions != nil && conn.Permissions.Extensions["shell"] != "" { shellName := conn.Permissions.Extensions["shell"] sh, ok := s.shellRegistry.Get(shellName) if ok { selectedShell = sh } else { s.logger.Warn("configured shell not found, falling back to random", "shell", shellName) } } if selectedShell == nil { var err error selectedShell, err = s.shellRegistry.Select() if err != nil { s.logger.Error("failed to select shell", "err", err) return } } ip := extractIP(conn.RemoteAddr()) sessionStart := time.Now() sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name()) if err != nil { s.logger.Error("failed to create session", "err", err) } else { s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc() s.metrics.SessionsActive.Inc() defer func() { s.metrics.SessionsActive.Dec() s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds()) if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil { s.logger.Error("failed to end session", "err", err) } }() } s.logger.Info("session started", "remote_addr", conn.RemoteAddr(), "user", conn.User(), "shell", selectedShell.Name(), "session_id", sessionID, ) // Send session_started notification. connectedAt := time.Now() sessionInfo := notify.SessionInfo{ ID: sessionID, IP: ip, Username: conn.User(), ShellName: selectedShell.Name(), ConnectedAt: notify.FormatConnectedAt(connectedAt), } s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo) defer s.notifier.CleanupSession(sessionID) // Handle session requests (pty-req, shell, etc.) go func() { for req := range requests { switch req.Type { case "pty-req", "shell": if req.WantReply { req.Reply(true, nil) } default: if req.WantReply { req.Reply(false, nil) } } } }() // Build session context. var shellCfg map[string]any if s.cfg.Shell.Shells != nil { shellCfg = s.cfg.Shell.Shells[selectedShell.Name()] } sessCtx := &shell.SessionContext{ SessionID: sessionID, Username: conn.User(), RemoteAddr: ip, ClientVersion: string(conn.ClientVersion()), Store: s.store, ShellConfig: shellCfg, CommonConfig: shell.ShellCommonConfig{ Hostname: s.cfg.Shell.Hostname, Banner: s.cfg.Shell.Banner, FakeUser: s.cfg.Shell.FakeUser, }, } // Wrap channel in RecordingChannel. recorder := shell.NewRecordingChannel(channel) // Always record session events for replay. eventRec := shell.NewEventRecorder(sessionID, s.store, s.logger) eventRec.Start(context.Background()) defer eventRec.Close() recorder.AddCallback(eventRec.RecordEvent) // Set up detection scorer if enabled. var scorer *detection.Scorer var scoreCancel context.CancelFunc if s.cfg.Detection.Enabled { scorer = detection.NewScorer() recorder.AddCallback(func(ts time.Time, direction int, data []byte) { scorer.RecordEvent(ts, direction, data) }) var scoreCtx context.Context scoreCtx, scoreCancel = context.WithCancel(context.Background()) go s.runScoreUpdater(scoreCtx, sessionID, scorer, sessionInfo) } if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil { s.logger.Error("shell error", "err", err, "session_id", sessionID) } // Stop score updater and write final score. if scoreCancel != nil { scoreCancel() } if scorer != nil { finalScore := scorer.Score() if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil { s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID) } s.logger.Info("session ended", "remote_addr", conn.RemoteAddr(), "user", conn.User(), "session_id", sessionID, "human_score", finalScore, ) } else { s.logger.Info("session ended", "remote_addr", conn.RemoteAddr(), "user", conn.User(), "session_id", sessionID, ) } } // runScoreUpdater periodically computes the human score, writes it to the DB, // and triggers a notification if the threshold is crossed. func (s *Server) runScoreUpdater(ctx context.Context, sessionID string, scorer *detection.Scorer, sessionInfo notify.SessionInfo) { ticker := time.NewTicker(s.cfg.Detection.UpdateIntervalDuration) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: score := scorer.Score() if err := s.store.UpdateHumanScore(ctx, sessionID, score); err != nil { s.logger.Error("failed to update human score", "err", err, "session_id", sessionID) continue } s.logger.Debug("human score updated", "session_id", sessionID, "score", score) if score >= s.cfg.Detection.Threshold { info := sessionInfo info.HumanScore = score s.notifier.Notify(ctx, notify.EventHumanDetected, info) } } } } func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { ip := extractIP(conn.RemoteAddr()) d := s.authenticator.Authenticate(ip, conn.User(), string(password)) if d.Accepted { s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc() } else { s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc() } s.logger.Info("auth attempt", "remote_addr", conn.RemoteAddr(), "username", conn.User(), "accepted", d.Accepted, "reason", d.Reason, ) if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil { s.logger.Error("failed to record login attempt", "err", err) } if d.Accepted { var perms *ssh.Permissions if d.Shell != "" { perms = &ssh.Permissions{ Extensions: map[string]string{"shell": d.Shell}, } } return perms, nil } return nil, fmt.Errorf("rejected") } func extractIP(addr net.Addr) string { host, _, err := net.SplitHostPort(addr.String()) if err != nil { // Might not have a port, try using the string directly. return addr.String() } // Normalize IPv4-mapped IPv6 addresses. ip := net.ParseIP(host) if ip == nil { return host } if v4 := ip.To4(); v4 != nil { return v4.String() } return ip.String() } func loadOrGenerateHostKey(path string) (ssh.Signer, error) { data, err := os.ReadFile(path) if err == nil { signer, err := ssh.ParsePrivateKey(data) if err != nil { return nil, fmt.Errorf("parsing host key: %w", err) } return signer, nil } if !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("reading host key: %w", err) } // Generate new Ed25519 key. _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, fmt.Errorf("generating key: %w", err) } privBytes, err := ssh.MarshalPrivateKey(priv, "") if err != nil { return nil, fmt.Errorf("marshaling key: %w", err) } pemData := pem.EncodeToMemory(privBytes) if err := os.WriteFile(path, pemData, 0600); err != nil { return nil, fmt.Errorf("writing host key: %w", err) } signer, err := ssh.ParsePrivateKey(pemData) if err != nil { return nil, fmt.Errorf("parsing generated key: %w", err) } slog.Info("generated new host key", "path", path) return signer, nil }