package auth import ( "fmt" "sync" "testing" "time" "git.t-juice.club/torjus/oubliette/internal/config" ) func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator { a := NewAuthenticator(config.AuthConfig{ AcceptAfter: acceptAfter, CredentialTTLDuration: ttl, StaticCredentials: statics, }) return a } func TestStaticCredentialsAccepted(t *testing.T) { a := newTestAuth(10, time.Hour, config.Credential{Username: "root", Password: "toor"}) d := a.Authenticate("1.2.3.4", "root", "toor") if !d.Accepted || d.Reason != "static_credential" { t.Errorf("got %+v, want accepted with static_credential", d) } } func TestStaticCredentialsWrongPassword(t *testing.T) { a := newTestAuth(10, time.Hour, config.Credential{Username: "root", Password: "toor"}) d := a.Authenticate("1.2.3.4", "root", "wrong") if d.Accepted { t.Errorf("should not accept wrong password for static credential") } } func TestRejectionBeforeThreshold(t *testing.T) { a := newTestAuth(3, time.Hour) for i := range 2 { d := a.Authenticate("1.2.3.4", "user", "pass") if d.Accepted { t.Fatalf("attempt %d should be rejected", i+1) } if d.Reason != "rejected" { t.Errorf("attempt %d reason = %q, want %q", i+1, d.Reason, "rejected") } } } func TestThresholdAcceptance(t *testing.T) { a := newTestAuth(3, time.Hour) for i := range 2 { d := a.Authenticate("1.2.3.4", "user", "pass") if d.Accepted { t.Fatalf("attempt %d should be rejected", i+1) } } d := a.Authenticate("1.2.3.4", "user", "pass") if !d.Accepted || d.Reason != "threshold_reached" { t.Errorf("attempt 3 got %+v, want accepted with threshold_reached", d) } } func TestPerIPIsolation(t *testing.T) { a := newTestAuth(3, time.Hour) // IP1 gets 2 failures. for range 2 { a.Authenticate("1.1.1.1", "user", "pass") } // IP2 should start at 0, not inherit IP1's count. d := a.Authenticate("2.2.2.2", "user", "pass") if d.Accepted { t.Error("IP2's first attempt should be rejected") } } func TestCredentialMemoryAcrossIPs(t *testing.T) { a := newTestAuth(2, time.Hour) // IP1 reaches threshold, credential is remembered. a.Authenticate("1.1.1.1", "user", "pass") d := a.Authenticate("1.1.1.1", "user", "pass") if !d.Accepted || d.Reason != "threshold_reached" { t.Fatalf("threshold not reached: %+v", d) } // IP2 should get in with the remembered credential. d = a.Authenticate("2.2.2.2", "user", "pass") if !d.Accepted || d.Reason != "remembered_credential" { t.Errorf("IP2 got %+v, want accepted with remembered_credential", d) } } func TestCredentialMemoryExpires(t *testing.T) { a := newTestAuth(2, time.Hour) now := time.Now() a.now = func() time.Time { return now } // Reach threshold to remember credential. a.Authenticate("1.1.1.1", "user", "pass") a.Authenticate("1.1.1.1", "user", "pass") // Advance past TTL. a.now = func() time.Time { return now.Add(2 * time.Hour) } d := a.Authenticate("2.2.2.2", "user", "pass") if d.Accepted { t.Errorf("expired credential should not be accepted: %+v", d) } } func TestCounterResetsAfterAcceptance(t *testing.T) { a := newTestAuth(2, time.Hour) // Reach threshold. a.Authenticate("1.1.1.1", "user", "pass") d := a.Authenticate("1.1.1.1", "user", "pass") if !d.Accepted { t.Fatal("threshold not reached") } // With a different credential, counter should be reset. d = a.Authenticate("1.1.1.1", "other", "cred") if d.Accepted { t.Error("first attempt after reset should be rejected") } } func TestExpiredCredentialsSweep(t *testing.T) { a := newTestAuth(2, time.Hour) now := time.Now() a.now = func() time.Time { return now } // Create several remembered credentials by reaching the threshold. for i := range 5 { ip := fmt.Sprintf("10.0.0.%d", i) a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass") a.Authenticate(ip, fmt.Sprintf("user%d", i), "pass") } if len(a.rememberedCreds) != 5 { t.Fatalf("expected 5 remembered creds, got %d", len(a.rememberedCreds)) } // Advance past TTL so all are expired, then trigger sweep. a.now = func() time.Time { return now.Add(2 * time.Hour) } a.Authenticate("99.99.99.99", "trigger", "sweep") if len(a.rememberedCreds) != 0 { t.Errorf("expected 0 remembered creds after sweep, got %d", len(a.rememberedCreds)) } } func TestConcurrentAccess(t *testing.T) { a := newTestAuth(5, time.Hour) var wg sync.WaitGroup for range 100 { wg.Go(func() { a.Authenticate("1.2.3.4", "user", "pass") }) } wg.Wait() }