Update Go module path and all import references to reflect the migration from Gitea (git.t-juice.club) to Forgejo (code.t-juice.club). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
501 lines
14 KiB
Go
501 lines
14 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
|
|
"code.t-juice.club/torjus/oubliette/internal/auth"
|
|
"code.t-juice.club/torjus/oubliette/internal/config"
|
|
"code.t-juice.club/torjus/oubliette/internal/detection"
|
|
"code.t-juice.club/torjus/oubliette/internal/geoip"
|
|
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
|
"code.t-juice.club/torjus/oubliette/internal/notify"
|
|
"code.t-juice.club/torjus/oubliette/internal/shell"
|
|
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
|
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
|
|
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
|
|
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
|
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
|
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
|
|
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
|
|
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
|
|
"code.t-juice.club/torjus/oubliette/internal/storage"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type Server struct {
|
|
cfg config.Config
|
|
store storage.Store
|
|
authenticator *auth.Authenticator
|
|
sshConfig *ssh.ServerConfig
|
|
logger *slog.Logger
|
|
connSem chan struct{} // semaphore limiting concurrent connections
|
|
shellRegistry *shell.Registry
|
|
notifier notify.Sender
|
|
metrics *metrics.Metrics
|
|
geoip *geoip.Reader
|
|
}
|
|
|
|
func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics.Metrics) (*Server, error) {
|
|
registry := shell.NewRegistry()
|
|
if err := registry.Register(bash.NewBashShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering bash shell: %w", err)
|
|
}
|
|
if err := registry.Register(fridge.NewFridgeShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering fridge shell: %w", err)
|
|
}
|
|
if err := registry.Register(banking.NewBankingShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering banking shell: %w", err)
|
|
}
|
|
if err := registry.Register(adventure.NewAdventureShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering adventure shell: %w", err)
|
|
}
|
|
if err := registry.Register(cisco.NewCiscoShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering cisco shell: %w", err)
|
|
}
|
|
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering psql shell: %w", err)
|
|
}
|
|
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering roomba shell: %w", err)
|
|
}
|
|
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
|
|
return nil, fmt.Errorf("registering tetris shell: %w", err)
|
|
}
|
|
|
|
geo, err := geoip.New()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("opening geoip database: %w", err)
|
|
}
|
|
|
|
s := &Server{
|
|
cfg: cfg,
|
|
store: store,
|
|
authenticator: auth.NewAuthenticator(cfg.Auth),
|
|
logger: logger,
|
|
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
|
shellRegistry: registry,
|
|
notifier: notify.NewSender(cfg.Notify.Webhooks, logger),
|
|
metrics: m,
|
|
geoip: geo,
|
|
}
|
|
|
|
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("host key: %w", err)
|
|
}
|
|
|
|
s.sshConfig = &ssh.ServerConfig{
|
|
PasswordCallback: s.passwordCallback,
|
|
ServerVersion: "SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.6",
|
|
}
|
|
s.sshConfig.AddHostKey(hostKey)
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Server) ListenAndServe(ctx context.Context) error {
|
|
defer s.geoip.Close()
|
|
|
|
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen: %w", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
s.logger.Info("SSH server listening", "addr", s.cfg.SSH.ListenAddr)
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
listener.Close()
|
|
}()
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil
|
|
}
|
|
s.logger.Error("accept error", "err", err)
|
|
continue
|
|
}
|
|
|
|
// Enforce max concurrent connections.
|
|
select {
|
|
case s.connSem <- struct{}{}:
|
|
s.metrics.SSHConnectionsActive.Inc()
|
|
go func() {
|
|
defer func() {
|
|
<-s.connSem
|
|
s.metrics.SSHConnectionsActive.Dec()
|
|
}()
|
|
s.handleConn(conn)
|
|
}()
|
|
default:
|
|
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_max_connections").Inc()
|
|
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
|
conn.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleConn(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
|
if err != nil {
|
|
s.metrics.SSHConnectionsTotal.WithLabelValues("rejected_handshake").Inc()
|
|
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
|
|
return
|
|
}
|
|
defer sshConn.Close()
|
|
|
|
s.metrics.SSHConnectionsTotal.WithLabelValues("accepted").Inc()
|
|
s.logger.Info("SSH connection established",
|
|
"remote_addr", sshConn.RemoteAddr(),
|
|
"user", sshConn.User(),
|
|
)
|
|
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
for newChan := range chans {
|
|
if newChan.ChannelType() != "session" {
|
|
newChan.Reject(ssh.UnknownChannelType, "unknown channel type")
|
|
continue
|
|
}
|
|
|
|
channel, requests, err := newChan.Accept()
|
|
if err != nil {
|
|
s.logger.Error("channel accept error", "err", err)
|
|
return
|
|
}
|
|
|
|
go s.handleSession(channel, requests, sshConn)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
|
defer channel.Close()
|
|
|
|
// Select a shell from the registry.
|
|
// If the auth layer specified a shell preference, use it; otherwise random.
|
|
var selectedShell shell.Shell
|
|
if conn.Permissions != nil && conn.Permissions.Extensions["shell"] != "" {
|
|
shellName := conn.Permissions.Extensions["shell"]
|
|
sh, ok := s.shellRegistry.Get(shellName)
|
|
if ok {
|
|
selectedShell = sh
|
|
} else {
|
|
s.logger.Warn("configured shell not found, falling back to random", "shell", shellName)
|
|
}
|
|
}
|
|
// Second priority: username-based route.
|
|
if selectedShell == nil {
|
|
if shellName, ok := s.cfg.Shell.UsernameRoutes[conn.User()]; ok {
|
|
sh, found := s.shellRegistry.Get(shellName)
|
|
if found {
|
|
selectedShell = sh
|
|
} else {
|
|
s.logger.Warn("username route shell not found, falling back to random", "shell", shellName, "user", conn.User())
|
|
}
|
|
}
|
|
}
|
|
// Lowest priority: random selection.
|
|
if selectedShell == nil {
|
|
var err error
|
|
selectedShell, err = s.shellRegistry.Select()
|
|
if err != nil {
|
|
s.logger.Error("failed to select shell", "err", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
ip := extractIP(conn.RemoteAddr())
|
|
country := s.geoip.Lookup(ip)
|
|
sessionStart := time.Now()
|
|
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), selectedShell.Name(), country)
|
|
if err != nil {
|
|
s.logger.Error("failed to create session", "err", err)
|
|
} else {
|
|
s.metrics.SessionsTotal.WithLabelValues(selectedShell.Name()).Inc()
|
|
s.metrics.SessionsActive.Inc()
|
|
defer func() {
|
|
s.metrics.SessionsActive.Dec()
|
|
s.metrics.SessionDuration.Observe(time.Since(sessionStart).Seconds())
|
|
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
|
|
s.logger.Error("failed to end session", "err", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
s.logger.Info("session started",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"user", conn.User(),
|
|
"shell", selectedShell.Name(),
|
|
"session_id", sessionID,
|
|
)
|
|
|
|
// Send session_started notification.
|
|
connectedAt := time.Now()
|
|
sessionInfo := notify.SessionInfo{
|
|
ID: sessionID,
|
|
IP: ip,
|
|
Username: conn.User(),
|
|
ShellName: selectedShell.Name(),
|
|
ConnectedAt: notify.FormatConnectedAt(connectedAt),
|
|
}
|
|
s.notifier.Notify(context.Background(), notify.EventSessionStarted, sessionInfo)
|
|
defer s.notifier.CleanupSession(sessionID)
|
|
|
|
// Handle session requests (pty-req, shell, exec, etc.)
|
|
execCh := make(chan string, 1)
|
|
go func() {
|
|
defer close(execCh)
|
|
for req := range requests {
|
|
switch req.Type {
|
|
case "pty-req", "shell":
|
|
if req.WantReply {
|
|
req.Reply(true, nil)
|
|
}
|
|
case "exec":
|
|
if req.WantReply {
|
|
req.Reply(true, nil)
|
|
}
|
|
var payload struct{ Command string }
|
|
if err := ssh.Unmarshal(req.Payload, &payload); err == nil {
|
|
execCh <- payload.Command
|
|
}
|
|
default:
|
|
if req.WantReply {
|
|
req.Reply(false, nil)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Check for exec request before proceeding to interactive shell.
|
|
select {
|
|
case cmd, ok := <-execCh:
|
|
if ok && cmd != "" {
|
|
s.logger.Info("exec command received",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"user", conn.User(),
|
|
"session_id", sessionID,
|
|
"command", cmd,
|
|
)
|
|
if err := s.store.SetExecCommand(context.Background(), sessionID, cmd); err != nil {
|
|
s.logger.Error("failed to set exec command", "err", err, "session_id", sessionID)
|
|
}
|
|
s.metrics.ExecCommandsTotal.Inc()
|
|
// Send exit-status 0 and close channel.
|
|
exitPayload := make([]byte, 4) // uint32(0)
|
|
_, _ = channel.SendRequest("exit-status", false, exitPayload)
|
|
return
|
|
}
|
|
case <-time.After(500 * time.Millisecond):
|
|
// No exec request within timeout — proceed with interactive shell.
|
|
}
|
|
|
|
// Build session context.
|
|
var shellCfg map[string]any
|
|
if s.cfg.Shell.Shells != nil {
|
|
shellCfg = s.cfg.Shell.Shells[selectedShell.Name()]
|
|
}
|
|
sessCtx := &shell.SessionContext{
|
|
SessionID: sessionID,
|
|
Username: conn.User(),
|
|
RemoteAddr: ip,
|
|
ClientVersion: string(conn.ClientVersion()),
|
|
Store: s.store,
|
|
ShellConfig: shellCfg,
|
|
CommonConfig: shell.ShellCommonConfig{
|
|
Hostname: s.cfg.Shell.Hostname,
|
|
Banner: s.cfg.Shell.Banner,
|
|
FakeUser: s.cfg.Shell.FakeUser,
|
|
},
|
|
OnCommand: func(sh string) {
|
|
s.metrics.CommandsExecuted.WithLabelValues(sh).Inc()
|
|
},
|
|
}
|
|
|
|
// Wrap channel in RecordingChannel.
|
|
recorder := shell.NewRecordingChannel(channel)
|
|
|
|
// Always record session events for replay.
|
|
eventRec := shell.NewEventRecorder(sessionID, s.store, s.logger)
|
|
eventRec.Start(context.Background())
|
|
defer eventRec.Close()
|
|
recorder.AddCallback(eventRec.RecordEvent)
|
|
|
|
// Set up detection scorer if enabled.
|
|
var scorer *detection.Scorer
|
|
var scoreCancel context.CancelFunc
|
|
if s.cfg.Detection.Enabled {
|
|
scorer = detection.NewScorer()
|
|
recorder.AddCallback(func(ts time.Time, direction int, data []byte) {
|
|
scorer.RecordEvent(ts, direction, data)
|
|
})
|
|
|
|
var scoreCtx context.Context
|
|
scoreCtx, scoreCancel = context.WithCancel(context.Background())
|
|
go s.runScoreUpdater(scoreCtx, sessionID, scorer, sessionInfo)
|
|
}
|
|
|
|
if err := selectedShell.Handle(context.Background(), sessCtx, recorder); err != nil {
|
|
s.logger.Error("shell error", "err", err, "session_id", sessionID)
|
|
}
|
|
|
|
// Stop score updater and write final score.
|
|
if scoreCancel != nil {
|
|
scoreCancel()
|
|
}
|
|
if scorer != nil {
|
|
finalScore := scorer.Score()
|
|
s.metrics.HumanScore.Observe(finalScore)
|
|
if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil {
|
|
s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID)
|
|
}
|
|
s.logger.Info("session ended",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"user", conn.User(),
|
|
"session_id", sessionID,
|
|
"human_score", finalScore,
|
|
)
|
|
} else {
|
|
s.logger.Info("session ended",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"user", conn.User(),
|
|
"session_id", sessionID,
|
|
)
|
|
}
|
|
}
|
|
|
|
// runScoreUpdater periodically computes the human score, writes it to the DB,
|
|
// and triggers a notification if the threshold is crossed.
|
|
func (s *Server) runScoreUpdater(ctx context.Context, sessionID string, scorer *detection.Scorer, sessionInfo notify.SessionInfo) {
|
|
ticker := time.NewTicker(s.cfg.Detection.UpdateIntervalDuration)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
score := scorer.Score()
|
|
if err := s.store.UpdateHumanScore(ctx, sessionID, score); err != nil {
|
|
s.logger.Error("failed to update human score", "err", err, "session_id", sessionID)
|
|
continue
|
|
}
|
|
s.logger.Debug("human score updated", "session_id", sessionID, "score", score)
|
|
|
|
if score >= s.cfg.Detection.Threshold {
|
|
info := sessionInfo
|
|
info.HumanScore = score
|
|
s.notifier.Notify(ctx, notify.EventHumanDetected, info)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
ip := extractIP(conn.RemoteAddr())
|
|
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
|
|
|
|
if d.Accepted {
|
|
s.metrics.AuthAttemptsTotal.WithLabelValues("accepted", d.Reason).Inc()
|
|
} else {
|
|
s.metrics.AuthAttemptsTotal.WithLabelValues("rejected", d.Reason).Inc()
|
|
}
|
|
|
|
s.logger.Info("auth attempt",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"username", conn.User(),
|
|
"accepted", d.Accepted,
|
|
"reason", d.Reason,
|
|
)
|
|
|
|
country := s.geoip.Lookup(ip)
|
|
if country != "" {
|
|
s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc()
|
|
}
|
|
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil {
|
|
s.logger.Error("failed to record login attempt", "err", err)
|
|
}
|
|
|
|
if d.Accepted {
|
|
var perms *ssh.Permissions
|
|
if d.Shell != "" {
|
|
perms = &ssh.Permissions{
|
|
Extensions: map[string]string{"shell": d.Shell},
|
|
}
|
|
}
|
|
return perms, nil
|
|
}
|
|
return nil, fmt.Errorf("rejected")
|
|
}
|
|
|
|
func extractIP(addr net.Addr) string {
|
|
host, _, err := net.SplitHostPort(addr.String())
|
|
if err != nil {
|
|
// Might not have a port, try using the string directly.
|
|
return addr.String()
|
|
}
|
|
// Normalize IPv4-mapped IPv6 addresses.
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
return host
|
|
}
|
|
if v4 := ip.To4(); v4 != nil {
|
|
return v4.String()
|
|
}
|
|
return ip.String()
|
|
}
|
|
|
|
func loadOrGenerateHostKey(path string) (ssh.Signer, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err == nil {
|
|
signer, err := ssh.ParsePrivateKey(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing host key: %w", err)
|
|
}
|
|
return signer, nil
|
|
}
|
|
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
return nil, fmt.Errorf("reading host key: %w", err)
|
|
}
|
|
|
|
// Generate new Ed25519 key.
|
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating key: %w", err)
|
|
}
|
|
|
|
privBytes, err := ssh.MarshalPrivateKey(priv, "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshaling key: %w", err)
|
|
}
|
|
|
|
pemData := pem.EncodeToMemory(privBytes)
|
|
if err := os.WriteFile(path, pemData, 0600); err != nil {
|
|
return nil, fmt.Errorf("writing host key: %w", err)
|
|
}
|
|
|
|
signer, err := ssh.ParsePrivateKey(pemData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing generated key: %w", err)
|
|
}
|
|
|
|
slog.Info("generated new host key", "path", path)
|
|
return signer, nil
|
|
}
|
|
|