diff --git a/README.md b/README.md index 7058849..cd939a3 100644 --- a/README.md +++ b/README.md @@ -7,3 +7,42 @@ Named after the medieval dungeon - a place you throw people into and forget abou ## Status Early development. See `PLAN.md` for the roadmap. + +## Usage + +### Build + +```sh +# With Nix +nix build + +# With Go +nix develop -c go build ./cmd/oubliette +``` + +### Configure + +Copy and edit the example config: + +```sh +cp oubliette.toml.example oubliette.toml +``` + +Key settings: +- `ssh.listen_addr` — listen address (default `:2222`) +- `ssh.host_key_path` — Ed25519 host key, auto-generated if missing +- `auth.accept_after` — accept login after N failures per IP (default `10`) +- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`) +- `auth.static_credentials` — always-accepted username/password pairs + +### Run + +```sh +./oubliette -config oubliette.toml +``` + +Test with: + +```sh +ssh -o StrictHostKeyChecking=no -p 2222 root@localhost +``` diff --git a/cmd/oubliette/main.go b/cmd/oubliette/main.go new file mode 100644 index 0000000..be43975 --- /dev/null +++ b/cmd/oubliette/main.go @@ -0,0 +1,55 @@ +package main + +import ( + "context" + "flag" + "log/slog" + "os" + "os/signal" + "syscall" + + "git.t-juice.club/torjus/oubliette/internal/config" + "git.t-juice.club/torjus/oubliette/internal/server" +) + +func main() { + configPath := flag.String("config", "oubliette.toml", "path to config file") + flag.Parse() + + cfg, err := config.Load(*configPath) + if err != nil { + slog.Error("failed to load config", "err", err) + os.Exit(1) + } + + level := new(slog.LevelVar) + switch cfg.LogLevel { + case "debug": + level.Set(slog.LevelDebug) + case "warn": + level.Set(slog.LevelWarn) + case "error": + level.Set(slog.LevelError) + default: + level.Set(slog.LevelInfo) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: level})) + slog.SetDefault(logger) + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + srv, err := server.New(*cfg, logger) + if err != nil { + logger.Error("failed to create server", "err", err) + os.Exit(1) + } + + if err := srv.ListenAndServe(ctx); err != nil { + logger.Error("server error", "err", err) + os.Exit(1) + } + + logger.Info("server stopped") +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..18b7a3c --- /dev/null +++ b/flake.lock @@ -0,0 +1,27 @@ +{ + "nodes": { + "nixpkgs": { + "locked": { + "lastModified": 1771008912, + "narHash": "sha256-gf2AmWVTs8lEq7z/3ZAsgnZDhWIckkb+ZnAo5RzSxJg=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "a82ccc39b39b621151d6732718e3e250109076fa", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix index 1bf099f..2f1ac0d 100644 --- a/flake.nix +++ b/flake.nix @@ -11,6 +11,24 @@ forAllSystems = nixpkgs.lib.genAttrs supportedSystems; in { + packages = forAllSystems (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + default = pkgs.buildGoModule { + pname = "oubliette"; + version = "0.1.0"; + src = ./.; + vendorHash = "sha256-z/E1ZDfedOxI8CSUfcpFGYX0SrdcnAYuu2p0ATozDaA="; + subPackages = [ "cmd/oubliette" ]; + meta = { + description = "SSH honeypot"; + mainProgram = "oubliette"; + }; + }; + }); + devShells = forAllSystems (system: let pkgs = nixpkgs.legacyPackages.${system}; diff --git a/go.mod b/go.mod index 0fb32a3..77e5311 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,9 @@ module git.t-juice.club/torjus/oubliette go 1.25.5 + +require ( + github.com/BurntSushi/toml v1.6.0 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/sys v0.41.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..df3fc95 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= +github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..518d343 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,68 @@ +package auth + +import ( + "sync" + "time" + + "git.t-juice.club/torjus/oubliette/internal/config" +) + +type credKey struct { + Username string + Password string +} + +type Decision struct { + Accepted bool + Reason string // "static_credential", "threshold_reached", "remembered_credential", "rejected" +} + +type Authenticator struct { + mu sync.Mutex + cfg config.AuthConfig + failCounts map[string]int // IP -> consecutive failures + rememberedCreds map[credKey]time.Time // (user,pass) -> expiry + now func() time.Time // for testing +} + +func NewAuthenticator(cfg config.AuthConfig) *Authenticator { + return &Authenticator{ + cfg: cfg, + failCounts: make(map[string]int), + rememberedCreds: make(map[credKey]time.Time), + now: time.Now, + } +} + +func (a *Authenticator) Authenticate(ip, username, password string) Decision { + a.mu.Lock() + defer a.mu.Unlock() + + // 1. Check static credentials. + for _, cred := range a.cfg.StaticCredentials { + if cred.Username == username && cred.Password == password { + a.failCounts[ip] = 0 + return Decision{Accepted: true, Reason: "static_credential"} + } + } + + // 2. Check remembered credentials. + key := credKey{Username: username, Password: password} + if expiry, ok := a.rememberedCreds[key]; ok { + if a.now().Before(expiry) { + a.failCounts[ip] = 0 + return Decision{Accepted: true, Reason: "remembered_credential"} + } + delete(a.rememberedCreds, key) + } + + // 3. Increment fail count, check threshold. + a.failCounts[ip]++ + if a.failCounts[ip] >= a.cfg.AcceptAfter { + a.failCounts[ip] = 0 + a.rememberedCreds[key] = a.now().Add(a.cfg.CredentialTTLDuration) + return Decision{Accepted: true, Reason: "threshold_reached"} + } + + return Decision{Accepted: false, Reason: "rejected"} +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go new file mode 100644 index 0000000..00406c3 --- /dev/null +++ b/internal/auth/auth_test.go @@ -0,0 +1,142 @@ +package auth + +import ( + "sync" + "testing" + "time" + + "git.t-juice.club/torjus/oubliette/internal/config" +) + +func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator { + a := NewAuthenticator(config.AuthConfig{ + AcceptAfter: acceptAfter, + CredentialTTLDuration: ttl, + StaticCredentials: statics, + }) + return a +} + +func TestStaticCredentialsAccepted(t *testing.T) { + a := newTestAuth(10, time.Hour, config.Credential{Username: "root", Password: "toor"}) + d := a.Authenticate("1.2.3.4", "root", "toor") + if !d.Accepted || d.Reason != "static_credential" { + t.Errorf("got %+v, want accepted with static_credential", d) + } +} + +func TestStaticCredentialsWrongPassword(t *testing.T) { + a := newTestAuth(10, time.Hour, config.Credential{Username: "root", Password: "toor"}) + d := a.Authenticate("1.2.3.4", "root", "wrong") + if d.Accepted { + t.Errorf("should not accept wrong password for static credential") + } +} + +func TestRejectionBeforeThreshold(t *testing.T) { + a := newTestAuth(3, time.Hour) + for i := 0; i < 2; i++ { + d := a.Authenticate("1.2.3.4", "user", "pass") + if d.Accepted { + t.Fatalf("attempt %d should be rejected", i+1) + } + if d.Reason != "rejected" { + t.Errorf("attempt %d reason = %q, want %q", i+1, d.Reason, "rejected") + } + } +} + +func TestThresholdAcceptance(t *testing.T) { + a := newTestAuth(3, time.Hour) + for i := 0; i < 2; i++ { + d := a.Authenticate("1.2.3.4", "user", "pass") + if d.Accepted { + t.Fatalf("attempt %d should be rejected", i+1) + } + } + d := a.Authenticate("1.2.3.4", "user", "pass") + if !d.Accepted || d.Reason != "threshold_reached" { + t.Errorf("attempt 3 got %+v, want accepted with threshold_reached", d) + } +} + +func TestPerIPIsolation(t *testing.T) { + a := newTestAuth(3, time.Hour) + + // IP1 gets 2 failures. + for i := 0; i < 2; i++ { + a.Authenticate("1.1.1.1", "user", "pass") + } + + // IP2 should start at 0, not inherit IP1's count. + d := a.Authenticate("2.2.2.2", "user", "pass") + if d.Accepted { + t.Error("IP2's first attempt should be rejected") + } +} + +func TestCredentialMemoryAcrossIPs(t *testing.T) { + a := newTestAuth(2, time.Hour) + + // IP1 reaches threshold, credential is remembered. + a.Authenticate("1.1.1.1", "user", "pass") + d := a.Authenticate("1.1.1.1", "user", "pass") + if !d.Accepted || d.Reason != "threshold_reached" { + t.Fatalf("threshold not reached: %+v", d) + } + + // IP2 should get in with the remembered credential. + d = a.Authenticate("2.2.2.2", "user", "pass") + if !d.Accepted || d.Reason != "remembered_credential" { + t.Errorf("IP2 got %+v, want accepted with remembered_credential", d) + } +} + +func TestCredentialMemoryExpires(t *testing.T) { + a := newTestAuth(2, time.Hour) + now := time.Now() + a.now = func() time.Time { return now } + + // Reach threshold to remember credential. + a.Authenticate("1.1.1.1", "user", "pass") + a.Authenticate("1.1.1.1", "user", "pass") + + // Advance past TTL. + a.now = func() time.Time { return now.Add(2 * time.Hour) } + + d := a.Authenticate("2.2.2.2", "user", "pass") + if d.Accepted { + t.Errorf("expired credential should not be accepted: %+v", d) + } +} + +func TestCounterResetsAfterAcceptance(t *testing.T) { + a := newTestAuth(2, time.Hour) + + // Reach threshold. + a.Authenticate("1.1.1.1", "user", "pass") + d := a.Authenticate("1.1.1.1", "user", "pass") + if !d.Accepted { + t.Fatal("threshold not reached") + } + + // With a different credential, counter should be reset. + d = a.Authenticate("1.1.1.1", "other", "cred") + if d.Accepted { + t.Error("first attempt after reset should be rejected") + } +} + +func TestConcurrentAccess(t *testing.T) { + a := newTestAuth(5, time.Hour) + var wg sync.WaitGroup + + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + a.Authenticate("1.2.3.4", "user", "pass") + }() + } + wg.Wait() +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..1098fb9 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,98 @@ +package config + +import ( + "fmt" + "os" + "time" + + "github.com/BurntSushi/toml" +) + +type Config struct { + SSH SSHConfig `toml:"ssh"` + Auth AuthConfig `toml:"auth"` + LogLevel string `toml:"log_level"` +} + +type SSHConfig struct { + ListenAddr string `toml:"listen_addr"` + HostKeyPath string `toml:"host_key_path"` +} + +type AuthConfig struct { + AcceptAfter int `toml:"accept_after"` + CredentialTTL string `toml:"credential_ttl"` + StaticCredentials []Credential `toml:"static_credentials"` + + // Parsed duration, not from TOML directly. + CredentialTTLDuration time.Duration `toml:"-"` +} + +type Credential struct { + Username string `toml:"username"` + Password string `toml:"password"` +} + +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading config: %w", err) + } + + cfg := &Config{} + if err := toml.Unmarshal(data, cfg); err != nil { + return nil, fmt.Errorf("parsing config: %w", err) + } + + applyDefaults(cfg) + + if err := validate(cfg); err != nil { + return nil, fmt.Errorf("validating config: %w", err) + } + + return cfg, nil +} + +func applyDefaults(cfg *Config) { + if cfg.SSH.ListenAddr == "" { + cfg.SSH.ListenAddr = ":2222" + } + if cfg.SSH.HostKeyPath == "" { + cfg.SSH.HostKeyPath = "oubliette_host_key" + } + if cfg.Auth.AcceptAfter == 0 { + cfg.Auth.AcceptAfter = 10 + } + if cfg.Auth.CredentialTTL == "" { + cfg.Auth.CredentialTTL = "24h" + } + if cfg.LogLevel == "" { + cfg.LogLevel = "info" + } +} + +func validate(cfg *Config) error { + d, err := time.ParseDuration(cfg.Auth.CredentialTTL) + if err != nil { + return fmt.Errorf("invalid credential_ttl %q: %w", cfg.Auth.CredentialTTL, err) + } + if d <= 0 { + return fmt.Errorf("credential_ttl must be positive, got %s", d) + } + cfg.Auth.CredentialTTLDuration = d + + if cfg.Auth.AcceptAfter < 1 { + return fmt.Errorf("accept_after must be at least 1, got %d", cfg.Auth.AcceptAfter) + } + + for i, cred := range cfg.Auth.StaticCredentials { + if cred.Username == "" { + return fmt.Errorf("static_credentials[%d]: username must not be empty", i) + } + if cred.Password == "" { + return fmt.Errorf("static_credentials[%d]: password must not be empty", i) + } + } + + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..c6b5288 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,138 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestLoadValidConfig(t *testing.T) { + content := ` +log_level = "debug" + +[ssh] +listen_addr = ":3333" +host_key_path = "/tmp/test_key" + +[auth] +accept_after = 5 +credential_ttl = "1h" + +[[auth.static_credentials]] +username = "root" +password = "toor" +` + path := writeTemp(t, content) + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.LogLevel != "debug" { + t.Errorf("log_level = %q, want %q", cfg.LogLevel, "debug") + } + if cfg.SSH.ListenAddr != ":3333" { + t.Errorf("listen_addr = %q, want %q", cfg.SSH.ListenAddr, ":3333") + } + if cfg.SSH.HostKeyPath != "/tmp/test_key" { + t.Errorf("host_key_path = %q, want %q", cfg.SSH.HostKeyPath, "/tmp/test_key") + } + if cfg.Auth.AcceptAfter != 5 { + t.Errorf("accept_after = %d, want %d", cfg.Auth.AcceptAfter, 5) + } + if cfg.Auth.CredentialTTLDuration != time.Hour { + t.Errorf("credential_ttl_duration = %v, want %v", cfg.Auth.CredentialTTLDuration, time.Hour) + } + if len(cfg.Auth.StaticCredentials) != 1 { + t.Fatalf("static_credentials len = %d, want 1", len(cfg.Auth.StaticCredentials)) + } + if cfg.Auth.StaticCredentials[0].Username != "root" { + t.Errorf("username = %q, want %q", cfg.Auth.StaticCredentials[0].Username, "root") + } +} + +func TestLoadDefaults(t *testing.T) { + path := writeTemp(t, "") + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.SSH.ListenAddr != ":2222" { + t.Errorf("default listen_addr = %q, want %q", cfg.SSH.ListenAddr, ":2222") + } + if cfg.SSH.HostKeyPath != "oubliette_host_key" { + t.Errorf("default host_key_path = %q, want %q", cfg.SSH.HostKeyPath, "oubliette_host_key") + } + if cfg.Auth.AcceptAfter != 10 { + t.Errorf("default accept_after = %d, want %d", cfg.Auth.AcceptAfter, 10) + } + if cfg.Auth.CredentialTTLDuration != 24*time.Hour { + t.Errorf("default credential_ttl = %v, want %v", cfg.Auth.CredentialTTLDuration, 24*time.Hour) + } + if cfg.LogLevel != "info" { + t.Errorf("default log_level = %q, want %q", cfg.LogLevel, "info") + } +} + +func TestLoadInvalidTTL(t *testing.T) { + content := ` +[auth] +credential_ttl = "notaduration" +` + path := writeTemp(t, content) + _, err := Load(path) + if err == nil { + t.Fatal("expected error for invalid credential_ttl") + } +} + +func TestLoadNegativeTTL(t *testing.T) { + content := ` +[auth] +credential_ttl = "-1h" +` + path := writeTemp(t, content) + _, err := Load(path) + if err == nil { + t.Fatal("expected error for negative credential_ttl") + } +} + +func TestLoadInvalidStaticCredential(t *testing.T) { + content := ` +[[auth.static_credentials]] +username = "" +password = "test" +` + path := writeTemp(t, content) + _, err := Load(path) + if err == nil { + t.Fatal("expected error for empty username") + } +} + +func TestLoadMissingFile(t *testing.T) { + _, err := Load("/nonexistent/path/config.toml") + if err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestLoadInvalidTOML(t *testing.T) { + path := writeTemp(t, "{{{{invalid toml") + _, err := Load(path) + if err == nil { + t.Fatal("expected error for invalid TOML") + } +} + +func writeTemp(t *testing.T, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), "config.toml") + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("writing temp config: %v", err) + } + return path +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..91212c7 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,230 @@ +package server + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "encoding/pem" + "errors" + "fmt" + "log/slog" + "net" + "os" + "time" + + "git.t-juice.club/torjus/oubliette/internal/auth" + "git.t-juice.club/torjus/oubliette/internal/config" + "golang.org/x/crypto/ssh" +) + +const sessionTimeout = 30 * time.Second + +type Server struct { + cfg config.Config + authenticator *auth.Authenticator + sshConfig *ssh.ServerConfig + logger *slog.Logger +} + +func New(cfg config.Config, logger *slog.Logger) (*Server, error) { + s := &Server{ + cfg: cfg, + authenticator: auth.NewAuthenticator(cfg.Auth), + logger: logger, + } + + hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath) + if err != nil { + return nil, fmt.Errorf("host key: %w", err) + } + + s.sshConfig = &ssh.ServerConfig{ + PasswordCallback: s.passwordCallback, + ServerVersion: "SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.6", + } + s.sshConfig.AddHostKey(hostKey) + + return s, nil +} + +func (s *Server) ListenAndServe(ctx context.Context) error { + listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr) + if err != nil { + return fmt.Errorf("listen: %w", err) + } + defer listener.Close() + + s.logger.Info("SSH server listening", "addr", s.cfg.SSH.ListenAddr) + + go func() { + <-ctx.Done() + listener.Close() + }() + + for { + conn, err := listener.Accept() + if err != nil { + if ctx.Err() != nil { + return nil + } + s.logger.Error("accept error", "err", err) + continue + } + go s.handleConn(conn) + } +} + +func (s *Server) handleConn(conn net.Conn) { + defer conn.Close() + + sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig) + if err != nil { + s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err) + return + } + defer sshConn.Close() + + s.logger.Info("SSH connection established", + "remote_addr", sshConn.RemoteAddr(), + "user", sshConn.User(), + ) + + go ssh.DiscardRequests(reqs) + + for newChan := range chans { + if newChan.ChannelType() != "session" { + newChan.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + + channel, requests, err := newChan.Accept() + if err != nil { + s.logger.Error("channel accept error", "err", err) + return + } + + go s.handleSession(channel, requests, sshConn) + } +} + +func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) { + defer channel.Close() + + // Handle session requests (pty-req, shell, etc.) + go func() { + for req := range requests { + switch req.Type { + case "pty-req", "shell": + if req.WantReply { + req.Reply(true, nil) + } + default: + if req.WantReply { + req.Reply(false, nil) + } + } + } + }() + + // Write a fake banner. + fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n") + fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006")) + fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User()) + + // Hold connection open until timeout or client disconnect. + timer := time.NewTimer(sessionTimeout) + defer timer.Stop() + + done := make(chan struct{}) + go func() { + buf := make([]byte, 256) + for { + _, err := channel.Read(buf) + if err != nil { + close(done) + return + } + } + }() + + select { + case <-timer.C: + s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User()) + case <-done: + s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User()) + } +} + +func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + ip := extractIP(conn.RemoteAddr()) + d := s.authenticator.Authenticate(ip, conn.User(), string(password)) + + s.logger.Info("auth attempt", + "remote_addr", conn.RemoteAddr(), + "username", conn.User(), + "accepted", d.Accepted, + "reason", d.Reason, + ) + + if d.Accepted { + return nil, nil + } + return nil, fmt.Errorf("rejected") +} + +func extractIP(addr net.Addr) string { + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + // Might not have a port, try using the string directly. + return addr.String() + } + // Normalize IPv4-mapped IPv6 addresses. + ip := net.ParseIP(host) + if ip == nil { + return host + } + if v4 := ip.To4(); v4 != nil { + return v4.String() + } + return ip.String() +} + +func loadOrGenerateHostKey(path string) (ssh.Signer, error) { + data, err := os.ReadFile(path) + if err == nil { + signer, err := ssh.ParsePrivateKey(data) + if err != nil { + return nil, fmt.Errorf("parsing host key: %w", err) + } + return signer, nil + } + + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("reading host key: %w", err) + } + + // Generate new Ed25519 key. + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("generating key: %w", err) + } + + privBytes, err := ssh.MarshalPrivateKey(priv, "") + if err != nil { + return nil, fmt.Errorf("marshaling key: %w", err) + } + + pemData := pem.EncodeToMemory(privBytes) + if err := os.WriteFile(path, pemData, 0600); err != nil { + return nil, fmt.Errorf("writing host key: %w", err) + } + + signer, err := ssh.ParsePrivateKey(pemData) + if err != nil { + return nil, fmt.Errorf("parsing generated key: %w", err) + } + + slog.Info("generated new host key", "path", path) + return signer, nil +} + diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..ff8d303 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,219 @@ +package server + +import ( + "context" + "log/slog" + "net" + "os" + "path/filepath" + "testing" + "time" + + "git.t-juice.club/torjus/oubliette/internal/config" + "golang.org/x/crypto/ssh" +) + +type testAddr struct { + str string + network string +} + +func (a testAddr) Network() string { return a.network } +func (a testAddr) String() string { return a.str } + +func newAddr(s, network string) net.Addr { + return testAddr{str: s, network: network} +} + +func TestHostKey_Generate(t *testing.T) { + path := filepath.Join(t.TempDir(), "host_key") + + signer, err := loadOrGenerateHostKey(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if signer == nil { + t.Fatal("signer is nil") + } + + // File should exist with correct permissions. + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat host key: %v", err) + } + if perm := info.Mode().Perm(); perm != 0600 { + t.Errorf("permissions = %o, want 0600", perm) + } +} + +func TestHostKey_Load(t *testing.T) { + path := filepath.Join(t.TempDir(), "host_key") + + // Generate first. + signer1, err := loadOrGenerateHostKey(path) + if err != nil { + t.Fatalf("generate: %v", err) + } + + // Load existing. + signer2, err := loadOrGenerateHostKey(path) + if err != nil { + t.Fatalf("load: %v", err) + } + + // Keys should be the same. + if string(signer1.PublicKey().Marshal()) != string(signer2.PublicKey().Marshal()) { + t.Error("loaded key differs from generated key") + } +} + +func TestExtractIP(t *testing.T) { + tests := []struct { + addr string + want string + }{ + {"192.168.1.1:22", "192.168.1.1"}, + {"[::1]:22", "::1"}, + {"[::ffff:192.168.1.1]:22", "192.168.1.1"}, + {"10.0.0.1:12345", "10.0.0.1"}, + } + + for _, tt := range tests { + t.Run(tt.addr, func(t *testing.T) { + addr := newAddr(tt.addr, "tcp") + got := extractIP(addr) + if got != tt.want { + t.Errorf("extractIP(%q) = %q, want %q", tt.addr, got, tt.want) + } + }) + } +} + +func TestIntegrationSSHConnect(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + tmpDir := t.TempDir() + cfg := config.Config{ + SSH: config.SSHConfig{ + ListenAddr: "127.0.0.1:0", + HostKeyPath: filepath.Join(tmpDir, "host_key"), + }, + Auth: config.AuthConfig{ + AcceptAfter: 2, + CredentialTTLDuration: time.Hour, + StaticCredentials: []config.Credential{ + {Username: "root", Password: "toor"}, + }, + }, + LogLevel: "debug", + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) + srv, err := New(cfg, logger) + if err != nil { + t.Fatalf("creating server: %v", err) + } + + // Use a listener to get the actual port. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + addr := listener.Addr().String() + listener.Close() + + cfg.SSH.ListenAddr = addr + srv.cfg = cfg + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- srv.ListenAndServe(ctx) + }() + + // Wait for server to be ready. + var conn net.Conn + for i := range 50 { + conn, err = net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err == nil { + conn.Close() + break + } + if i == 49 { + t.Fatalf("server not ready after retries: %v", err) + } + time.Sleep(50 * time.Millisecond) + } + + // Test static credential login. + t.Run("static_cred", func(t *testing.T) { + clientCfg := &ssh.ClientConfig{ + User: "root", + Auth: []ssh.AuthMethod{ssh.Password("toor")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + client, err := ssh.Dial("tcp", addr, clientCfg) + if err != nil { + t.Fatalf("SSH dial: %v", err) + } + defer client.Close() + + session, err := client.NewSession() + if err != nil { + t.Fatalf("new session: %v", err) + } + defer session.Close() + }) + + // Test wrong password is rejected. + t.Run("wrong_password", func(t *testing.T) { + clientCfg := &ssh.ClientConfig{ + User: "root", + Auth: []ssh.AuthMethod{ssh.Password("wrong")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + _, err := ssh.Dial("tcp", addr, clientCfg) + if err == nil { + t.Fatal("expected error for wrong password") + } + }) + + // Test threshold acceptance: after enough failed dials, a subsequent + // dial with the same credentials should succeed via threshold or + // remembered credential. + t.Run("threshold", func(t *testing.T) { + clientCfg := &ssh.ClientConfig{ + User: "threshuser", + Auth: []ssh.AuthMethod{ssh.Password("threshpass")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + // Make several dials to accumulate failures past the threshold. + for range 5 { + c, err := ssh.Dial("tcp", addr, clientCfg) + if err == nil { + // Threshold reached, success! + c.Close() + return + } + } + + // After enough failures the credential should be remembered. + client, err := ssh.Dial("tcp", addr, clientCfg) + if err != nil { + t.Fatalf("expected threshold/remembered acceptance after many attempts: %v", err) + } + client.Close() + }) + + cancel() +} diff --git a/oubliette.toml.example b/oubliette.toml.example new file mode 100644 index 0000000..6b5b64a --- /dev/null +++ b/oubliette.toml.example @@ -0,0 +1,17 @@ +log_level = "info" + +[ssh] +listen_addr = ":2222" +host_key_path = "oubliette_host_key" + +[auth] +accept_after = 10 +credential_ttl = "24h" + +[[auth.static_credentials]] +username = "root" +password = "toor" + +[[auth.static_credentials]] +username = "admin" +password = "admin"