Compare commits
3 Commits
62de222488
...
664e79fce6
| Author | SHA1 | Date | |
|---|---|---|---|
|
664e79fce6
|
|||
|
c74313c195
|
|||
|
9783ae5865
|
@@ -20,7 +20,7 @@ import (
|
||||
"git.t-juice.club/torjus/oubliette/internal/web"
|
||||
)
|
||||
|
||||
const Version = "0.17.0"
|
||||
const Version = "0.18.0"
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
@@ -76,12 +76,13 @@ func run() error {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
|
||||
m := metrics.New(Version)
|
||||
m.RegisterStoreCollector(store)
|
||||
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
|
||||
m.RegisterStoreCollector(instrumentedStore)
|
||||
|
||||
srv, err := server.New(*cfg, store, logger, m)
|
||||
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
|
||||
srv, err := server.New(*cfg, instrumentedStore, logger, m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create server: %w", err)
|
||||
}
|
||||
@@ -95,7 +96,7 @@ func run() error {
|
||||
metricsHandler = m.Handler()
|
||||
}
|
||||
|
||||
webHandler, err := web.NewServer(store, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create web server: %w", err)
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -9,6 +9,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/oschwald/maxminddb-golang v1.13.1
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.45.0
|
||||
)
|
||||
@@ -33,7 +34,6 @@ require (
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
|
||||
@@ -25,6 +25,8 @@ type Metrics struct {
|
||||
SessionDuration prometheus.Histogram
|
||||
ExecCommandsTotal prometheus.Counter
|
||||
BuildInfo *prometheus.GaugeVec
|
||||
StorageQueryDuration *prometheus.HistogramVec
|
||||
StorageQueryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// New creates a new Metrics instance with all collectors registered.
|
||||
@@ -79,6 +81,15 @@ func New(version string) *Metrics {
|
||||
Name: "oubliette_build_info",
|
||||
Help: "Build information. Always 1.",
|
||||
}, []string{"version"}),
|
||||
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "oubliette_storage_query_duration_seconds",
|
||||
Help: "Duration of storage query calls in seconds.",
|
||||
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
}, []string{"method"}),
|
||||
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_storage_query_errors_total",
|
||||
Help: "Total storage query errors.",
|
||||
}, []string{"method"}),
|
||||
}
|
||||
|
||||
reg.MustRegister(
|
||||
@@ -95,6 +106,8 @@ func New(version string) *Metrics {
|
||||
m.SessionDuration,
|
||||
m.ExecCommandsTotal,
|
||||
m.BuildInfo,
|
||||
m.StorageQueryDuration,
|
||||
m.StorageQueryErrors,
|
||||
)
|
||||
|
||||
m.BuildInfo.WithLabelValues(version).Set(1)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -216,22 +217,22 @@ func (s *roombaState) cmdHelp() commandResult {
|
||||
func (s *roombaState) cmdStatus() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== RoombaOS System Status ===\n")
|
||||
b.WriteString(fmt.Sprintf("Model: iRobot Roomba j7+\n"))
|
||||
b.WriteString("Model: iRobot Roomba j7+\n")
|
||||
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
|
||||
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
|
||||
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
|
||||
b.WriteString(fmt.Sprintf("Side brush: OK (142 hrs)\n"))
|
||||
b.WriteString(fmt.Sprintf("Main brush: OK (98 hrs)\n"))
|
||||
b.WriteString("Side brush: OK (142 hrs)\n")
|
||||
b.WriteString("Main brush: OK (98 hrs)\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString(fmt.Sprintf("WiFi: Connected (SmartHome-5G)\n"))
|
||||
b.WriteString(fmt.Sprintf("Signal: -38 dBm\n"))
|
||||
b.WriteString(fmt.Sprintf("Alexa: Linked\n"))
|
||||
b.WriteString(fmt.Sprintf("Google Home: Linked\n"))
|
||||
b.WriteString(fmt.Sprintf("iRobot Home App: Connected\n"))
|
||||
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
|
||||
b.WriteString("Signal: -38 dBm\n")
|
||||
b.WriteString("Alexa: Linked\n")
|
||||
b.WriteString("Google Home: Linked\n")
|
||||
b.WriteString("iRobot Home App: Connected\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString(fmt.Sprintf("Firmware: v4.3.7-stable\n"))
|
||||
b.WriteString(fmt.Sprintf("LIDAR: Operational\n"))
|
||||
b.WriteString(fmt.Sprintf("Clean Area Total: 12,847 sq ft (lifetime)"))
|
||||
b.WriteString("Firmware: v4.3.7-stable\n")
|
||||
b.WriteString("LIDAR: Operational\n")
|
||||
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
@@ -349,14 +350,7 @@ func (s *roombaState) scheduleList() commandResult {
|
||||
func (s *roombaState) scheduleAdd(day, t string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
|
||||
valid := false
|
||||
for _, d := range validDays {
|
||||
if d == day {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
if !slices.Contains(validDays, day) {
|
||||
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
|
||||
}
|
||||
|
||||
@@ -459,7 +453,7 @@ func formatDuration(d time.Duration) string {
|
||||
minutes := int(d.Minutes()) % 60
|
||||
if hours >= 24 {
|
||||
days := hours / 24
|
||||
hours = hours % 24
|
||||
hours %= 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
if hours > 0 {
|
||||
|
||||
217
internal/storage/instrumented.go
Normal file
217
internal/storage/instrumented.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// InstrumentedStore wraps a Store and records query duration and errors
|
||||
// as Prometheus metrics for each method call.
|
||||
type InstrumentedStore struct {
|
||||
store Store
|
||||
queryDuration *prometheus.HistogramVec
|
||||
queryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
|
||||
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
|
||||
return &InstrumentedStore{
|
||||
store: store,
|
||||
queryDuration: queryDuration,
|
||||
queryErrors: queryErrors,
|
||||
}
|
||||
}
|
||||
|
||||
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
v, err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||
return observeErr(s, "RecordLoginAttempt", func() error {
|
||||
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||
return observe(s, "CreateSession", func() (string, error) {
|
||||
return s.store.CreateSession(ctx, ip, username, shellName, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
|
||||
return observeErr(s, "EndSession", func() error {
|
||||
return s.store.EndSession(ctx, sessionID, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
|
||||
return observeErr(s, "UpdateHumanScore", func() error {
|
||||
return s.store.UpdateHumanScore(ctx, sessionID, score)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||
return observeErr(s, "SetExecCommand", func() error {
|
||||
return s.store.SetExecCommand(ctx, sessionID, command)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||
return observeErr(s, "AppendSessionLog", func() error {
|
||||
return s.store.AppendSessionLog(ctx, sessionID, input, output)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
|
||||
return s.store.DeleteRecordsBefore(ctx, cutoff)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetDashboardStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopUsernames(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopPasswords(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopIPs(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopCountries(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopExecCommands(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
return observe(s, "GetRecentSessions", func() ([]Session, error) {
|
||||
return s.store.GetRecentSessions(ctx, limit, activeOnly)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
|
||||
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||
return observe(s, "GetSession", func() (*Session, error) {
|
||||
return s.store.GetSession(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
||||
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
|
||||
return s.store.GetSessionLogs(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
||||
return observeErr(s, "AppendSessionEvents", func() error {
|
||||
return s.store.AppendSessionEvents(ctx, events)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
||||
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
|
||||
return s.store.GetSessionEvents(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
return observe(s, "CloseActiveSessions", func() (int64, error) {
|
||||
return s.store.CloseActiveSessions(ctx, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
|
||||
return s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
|
||||
return s.store.GetHourlyPattern(ctx, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
|
||||
return s.store.GetCountryStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetFilteredDashboardStats(ctx, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopUsernames(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopPasswords(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopIPs(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopCountries(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) Close() error {
|
||||
return s.store.Close()
|
||||
}
|
||||
163
internal/storage/instrumented_test.go
Normal file
163
internal/storage/instrumented_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
)
|
||||
|
||||
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
store := NewMemoryStore()
|
||||
return NewInstrumentedStore(store, dur, errs), dur, errs
|
||||
}
|
||||
|
||||
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
|
||||
m := &dto.Metric{}
|
||||
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
|
||||
return m.GetHistogram().GetSampleCount()
|
||||
}
|
||||
|
||||
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
|
||||
m := &dto.Metric{}
|
||||
c.WithLabelValues(method).Write(m)
|
||||
return m.GetCounter().GetValue()
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreDelegation(t *testing.T) {
|
||||
s, dur, _ := newTestInstrumented()
|
||||
ctx := context.Background()
|
||||
|
||||
// RecordLoginAttempt should delegate and record duration.
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// CreateSession should delegate and return a valid ID.
|
||||
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Fatal("CreateSession returned empty ID")
|
||||
}
|
||||
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// GetDashboardStats should delegate.
|
||||
stats, err := s.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetDashboardStats: %v", err)
|
||||
}
|
||||
if stats == nil {
|
||||
t.Fatal("GetDashboardStats returned nil")
|
||||
}
|
||||
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreErrorCounting(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_ec_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_ec_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
// Error should be counted.
|
||||
err := s.EndSession(ctx, "nonexistent", time.Now())
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
if c := getCounterValue(errs, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
|
||||
// Successful call should not increment error counter.
|
||||
s2, _, errs2 := newTestInstrumented()
|
||||
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
|
||||
t.Fatalf("expected error count 0, got %f", c)
|
||||
}
|
||||
}
|
||||
|
||||
// errorStore is a Store that returns errors for all methods.
|
||||
type errorStore struct {
|
||||
MemoryStore
|
||||
}
|
||||
|
||||
var errFake = errors.New("fake error")
|
||||
|
||||
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreObserveErr(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test2_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test2_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreClose(t *testing.T) {
|
||||
s, _, _ := newTestInstrumented()
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
|
||||
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
|
||||
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);
|
||||
@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 4 {
|
||||
t.Errorf("version = %d, want 4", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d, want 5", version)
|
||||
}
|
||||
|
||||
// Verify tables exist by inserting into them.
|
||||
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 4 {
|
||||
t.Errorf("version = %d after double migrate, want 4", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d after double migrate, want 5", version)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
@@ -10,6 +11,13 @@ import (
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// dbContext returns a context detached from the HTTP request lifecycle with a
|
||||
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
|
||||
// queries when the browser aborts the previous XHR.
|
||||
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
|
||||
}
|
||||
|
||||
type dashboardData struct {
|
||||
Stats *storage.DashboardStats
|
||||
TopUsernames []storage.TopEntry
|
||||
@@ -22,7 +30,8 @@ type dashboardData struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
@@ -98,7 +107,10 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := s.store.GetDashboardStats(r.Context())
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get dashboard stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -112,7 +124,10 @@ func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
|
||||
sessions, err := s.store.GetRecentSessions(r.Context(), 50, true)
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get active sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -126,8 +141,11 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
f := parseDashboardFilter(r)
|
||||
sessions, err := s.store.GetFilteredSessions(r.Context(), 50, false, f)
|
||||
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -147,7 +165,8 @@ type sessionDetailData struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
session, err := s.store.GetSession(ctx, sessionID)
|
||||
@@ -246,7 +265,10 @@ func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Reques
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
points, err := s.store.GetAttemptsOverTime(r.Context(), days, since, until)
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get attempts over time", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -277,10 +299,13 @@ type apiHourlyPatternResponse struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
counts, err := s.store.GetHourlyPattern(r.Context(), since, until)
|
||||
counts, err := s.store.GetHourlyPattern(ctx, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get hourly pattern", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -308,7 +333,10 @@ type apiCountryStatsResponse struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
|
||||
counts, err := s.store.GetCountryStats(r.Context())
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
counts, err := s.store.GetCountryStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get country stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -327,7 +355,8 @@ func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
f := parseDashboardFilter(r)
|
||||
|
||||
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
|
||||
@@ -380,7 +409,8 @@ func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.R
|
||||
}
|
||||
|
||||
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
events, err := s.store.GetSessionEvents(ctx, sessionID)
|
||||
|
||||
@@ -54,6 +54,30 @@ func newSeededTestServer(t *testing.T) *Server {
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestDbContextNotCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
dbCtx, dbCancel := dbContext(req)
|
||||
defer dbCancel()
|
||||
|
||||
// Cancel the original request context.
|
||||
cancel()
|
||||
|
||||
// The DB context should still be usable.
|
||||
select {
|
||||
case <-dbCtx.Done():
|
||||
t.Fatal("dbContext should not be canceled when request context is canceled")
|
||||
default:
|
||||
}
|
||||
|
||||
// Verify the DB context has a deadline (from the timeout).
|
||||
if _, ok := dbCtx.Deadline(); !ok {
|
||||
t.Error("dbContext should have a deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardHandler(t *testing.T) {
|
||||
t.Run("empty store", func(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
|
||||
Reference in New Issue
Block a user