- Use subtle.ConstantTimeCompare for static credential checks to prevent timing side-channel attacks - Cap failCounts (100k) and rememberedCreds (10k) maps with eviction to prevent memory exhaustion from botnet-scale scanning - Sweep expired credentials on each auth attempt - Add configurable max_connections (default 500) with semaphore to limit concurrent connections and prevent goroutine/fd exhaustion Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
103 lines
2.3 KiB
Go
103 lines
2.3 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/BurntSushi/toml"
|
|
)
|
|
|
|
type Config struct {
|
|
SSH SSHConfig `toml:"ssh"`
|
|
Auth AuthConfig `toml:"auth"`
|
|
LogLevel string `toml:"log_level"`
|
|
}
|
|
|
|
type SSHConfig struct {
|
|
ListenAddr string `toml:"listen_addr"`
|
|
HostKeyPath string `toml:"host_key_path"`
|
|
MaxConnections int `toml:"max_connections"`
|
|
}
|
|
|
|
type AuthConfig struct {
|
|
AcceptAfter int `toml:"accept_after"`
|
|
CredentialTTL string `toml:"credential_ttl"`
|
|
StaticCredentials []Credential `toml:"static_credentials"`
|
|
|
|
// Parsed duration, not from TOML directly.
|
|
CredentialTTLDuration time.Duration `toml:"-"`
|
|
}
|
|
|
|
type Credential struct {
|
|
Username string `toml:"username"`
|
|
Password string `toml:"password"`
|
|
}
|
|
|
|
func Load(path string) (*Config, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading config: %w", err)
|
|
}
|
|
|
|
cfg := &Config{}
|
|
if err := toml.Unmarshal(data, cfg); err != nil {
|
|
return nil, fmt.Errorf("parsing config: %w", err)
|
|
}
|
|
|
|
applyDefaults(cfg)
|
|
|
|
if err := validate(cfg); err != nil {
|
|
return nil, fmt.Errorf("validating config: %w", err)
|
|
}
|
|
|
|
return cfg, nil
|
|
}
|
|
|
|
func applyDefaults(cfg *Config) {
|
|
if cfg.SSH.ListenAddr == "" {
|
|
cfg.SSH.ListenAddr = ":2222"
|
|
}
|
|
if cfg.SSH.HostKeyPath == "" {
|
|
cfg.SSH.HostKeyPath = "oubliette_host_key"
|
|
}
|
|
if cfg.SSH.MaxConnections == 0 {
|
|
cfg.SSH.MaxConnections = 500
|
|
}
|
|
if cfg.Auth.AcceptAfter == 0 {
|
|
cfg.Auth.AcceptAfter = 10
|
|
}
|
|
if cfg.Auth.CredentialTTL == "" {
|
|
cfg.Auth.CredentialTTL = "24h"
|
|
}
|
|
if cfg.LogLevel == "" {
|
|
cfg.LogLevel = "info"
|
|
}
|
|
}
|
|
|
|
func validate(cfg *Config) error {
|
|
d, err := time.ParseDuration(cfg.Auth.CredentialTTL)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid credential_ttl %q: %w", cfg.Auth.CredentialTTL, err)
|
|
}
|
|
if d <= 0 {
|
|
return fmt.Errorf("credential_ttl must be positive, got %s", d)
|
|
}
|
|
cfg.Auth.CredentialTTLDuration = d
|
|
|
|
if cfg.Auth.AcceptAfter < 1 {
|
|
return fmt.Errorf("accept_after must be at least 1, got %d", cfg.Auth.AcceptAfter)
|
|
}
|
|
|
|
for i, cred := range cfg.Auth.StaticCredentials {
|
|
if cred.Username == "" {
|
|
return fmt.Errorf("static_credentials[%d]: username must not be empty", i)
|
|
}
|
|
if cred.Password == "" {
|
|
return fmt.Errorf("static_credentials[%d]: password must not be empty", i)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|