fix: address high-severity security issues from review
- Use subtle.ConstantTimeCompare for static credential checks to prevent timing side-channel attacks - Cap failCounts (100k) and rememberedCreds (10k) maps with eviction to prevent memory exhaustion from botnet-scale scanning - Sweep expired credentials on each auth attempt - Add configurable max_connections (default 500) with semaphore to limit concurrent connections and prevent goroutine/fd exhaustion Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,12 +1,18 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
maxFailCountEntries = 100000
|
||||
maxRememberedCredentials = 10000
|
||||
)
|
||||
|
||||
type credKey struct {
|
||||
Username string
|
||||
Password string
|
||||
@@ -38,9 +44,11 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
// 1. Check static credentials.
|
||||
// 1. Check static credentials (constant-time comparison).
|
||||
for _, cred := range a.cfg.StaticCredentials {
|
||||
if cred.Username == username && cred.Password == password {
|
||||
uMatch := subtle.ConstantTimeCompare([]byte(cred.Username), []byte(username))
|
||||
pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password))
|
||||
if uMatch == 1 && pMatch == 1 {
|
||||
a.failCounts[ip] = 0
|
||||
return Decision{Accepted: true, Reason: "static_credential"}
|
||||
}
|
||||
@@ -57,6 +65,7 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
|
||||
}
|
||||
|
||||
// 3. Increment fail count, check threshold.
|
||||
a.evictIfNeeded()
|
||||
a.failCounts[ip]++
|
||||
if a.failCounts[ip] >= a.cfg.AcceptAfter {
|
||||
a.failCounts[ip] = 0
|
||||
@@ -66,3 +75,32 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
|
||||
|
||||
return Decision{Accepted: false, Reason: "rejected"}
|
||||
}
|
||||
|
||||
// evictIfNeeded removes stale entries when maps exceed their size limits.
|
||||
// Must be called with a.mu held.
|
||||
func (a *Authenticator) evictIfNeeded() {
|
||||
now := a.now()
|
||||
|
||||
// Sweep expired remembered credentials.
|
||||
for k, expiry := range a.rememberedCreds {
|
||||
if now.After(expiry) {
|
||||
delete(a.rememberedCreds, k)
|
||||
}
|
||||
}
|
||||
|
||||
// If remembered creds still over limit, drop oldest entries.
|
||||
for len(a.rememberedCreds) > maxRememberedCredentials {
|
||||
for k := range a.rememberedCreds {
|
||||
delete(a.rememberedCreds, k)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If fail counts over limit, drop arbitrary entries.
|
||||
for len(a.failCounts) > maxFailCountEntries {
|
||||
for k := range a.failCounts {
|
||||
delete(a.failCounts, k)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -127,6 +128,31 @@ func TestCounterResetsAfterAcceptance(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiredCredentialsSweep(t *testing.T) {
|
||||
a := newTestAuth(2, time.Hour)
|
||||
now := time.Now()
|
||||
a.now = func() time.Time { return now }
|
||||
|
||||
// Create several remembered credentials by reaching the threshold.
|
||||
for i := range 5 {
|
||||
ip := fmt.Sprintf("10.0.0.%d", i)
|
||||
a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass")
|
||||
a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass")
|
||||
}
|
||||
|
||||
if len(a.rememberedCreds) != 5 {
|
||||
t.Fatalf("expected 5 remembered creds, got %d", len(a.rememberedCreds))
|
||||
}
|
||||
|
||||
// Advance past TTL so all are expired, then trigger sweep.
|
||||
a.now = func() time.Time { return now.Add(2 * time.Hour) }
|
||||
a.Authenticate("99.99.99.99", "trigger", "sweep")
|
||||
|
||||
if len(a.rememberedCreds) != 0 {
|
||||
t.Errorf("expected 0 remembered creds after sweep, got %d", len(a.rememberedCreds))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
a := newTestAuth(5, time.Hour)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
@@ -15,8 +15,9 @@ type Config struct {
|
||||
}
|
||||
|
||||
type SSHConfig struct {
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
HostKeyPath string `toml:"host_key_path"`
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
HostKeyPath string `toml:"host_key_path"`
|
||||
MaxConnections int `toml:"max_connections"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
@@ -60,6 +61,9 @@ func applyDefaults(cfg *Config) {
|
||||
if cfg.SSH.HostKeyPath == "" {
|
||||
cfg.SSH.HostKeyPath = "oubliette_host_key"
|
||||
}
|
||||
if cfg.SSH.MaxConnections == 0 {
|
||||
cfg.SSH.MaxConnections = 500
|
||||
}
|
||||
if cfg.Auth.AcceptAfter == 0 {
|
||||
cfg.Auth.AcceptAfter = 10
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type Server struct {
|
||||
authenticator *auth.Authenticator
|
||||
sshConfig *ssh.ServerConfig
|
||||
logger *slog.Logger
|
||||
connSem chan struct{} // semaphore limiting concurrent connections
|
||||
}
|
||||
|
||||
func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
||||
@@ -31,6 +32,7 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
||||
cfg: cfg,
|
||||
authenticator: auth.NewAuthenticator(cfg.Auth),
|
||||
logger: logger,
|
||||
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
||||
}
|
||||
|
||||
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
||||
@@ -70,7 +72,18 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
|
||||
s.logger.Error("accept error", "err", err)
|
||||
continue
|
||||
}
|
||||
go s.handleConn(conn)
|
||||
|
||||
// Enforce max concurrent connections.
|
||||
select {
|
||||
case s.connSem <- struct{}{}:
|
||||
go func() {
|
||||
defer func() { <-s.connSem }()
|
||||
s.handleConn(conn)
|
||||
}()
|
||||
default:
|
||||
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -97,8 +97,9 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
cfg := config.Config{
|
||||
SSH: config.SSHConfig{
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
HostKeyPath: filepath.Join(tmpDir, "host_key"),
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
HostKeyPath: filepath.Join(tmpDir, "host_key"),
|
||||
MaxConnections: 100,
|
||||
},
|
||||
Auth: config.AuthConfig{
|
||||
AcceptAfter: 2,
|
||||
|
||||
Reference in New Issue
Block a user