package config import ( "fmt" "os" "time" "github.com/BurntSushi/toml" ) type Config struct { SSH SSHConfig `toml:"ssh"` Auth AuthConfig `toml:"auth"` Storage StorageConfig `toml:"storage"` Shell ShellConfig `toml:"shell"` Web WebConfig `toml:"web"` LogLevel string `toml:"log_level"` LogFormat string `toml:"log_format"` // "text" (default) or "json" } type WebConfig struct { Enabled bool `toml:"enabled"` ListenAddr string `toml:"listen_addr"` } type ShellConfig struct { Hostname string `toml:"hostname"` Banner string `toml:"banner"` FakeUser string `toml:"fake_user"` Shells map[string]map[string]any `toml:"-"` // per-shell config extracted manually } type StorageConfig struct { DBPath string `toml:"db_path"` RetentionDays int `toml:"retention_days"` RetentionInterval string `toml:"retention_interval"` // Parsed duration, not from TOML directly. RetentionIntervalDuration time.Duration `toml:"-"` } type SSHConfig struct { ListenAddr string `toml:"listen_addr"` HostKeyPath string `toml:"host_key_path"` MaxConnections int `toml:"max_connections"` } type AuthConfig struct { AcceptAfter int `toml:"accept_after"` CredentialTTL string `toml:"credential_ttl"` StaticCredentials []Credential `toml:"static_credentials"` // Parsed duration, not from TOML directly. CredentialTTLDuration time.Duration `toml:"-"` } type Credential struct { Username string `toml:"username"` Password string `toml:"password"` } func Load(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("reading config: %w", err) } cfg := &Config{} if err := toml.Unmarshal(data, cfg); err != nil { return nil, fmt.Errorf("parsing config: %w", err) } // Second pass: extract per-shell sub-tables (e.g. [shell.bash]). var raw map[string]any if err := toml.Unmarshal(data, &raw); err == nil { if shellSection, ok := raw["shell"].(map[string]any); ok { cfg.Shell.Shells = extractShellTables(shellSection) } } applyDefaults(cfg) if err := validate(cfg); err != nil { return nil, fmt.Errorf("validating config: %w", err) } return cfg, nil } func applyDefaults(cfg *Config) { if cfg.SSH.ListenAddr == "" { cfg.SSH.ListenAddr = ":2222" } if cfg.SSH.HostKeyPath == "" { cfg.SSH.HostKeyPath = "oubliette_host_key" } if cfg.SSH.MaxConnections == 0 { cfg.SSH.MaxConnections = 500 } if cfg.Auth.AcceptAfter == 0 { cfg.Auth.AcceptAfter = 10 } if cfg.Auth.CredentialTTL == "" { cfg.Auth.CredentialTTL = "24h" } if cfg.LogLevel == "" { cfg.LogLevel = "info" } if cfg.LogFormat == "" { cfg.LogFormat = "text" } if cfg.Storage.DBPath == "" { cfg.Storage.DBPath = "oubliette.db" } if cfg.Storage.RetentionDays == 0 { cfg.Storage.RetentionDays = 90 } if cfg.Storage.RetentionInterval == "" { cfg.Storage.RetentionInterval = "1h" } if cfg.Web.ListenAddr == "" { cfg.Web.ListenAddr = ":8080" } if cfg.Shell.Hostname == "" { cfg.Shell.Hostname = "ubuntu-server" } if cfg.Shell.Banner == "" { cfg.Shell.Banner = "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n" } } // knownShellKeys are top-level keys in [shell] that are not per-shell sub-tables. var knownShellKeys = map[string]bool{ "hostname": true, "banner": true, "fake_user": true, } // extractShellTables pulls per-shell config sub-tables from the raw [shell] section. func extractShellTables(section map[string]any) map[string]map[string]any { result := make(map[string]map[string]any) for key, val := range section { if knownShellKeys[key] { continue } if sub, ok := val.(map[string]any); ok { result[key] = sub } } if len(result) == 0 { return nil } return result } func validate(cfg *Config) error { d, err := time.ParseDuration(cfg.Auth.CredentialTTL) if err != nil { return fmt.Errorf("invalid credential_ttl %q: %w", cfg.Auth.CredentialTTL, err) } if d <= 0 { return fmt.Errorf("credential_ttl must be positive, got %s", d) } cfg.Auth.CredentialTTLDuration = d if cfg.Auth.AcceptAfter < 1 { return fmt.Errorf("accept_after must be at least 1, got %d", cfg.Auth.AcceptAfter) } ri, err := time.ParseDuration(cfg.Storage.RetentionInterval) if err != nil { return fmt.Errorf("invalid retention_interval %q: %w", cfg.Storage.RetentionInterval, err) } if ri <= 0 { return fmt.Errorf("retention_interval must be positive, got %s", ri) } cfg.Storage.RetentionIntervalDuration = ri if cfg.Storage.RetentionDays < 1 { return fmt.Errorf("retention_days must be at least 1, got %d", cfg.Storage.RetentionDays) } for i, cred := range cfg.Auth.StaticCredentials { if cred.Username == "" { return fmt.Errorf("static_credentials[%d]: username must not be empty", i) } if cred.Password == "" { return fmt.Errorf("static_credentials[%d]: password must not be empty", i) } } return nil }