From a40110f2f5f52509ff79a664395ab6a0f1e3e397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torjus=20H=C3=A5kestad?= Date: Sat, 14 Feb 2026 16:41:23 +0100 Subject: [PATCH] fix: address high-severity security issues from review - Use subtle.ConstantTimeCompare for static credential checks to prevent timing side-channel attacks - Cap failCounts (100k) and rememberedCreds (10k) maps with eviction to prevent memory exhaustion from botnet-scale scanning - Sweep expired credentials on each auth attempt - Add configurable max_connections (default 500) with semaphore to limit concurrent connections and prevent goroutine/fd exhaustion Co-Authored-By: Claude Opus 4.6 --- internal/auth/auth.go | 42 ++++++++++++++++++++++++++++++++-- internal/auth/auth_test.go | 26 +++++++++++++++++++++ internal/config/config.go | 8 +++++-- internal/server/server.go | 15 +++++++++++- internal/server/server_test.go | 5 ++-- oubliette.toml.example | 1 + 6 files changed, 90 insertions(+), 7 deletions(-) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 518d343..f47b49a 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,12 +1,18 @@ package auth import ( + "crypto/subtle" "sync" "time" "git.t-juice.club/torjus/oubliette/internal/config" ) +const ( + maxFailCountEntries = 100000 + maxRememberedCredentials = 10000 +) + type credKey struct { Username string Password string @@ -38,9 +44,11 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision { a.mu.Lock() defer a.mu.Unlock() - // 1. Check static credentials. + // 1. Check static credentials (constant-time comparison). for _, cred := range a.cfg.StaticCredentials { - if cred.Username == username && cred.Password == password { + uMatch := subtle.ConstantTimeCompare([]byte(cred.Username), []byte(username)) + pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password)) + if uMatch == 1 && pMatch == 1 { a.failCounts[ip] = 0 return Decision{Accepted: true, Reason: "static_credential"} } @@ -57,6 +65,7 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision { } // 3. Increment fail count, check threshold. + a.evictIfNeeded() a.failCounts[ip]++ if a.failCounts[ip] >= a.cfg.AcceptAfter { a.failCounts[ip] = 0 @@ -66,3 +75,32 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision { return Decision{Accepted: false, Reason: "rejected"} } + +// evictIfNeeded removes stale entries when maps exceed their size limits. +// Must be called with a.mu held. +func (a *Authenticator) evictIfNeeded() { + now := a.now() + + // Sweep expired remembered credentials. + for k, expiry := range a.rememberedCreds { + if now.After(expiry) { + delete(a.rememberedCreds, k) + } + } + + // If remembered creds still over limit, drop oldest entries. + for len(a.rememberedCreds) > maxRememberedCredentials { + for k := range a.rememberedCreds { + delete(a.rememberedCreds, k) + break + } + } + + // If fail counts over limit, drop arbitrary entries. + for len(a.failCounts) > maxFailCountEntries { + for k := range a.failCounts { + delete(a.failCounts, k) + break + } + } +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index 00406c3..c083db3 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -1,6 +1,7 @@ package auth import ( + "fmt" "sync" "testing" "time" @@ -127,6 +128,31 @@ func TestCounterResetsAfterAcceptance(t *testing.T) { } } +func TestExpiredCredentialsSweep(t *testing.T) { + a := newTestAuth(2, time.Hour) + now := time.Now() + a.now = func() time.Time { return now } + + // Create several remembered credentials by reaching the threshold. + for i := range 5 { + ip := fmt.Sprintf("10.0.0.%d", i) + a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass") + a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass") + } + + if len(a.rememberedCreds) != 5 { + t.Fatalf("expected 5 remembered creds, got %d", len(a.rememberedCreds)) + } + + // Advance past TTL so all are expired, then trigger sweep. + a.now = func() time.Time { return now.Add(2 * time.Hour) } + a.Authenticate("99.99.99.99", "trigger", "sweep") + + if len(a.rememberedCreds) != 0 { + t.Errorf("expected 0 remembered creds after sweep, got %d", len(a.rememberedCreds)) + } +} + func TestConcurrentAccess(t *testing.T) { a := newTestAuth(5, time.Hour) var wg sync.WaitGroup diff --git a/internal/config/config.go b/internal/config/config.go index 1098fb9..4155b92 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,8 +15,9 @@ type Config struct { } type SSHConfig struct { - ListenAddr string `toml:"listen_addr"` - HostKeyPath string `toml:"host_key_path"` + ListenAddr string `toml:"listen_addr"` + HostKeyPath string `toml:"host_key_path"` + MaxConnections int `toml:"max_connections"` } type AuthConfig struct { @@ -60,6 +61,9 @@ func applyDefaults(cfg *Config) { if cfg.SSH.HostKeyPath == "" { cfg.SSH.HostKeyPath = "oubliette_host_key" } + if cfg.SSH.MaxConnections == 0 { + cfg.SSH.MaxConnections = 500 + } if cfg.Auth.AcceptAfter == 0 { cfg.Auth.AcceptAfter = 10 } diff --git a/internal/server/server.go b/internal/server/server.go index 91212c7..17c9540 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -24,6 +24,7 @@ type Server struct { authenticator *auth.Authenticator sshConfig *ssh.ServerConfig logger *slog.Logger + connSem chan struct{} // semaphore limiting concurrent connections } func New(cfg config.Config, logger *slog.Logger) (*Server, error) { @@ -31,6 +32,7 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) { cfg: cfg, authenticator: auth.NewAuthenticator(cfg.Auth), logger: logger, + connSem: make(chan struct{}, cfg.SSH.MaxConnections), } hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath) @@ -70,7 +72,18 @@ func (s *Server) ListenAndServe(ctx context.Context) error { s.logger.Error("accept error", "err", err) continue } - go s.handleConn(conn) + + // Enforce max concurrent connections. + select { + case s.connSem <- struct{}{}: + go func() { + defer func() { <-s.connSem }() + s.handleConn(conn) + }() + default: + s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr()) + conn.Close() + } } } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index ff8d303..d8a83d1 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -97,8 +97,9 @@ func TestIntegrationSSHConnect(t *testing.T) { tmpDir := t.TempDir() cfg := config.Config{ SSH: config.SSHConfig{ - ListenAddr: "127.0.0.1:0", - HostKeyPath: filepath.Join(tmpDir, "host_key"), + ListenAddr: "127.0.0.1:0", + HostKeyPath: filepath.Join(tmpDir, "host_key"), + MaxConnections: 100, }, Auth: config.AuthConfig{ AcceptAfter: 2, diff --git a/oubliette.toml.example b/oubliette.toml.example index 6b5b64a..c976f3d 100644 --- a/oubliette.toml.example +++ b/oubliette.toml.example @@ -3,6 +3,7 @@ log_level = "info" [ssh] listen_addr = ":2222" host_key_path = "oubliette_host_key" +max_connections = 500 [auth] accept_after = 10