diff --git a/README.md b/README.md index d618a5e..ad42b89 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ Key settings: - `web.listen_addr` — web dashboard listen address (default `:8080`) - Session detail pages at `/sessions/{id}` include terminal replay via xterm.js - `web.metrics_enabled` — expose Prometheus metrics at `/metrics` (default `true`) +- `web.metrics_token` — bearer token to protect `/metrics`; empty means no auth (default empty) - `detection.enabled` — enable human detection scoring (default `false`) - `detection.threshold` — score threshold (0.0–1.0) for flagging sessions (default `0.6`) - `detection.update_interval` — how often to recompute scores (default `5s`) diff --git a/cmd/oubliette/main.go b/cmd/oubliette/main.go index 71b7fde..79d85ed 100644 --- a/cmd/oubliette/main.go +++ b/cmd/oubliette/main.go @@ -79,6 +79,7 @@ func run() error { go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger) m := metrics.New(Version) + m.RegisterStoreCollector(store) srv, err := server.New(*cfg, store, logger, m) if err != nil { @@ -94,7 +95,7 @@ func run() error { metricsHandler = m.Handler() } - webHandler, err := web.NewServer(store, logger.With("component", "web"), metricsHandler) + webHandler, err := web.NewServer(store, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken) if err != nil { return fmt.Errorf("create web server: %w", err) } diff --git a/internal/config/config.go b/internal/config/config.go index 97ed465..751e53f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,9 +21,10 @@ type Config struct { } type WebConfig struct { - Enabled bool `toml:"enabled"` + Enabled bool `toml:"enabled"` ListenAddr string `toml:"listen_addr"` - MetricsEnabled *bool `toml:"metrics_enabled"` + MetricsEnabled *bool `toml:"metrics_enabled"` + MetricsToken string `toml:"metrics_token"` } type ShellConfig struct { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 01a5813..4abeae2 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -282,6 +282,22 @@ password = "toor" } } +func TestLoadMetricsToken(t *testing.T) { + content := ` +[web] +enabled = true +metrics_token = "my-secret-token" +` + path := writeTemp(t, content) + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Web.MetricsToken != "my-secret-token" { + t.Errorf("metrics_token = %q, want %q", cfg.Web.MetricsToken, "my-secret-token") + } +} + func TestLoadMissingFile(t *testing.T) { _, err := Load("/nonexistent/path/config.toml") if err == nil { diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index ea60a9f..870a920 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -1,8 +1,10 @@ package metrics import ( + "context" "net/http" + "git.t-juice.club/torjus/oubliette/internal/storage" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -12,13 +14,16 @@ import ( type Metrics struct { registry *prometheus.Registry - SSHConnectionsTotal *prometheus.CounterVec - SSHConnectionsActive prometheus.Gauge - AuthAttemptsTotal *prometheus.CounterVec - SessionsTotal *prometheus.CounterVec - SessionsActive prometheus.Gauge - SessionDuration prometheus.Histogram - BuildInfo *prometheus.GaugeVec + SSHConnectionsTotal *prometheus.CounterVec + SSHConnectionsActive prometheus.Gauge + AuthAttemptsTotal *prometheus.CounterVec + AuthAttemptsByCountry *prometheus.CounterVec + CommandsExecuted *prometheus.CounterVec + HumanScore prometheus.Histogram + SessionsTotal *prometheus.CounterVec + SessionsActive prometheus.Gauge + SessionDuration prometheus.Histogram + BuildInfo *prometheus.GaugeVec } // New creates a new Metrics instance with all collectors registered. @@ -39,6 +44,19 @@ func New(version string) *Metrics { Name: "oubliette_auth_attempts_total", Help: "Total authentication attempts.", }, []string{"result", "reason"}), + AuthAttemptsByCountry: prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "oubliette_auth_attempts_by_country_total", + Help: "Total authentication attempts by country.", + }, []string{"country"}), + CommandsExecuted: prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "oubliette_commands_executed_total", + Help: "Total commands executed in shells.", + }, []string{"shell"}), + HumanScore: prometheus.NewHistogram(prometheus.HistogramOpts{ + Name: "oubliette_human_score", + Help: "Distribution of final human detection scores.", + Buckets: prometheus.LinearBuckets(0, 0.1, 11), // 0.0, 0.1, ..., 1.0 + }), SessionsTotal: prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "oubliette_sessions_total", Help: "Total sessions created.", @@ -64,6 +82,9 @@ func New(version string) *Metrics { m.SSHConnectionsTotal, m.SSHConnectionsActive, m.AuthAttemptsTotal, + m.AuthAttemptsByCountry, + m.CommandsExecuted, + m.HumanScore, m.SessionsTotal, m.SessionsActive, m.SessionDuration, @@ -80,14 +101,59 @@ func New(version string) *Metrics { m.AuthAttemptsTotal.WithLabelValues("accepted", reason) m.AuthAttemptsTotal.WithLabelValues("rejected", reason) } - for _, shell := range []string{"bash", "fridge", "banking", "adventure"} { - m.SessionsTotal.WithLabelValues(shell) + for _, sh := range []string{"bash", "fridge", "banking", "adventure", "cisco"} { + m.SessionsTotal.WithLabelValues(sh) + m.CommandsExecuted.WithLabelValues(sh) } return m } +// RegisterStoreCollector registers a collector that queries storage stats on each scrape. +func (m *Metrics) RegisterStoreCollector(store storage.Store) { + m.registry.MustRegister(&storeCollector{store: store}) +} + // Handler returns an http.Handler that serves Prometheus metrics. func (m *Metrics) Handler() http.Handler { return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{}) } + +// storeCollector implements prometheus.Collector, querying storage on each scrape. +type storeCollector struct { + store storage.Store +} + +var ( + storageLoginAttemptsDesc = prometheus.NewDesc( + "oubliette_storage_login_attempts_total", + "Total login attempts in storage.", + nil, nil, + ) + storageUniqueIPsDesc = prometheus.NewDesc( + "oubliette_storage_unique_ips", + "Unique IPs in storage.", + nil, nil, + ) + storageSessionsDesc = prometheus.NewDesc( + "oubliette_storage_sessions_total", + "Total sessions in storage.", + nil, nil, + ) +) + +func (c *storeCollector) Describe(ch chan<- *prometheus.Desc) { + ch <- storageLoginAttemptsDesc + ch <- storageUniqueIPsDesc + ch <- storageSessionsDesc +} + +func (c *storeCollector) Collect(ch chan<- prometheus.Metric) { + stats, err := c.store.GetDashboardStats(context.Background()) + if err != nil { + return + } + ch <- prometheus.MustNewConstMetric(storageLoginAttemptsDesc, prometheus.GaugeValue, float64(stats.TotalAttempts)) + ch <- prometheus.MustNewConstMetric(storageUniqueIPsDesc, prometheus.GaugeValue, float64(stats.UniqueIPs)) + ch <- prometheus.MustNewConstMetric(storageSessionsDesc, prometheus.GaugeValue, float64(stats.TotalSessions)) +} diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 7b0c66a..15cb6ea 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -1,11 +1,14 @@ package metrics import ( + "context" "io" "net/http" "net/http/httptest" "strings" "testing" + + "git.t-juice.club/torjus/oubliette/internal/storage" ) func TestNew(t *testing.T) { @@ -21,10 +24,12 @@ func TestNew(t *testing.T) { "oubliette_ssh_connections_total": false, "oubliette_ssh_connections_active": false, "oubliette_auth_attempts_total": false, - "oubliette_sessions_total": false, - "oubliette_sessions_active": false, + "oubliette_commands_executed_total": false, + "oubliette_human_score": false, + "oubliette_sessions_total": false, + "oubliette_sessions_active": false, "oubliette_session_duration_seconds": false, - "oubliette_build_info": false, + "oubliette_build_info": false, } for _, f := range families { @@ -40,6 +45,31 @@ func TestNew(t *testing.T) { } } +func TestAuthAttemptsByCountry(t *testing.T) { + m := New("1.0.0") + m.AuthAttemptsByCountry.WithLabelValues("US").Inc() + m.AuthAttemptsByCountry.WithLabelValues("DE").Inc() + m.AuthAttemptsByCountry.WithLabelValues("US").Inc() + + families, err := m.registry.Gather() + if err != nil { + t.Fatalf("gather: %v", err) + } + + var found bool + for _, f := range families { + if f.GetName() == "oubliette_auth_attempts_by_country_total" { + found = true + if len(f.GetMetric()) != 2 { + t.Errorf("expected 2 label pairs (US, DE), got %d", len(f.GetMetric())) + } + } + } + if !found { + t.Error("oubliette_auth_attempts_by_country_total not found after incrementing") + } +} + func TestHandler(t *testing.T) { m := New("1.2.3") @@ -60,3 +90,53 @@ func TestHandler(t *testing.T) { t.Errorf("response should contain build_info metric, got:\n%s", body) } } + +func TestStoreCollector(t *testing.T) { + store := storage.NewMemoryStore() + ctx := context.Background() + + // Seed some data. + if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil { + t.Fatalf("RecordLoginAttempt: %v", err) + } + if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil { + t.Fatalf("RecordLoginAttempt: %v", err) + } + if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil { + t.Fatalf("CreateSession: %v", err) + } + + m := New("test") + m.RegisterStoreCollector(store) + + families, err := m.registry.Gather() + if err != nil { + t.Fatalf("gather: %v", err) + } + + wantMetrics := map[string]float64{ + "oubliette_storage_login_attempts_total": 2, + "oubliette_storage_unique_ips": 2, + "oubliette_storage_sessions_total": 1, + } + + for _, f := range families { + expected, ok := wantMetrics[f.GetName()] + if !ok { + continue + } + if len(f.GetMetric()) == 0 { + t.Errorf("metric %q has no samples", f.GetName()) + continue + } + got := f.GetMetric()[0].GetGauge().GetValue() + if got != expected { + t.Errorf("metric %q = %f, want %f", f.GetName(), got, expected) + } + delete(wantMetrics, f.GetName()) + } + + for name := range wantMetrics { + t.Errorf("metric %q not found in gather output", name) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index a446a17..a717071 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -264,6 +264,9 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request Banner: s.cfg.Shell.Banner, FakeUser: s.cfg.Shell.FakeUser, }, + OnCommand: func(sh string) { + s.metrics.CommandsExecuted.WithLabelValues(sh).Inc() + }, } // Wrap channel in RecordingChannel. @@ -299,6 +302,7 @@ func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request } if scorer != nil { finalScore := scorer.Score() + s.metrics.HumanScore.Observe(finalScore) if err := s.store.UpdateHumanScore(context.Background(), sessionID, finalScore); err != nil { s.logger.Error("failed to write final human score", "err", err, "session_id", sessionID) } @@ -362,6 +366,9 @@ func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh. ) country := s.geoip.Lookup(ip) + if country != "" { + s.metrics.AuthAttemptsByCountry.WithLabelValues(country).Inc() + } if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip, country); err != nil { s.logger.Error("failed to record login attempt", "err", err) } diff --git a/internal/shell/adventure/adventure.go b/internal/shell/adventure/adventure.go index 8f43927..5d3b86c 100644 --- a/internal/shell/adventure/adventure.go +++ b/internal/shell/adventure/adventure.go @@ -75,6 +75,9 @@ func (a *AdventureShell) Handle(ctx context.Context, sess *shell.SessionContext, return fmt.Errorf("append session log: %w", err) } } + if sess.OnCommand != nil { + sess.OnCommand("adventure") + } if result.exit { return nil diff --git a/internal/shell/banking/model.go b/internal/shell/banking/model.go index 090c099..e426c38 100644 --- a/internal/shell/banking/model.go +++ b/internal/shell/banking/model.go @@ -345,6 +345,9 @@ func logAction(sess *shell.SessionContext, input, output string) tea.Cmd { defer cancel() _ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output) } + if sess.OnCommand != nil { + sess.OnCommand("banking") + } return nil } } diff --git a/internal/shell/bash/bash.go b/internal/shell/bash/bash.go index bad169e..0234527 100644 --- a/internal/shell/bash/bash.go +++ b/internal/shell/bash/bash.go @@ -86,6 +86,9 @@ func (b *BashShell) Handle(ctx context.Context, sess *shell.SessionContext, rw i return fmt.Errorf("append session log: %w", err) } } + if sess.OnCommand != nil { + sess.OnCommand("bash") + } if result.exit { return nil diff --git a/internal/shell/cisco/cisco.go b/internal/shell/cisco/cisco.go index 4f96a4e..fb6fcc0 100644 --- a/internal/shell/cisco/cisco.go +++ b/internal/shell/cisco/cisco.go @@ -74,6 +74,9 @@ func (c *CiscoShell) Handle(ctx context.Context, sess *shell.SessionContext, rw return fmt.Errorf("append session log: %w", err) } } + if sess.OnCommand != nil { + sess.OnCommand("cisco") + } continue } @@ -92,6 +95,9 @@ func (c *CiscoShell) Handle(ctx context.Context, sess *shell.SessionContext, rw return fmt.Errorf("append session log: %w", err) } } + if sess.OnCommand != nil { + sess.OnCommand("cisco") + } if result.exit { return nil diff --git a/internal/shell/fridge/fridge.go b/internal/shell/fridge/fridge.go index f88f1c6..b24baee 100644 --- a/internal/shell/fridge/fridge.go +++ b/internal/shell/fridge/fridge.go @@ -69,6 +69,9 @@ func (f *FridgeShell) Handle(ctx context.Context, sess *shell.SessionContext, rw return fmt.Errorf("append session log: %w", err) } } + if sess.OnCommand != nil { + sess.OnCommand("fridge") + } if result.exit { return nil diff --git a/internal/shell/shell.go b/internal/shell/shell.go index 3c29a0f..74205eb 100644 --- a/internal/shell/shell.go +++ b/internal/shell/shell.go @@ -24,6 +24,7 @@ type SessionContext struct { Store storage.Store ShellConfig map[string]any CommonConfig ShellCommonConfig + OnCommand func(shell string) // called when a command is executed; may be nil } // ShellCommonConfig holds settings shared across all shell types. diff --git a/internal/web/web.go b/internal/web/web.go index c6c887e..4414053 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -1,9 +1,11 @@ package web import ( + "crypto/subtle" "embed" "log/slog" "net/http" + "strings" "git.t-juice.club/torjus/oubliette/internal/storage" ) @@ -21,7 +23,8 @@ type Server struct { // NewServer creates a new web Server with routes registered. // If metricsHandler is non-nil, a /metrics endpoint is registered. -func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Handler) (*Server, error) { +// If metricsToken is non-empty, the metrics endpoint requires Bearer token auth. +func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Handler, metricsToken string) (*Server, error) { tmpl, err := loadTemplates() if err != nil { return nil, err @@ -42,7 +45,11 @@ func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Han s.mux.HandleFunc("GET /fragments/active-sessions", s.handleFragmentActiveSessions) if metricsHandler != nil { - s.mux.Handle("GET /metrics", metricsHandler) + h := metricsHandler + if metricsToken != "" { + h = requireBearerToken(metricsToken, h) + } + s.mux.Handle("GET /metrics", h) } return s, nil @@ -52,3 +59,20 @@ func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Han func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.mux.ServeHTTP(w, r) } + +// requireBearerToken wraps a handler to require a valid Bearer token. +func requireBearerToken(token string, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + provided := auth[len("Bearer "):] + if subtle.ConstantTimeCompare([]byte(provided), []byte(token)) != 1 { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/web/web_test.go b/internal/web/web_test.go index c8b81a9..70ef0cb 100644 --- a/internal/web/web_test.go +++ b/internal/web/web_test.go @@ -18,7 +18,7 @@ func newTestServer(t *testing.T) *Server { t.Helper() store := storage.NewMemoryStore() logger := slog.Default() - srv, err := NewServer(store, logger, nil) + srv, err := NewServer(store, logger, nil, "") if err != nil { t.Fatalf("creating server: %v", err) } @@ -47,7 +47,7 @@ func newSeededTestServer(t *testing.T) *Server { } logger := slog.Default() - srv, err := NewServer(store, logger, nil) + srv, err := NewServer(store, logger, nil, "") if err != nil { t.Fatalf("creating server: %v", err) } @@ -155,7 +155,7 @@ func TestSessionDetailHandler(t *testing.T) { t.Fatalf("CreateSession: %v", err) } - srv, err := NewServer(store, slog.Default(), nil) + srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } @@ -195,7 +195,7 @@ func TestAPISessionEvents(t *testing.T) { t.Fatalf("AppendSessionEvents: %v", err) } - srv, err := NewServer(store, slog.Default(), nil) + srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } @@ -241,7 +241,7 @@ func TestMetricsEndpoint(t *testing.T) { t.Run("enabled", func(t *testing.T) { m := metrics.New("test") store := storage.NewMemoryStore() - srv, err := NewServer(store, slog.Default(), m.Handler()) + srv, err := NewServer(store, slog.Default(), m.Handler(), "") if err != nil { t.Fatalf("NewServer: %v", err) } @@ -261,7 +261,7 @@ func TestMetricsEndpoint(t *testing.T) { t.Run("disabled", func(t *testing.T) { store := storage.NewMemoryStore() - srv, err := NewServer(store, slog.Default(), nil) + srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } @@ -278,6 +278,68 @@ func TestMetricsEndpoint(t *testing.T) { }) } +func TestMetricsBearerToken(t *testing.T) { + m := metrics.New("test") + + t.Run("valid token", func(t *testing.T) { + store := storage.NewMemoryStore() + srv, err := NewServer(store, slog.Default(), m.Handler(), "secret") + if err != nil { + t.Fatalf("NewServer: %v", err) + } + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + req.Header.Set("Authorization", "Bearer secret") + w := httptest.NewRecorder() + srv.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } + }) + + t.Run("wrong token", func(t *testing.T) { + store := storage.NewMemoryStore() + srv, err := NewServer(store, slog.Default(), m.Handler(), "secret") + if err != nil { + t.Fatalf("NewServer: %v", err) + } + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + req.Header.Set("Authorization", "Bearer wrong") + w := httptest.NewRecorder() + srv.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", w.Code) + } + }) + + t.Run("missing header", func(t *testing.T) { + store := storage.NewMemoryStore() + srv, err := NewServer(store, slog.Default(), m.Handler(), "secret") + if err != nil { + t.Fatalf("NewServer: %v", err) + } + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + w := httptest.NewRecorder() + srv.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Errorf("status = %d, want 401", w.Code) + } + }) + + t.Run("no token configured", func(t *testing.T) { + store := storage.NewMemoryStore() + srv, err := NewServer(store, slog.Default(), m.Handler(), "") + if err != nil { + t.Fatalf("NewServer: %v", err) + } + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + w := httptest.NewRecorder() + srv.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Errorf("status = %d, want 200", w.Code) + } + }) +} + func TestStaticAssets(t *testing.T) { srv := newTestServer(t) diff --git a/oubliette.toml.example b/oubliette.toml.example index 27f427c..3960dd4 100644 --- a/oubliette.toml.example +++ b/oubliette.toml.example @@ -43,6 +43,7 @@ retention_interval = "1h" # enabled = true # listen_addr = ":8080" # metrics_enabled = true +# metrics_token = "" # bearer token for /metrics; empty = no auth [shell] hostname = "ubuntu-server"