diff --git a/internal/storage/memstore.go b/internal/storage/memstore.go index c22da89..244b3e0 100644 --- a/internal/storage/memstore.go +++ b/internal/storage/memstore.go @@ -336,10 +336,26 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly m.mu.Lock() defer m.mu.Unlock() - // Count events per session. + return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil +} + +func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.collectSessions(limit, activeOnly, f), nil +} + +// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held. +func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session { + // Compute event counts and input bytes per session. eventCounts := make(map[string]int) + inputBytes := make(map[string]int64) for _, e := range m.SessionEvents { eventCounts[e.SessionID]++ + if e.Direction == 0 { + inputBytes[e.SessionID] += int64(len(e.Data)) + } } var sessions []Session @@ -347,17 +363,54 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly if activeOnly && s.DisconnectedAt != nil { continue } + if !matchesSessionFilter(s, f) { + continue + } sess := *s sess.EventCount = eventCounts[s.ID] + sess.InputBytes = inputBytes[s.ID] sessions = append(sessions, sess) } - sort.Slice(sessions, func(i, j int) bool { - return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt) - }) + + if f.SortBy == "input_bytes" { + sort.Slice(sessions, func(i, j int) bool { + return sessions[i].InputBytes > sessions[j].InputBytes + }) + } else { + sort.Slice(sessions, func(i, j int) bool { + return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt) + }) + } + if limit > 0 && len(sessions) > limit { sessions = sessions[:limit] } - return sessions, nil + return sessions +} + +// matchesSessionFilter returns true if the session matches the given filter. +func matchesSessionFilter(s *Session, f DashboardFilter) bool { + if f.Since != nil && s.ConnectedAt.Before(*f.Since) { + return false + } + if f.Until != nil && s.ConnectedAt.After(*f.Until) { + return false + } + if f.IP != "" && s.IP != f.IP { + return false + } + if f.Country != "" && s.Country != f.Country { + return false + } + if f.Username != "" && s.Username != f.Username { + return false + } + if f.HumanScoreAboveZero { + if s.HumanScore == nil || *s.HumanScore <= 0 { + return false + } + } + return true } func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) { diff --git a/internal/storage/sqlite.go b/internal/storage/sqlite.go index c5e4359..0299c58 100644 --- a/internal/storage/sqlite.go +++ b/internal/storage/sqlite.go @@ -383,40 +383,104 @@ func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) ( } func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) { - query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` if activeOnly { query += ` WHERE s.disconnected_at IS NULL` } query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?` - rows, err := s.db.QueryContext(ctx, query, limit) + return s.scanSessions(ctx, query, limit) +} + +// buildSessionWhereClause builds a dynamic WHERE clause for session filtering. +func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) { + var clauses []string + var args []any + + if activeOnly { + clauses = append(clauses, "s.disconnected_at IS NULL") + } + if f.Since != nil { + clauses = append(clauses, "s.connected_at >= ?") + args = append(args, f.Since.UTC().Format(time.RFC3339)) + } + if f.Until != nil { + clauses = append(clauses, "s.connected_at <= ?") + args = append(args, f.Until.UTC().Format(time.RFC3339)) + } + if f.IP != "" { + clauses = append(clauses, "s.ip = ?") + args = append(args, f.IP) + } + if f.Country != "" { + clauses = append(clauses, "s.country = ?") + args = append(args, f.Country) + } + if f.Username != "" { + clauses = append(clauses, "s.username = ?") + args = append(args, f.Username) + } + if f.HumanScoreAboveZero { + clauses = append(clauses, "s.human_score > 0") + } + + if len(clauses) == 0 { + return "", nil + } + return " WHERE " + strings.Join(clauses, " AND "), args +} + +// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses. +var validSessionSorts = map[string]string{ + "connected_at": "s.connected_at DESC", + "input_bytes": "input_bytes DESC", +} + +func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) { + where, args := buildSessionWhereClause(f, activeOnly) + args = append(args, limit) + + orderBy := validSessionSorts["connected_at"] + if mapped, ok := validSessionSorts[f.SortBy]; ok { + orderBy = mapped + } + + //nolint:gosec // where/order clauses built from allowlisted constants, not raw user input + query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?` + + return s.scanSessions(ctx, query, args...) +} + +// scanSessions executes a session query and scans the results. +func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) { + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("querying recent sessions: %w", err) + return nil, fmt.Errorf("querying sessions: %w", err) } defer func() { _ = rows.Close() }() var sessions []Session for rows.Next() { - var s Session + var sess Session var connectedAt string var disconnectedAt sql.NullString var humanScore sql.NullFloat64 var execCommand sql.NullString - if err := rows.Scan(&s.ID, &s.IP, &s.Country, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &s.EventCount); err != nil { + if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil { return nil, fmt.Errorf("scanning session: %w", err) } - s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt) + sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt) if disconnectedAt.Valid { t, _ := time.Parse(time.RFC3339, disconnectedAt.String) - s.DisconnectedAt = &t + sess.DisconnectedAt = &t } if humanScore.Valid { - s.HumanScore = &humanScore.Float64 + sess.HumanScore = &humanScore.Float64 } if execCommand.Valid { - s.ExecCommand = &execCommand.String + sess.ExecCommand = &execCommand.String } - sessions = append(sessions, s) + sessions = append(sessions, sess) } return sessions, rows.Err() } diff --git a/internal/storage/store.go b/internal/storage/store.go index a2b1d16..3ed300f 100644 --- a/internal/storage/store.go +++ b/internal/storage/store.go @@ -29,6 +29,7 @@ type Session struct { HumanScore *float64 ExecCommand *string EventCount int + InputBytes int64 } // SessionLog represents a single log entry for a session. @@ -76,11 +77,13 @@ type CountryCount struct { // DashboardFilter contains optional filters for dashboard queries. type DashboardFilter struct { - Since *time.Time - Until *time.Time - IP string - Country string - Username string + Since *time.Time + Until *time.Time + IP string + Country string + Username string + HumanScoreAboveZero bool + SortBy string } // TopEntry represents a value and its count for top-N queries. @@ -137,6 +140,10 @@ type Store interface { // If activeOnly is true, only sessions with no disconnected_at are returned. GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) + // GetFilteredSessions returns sessions matching the given filter, ordered + // by the filter's SortBy field (default: connected_at DESC). + GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) + // GetSession returns a single session by ID. GetSession(ctx context.Context, sessionID string) (*Session, error) diff --git a/internal/storage/store_test.go b/internal/storage/store_test.go index 075b102..91b0333 100644 --- a/internal/storage/store_test.go +++ b/internal/storage/store_test.go @@ -700,3 +700,192 @@ func TestGetRecentSessions(t *testing.T) { }) }) } + +func TestInputBytes(t *testing.T) { + testStores(t, func(t *testing.T, newStore storeFactory) { + t.Run("counts only input direction", func(t *testing.T) { + store := newStore(t) + ctx := context.Background() + + id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + now := time.Now().UTC() + events := []SessionEvent{ + {SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input + {SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output + {SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input + } + if err := store.AppendSessionEvents(ctx, events); err != nil { + t.Fatalf("AppendSessionEvents: %v", err) + } + + sessions, err := store.GetRecentSessions(ctx, 10, false) + if err != nil { + t.Fatalf("GetRecentSessions: %v", err) + } + if len(sessions) != 1 { + t.Fatalf("len = %d, want 1", len(sessions)) + } + // Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7 + if sessions[0].InputBytes != 7 { + t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes) + } + }) + + t.Run("zero when no events", func(t *testing.T) { + store := newStore(t) + ctx := context.Background() + + _, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + sessions, err := store.GetRecentSessions(ctx, 10, false) + if err != nil { + t.Fatalf("GetRecentSessions: %v", err) + } + if len(sessions) != 1 { + t.Fatalf("len = %d, want 1", len(sessions)) + } + if sessions[0].InputBytes != 0 { + t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes) + } + }) + }) +} + +func TestGetFilteredSessions(t *testing.T) { + testStores(t, func(t *testing.T, newStore storeFactory) { + t.Run("filter by human score", func(t *testing.T) { + store := newStore(t) + ctx := context.Background() + + // Create two sessions, one with human score > 0. + id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil { + t.Fatalf("UpdateHumanScore: %v", err) + } + + _, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true}) + if err != nil { + t.Fatalf("GetFilteredSessions: %v", err) + } + if len(sessions) != 1 { + t.Fatalf("len = %d, want 1", len(sessions)) + } + if sessions[0].ID != id1 { + t.Errorf("expected session %s, got %s", id1, sessions[0].ID) + } + }) + + t.Run("sort by input bytes", func(t *testing.T) { + store := newStore(t) + ctx := context.Background() + + // Session with more input (created first). + id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + now := time.Now().UTC() + if err := store.AppendSessionEvents(ctx, []SessionEvent{ + {SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")}, + {SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")}, + }); err != nil { + t.Fatalf("AppendSessionEvents: %v", err) + } + + // Session with less input (created after id1, so would be first by connected_at). + // Sleep >1s to ensure different RFC3339 timestamps in SQLite. + time.Sleep(1100 * time.Millisecond) + id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if err := store.AppendSessionEvents(ctx, []SessionEvent{ + {SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")}, + }); err != nil { + t.Fatalf("AppendSessionEvents: %v", err) + } + + // Default sort (connected_at DESC) should show id2 first. + sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{}) + if err != nil { + t.Fatalf("GetFilteredSessions: %v", err) + } + if len(sessions) != 2 { + t.Fatalf("len = %d, want 2", len(sessions)) + } + if sessions[0].ID != id2 { + t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID) + } + + // Sort by input_bytes should show id1 first (more input). + sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"}) + if err != nil { + t.Fatalf("GetFilteredSessions: %v", err) + } + if len(sessions) != 2 { + t.Fatalf("len = %d, want 2", len(sessions)) + } + if sessions[0].ID != id1 { + t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID) + } + }) + + t.Run("combined filters", func(t *testing.T) { + store := newStore(t) + ctx := context.Background() + + id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil { + t.Fatalf("UpdateHumanScore: %v", err) + } + + // Different country, also has score. + id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil { + t.Fatalf("UpdateHumanScore: %v", err) + } + + // Same country CN but no score. + _, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN") + if err != nil { + t.Fatalf("CreateSession: %v", err) + } + + // Filter: CN + human score > 0 -> only id1. + sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{ + Country: "CN", + HumanScoreAboveZero: true, + }) + if err != nil { + t.Fatalf("GetFilteredSessions: %v", err) + } + if len(sessions) != 1 { + t.Fatalf("len = %d, want 1", len(sessions)) + } + if sessions[0].ID != id1 { + t.Errorf("expected session %s, got %s", id1, sessions[0].ID) + } + }) + }) +} diff --git a/internal/web/handlers.go b/internal/web/handlers.go index 5cf4fd1..3534d82 100644 --- a/internal/web/handlers.go +++ b/internal/web/handlers.go @@ -125,6 +125,21 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req } } +func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) { + f := parseDashboardFilter(r) + sessions, err := s.store.GetFilteredSessions(r.Context(), 50, false, f) + if err != nil { + s.logger.Error("failed to get filtered sessions", "err", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil { + s.logger.Error("failed to render recent sessions fragment", "err", err) + } +} + type sessionDetailData struct { Session *storage.Session Logs []storage.SessionLog @@ -201,11 +216,13 @@ func parseDateParam(r *http.Request, name string) *time.Time { func parseDashboardFilter(r *http.Request) storage.DashboardFilter { return storage.DashboardFilter{ - Since: parseDateParam(r, "since"), - Until: parseDateParam(r, "until"), - IP: r.URL.Query().Get("ip"), - Country: r.URL.Query().Get("country"), - Username: r.URL.Query().Get("username"), + Since: parseDateParam(r, "since"), + Until: parseDateParam(r, "until"), + IP: r.URL.Query().Get("ip"), + Country: r.URL.Query().Get("country"), + Username: r.URL.Query().Get("username"), + HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1", + SortBy: r.URL.Query().Get("sort"), } } diff --git a/internal/web/static/dashboard.js b/internal/web/static/dashboard.js index d55ba14..3a89ca9 100644 --- a/internal/web/static/dashboard.js +++ b/internal/web/static/dashboard.js @@ -16,6 +16,10 @@ var until = form.elements['until'].value; if (since) params.set('since', since); if (until) params.set('until', until); + var humanScore = form.elements['human_score']; + if (humanScore && humanScore.checked) params.set('human_score', '1'); + var sortBy = form.elements['sort']; + if (sortBy && sortBy.value) params.set('sort', sortBy.value); return params.toString(); } @@ -228,33 +232,20 @@ if (val) params.set(name, val); }); + var humanScore = form.elements['human_score']; + if (humanScore && humanScore.checked) params.set('human_score', '1'); + var sortBy = form.elements['sort']; + if (sortBy && sortBy.value) params.set('sort', sortBy.value); + var target = document.getElementById('dashboard-content'); if (target) { var url = '/fragments/dashboard-content?' + params.toString(); htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'}); } - // Client-side filter for recent sessions table - filterSessionsTable(form); - } - - function filterSessionsTable(form) { - var ip = form.elements['ip'].value.toLowerCase(); - var country = form.elements['country'].value.toLowerCase(); - var username = form.elements['username'].value.toLowerCase(); - - var rows = document.querySelectorAll('#recent-sessions-table tbody tr'); - rows.forEach(function(row) { - var cells = row.querySelectorAll('td'); - if (cells.length < 4) { row.style.display = ''; return; } - - var show = true; - if (ip && cells[1].textContent.toLowerCase().indexOf(ip) === -1) show = false; - if (country && cells[2].textContent.toLowerCase().indexOf(country) === -1) show = false; - if (username && cells[3].textContent.toLowerCase().indexOf(username) === -1) show = false; - - row.style.display = show ? '' : 'none'; - }); + // Server-side filter for recent sessions table + var sessionsUrl = '/fragments/recent-sessions?' + params.toString(); + htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'}); } window.clearFilters = function() { diff --git a/internal/web/templates.go b/internal/web/templates.go index 0b579fb..bf9bf9f 100644 --- a/internal/web/templates.go +++ b/internal/web/templates.go @@ -56,6 +56,20 @@ func templateFuncMap() template.FuncMap { } return s }, + "formatBytes": func(b int64) string { + const ( + kb = 1024 + mb = 1024 * kb + ) + switch { + case b >= mb: + return fmt.Sprintf("%.1f MB", float64(b)/float64(mb)) + case b >= kb: + return fmt.Sprintf("%.1f KB", float64(b)/float64(kb)) + default: + return fmt.Sprintf("%d B", b) + } + }, } } @@ -67,6 +81,7 @@ func loadTemplates() (*templateSet, error) { "templates/dashboard.html", "templates/fragments/stats.html", "templates/fragments/active_sessions.html", + "templates/fragments/recent_sessions.html", ) if err != nil { return nil, fmt.Errorf("parsing dashboard templates: %w", err) diff --git a/internal/web/templates/dashboard.html b/internal/web/templates/dashboard.html index fb26047..08ca8a7 100644 --- a/internal/web/templates/dashboard.html +++ b/internal/web/templates/dashboard.html @@ -13,6 +13,10 @@ +
{{truncateID .ID}}{{if gt .EventCount 0}} replay{{end}}{{truncateID .ID}}{{if gt .EventCount 0}} replay{{end}}