package server import ( "bytes" "context" "log/slog" "net" "os" "path/filepath" "strings" "testing" "time" "git.t-juice.club/torjus/oubliette/internal/config" "git.t-juice.club/torjus/oubliette/internal/metrics" "git.t-juice.club/torjus/oubliette/internal/storage" "golang.org/x/crypto/ssh" ) type testAddr struct { str string network string } func (a testAddr) Network() string { return a.network } func (a testAddr) String() string { return a.str } func newAddr(s, network string) net.Addr { return testAddr{str: s, network: network} } func TestHostKey_Generate(t *testing.T) { path := filepath.Join(t.TempDir(), "host_key") signer, err := loadOrGenerateHostKey(path) if err != nil { t.Fatalf("unexpected error: %v", err) } if signer == nil { t.Fatal("signer is nil") } // File should exist with correct permissions. info, err := os.Stat(path) if err != nil { t.Fatalf("stat host key: %v", err) } if perm := info.Mode().Perm(); perm != 0600 { t.Errorf("permissions = %o, want 0600", perm) } } func TestHostKey_Load(t *testing.T) { path := filepath.Join(t.TempDir(), "host_key") // Generate first. signer1, err := loadOrGenerateHostKey(path) if err != nil { t.Fatalf("generate: %v", err) } // Load existing. signer2, err := loadOrGenerateHostKey(path) if err != nil { t.Fatalf("load: %v", err) } // Keys should be the same. if string(signer1.PublicKey().Marshal()) != string(signer2.PublicKey().Marshal()) { t.Error("loaded key differs from generated key") } } func TestExtractIP(t *testing.T) { tests := []struct { addr string want string }{ {"192.168.1.1:22", "192.168.1.1"}, {"[::1]:22", "::1"}, {"[::ffff:192.168.1.1]:22", "192.168.1.1"}, {"10.0.0.1:12345", "10.0.0.1"}, } for _, tt := range tests { t.Run(tt.addr, func(t *testing.T) { addr := newAddr(tt.addr, "tcp") got := extractIP(addr) if got != tt.want { t.Errorf("extractIP(%q) = %q, want %q", tt.addr, got, tt.want) } }) } } func TestIntegrationSSHConnect(t *testing.T) { if testing.Short() { t.Skip("skipping integration test") } tmpDir := t.TempDir() cfg := config.Config{ SSH: config.SSHConfig{ ListenAddr: "127.0.0.1:0", HostKeyPath: filepath.Join(tmpDir, "host_key"), MaxConnections: 100, }, Auth: config.AuthConfig{ AcceptAfter: 2, CredentialTTLDuration: time.Hour, StaticCredentials: []config.Credential{ {Username: "root", Password: "toor", Shell: "bash"}, }, }, Shell: config.ShellConfig{ Hostname: "ubuntu-server", Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n", }, LogLevel: "debug", } logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) store := storage.NewMemoryStore() srv, err := New(cfg, store, logger, metrics.New("test")) if err != nil { t.Fatalf("creating server: %v", err) } // Use a listener to get the actual port. listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } addr := listener.Addr().String() listener.Close() cfg.SSH.ListenAddr = addr srv.cfg = cfg ctx, cancel := context.WithCancel(context.Background()) defer cancel() errCh := make(chan error, 1) go func() { errCh <- srv.ListenAndServe(ctx) }() // Wait for server to be ready. var conn net.Conn for i := range 50 { conn, err = net.DialTimeout("tcp", addr, 100*time.Millisecond) if err == nil { conn.Close() break } if i == 49 { t.Fatalf("server not ready after retries: %v", err) } time.Sleep(50 * time.Millisecond) } // Test static credential login with shell interaction. t.Run("static_cred", func(t *testing.T) { clientCfg := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.Password("toor")}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } client, err := ssh.Dial("tcp", addr, clientCfg) if err != nil { t.Fatalf("SSH dial: %v", err) } defer client.Close() session, err := client.NewSession() if err != nil { t.Fatalf("new session: %v", err) } defer session.Close() // Request PTY and shell. if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err != nil { t.Fatalf("request pty: %v", err) } stdin, err := session.StdinPipe() if err != nil { t.Fatalf("stdin pipe: %v", err) } var output bytes.Buffer session.Stdout = &output if err := session.Shell(); err != nil { t.Fatalf("shell: %v", err) } // Wait for the prompt, then send commands. time.Sleep(500 * time.Millisecond) stdin.Write([]byte("pwd\r")) time.Sleep(200 * time.Millisecond) stdin.Write([]byte("whoami\r")) time.Sleep(200 * time.Millisecond) stdin.Write([]byte("exit\r")) // Wait for session to end. session.Wait() out := output.String() if !strings.Contains(out, "Welcome to Ubuntu") { t.Errorf("output should contain banner, got: %s", out) } if !strings.Contains(out, "/root") { t.Errorf("output should contain /root from pwd, got: %s", out) } if !strings.Contains(out, "root") { t.Errorf("output should contain 'root' from whoami, got: %s", out) } // Verify session logs were recorded. if len(store.SessionLogs) < 2 { t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs)) } // Verify session was created with shell name. var foundBash bool for _, s := range store.Sessions { if s.ShellName == "bash" { foundBash = true break } } if !foundBash { t.Error("expected a session with shell_name='bash'") } }) // Test wrong password is rejected. t.Run("wrong_password", func(t *testing.T) { clientCfg := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.Password("wrong")}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } _, err := ssh.Dial("tcp", addr, clientCfg) if err == nil { t.Fatal("expected error for wrong password") } }) // Test exec command capture. t.Run("exec_command", func(t *testing.T) { clientCfg := &ssh.ClientConfig{ User: "root", Auth: []ssh.AuthMethod{ssh.Password("toor")}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } client, err := ssh.Dial("tcp", addr, clientCfg) if err != nil { t.Fatalf("SSH dial: %v", err) } defer client.Close() session, err := client.NewSession() if err != nil { t.Fatalf("new session: %v", err) } defer session.Close() // Run a command via exec (no PTY, no shell). if err := session.Run("uname -a"); err != nil { // Run returns an error because the server closes the channel, // but that's expected. _ = err } // Give the server a moment to store the command. time.Sleep(200 * time.Millisecond) // Verify the exec command was captured. sessions, err := store.GetRecentSessions(context.Background(), 50, false) if err != nil { t.Fatalf("GetRecentSessions: %v", err) } var foundExec bool for _, s := range sessions { if s.ExecCommand != nil && *s.ExecCommand == "uname -a" { foundExec = true break } } if !foundExec { t.Error("expected a session with exec_command='uname -a'") } }) // Test threshold acceptance: after enough failed dials, a subsequent // dial with the same credentials should succeed via threshold or // remembered credential. t.Run("threshold", func(t *testing.T) { clientCfg := &ssh.ClientConfig{ User: "threshuser", Auth: []ssh.AuthMethod{ssh.Password("threshpass")}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: 5 * time.Second, } // Make several dials to accumulate failures past the threshold. for range 5 { c, err := ssh.Dial("tcp", addr, clientCfg) if err == nil { // Threshold reached, success! c.Close() return } } // After enough failures the credential should be remembered. client, err := ssh.Dial("tcp", addr, clientCfg) if err != nil { t.Fatalf("expected threshold/remembered acceptance after many attempts: %v", err) } client.Close() }) cancel() }