package storage import ( "context" "time" "github.com/prometheus/client_golang/prometheus" ) // InstrumentedStore wraps a Store and records query duration and errors // as Prometheus metrics for each method call. type InstrumentedStore struct { store Store queryDuration *prometheus.HistogramVec queryErrors *prometheus.CounterVec } // NewInstrumentedStore returns a new InstrumentedStore wrapping the given store. func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore { return &InstrumentedStore{ store: store, queryDuration: queryDuration, queryErrors: queryErrors, } } func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) { timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method)) v, err := fn() timer.ObserveDuration() if err != nil { s.queryErrors.WithLabelValues(method).Inc() } return v, err } func observeErr(s *InstrumentedStore, method string, fn func() error) error { timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method)) err := fn() timer.ObserveDuration() if err != nil { s.queryErrors.WithLabelValues(method).Inc() } return err } func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error { return observeErr(s, "RecordLoginAttempt", func() error { return s.store.RecordLoginAttempt(ctx, username, password, ip, country) }) } func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) { return observe(s, "CreateSession", func() (string, error) { return s.store.CreateSession(ctx, ip, username, shellName, country) }) } func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error { return observeErr(s, "EndSession", func() error { return s.store.EndSession(ctx, sessionID, disconnectedAt) }) } func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error { return observeErr(s, "UpdateHumanScore", func() error { return s.store.UpdateHumanScore(ctx, sessionID, score) }) } func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error { return observeErr(s, "SetExecCommand", func() error { return s.store.SetExecCommand(ctx, sessionID, command) }) } func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error { return observeErr(s, "AppendSessionLog", func() error { return s.store.AppendSessionLog(ctx, sessionID, input, output) }) } func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) { return observe(s, "DeleteRecordsBefore", func() (int64, error) { return s.store.DeleteRecordsBefore(ctx, cutoff) }) } func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { return observe(s, "GetDashboardStats", func() (*DashboardStats, error) { return s.store.GetDashboardStats(ctx) }) } func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) { return observe(s, "GetTopUsernames", func() ([]TopEntry, error) { return s.store.GetTopUsernames(ctx, limit) }) } func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) { return observe(s, "GetTopPasswords", func() ([]TopEntry, error) { return s.store.GetTopPasswords(ctx, limit) }) } func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) { return observe(s, "GetTopIPs", func() ([]TopEntry, error) { return s.store.GetTopIPs(ctx, limit) }) } func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) { return observe(s, "GetTopCountries", func() ([]TopEntry, error) { return s.store.GetTopCountries(ctx, limit) }) } func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) { return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) { return s.store.GetTopExecCommands(ctx, limit) }) } func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) { return observe(s, "GetRecentSessions", func() ([]Session, error) { return s.store.GetRecentSessions(ctx, limit, activeOnly) }) } func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) { return observe(s, "GetFilteredSessions", func() ([]Session, error) { return s.store.GetFilteredSessions(ctx, limit, activeOnly, f) }) } func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) { return observe(s, "GetSession", func() (*Session, error) { return s.store.GetSession(ctx, sessionID) }) } func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) { return observe(s, "GetSessionLogs", func() ([]SessionLog, error) { return s.store.GetSessionLogs(ctx, sessionID) }) } func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error { return observeErr(s, "AppendSessionEvents", func() error { return s.store.AppendSessionEvents(ctx, events) }) } func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) { return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) { return s.store.GetSessionEvents(ctx, sessionID) }) } func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) { return observe(s, "CloseActiveSessions", func() (int64, error) { return s.store.CloseActiveSessions(ctx, disconnectedAt) }) } func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) { return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) { return s.store.GetAttemptsOverTime(ctx, days, since, until) }) } func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) { return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) { return s.store.GetHourlyPattern(ctx, since, until) }) } func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) { return observe(s, "GetCountryStats", func() ([]CountryCount, error) { return s.store.GetCountryStats(ctx) }) } func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) { return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) { return s.store.GetFilteredDashboardStats(ctx, f) }) } func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) { return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) { return s.store.GetFilteredTopUsernames(ctx, limit, f) }) } func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) { return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) { return s.store.GetFilteredTopPasswords(ctx, limit, f) }) } func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) { return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) { return s.store.GetFilteredTopIPs(ctx, limit, f) }) } func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) { return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) { return s.store.GetFilteredTopCountries(ctx, limit, f) }) } func (s *InstrumentedStore) Close() error { return s.store.Close() }