diff --git a/cmd/oubliette/main.go b/cmd/oubliette/main.go index c4af8ad..b2fe79f 100644 --- a/cmd/oubliette/main.go +++ b/cmd/oubliette/main.go @@ -65,6 +65,13 @@ func run() error { } defer store.Close() + // Clean up sessions left active by a previous unclean shutdown. + if n, err := store.CloseActiveSessions(context.Background(), time.Now()); err != nil { + return fmt.Errorf("close stale sessions: %w", err) + } else if n > 0 { + logger.Info("closed stale sessions from previous run", "count", n) + } + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() diff --git a/internal/storage/memstore.go b/internal/storage/memstore.go index 9afff6a..15b483e 100644 --- a/internal/storage/memstore.go +++ b/internal/storage/memstore.go @@ -286,6 +286,21 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly return sessions, nil } +func (m *MemoryStore) CloseActiveSessions(_ context.Context, disconnectedAt time.Time) (int64, error) { + m.mu.Lock() + defer m.mu.Unlock() + + var count int64 + t := disconnectedAt.UTC() + for _, s := range m.Sessions { + if s.DisconnectedAt == nil { + s.DisconnectedAt = &t + count++ + } + } + return count, nil +} + func (m *MemoryStore) Close() error { return nil } diff --git a/internal/storage/sqlite.go b/internal/storage/sqlite.go index 4344ede..c6e585e 100644 --- a/internal/storage/sqlite.go +++ b/internal/storage/sqlite.go @@ -351,6 +351,16 @@ func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOn return sessions, rows.Err() } +func (s *SQLiteStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) { + res, err := s.db.ExecContext(ctx, ` + UPDATE sessions SET disconnected_at = ? WHERE disconnected_at IS NULL`, + disconnectedAt.UTC().Format(time.RFC3339)) + if err != nil { + return 0, fmt.Errorf("closing active sessions: %w", err) + } + return res.RowsAffected() +} + func (s *SQLiteStore) Close() error { return s.db.Close() } diff --git a/internal/storage/store.go b/internal/storage/store.go index 2faf91f..e600bc2 100644 --- a/internal/storage/store.go +++ b/internal/storage/store.go @@ -108,6 +108,11 @@ type Store interface { // GetSessionEvents returns all events for a session ordered by id. GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) + // CloseActiveSessions sets disconnected_at for all sessions that are + // still marked as active. This should be called at startup to clean up + // sessions left over from a previous unclean shutdown. + CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) + // Close releases any resources held by the store. Close() error } diff --git a/internal/storage/store_test.go b/internal/storage/store_test.go index 4652a7e..deb91a8 100644 --- a/internal/storage/store_test.go +++ b/internal/storage/store_test.go @@ -316,6 +316,51 @@ func TestSessionEvents(t *testing.T) { }) } +func TestCloseActiveSessions(t *testing.T) { + testStores(t, func(t *testing.T, newStore storeFactory) { + t.Run("no active sessions", func(t *testing.T) { + store := newStore(t) + ctx := context.Background() + + n, err := store.CloseActiveSessions(ctx, time.Now()) + if err != nil { + t.Fatalf("CloseActiveSessions: %v", err) + } + if n != 0 { + t.Errorf("closed %d, want 0", n) + } + }) + + t.Run("closes only active sessions", func(t *testing.T) { + store := newStore(t) + ctx := context.Background() + + // Create 3 sessions: end one, leave two active. + id1, _ := store.CreateSession(ctx, "10.0.0.1", "root", "bash") + store.CreateSession(ctx, "10.0.0.2", "admin", "bash") + store.CreateSession(ctx, "10.0.0.3", "test", "bash") + store.EndSession(ctx, id1, time.Now()) + + n, err := store.CloseActiveSessions(ctx, time.Now()) + if err != nil { + t.Fatalf("CloseActiveSessions: %v", err) + } + if n != 2 { + t.Errorf("closed %d, want 2", n) + } + + // Verify no active sessions remain. + active, err := store.GetRecentSessions(ctx, 10, true) + if err != nil { + t.Fatalf("GetRecentSessions: %v", err) + } + if len(active) != 0 { + t.Errorf("active sessions = %d, want 0", len(active)) + } + }) + }) +} + func TestGetRecentSessions(t *testing.T) { testStores(t, func(t *testing.T, newStore storeFactory) { t.Run("empty", func(t *testing.T) {