From 664e79fce6322e651e8a17dfcdb36e76064045ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torjus=20H=C3=A5kestad?= Date: Sat, 7 Mar 2026 22:29:51 +0100 Subject: [PATCH] feat: add Prometheus metrics for Store queries Add InstrumentedStore decorator that wraps any Store and records per-method query duration histograms and error counters. Wired into main.go so all storage consumers get automatic observability. Bump version to 0.18.0. Co-Authored-By: Claude Opus 4.6 --- cmd/oubliette/main.go | 13 +- go.mod | 2 +- internal/metrics/metrics.go | 13 ++ internal/storage/instrumented.go | 217 ++++++++++++++++++++++++++ internal/storage/instrumented_test.go | 163 +++++++++++++++++++ 5 files changed, 401 insertions(+), 7 deletions(-) create mode 100644 internal/storage/instrumented.go create mode 100644 internal/storage/instrumented_test.go diff --git a/cmd/oubliette/main.go b/cmd/oubliette/main.go index bf191fc..0cb87d3 100644 --- a/cmd/oubliette/main.go +++ b/cmd/oubliette/main.go @@ -20,7 +20,7 @@ import ( "git.t-juice.club/torjus/oubliette/internal/web" ) -const Version = "0.17.1" +const Version = "0.18.0" func main() { if err := run(); err != nil { @@ -76,12 +76,13 @@ func run() error { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() - go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger) - m := metrics.New(Version) - m.RegisterStoreCollector(store) + instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors) + m.RegisterStoreCollector(instrumentedStore) - srv, err := server.New(*cfg, store, logger, m) + go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger) + + srv, err := server.New(*cfg, instrumentedStore, logger, m) if err != nil { return fmt.Errorf("create server: %w", err) } @@ -95,7 +96,7 @@ func run() error { metricsHandler = m.Handler() } - webHandler, err := web.NewServer(store, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken) + webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken) if err != nil { return fmt.Errorf("create web server: %w", err) } diff --git a/go.mod b/go.mod index fd8d90b..fcdf608 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/google/uuid v1.6.0 github.com/oschwald/maxminddb-golang v1.13.1 github.com/prometheus/client_golang v1.23.2 + github.com/prometheus/client_model v0.6.2 golang.org/x/crypto v0.48.0 modernc.org/sqlite v1.45.0 ) @@ -33,7 +34,6 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect - github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index e5dd0b5..16cda99 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -25,6 +25,8 @@ type Metrics struct { SessionDuration prometheus.Histogram ExecCommandsTotal prometheus.Counter BuildInfo *prometheus.GaugeVec + StorageQueryDuration *prometheus.HistogramVec + StorageQueryErrors *prometheus.CounterVec } // New creates a new Metrics instance with all collectors registered. @@ -79,6 +81,15 @@ func New(version string) *Metrics { Name: "oubliette_build_info", Help: "Build information. Always 1.", }, []string{"version"}), + StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "oubliette_storage_query_duration_seconds", + Help: "Duration of storage query calls in seconds.", + Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}, + }, []string{"method"}), + StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "oubliette_storage_query_errors_total", + Help: "Total storage query errors.", + }, []string{"method"}), } reg.MustRegister( @@ -95,6 +106,8 @@ func New(version string) *Metrics { m.SessionDuration, m.ExecCommandsTotal, m.BuildInfo, + m.StorageQueryDuration, + m.StorageQueryErrors, ) m.BuildInfo.WithLabelValues(version).Set(1) diff --git a/internal/storage/instrumented.go b/internal/storage/instrumented.go new file mode 100644 index 0000000..a13cbac --- /dev/null +++ b/internal/storage/instrumented.go @@ -0,0 +1,217 @@ +package storage + +import ( + "context" + "time" + + "github.com/prometheus/client_golang/prometheus" +) + +// InstrumentedStore wraps a Store and records query duration and errors +// as Prometheus metrics for each method call. +type InstrumentedStore struct { + store Store + queryDuration *prometheus.HistogramVec + queryErrors *prometheus.CounterVec +} + +// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store. +func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore { + return &InstrumentedStore{ + store: store, + queryDuration: queryDuration, + queryErrors: queryErrors, + } +} + +func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) { + timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method)) + v, err := fn() + timer.ObserveDuration() + if err != nil { + s.queryErrors.WithLabelValues(method).Inc() + } + return v, err +} + +func observeErr(s *InstrumentedStore, method string, fn func() error) error { + timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method)) + err := fn() + timer.ObserveDuration() + if err != nil { + s.queryErrors.WithLabelValues(method).Inc() + } + return err +} + +func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error { + return observeErr(s, "RecordLoginAttempt", func() error { + return s.store.RecordLoginAttempt(ctx, username, password, ip, country) + }) +} + +func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) { + return observe(s, "CreateSession", func() (string, error) { + return s.store.CreateSession(ctx, ip, username, shellName, country) + }) +} + +func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error { + return observeErr(s, "EndSession", func() error { + return s.store.EndSession(ctx, sessionID, disconnectedAt) + }) +} + +func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error { + return observeErr(s, "UpdateHumanScore", func() error { + return s.store.UpdateHumanScore(ctx, sessionID, score) + }) +} + +func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error { + return observeErr(s, "SetExecCommand", func() error { + return s.store.SetExecCommand(ctx, sessionID, command) + }) +} + +func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error { + return observeErr(s, "AppendSessionLog", func() error { + return s.store.AppendSessionLog(ctx, sessionID, input, output) + }) +} + +func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) { + return observe(s, "DeleteRecordsBefore", func() (int64, error) { + return s.store.DeleteRecordsBefore(ctx, cutoff) + }) +} + +func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { + return observe(s, "GetDashboardStats", func() (*DashboardStats, error) { + return s.store.GetDashboardStats(ctx) + }) +} + +func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) { + return observe(s, "GetTopUsernames", func() ([]TopEntry, error) { + return s.store.GetTopUsernames(ctx, limit) + }) +} + +func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) { + return observe(s, "GetTopPasswords", func() ([]TopEntry, error) { + return s.store.GetTopPasswords(ctx, limit) + }) +} + +func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) { + return observe(s, "GetTopIPs", func() ([]TopEntry, error) { + return s.store.GetTopIPs(ctx, limit) + }) +} + +func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) { + return observe(s, "GetTopCountries", func() ([]TopEntry, error) { + return s.store.GetTopCountries(ctx, limit) + }) +} + +func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) { + return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) { + return s.store.GetTopExecCommands(ctx, limit) + }) +} + +func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) { + return observe(s, "GetRecentSessions", func() ([]Session, error) { + return s.store.GetRecentSessions(ctx, limit, activeOnly) + }) +} + +func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) { + return observe(s, "GetFilteredSessions", func() ([]Session, error) { + return s.store.GetFilteredSessions(ctx, limit, activeOnly, f) + }) +} + +func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) { + return observe(s, "GetSession", func() (*Session, error) { + return s.store.GetSession(ctx, sessionID) + }) +} + +func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) { + return observe(s, "GetSessionLogs", func() ([]SessionLog, error) { + return s.store.GetSessionLogs(ctx, sessionID) + }) +} + +func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error { + return observeErr(s, "AppendSessionEvents", func() error { + return s.store.AppendSessionEvents(ctx, events) + }) +} + +func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) { + return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) { + return s.store.GetSessionEvents(ctx, sessionID) + }) +} + +func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) { + return observe(s, "CloseActiveSessions", func() (int64, error) { + return s.store.CloseActiveSessions(ctx, disconnectedAt) + }) +} + +func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) { + return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) { + return s.store.GetAttemptsOverTime(ctx, days, since, until) + }) +} + +func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) { + return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) { + return s.store.GetHourlyPattern(ctx, since, until) + }) +} + +func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) { + return observe(s, "GetCountryStats", func() ([]CountryCount, error) { + return s.store.GetCountryStats(ctx) + }) +} + +func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) { + return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) { + return s.store.GetFilteredDashboardStats(ctx, f) + }) +} + +func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) { + return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) { + return s.store.GetFilteredTopUsernames(ctx, limit, f) + }) +} + +func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) { + return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) { + return s.store.GetFilteredTopPasswords(ctx, limit, f) + }) +} + +func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) { + return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) { + return s.store.GetFilteredTopIPs(ctx, limit, f) + }) +} + +func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) { + return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) { + return s.store.GetFilteredTopCountries(ctx, limit, f) + }) +} + +func (s *InstrumentedStore) Close() error { + return s.store.Close() +} diff --git a/internal/storage/instrumented_test.go b/internal/storage/instrumented_test.go new file mode 100644 index 0000000..f883654 --- /dev/null +++ b/internal/storage/instrumented_test.go @@ -0,0 +1,163 @@ +package storage + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" +) + +func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) { + dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "test_query_duration_seconds", + Help: "test", + Buckets: []float64{0.001, 0.01, 0.1, 1}, + }, []string{"method"}) + errs := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "test_query_errors_total", + Help: "test", + }, []string{"method"}) + + store := NewMemoryStore() + return NewInstrumentedStore(store, dur, errs), dur, errs +} + +func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 { + m := &dto.Metric{} + h.WithLabelValues(method).(prometheus.Histogram).Write(m) + return m.GetHistogram().GetSampleCount() +} + +func getCounterValue(c *prometheus.CounterVec, method string) float64 { + m := &dto.Metric{} + c.WithLabelValues(method).Write(m) + return m.GetCounter().GetValue() +} + +func TestInstrumentedStoreDelegation(t *testing.T) { + s, dur, _ := newTestInstrumented() + ctx := context.Background() + + // RecordLoginAttempt should delegate and record duration. + err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US") + if err != nil { + t.Fatalf("RecordLoginAttempt: %v", err) + } + if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 { + t.Fatalf("expected 1 observation, got %d", c) + } + + // CreateSession should delegate and return a valid ID. + id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if id == "" { + t.Fatal("CreateSession returned empty ID") + } + if c := getHistogramCount(dur, "CreateSession"); c != 1 { + t.Fatalf("expected 1 observation, got %d", c) + } + + // GetDashboardStats should delegate. + stats, err := s.GetDashboardStats(ctx) + if err != nil { + t.Fatalf("GetDashboardStats: %v", err) + } + if stats == nil { + t.Fatal("GetDashboardStats returned nil") + } + if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 { + t.Fatalf("expected 1 observation, got %d", c) + } +} + +func TestInstrumentedStoreErrorCounting(t *testing.T) { + dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "test_ec_query_duration_seconds", + Help: "test", + Buckets: []float64{0.001, 0.01, 0.1, 1}, + }, []string{"method"}) + errs := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "test_ec_query_errors_total", + Help: "test", + }, []string{"method"}) + + es := &errorStore{} + s := NewInstrumentedStore(es, dur, errs) + ctx := context.Background() + + // Error should be counted. + err := s.EndSession(ctx, "nonexistent", time.Now()) + if !errors.Is(err, errFake) { + t.Fatalf("expected errFake, got %v", err) + } + if c := getHistogramCount(dur, "EndSession"); c != 1 { + t.Fatalf("expected 1 observation, got %d", c) + } + if c := getCounterValue(errs, "EndSession"); c != 1 { + t.Fatalf("expected error count 1, got %f", c) + } + + // Successful call should not increment error counter. + s2, _, errs2 := newTestInstrumented() + err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US") + if err != nil { + t.Fatalf("RecordLoginAttempt: %v", err) + } + if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 { + t.Fatalf("expected error count 0, got %f", c) + } +} + +// errorStore is a Store that returns errors for all methods. +type errorStore struct { + MemoryStore +} + +var errFake = errors.New("fake error") + +func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error { + return errFake +} + +func (s *errorStore) EndSession(context.Context, string, time.Time) error { + return errFake +} + +func TestInstrumentedStoreObserveErr(t *testing.T) { + dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "test2_query_duration_seconds", + Help: "test", + Buckets: []float64{0.001, 0.01, 0.1, 1}, + }, []string{"method"}) + errs := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "test2_query_errors_total", + Help: "test", + }, []string{"method"}) + + es := &errorStore{} + s := NewInstrumentedStore(es, dur, errs) + ctx := context.Background() + + err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US") + if !errors.Is(err, errFake) { + t.Fatalf("expected errFake, got %v", err) + } + if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 { + t.Fatalf("expected error count 1, got %f", c) + } + if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 { + t.Fatalf("expected 1 observation, got %d", c) + } +} + +func TestInstrumentedStoreClose(t *testing.T) { + s, _, _ := newTestInstrumented() + if err := s.Close(); err != nil { + t.Fatalf("Close: %v", err) + } +}