package web import ( "context" "encoding/json" "log/slog" "net/http" "net/http/httptest" "strings" "testing" "time" "git.t-juice.club/torjus/oubliette/internal/metrics" "git.t-juice.club/torjus/oubliette/internal/storage" ) func newTestServer(t *testing.T) *Server { t.Helper() store := storage.NewMemoryStore() logger := slog.Default() srv, err := NewServer(store, logger, nil, "") if err != nil { t.Fatalf("creating server: %v", err) } return srv } func newSeededTestServer(t *testing.T) *Server { t.Helper() store := storage.NewMemoryStore() ctx := context.Background() for range 5 { if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", ""); err != nil { t.Fatalf("seeding attempt: %v", err) } } if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", ""); err != nil { t.Fatalf("seeding attempt: %v", err) } if _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", ""); err != nil { t.Fatalf("creating session: %v", err) } if _, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", ""); err != nil { t.Fatalf("creating session: %v", err) } logger := slog.Default() srv, err := NewServer(store, logger, nil, "") if err != nil { t.Fatalf("creating server: %v", err) } return srv } func TestDashboardHandler(t *testing.T) { t.Run("empty store", func(t *testing.T) { srv := newTestServer(t) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() if !strings.Contains(body, "Oubliette") { t.Error("response should contain 'Oubliette'") } if !strings.Contains(body, "No data") { t.Error("response should contain 'No data' for empty tables") } }) t.Run("with data", func(t *testing.T) { srv := newSeededTestServer(t) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() if !strings.Contains(body, "root") { t.Error("response should contain username 'root'") } if !strings.Contains(body, "10.0.0.1") { t.Error("response should contain IP '10.0.0.1'") } }) } func TestFragmentStats(t *testing.T) { srv := newSeededTestServer(t) req := httptest.NewRequest(http.MethodGet, "/fragments/stats", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() // Should be a fragment, not a full HTML page. if strings.Contains(body, "") { t.Error("stats fragment should not contain full HTML document") } if !strings.Contains(body, "Total Attempts") { t.Error("stats fragment should contain 'Total Attempts'") } } func TestFragmentActiveSessions(t *testing.T) { srv := newSeededTestServer(t) req := httptest.NewRequest(http.MethodGet, "/fragments/active-sessions", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() if strings.Contains(body, "") { t.Error("active sessions fragment should not contain full HTML document") } // Both sessions are active (not ended). if !strings.Contains(body, "10.0.0.1") { t.Error("active sessions should contain IP '10.0.0.1'") } } func TestSessionDetailHandler(t *testing.T) { t.Run("not found", func(t *testing.T) { srv := newTestServer(t) req := httptest.NewRequest(http.MethodGet, "/sessions/nonexistent", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusNotFound { t.Errorf("status = %d, want 404", w.Code) } }) t.Run("found", func(t *testing.T) { store := storage.NewMemoryStore() ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") if err != nil { t.Fatalf("CreateSession: %v", err) } srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/sessions/"+id, nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() if !strings.Contains(body, "10.0.0.1") { t.Error("response should contain IP") } if !strings.Contains(body, "root") { t.Error("response should contain username") } }) } func TestAPISessionEvents(t *testing.T) { store := storage.NewMemoryStore() ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") if err != nil { t.Fatalf("CreateSession: %v", err) } now := time.Now().UTC() events := []storage.SessionEvent{ {SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, {SessionID: id, Timestamp: now.Add(500 * time.Millisecond), Direction: 1, Data: []byte("file1\n")}, } if err := store.AppendSessionEvents(ctx, events); err != nil { t.Fatalf("AppendSessionEvents: %v", err) } srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/api/sessions/"+id+"/events", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } ct := w.Header().Get("Content-Type") if !strings.Contains(ct, "application/json") { t.Errorf("Content-Type = %q, want application/json", ct) } var resp apiEventsResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decoding response: %v", err) } if len(resp.Events) != 2 { t.Fatalf("len = %d, want 2", len(resp.Events)) } // First event should have t=0 (relative). if resp.Events[0].T != 0 { t.Errorf("events[0].T = %d, want 0", resp.Events[0].T) } // Second event should have t=500 (500ms later). if resp.Events[1].T != 500 { t.Errorf("events[1].T = %d, want 500", resp.Events[1].T) } if resp.Events[0].D != 0 { t.Errorf("events[0].D = %d, want 0", resp.Events[0].D) } if resp.Events[1].D != 1 { t.Errorf("events[1].D = %d, want 1", resp.Events[1].D) } } func TestMetricsEndpoint(t *testing.T) { t.Run("enabled", func(t *testing.T) { m := metrics.New("test") store := storage.NewMemoryStore() srv, err := NewServer(store, slog.Default(), m.Handler(), "") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/metrics", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() if !strings.Contains(body, `oubliette_build_info{version="test"} 1`) { t.Errorf("response should contain build_info metric, got:\n%s", body) } }) t.Run("disabled", func(t *testing.T) { store := storage.NewMemoryStore() srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/metrics", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) // Without a metrics handler, /metrics falls through to the dashboard. body := w.Body.String() if strings.Contains(body, "oubliette_build_info") { t.Error("response should not contain prometheus metrics when disabled") } }) } func TestMetricsBearerToken(t *testing.T) { m := metrics.New("test") t.Run("valid token", func(t *testing.T) { store := storage.NewMemoryStore() srv, err := NewServer(store, slog.Default(), m.Handler(), "secret") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/metrics", nil) req.Header.Set("Authorization", "Bearer secret") w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } }) t.Run("wrong token", func(t *testing.T) { store := storage.NewMemoryStore() srv, err := NewServer(store, slog.Default(), m.Handler(), "secret") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/metrics", nil) req.Header.Set("Authorization", "Bearer wrong") w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", w.Code) } }) t.Run("missing header", func(t *testing.T) { store := storage.NewMemoryStore() srv, err := NewServer(store, slog.Default(), m.Handler(), "secret") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/metrics", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", w.Code) } }) t.Run("no token configured", func(t *testing.T) { store := storage.NewMemoryStore() srv, err := NewServer(store, slog.Default(), m.Handler(), "") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/metrics", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } }) } func TestTruncateCommand(t *testing.T) { funcMap := templateFuncMap() fn := funcMap["truncateCommand"].(func(string) string) tests := []struct { input string want string }{ {"short", "short"}, {"exactly fifty characters long! that is what it i.", "exactly fifty characters long! that is what it i."}, {"this string is definitely longer than fifty characters and should be truncated", "this string is definitely longer than fifty charac..."}, {"", ""}, } for _, tt := range tests { got := fn(tt.input) if got != tt.want { t.Errorf("truncateCommand(%q) = %q, want %q", tt.input, got, tt.want) } } } func TestDashboardExecCommands(t *testing.T) { store := storage.NewMemoryStore() ctx := context.Background() id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") if err != nil { t.Fatalf("creating session: %v", err) } if err := store.SetExecCommand(ctx, id, "uname -a"); err != nil { t.Fatalf("setting exec command: %v", err) } srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() if !strings.Contains(body, "Top Exec Commands") { t.Error("response should contain 'Top Exec Commands'") } if !strings.Contains(body, "uname -a") { t.Error("response should contain exec command 'uname -a'") } } func TestAPIAttemptsOverTime(t *testing.T) { srv := newSeededTestServer(t) req := httptest.NewRequest(http.MethodGet, "/api/charts/attempts-over-time", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } ct := w.Header().Get("Content-Type") if !strings.Contains(ct, "application/json") { t.Errorf("Content-Type = %q, want application/json", ct) } var resp apiAttemptsOverTimeResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decoding response: %v", err) } // Seeded data inserted today -> at least 1 point. if len(resp.Points) == 0 { t.Error("expected at least one data point") } } func TestAPIHourlyPattern(t *testing.T) { srv := newSeededTestServer(t) req := httptest.NewRequest(http.MethodGet, "/api/charts/hourly-pattern", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } var resp apiHourlyPatternResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decoding response: %v", err) } if len(resp.Hours) == 0 { t.Error("expected at least one hourly data point") } } func TestAPICountryStats(t *testing.T) { store := storage.NewMemoryStore() ctx := context.Background() if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil { t.Fatalf("seeding: %v", err) } if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil { t.Fatalf("seeding: %v", err) } srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/api/charts/country-stats", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } var resp apiCountryStatsResponse if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { t.Fatalf("decoding response: %v", err) } if len(resp.Countries) != 2 { t.Fatalf("len = %d, want 2", len(resp.Countries)) } } func TestFragmentDashboardContent(t *testing.T) { srv := newSeededTestServer(t) req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() if strings.Contains(body, "") { t.Error("dashboard content fragment should not contain full HTML document") } if !strings.Contains(body, "Top Usernames") { t.Error("dashboard content fragment should contain 'Top Usernames'") } } func TestFragmentDashboardContentWithFilter(t *testing.T) { store := storage.NewMemoryStore() ctx := context.Background() for range 5 { if err := store.RecordLoginAttempt(ctx, "root", "toor", "10.0.0.1", "CN"); err != nil { t.Fatalf("seeding: %v", err) } } for range 3 { if err := store.RecordLoginAttempt(ctx, "admin", "admin", "10.0.0.2", "RU"); err != nil { t.Fatalf("seeding: %v", err) } } srv, err := NewServer(store, slog.Default(), nil, "") if err != nil { t.Fatalf("NewServer: %v", err) } req := httptest.NewRequest(http.MethodGet, "/fragments/dashboard-content?country=CN", nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } body := w.Body.String() // When filtered by CN, should show root but not admin. if !strings.Contains(body, "root") { t.Error("response should contain 'root' when filtered by CN") } } func TestStaticAssets(t *testing.T) { srv := newTestServer(t) tests := []struct { path string contentType string }{ {"/static/pico.min.css", "text/css"}, {"/static/htmx.min.js", "text/javascript"}, {"/static/chart.min.js", "text/javascript"}, {"/static/dashboard.js", "text/javascript"}, {"/static/world.svg", "image/svg+xml"}, } for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tt.path, nil) w := httptest.NewRecorder() srv.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("status = %d, want 200", w.Code) } ct := w.Header().Get("Content-Type") if !strings.Contains(ct, tt.contentType) { t.Errorf("Content-Type = %q, want to contain %q", ct, tt.contentType) } }) } }