fix: address high-severity security issues from review
- Use subtle.ConstantTimeCompare for static credential checks to prevent timing side-channel attacks - Cap failCounts (100k) and rememberedCreds (10k) maps with eviction to prevent memory exhaustion from botnet-scale scanning - Sweep expired credentials on each auth attempt - Add configurable max_connections (default 500) with semaphore to limit concurrent connections and prevent goroutine/fd exhaustion Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,12 +1,18 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxFailCountEntries = 100000
|
||||||
|
maxRememberedCredentials = 10000
|
||||||
|
)
|
||||||
|
|
||||||
type credKey struct {
|
type credKey struct {
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
@@ -38,9 +44,11 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
|
|||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
defer a.mu.Unlock()
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
// 1. Check static credentials.
|
// 1. Check static credentials (constant-time comparison).
|
||||||
for _, cred := range a.cfg.StaticCredentials {
|
for _, cred := range a.cfg.StaticCredentials {
|
||||||
if cred.Username == username && cred.Password == password {
|
uMatch := subtle.ConstantTimeCompare([]byte(cred.Username), []byte(username))
|
||||||
|
pMatch := subtle.ConstantTimeCompare([]byte(cred.Password), []byte(password))
|
||||||
|
if uMatch == 1 && pMatch == 1 {
|
||||||
a.failCounts[ip] = 0
|
a.failCounts[ip] = 0
|
||||||
return Decision{Accepted: true, Reason: "static_credential"}
|
return Decision{Accepted: true, Reason: "static_credential"}
|
||||||
}
|
}
|
||||||
@@ -57,6 +65,7 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. Increment fail count, check threshold.
|
// 3. Increment fail count, check threshold.
|
||||||
|
a.evictIfNeeded()
|
||||||
a.failCounts[ip]++
|
a.failCounts[ip]++
|
||||||
if a.failCounts[ip] >= a.cfg.AcceptAfter {
|
if a.failCounts[ip] >= a.cfg.AcceptAfter {
|
||||||
a.failCounts[ip] = 0
|
a.failCounts[ip] = 0
|
||||||
@@ -66,3 +75,32 @@ func (a *Authenticator) Authenticate(ip, username, password string) Decision {
|
|||||||
|
|
||||||
return Decision{Accepted: false, Reason: "rejected"}
|
return Decision{Accepted: false, Reason: "rejected"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// evictIfNeeded removes stale entries when maps exceed their size limits.
|
||||||
|
// Must be called with a.mu held.
|
||||||
|
func (a *Authenticator) evictIfNeeded() {
|
||||||
|
now := a.now()
|
||||||
|
|
||||||
|
// Sweep expired remembered credentials.
|
||||||
|
for k, expiry := range a.rememberedCreds {
|
||||||
|
if now.After(expiry) {
|
||||||
|
delete(a.rememberedCreds, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If remembered creds still over limit, drop oldest entries.
|
||||||
|
for len(a.rememberedCreds) > maxRememberedCredentials {
|
||||||
|
for k := range a.rememberedCreds {
|
||||||
|
delete(a.rememberedCreds, k)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If fail counts over limit, drop arbitrary entries.
|
||||||
|
for len(a.failCounts) > maxFailCountEntries {
|
||||||
|
for k := range a.failCounts {
|
||||||
|
delete(a.failCounts, k)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -127,6 +128,31 @@ func TestCounterResetsAfterAcceptance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExpiredCredentialsSweep(t *testing.T) {
|
||||||
|
a := newTestAuth(2, time.Hour)
|
||||||
|
now := time.Now()
|
||||||
|
a.now = func() time.Time { return now }
|
||||||
|
|
||||||
|
// Create several remembered credentials by reaching the threshold.
|
||||||
|
for i := range 5 {
|
||||||
|
ip := fmt.Sprintf("10.0.0.%d", i)
|
||||||
|
a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass")
|
||||||
|
a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a.rememberedCreds) != 5 {
|
||||||
|
t.Fatalf("expected 5 remembered creds, got %d", len(a.rememberedCreds))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance past TTL so all are expired, then trigger sweep.
|
||||||
|
a.now = func() time.Time { return now.Add(2 * time.Hour) }
|
||||||
|
a.Authenticate("99.99.99.99", "trigger", "sweep")
|
||||||
|
|
||||||
|
if len(a.rememberedCreds) != 0 {
|
||||||
|
t.Errorf("expected 0 remembered creds after sweep, got %d", len(a.rememberedCreds))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConcurrentAccess(t *testing.T) {
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
a := newTestAuth(5, time.Hour)
|
a := newTestAuth(5, time.Hour)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|||||||
@@ -15,8 +15,9 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SSHConfig struct {
|
type SSHConfig struct {
|
||||||
ListenAddr string `toml:"listen_addr"`
|
ListenAddr string `toml:"listen_addr"`
|
||||||
HostKeyPath string `toml:"host_key_path"`
|
HostKeyPath string `toml:"host_key_path"`
|
||||||
|
MaxConnections int `toml:"max_connections"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthConfig struct {
|
type AuthConfig struct {
|
||||||
@@ -60,6 +61,9 @@ func applyDefaults(cfg *Config) {
|
|||||||
if cfg.SSH.HostKeyPath == "" {
|
if cfg.SSH.HostKeyPath == "" {
|
||||||
cfg.SSH.HostKeyPath = "oubliette_host_key"
|
cfg.SSH.HostKeyPath = "oubliette_host_key"
|
||||||
}
|
}
|
||||||
|
if cfg.SSH.MaxConnections == 0 {
|
||||||
|
cfg.SSH.MaxConnections = 500
|
||||||
|
}
|
||||||
if cfg.Auth.AcceptAfter == 0 {
|
if cfg.Auth.AcceptAfter == 0 {
|
||||||
cfg.Auth.AcceptAfter = 10
|
cfg.Auth.AcceptAfter = 10
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ type Server struct {
|
|||||||
authenticator *auth.Authenticator
|
authenticator *auth.Authenticator
|
||||||
sshConfig *ssh.ServerConfig
|
sshConfig *ssh.ServerConfig
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
|
connSem chan struct{} // semaphore limiting concurrent connections
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
||||||
@@ -31,6 +32,7 @@ func New(cfg config.Config, logger *slog.Logger) (*Server, error) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
authenticator: auth.NewAuthenticator(cfg.Auth),
|
authenticator: auth.NewAuthenticator(cfg.Auth),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
||||||
}
|
}
|
||||||
|
|
||||||
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
||||||
@@ -70,7 +72,18 @@ func (s *Server) ListenAndServe(ctx context.Context) error {
|
|||||||
s.logger.Error("accept error", "err", err)
|
s.logger.Error("accept error", "err", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go s.handleConn(conn)
|
|
||||||
|
// Enforce max concurrent connections.
|
||||||
|
select {
|
||||||
|
case s.connSem <- struct{}{}:
|
||||||
|
go func() {
|
||||||
|
defer func() { <-s.connSem }()
|
||||||
|
s.handleConn(conn)
|
||||||
|
}()
|
||||||
|
default:
|
||||||
|
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -97,8 +97,9 @@ func TestIntegrationSSHConnect(t *testing.T) {
|
|||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
cfg := config.Config{
|
cfg := config.Config{
|
||||||
SSH: config.SSHConfig{
|
SSH: config.SSHConfig{
|
||||||
ListenAddr: "127.0.0.1:0",
|
ListenAddr: "127.0.0.1:0",
|
||||||
HostKeyPath: filepath.Join(tmpDir, "host_key"),
|
HostKeyPath: filepath.Join(tmpDir, "host_key"),
|
||||||
|
MaxConnections: 100,
|
||||||
},
|
},
|
||||||
Auth: config.AuthConfig{
|
Auth: config.AuthConfig{
|
||||||
AcceptAfter: 2,
|
AcceptAfter: 2,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ log_level = "info"
|
|||||||
[ssh]
|
[ssh]
|
||||||
listen_addr = ":2222"
|
listen_addr = ":2222"
|
||||||
host_key_path = "oubliette_host_key"
|
host_key_path = "oubliette_host_key"
|
||||||
|
max_connections = 500
|
||||||
|
|
||||||
[auth]
|
[auth]
|
||||||
accept_after = 10
|
accept_after = 10
|
||||||
|
|||||||
Reference in New Issue
Block a user