fix: address high-severity security issues from review

- Use subtle.ConstantTimeCompare for static credential checks to
  prevent timing side-channel attacks
- Cap failCounts (100k) and rememberedCreds (10k) maps with eviction
  to prevent memory exhaustion from botnet-scale scanning
- Sweep expired credentials on each auth attempt
- Add configurable max_connections (default 500) with semaphore to
  limit concurrent connections and prevent goroutine/fd exhaustion

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 16:41:23 +01:00
parent 51fdea0c2f
commit a40110f2f5
6 changed files with 90 additions and 7 deletions

View File

@@ -1,12 +1,18 @@
package auth package auth
import ( import (
"crypto/subtle"
"sync" "sync"
"time" "time"
"git.t-juice.club/torjus/oubliette/internal/config" "git.t-juice.club/torjus/oubliette/internal/config"
) )
const (
maxFailCountEntries = 100000
maxRememberedCredentials = 10000
)
type credKey struct { type credKey struct {
Username string Username string
Password string Password string
@@ -38,9 +44,11 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
// 1. Check static credentials. // 1. Check static credentials (constant-time comparison).
for _, cred := range a.cfg.StaticCredentials { for _, cred := range a.cfg.StaticCredentials {
if cred.Username == username && cred.Password == password { uMatch := subtle.ConstantTimeCompare([]byte(cred.Username), []byte(username))
pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password))
if uMatch == 1 && pMatch == 1 {
a.failCounts[ip] = 0 a.failCounts[ip] = 0
return Decision{Accepted: true, Reason: "static_credential"} return Decision{Accepted: true, Reason: "static_credential"}
} }
@@ -57,6 +65,7 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
} }
// 3. Increment fail count, check threshold. // 3. Increment fail count, check threshold.
a.evictIfNeeded()
a.failCounts[ip]++ a.failCounts[ip]++
if a.failCounts[ip] >= a.cfg.AcceptAfter { if a.failCounts[ip] >= a.cfg.AcceptAfter {
a.failCounts[ip] = 0 a.failCounts[ip] = 0
@@ -66,3 +75,32 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
return Decision{Accepted: false, Reason: "rejected"} return Decision{Accepted: false, Reason: "rejected"}
} }
// evictIfNeeded removes stale entries when maps exceed their size limits.
// Must be called with a.mu held.
func (a *Authenticator) evictIfNeeded() {
now := a.now()
// Sweep expired remembered credentials.
for k, expiry := range a.rememberedCreds {
if now.After(expiry) {
delete(a.rememberedCreds, k)
}
}
// If remembered creds still over limit, drop oldest entries.
for len(a.rememberedCreds) > maxRememberedCredentials {
for k := range a.rememberedCreds {
delete(a.rememberedCreds, k)
break
}
}
// If fail counts over limit, drop arbitrary entries.
for len(a.failCounts) > maxFailCountEntries {
for k := range a.failCounts {
delete(a.failCounts, k)
break
}
}
}

View File

@@ -1,6 +1,7 @@
package auth package auth
import ( import (
"fmt"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -127,6 +128,31 @@ func TestCounterResetsAfterAcceptance(t *testing.T) {
} }
} }
func TestExpiredCredentialsSweep(t *testing.T) {
a := newTestAuth(2, time.Hour)
now := time.Now()
a.now = func() time.Time { return now }
// Create several remembered credentials by reaching the threshold.
for i := range 5 {
ip := fmt.Sprintf("10.0.0.%d", i)
a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass")
a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass")
}
if len(a.rememberedCreds) != 5 {
t.Fatalf("expected 5 remembered creds, got %d", len(a.rememberedCreds))
}
// Advance past TTL so all are expired, then trigger sweep.
a.now = func() time.Time { return now.Add(2 * time.Hour) }
a.Authenticate("99.99.99.99", "trigger", "sweep")
if len(a.rememberedCreds) != 0 {
t.Errorf("expected 0 remembered creds after sweep, got %d", len(a.rememberedCreds))
}
}
func TestConcurrentAccess(t *testing.T) { func TestConcurrentAccess(t *testing.T) {
a := newTestAuth(5, time.Hour) a := newTestAuth(5, time.Hour)
var wg sync.WaitGroup var wg sync.WaitGroup

View File

@@ -17,6 +17,7 @@ type Config struct {
type SSHConfig struct { type SSHConfig struct {
ListenAddr string `toml:"listen_addr"` ListenAddr string `toml:"listen_addr"`
HostKeyPath string `toml:"host_key_path"` HostKeyPath string `toml:"host_key_path"`
MaxConnections int `toml:"max_connections"`
} }
type AuthConfig struct { type AuthConfig struct {
@@ -60,6 +61,9 @@ func applyDefaults(cfg *Config) {
if cfg.SSH.HostKeyPath == "" { if cfg.SSH.HostKeyPath == "" {
cfg.SSH.HostKeyPath = "oubliette_host_key" cfg.SSH.HostKeyPath = "oubliette_host_key"
} }
if cfg.SSH.MaxConnections == 0 {
cfg.SSH.MaxConnections = 500
}
if cfg.Auth.AcceptAfter == 0 { if cfg.Auth.AcceptAfter == 0 {
cfg.Auth.AcceptAfter = 10 cfg.Auth.AcceptAfter = 10
} }

View File

@@ -24,6 +24,7 @@ type Server struct {
authenticator *auth.Authenticator authenticator *auth.Authenticator
sshConfig *ssh.ServerConfig sshConfig *ssh.ServerConfig
logger *slog.Logger logger *slog.Logger
connSem chan struct{} // semaphore limiting concurrent connections
} }
func New(cfg config.Config, logger *slog.Logger) (*Server, error) { func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
@@ -31,6 +32,7 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
cfg: cfg, cfg: cfg,
authenticator: auth.NewAuthenticator(cfg.Auth), authenticator: auth.NewAuthenticator(cfg.Auth),
logger: logger, logger: logger,
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
} }
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath) hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
@@ -70,7 +72,18 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
s.logger.Error("accept error", "err", err) s.logger.Error("accept error", "err", err)
continue continue
} }
go s.handleConn(conn)
// Enforce max concurrent connections.
select {
case s.connSem <- struct{}{}:
go func() {
defer func() { <-s.connSem }()
s.handleConn(conn)
}()
default:
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
conn.Close()
}
} }
} }

View File

@@ -99,6 +99,7 @@ func TestIntegrationSSHConnect(t *testing.T) {
SSH: config.SSHConfig{ SSH: config.SSHConfig{
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
HostKeyPath: filepath.Join(tmpDir, "host_key"), HostKeyPath: filepath.Join(tmpDir, "host_key"),
MaxConnections: 100,
}, },
Auth: config.AuthConfig{ Auth: config.AuthConfig{
AcceptAfter: 2, AcceptAfter: 2,

View File

@@ -3,6 +3,7 @@ log_level = "info"
[ssh] [ssh]
listen_addr = ":2222" listen_addr = ":2222"
host_key_path = "oubliette_host_key" host_key_path = "oubliette_host_key"
max_connections = 500
[auth] [auth]
accept_after = 10 accept_after = 10