package config import ( "fmt" "os" "time" "github.com/BurntSushi/toml" ) type Config struct { SSH SSHConfig `toml:"ssh"` Auth AuthConfig `toml:"auth"` Storage StorageConfig `toml:"storage"` Shell ShellConfig `toml:"shell"` Web WebConfig `toml:"web"` Detection DetectionConfig `toml:"detection"` Notify NotifyConfig `toml:"notify"` LogLevel string `toml:"log_level"` LogFormat string `toml:"log_format"` // "text" (default) or "json" } type WebConfig struct { Enabled bool `toml:"enabled"` ListenAddr string `toml:"listen_addr"` MetricsEnabled *bool `toml:"metrics_enabled"` MetricsToken string `toml:"metrics_token"` } type ShellConfig struct { Hostname string `toml:"hostname"` Banner string `toml:"banner"` FakeUser string `toml:"fake_user"` UsernameRoutes map[string]string `toml:"username_routes"` Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually } type StorageConfig struct { DBPath string `toml:"db_path"` RetentionDays int `toml:"retention_days"` RetentionInterval string `toml:"retention_interval"` // Parsed duration, not from TOML directly. RetentionIntervalDuration time.Duration `toml:"-"` } type SSHConfig struct { ListenAddr string `toml:"listen_addr"` HostKeyPath string `toml:"host_key_path"` MaxConnections int `toml:"max_connections"` } type AuthConfig struct { AcceptAfter int `toml:"accept_after"` CredentialTTL string `toml:"credential_ttl"` StaticCredentials []Credential `toml:"static_credentials"` // Parsed duration, not from TOML directly. CredentialTTLDuration time.Duration `toml:"-"` } type Credential struct { Username string `toml:"username"` Password string `toml:"password"` Shell string `toml:"shell"` // optional: route to specific shell (empty = random) } type DetectionConfig struct { Enabled bool `toml:"enabled"` Threshold float64 `toml:"threshold"` UpdateInterval string `toml:"update_interval"` // Parsed duration, not from TOML directly. UpdateIntervalDuration time.Duration `toml:"-"` } type NotifyConfig struct { Webhooks []WebhookNotifyConfig `toml:"webhooks"` } type WebhookNotifyConfig struct { URL string `toml:"url"` Headers map[string]string `toml:"headers"` Events []string `toml:"events"` // empty = all events } func Load(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("reading config: %w", err) } cfg := &Config{} if err := toml.Unmarshal(data, cfg); err != nil { return nil, fmt.Errorf("parsing config: %w", err) } // Second pass: extract per-shell sub-tables (e.g. [shell.bash]). var raw map[string]any if err := toml.Unmarshal(data, &raw); err == nil { if shellSection, ok := raw["shell"].(map[string]any); ok { cfg.Shell.Shells = extractShellTables(shellSection) } } applyDefaults(cfg) if err := validate(cfg); err != nil { return nil, fmt.Errorf("validating config: %w", err) } return cfg, nil } func applyDefaults(cfg *Config) { if cfg.SSH.ListenAddr == "" { cfg.SSH.ListenAddr = ":2222" } if cfg.SSH.HostKeyPath == "" { cfg.SSH.HostKeyPath = "oubliette_host_key" } if cfg.SSH.MaxConnections == 0 { cfg.SSH.MaxConnections = 500 } if cfg.Auth.AcceptAfter == 0 { cfg.Auth.AcceptAfter = 10 } if cfg.Auth.CredentialTTL == "" { cfg.Auth.CredentialTTL = "24h" } if cfg.LogLevel == "" { cfg.LogLevel = "info" } if cfg.LogFormat == "" { cfg.LogFormat = "text" } if cfg.Storage.DBPath == "" { cfg.Storage.DBPath = "oubliette.db" } if cfg.Storage.RetentionDays == 0 { cfg.Storage.RetentionDays = 90 } if cfg.Storage.RetentionInterval == "" { cfg.Storage.RetentionInterval = "1h" } if cfg.Web.ListenAddr == "" { cfg.Web.ListenAddr = ":8080" } if cfg.Web.MetricsEnabled == nil { t := true cfg.Web.MetricsEnabled = &t } if cfg.Shell.Hostname == "" { cfg.Shell.Hostname = "ubuntu-server" } if cfg.Shell.Banner == "" { cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n" } if cfg.Detection.Threshold == 0 { cfg.Detection.Threshold = 0.6 } if cfg.Detection.UpdateInterval == "" { cfg.Detection.UpdateInterval = "5s" } } // knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables. var knownShellKeys = map[string]bool{ "hostname": true, "banner": true, "fake_user": true, "username_routes": true, } // extractShellTables pulls per-shell config sub-tables from the raw [shell] section. func extractShellTables(section map[string]any) map[string]map[string]any { result := make(map[string]map[string]any) for key, val := range section { if knownShellKeys[key] { continue } if sub, ok := val.(map[string]any); ok { result[key] = sub } } if len(result) == 0 { return nil } return result } func validate(cfg *Config) error { d, err := time.ParseDuration(cfg.Auth.CredentialTTL) if err != nil { return fmt.Errorf("invalid credential_ttl %q: %w", cfg.Auth.CredentialTTL, err) } if d <= 0 { return fmt.Errorf("credential_ttl must be positive, got %s", d) } cfg.Auth.CredentialTTLDuration = d if cfg.Auth.AcceptAfter < 1 { return fmt.Errorf("accept_after must be at least 1, got %d", cfg.Auth.AcceptAfter) } ri, err := time.ParseDuration(cfg.Storage.RetentionInterval) if err != nil { return fmt.Errorf("invalid retention_interval %q: %w", cfg.Storage.RetentionInterval, err) } if ri <= 0 { return fmt.Errorf("retention_interval must be positive, got %s", ri) } cfg.Storage.RetentionIntervalDuration = ri if cfg.Storage.RetentionDays < 1 { return fmt.Errorf("retention_days must be at least 1, got %d", cfg.Storage.RetentionDays) } for i, cred := range cfg.Auth.StaticCredentials { if cred.Username == "" { return fmt.Errorf("static_credentials[%d]: username must not be empty", i) } if cred.Password == "" { return fmt.Errorf("static_credentials[%d]: password must not be empty", i) } } // Validate detection config. if cfg.Detection.Enabled { if cfg.Detection.Threshold < 0 || cfg.Detection.Threshold > 1 { return fmt.Errorf("detection.threshold must be between 0 and 1, got %f", cfg.Detection.Threshold) } ui, err := time.ParseDuration(cfg.Detection.UpdateInterval) if err != nil { return fmt.Errorf("invalid detection.update_interval %q: %w", cfg.Detection.UpdateInterval, err) } if ui <= 0 { return fmt.Errorf("detection.update_interval must be positive, got %s", ui) } cfg.Detection.UpdateIntervalDuration = ui } // Validate notify config. knownEvents := map[string]bool{"human_detected": true, "session_started": true} for i, wh := range cfg.Notify.Webhooks { if wh.URL == "" { return fmt.Errorf("notify.webhooks[%d]: url must not be empty", i) } for j, ev := range wh.Events { if !knownEvents[ev] { return fmt.Errorf("notify.webhooks[%d].events[%d]: unknown event %q", i, j, ev) } } } return nil }