Adds persistent storage using modernc.org/sqlite (pure Go). Login attempts are deduplicated by (username, password, ip) with counts. Sessions and session logs are tracked with UUID IDs. Includes embedded SQL migrations, configurable retention with background pruning, and an in-memory store for tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
223 lines
5.1 KiB
Go
223 lines
5.1 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.t-juice.club/torjus/oubliette/internal/config"
|
|
"git.t-juice.club/torjus/oubliette/internal/storage"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type testAddr struct {
|
|
str string
|
|
network string
|
|
}
|
|
|
|
func (a testAddr) Network() string { return a.network }
|
|
func (a testAddr) String() string { return a.str }
|
|
|
|
func newAddr(s, network string) net.Addr {
|
|
return testAddr{str: s, network: network}
|
|
}
|
|
|
|
func TestHostKey_Generate(t *testing.T) {
|
|
path := filepath.Join(t.TempDir(), "host_key")
|
|
|
|
signer, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if signer == nil {
|
|
t.Fatal("signer is nil")
|
|
}
|
|
|
|
// File should exist with correct permissions.
|
|
info, err := os.Stat(path)
|
|
if err != nil {
|
|
t.Fatalf("stat host key: %v", err)
|
|
}
|
|
if perm := info.Mode().Perm(); perm != 0600 {
|
|
t.Errorf("permissions = %o, want 0600", perm)
|
|
}
|
|
}
|
|
|
|
func TestHostKey_Load(t *testing.T) {
|
|
path := filepath.Join(t.TempDir(), "host_key")
|
|
|
|
// Generate first.
|
|
signer1, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("generate: %v", err)
|
|
}
|
|
|
|
// Load existing.
|
|
signer2, err := loadOrGenerateHostKey(path)
|
|
if err != nil {
|
|
t.Fatalf("load: %v", err)
|
|
}
|
|
|
|
// Keys should be the same.
|
|
if string(signer1.PublicKey().Marshal()) != string(signer2.PublicKey().Marshal()) {
|
|
t.Error("loaded key differs from generated key")
|
|
}
|
|
}
|
|
|
|
func TestExtractIP(t *testing.T) {
|
|
tests := []struct {
|
|
addr string
|
|
want string
|
|
}{
|
|
{"192.168.1.1:22", "192.168.1.1"},
|
|
{"[::1]:22", "::1"},
|
|
{"[::ffff:192.168.1.1]:22", "192.168.1.1"},
|
|
{"10.0.0.1:12345", "10.0.0.1"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.addr, func(t *testing.T) {
|
|
addr := newAddr(tt.addr, "tcp")
|
|
got := extractIP(addr)
|
|
if got != tt.want {
|
|
t.Errorf("extractIP(%q) = %q, want %q", tt.addr, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIntegrationSSHConnect(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping integration test")
|
|
}
|
|
|
|
tmpDir := t.TempDir()
|
|
cfg := config.Config{
|
|
SSH: config.SSHConfig{
|
|
ListenAddr: "127.0.0.1:0",
|
|
HostKeyPath: filepath.Join(tmpDir, "host_key"),
|
|
MaxConnections: 100,
|
|
},
|
|
Auth: config.AuthConfig{
|
|
AcceptAfter: 2,
|
|
CredentialTTLDuration: time.Hour,
|
|
StaticCredentials: []config.Credential{
|
|
{Username: "root", Password: "toor"},
|
|
},
|
|
},
|
|
LogLevel: "debug",
|
|
}
|
|
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
store := storage.NewMemoryStore()
|
|
srv, err := New(cfg, store, logger)
|
|
if err != nil {
|
|
t.Fatalf("creating server: %v", err)
|
|
}
|
|
|
|
// Use a listener to get the actual port.
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
addr := listener.Addr().String()
|
|
listener.Close()
|
|
|
|
cfg.SSH.ListenAddr = addr
|
|
srv.cfg = cfg
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
errCh <- srv.ListenAndServe(ctx)
|
|
}()
|
|
|
|
// Wait for server to be ready.
|
|
var conn net.Conn
|
|
for i := range 50 {
|
|
conn, err = net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
|
if err == nil {
|
|
conn.Close()
|
|
break
|
|
}
|
|
if i == 49 {
|
|
t.Fatalf("server not ready after retries: %v", err)
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
}
|
|
|
|
// Test static credential login.
|
|
t.Run("static_cred", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.Password("toor")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("SSH dial: %v", err)
|
|
}
|
|
defer client.Close()
|
|
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
t.Fatalf("new session: %v", err)
|
|
}
|
|
defer session.Close()
|
|
})
|
|
|
|
// Test wrong password is rejected.
|
|
t.Run("wrong_password", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "root",
|
|
Auth: []ssh.AuthMethod{ssh.Password("wrong")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
_, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err == nil {
|
|
t.Fatal("expected error for wrong password")
|
|
}
|
|
})
|
|
|
|
// Test threshold acceptance: after enough failed dials, a subsequent
|
|
// dial with the same credentials should succeed via threshold or
|
|
// remembered credential.
|
|
t.Run("threshold", func(t *testing.T) {
|
|
clientCfg := &ssh.ClientConfig{
|
|
User: "threshuser",
|
|
Auth: []ssh.AuthMethod{ssh.Password("threshpass")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
// Make several dials to accumulate failures past the threshold.
|
|
for range 5 {
|
|
c, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err == nil {
|
|
// Threshold reached, success!
|
|
c.Close()
|
|
return
|
|
}
|
|
}
|
|
|
|
// After enough failures the credential should be remembered.
|
|
client, err := ssh.Dial("tcp", addr, clientCfg)
|
|
if err != nil {
|
|
t.Fatalf("expected threshold/remembered acceptance after many attempts: %v", err)
|
|
}
|
|
client.Close()
|
|
})
|
|
|
|
cancel()
|
|
}
|