package storage import ( "context" "errors" "testing" "time" "github.com/prometheus/client_golang/prometheus" dto "github.com/prometheus/client_model/go" ) func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) { dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{ Name: "test_query_duration_seconds", Help: "test", Buckets: []float64{0.001, 0.01, 0.1, 1}, }, []string{"method"}) errs := prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "test_query_errors_total", Help: "test", }, []string{"method"}) store := NewMemoryStore() return NewInstrumentedStore(store, dur, errs), dur, errs } func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 { m := &dto.Metric{} h.WithLabelValues(method).(prometheus.Histogram).Write(m) return m.GetHistogram().GetSampleCount() } func getCounterValue(c *prometheus.CounterVec, method string) float64 { m := &dto.Metric{} c.WithLabelValues(method).Write(m) return m.GetCounter().GetValue() } func TestInstrumentedStoreDelegation(t *testing.T) { s, dur, _ := newTestInstrumented() ctx := context.Background() // RecordLoginAttempt should delegate and record duration. err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US") if err != nil { t.Fatalf("RecordLoginAttempt: %v", err) } if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 { t.Fatalf("expected 1 observation, got %d", c) } // CreateSession should delegate and return a valid ID. id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US") if err != nil { t.Fatalf("CreateSession: %v", err) } if id == "" { t.Fatal("CreateSession returned empty ID") } if c := getHistogramCount(dur, "CreateSession"); c != 1 { t.Fatalf("expected 1 observation, got %d", c) } // GetDashboardStats should delegate. stats, err := s.GetDashboardStats(ctx) if err != nil { t.Fatalf("GetDashboardStats: %v", err) } if stats == nil { t.Fatal("GetDashboardStats returned nil") } if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 { t.Fatalf("expected 1 observation, got %d", c) } } func TestInstrumentedStoreErrorCounting(t *testing.T) { dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{ Name: "test_ec_query_duration_seconds", Help: "test", Buckets: []float64{0.001, 0.01, 0.1, 1}, }, []string{"method"}) errs := prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "test_ec_query_errors_total", Help: "test", }, []string{"method"}) es := &errorStore{} s := NewInstrumentedStore(es, dur, errs) ctx := context.Background() // Error should be counted. err := s.EndSession(ctx, "nonexistent", time.Now()) if !errors.Is(err, errFake) { t.Fatalf("expected errFake, got %v", err) } if c := getHistogramCount(dur, "EndSession"); c != 1 { t.Fatalf("expected 1 observation, got %d", c) } if c := getCounterValue(errs, "EndSession"); c != 1 { t.Fatalf("expected error count 1, got %f", c) } // Successful call should not increment error counter. s2, _, errs2 := newTestInstrumented() err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US") if err != nil { t.Fatalf("RecordLoginAttempt: %v", err) } if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 { t.Fatalf("expected error count 0, got %f", c) } } // errorStore is a Store that returns errors for all methods. type errorStore struct { MemoryStore } var errFake = errors.New("fake error") func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error { return errFake } func (s *errorStore) EndSession(context.Context, string, time.Time) error { return errFake } func TestInstrumentedStoreObserveErr(t *testing.T) { dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{ Name: "test2_query_duration_seconds", Help: "test", Buckets: []float64{0.001, 0.01, 0.1, 1}, }, []string{"method"}) errs := prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "test2_query_errors_total", Help: "test", }, []string{"method"}) es := &errorStore{} s := NewInstrumentedStore(es, dur, errs) ctx := context.Background() err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US") if !errors.Is(err, errFake) { t.Fatalf("expected errFake, got %v", err) } if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 { t.Fatalf("expected error count 1, got %f", c) } if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 { t.Fatalf("expected 1 observation, got %d", c) } } func TestInstrumentedStoreClose(t *testing.T) { s, _, _ := newTestInstrumented() if err := s.Close(); err != nil { t.Fatalf("Close: %v", err) } }