package server import ( "context" "log/slog" "net" "os" "path/filepath" "testing" "time" "git.t-juice.club/torjus/oubliette/internal/config" "git.t-juice.club/torjus/oubliette/internal/storage" "golang.org/x/crypto/ssh" ) type testAddr struct { str string network string } func (a testAddr) Network() string { return a.network } func (a testAddr) String() string { return a.str } func newAddr(s, network string) net.Addr { return testAddr{str: s, network: network} } func TestHostKey_Generate(t *testing.T) { path := filepath.Join(t.TempDir(), "host_key") signer, err := loadOrGenerateHostKey(path) if err != nil { t.Fatalf("unexpected error: %v", err) } if signer == nil { t.Fatal("signer is nil") } // File should exist with correct permissions. info, err := os.Stat(path) if err != nil { t.Fatalf("stat host key: %v", err) } if perm := info.Mode().Perm(); perm != 0600 { t.Errorf("permissions = %o, want 0600", perm) } } func TestHostKey_Load(t *testing.T) { path := filepath.Join(t.TempDir(), "host_key") // Generate first. signer1, err := loadOrGenerateHostKey(path) if err != nil { t.Fatalf("generate: %v", err) } // Load existing. signer2, err := loadOrGenerateHostKey(path) if err != nil { t.Fatalf("load: %v", err) } // Keys should be the same. if string(signer1.PublicKey().Marshal()) != string(signer2.PublicKey().Marshal()) { t.Error("loaded key differs from generated key") } } func TestExtractIP(t *testing.T) { tests := []struct { addr string want string }{ {"192.168.1.1:22", "192.168.1.1"}, {"[::1]:22", "::1"}, {"[::ffff:192.168.1.1]:22", "192.168.1.1"}, {"10.0.0.1:12345", "10.0.0.1"}, } for _, tt := range tests { t.Run(tt.addr, func(t *testing.T) { addr := newAddr(tt.addr, "tcp") got := extractIP(addr) if got != tt.want { t.Errorf("extractIP(%q) = %q, want %q", tt.addr, got, tt.want) } }) } } func TestIntegrationSSHConnect(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") } tmpDir := t.TempDir() cfg := config.Config{ SSH: config.SSHConfig{ ListenAddr: "127.0.0.1:0", HostKeyPath: filepath.Join(tmpDir, "host_key"), MaxConnections: 100, }, Auth: config.AuthConfig{ AcceptAfter: 2, CredentialTTLDuration: time.Hour, StaticCredentials: []config.Credential{ {Username: "root", Password: "toor"}, }, }, LogLevel: "debug", } logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) store := storage.NewMemoryStore() srv, err := New(cfg, store, logger) if err != nil { t.Fatalf("creating server: %v", err) } // Use a listener to get the actual port. listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } addr := listener.Addr().String() listener.Close() cfg.SSH.ListenAddr = addr srv.cfg = cfg ctx, cancel := context.WithCancel(context.Background()) defer cancel() errCh := make(chan error, 1) go func() { errCh <- srv.ListenAndServe(ctx) }() // Wait for server to be ready. var conn net.Conn for i := range 50 { conn, err = net.DialTimeout("tcp", addr, 100*time.Millisecond) if err == nil { conn.Close() break } if i == 49 { t.Fatalf("server not ready after retries: %v", err) } time.Sleep(50 * time.Millisecond) } // Test static credential login. t.Run("static_cred", func(t *testing.T) { clientCfg := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.Password("toor")}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } client, err := ssh.Dial("tcp", addr, clientCfg) if err != nil { t.Fatalf("SSH dial: %v", err) } defer client.Close() session, err := client.NewSession() if err != nil { t.Fatalf("new session: %v", err) } defer session.Close() }) // Test wrong password is rejected. t.Run("wrong_password", func(t *testing.T) { clientCfg := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.Password("wrong")}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } _, err := ssh.Dial("tcp", addr, clientCfg) if err == nil { t.Fatal("expected error for wrong password") } }) // Test threshold acceptance: after enough failed dials, a subsequent // dial with the same credentials should succeed via threshold or // remembered credential. t.Run("threshold", func(t *testing.T) { clientCfg := &ssh.ClientConfig{ User: "threshuser", Auth: []ssh.AuthMethod{ssh.Password("threshpass")}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } // Make several dials to accumulate failures past the threshold. for range 5 { c, err := ssh.Dial("tcp", addr, clientCfg) if err == nil { // Threshold reached, success! c.Close() return } } // After enough failures the credential should be remembered. client, err := ssh.Dial("tcp", addr, clientCfg) if err != nil { t.Fatalf("expected threshold/remembered acceptance after many attempts: %v", err) } client.Close() }) cancel() }