diff --git a/cmd/apiary.go b/cmd/apiary.go index 2b5152e..bdedd0e 100644 --- a/cmd/apiary.go +++ b/cmd/apiary.go @@ -60,9 +60,11 @@ func ActionServe(c *cli.Context) error { var s store.LoginAttemptStore switch cfg.Store.Type { case "MEMORY", "memory": - loggers.rootLogger.Debugw("Initializing store", "store_type", "memory") + loggers.rootLogger.Infow("Initialized store", "store_type", "memory") s = &store.MemoryStore{} case "POSTGRES", "postgres": + pgStartTime := time.Now() + loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres") pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN) if err != nil { return err @@ -70,12 +72,14 @@ func ActionServe(c *cli.Context) error { if err := pgStore.InitDB(); err != nil { return err } + loggers.rootLogger.Infow("Initialized store", "store_type", "postgres", "init_time", time.Since(pgStartTime)) if cfg.Store.EnableCache { loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres") + startTime := time.Now() cachingStore := store.NewCachingStore(pgStore) s = cachingStore + loggers.rootLogger.Infow("Initialized store", "store_type", "cache-postgres", "init_time", time.Since(startTime)) } else { - loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres") s = pgStore } default: @@ -92,7 +96,7 @@ func ActionServe(c *cli.Context) error { defer serversCancel() // Setup metrics collection - s = store.NewMetricsCollectingStore(rootCtx, s) + s = store.NewMetricsCollectingStore(s) // Setup honeypot hs, err := ssh.NewHoneypotServer(cfg.Honeypot, s) diff --git a/honeypot/ssh/store/cache.go b/honeypot/ssh/store/cache.go index 16799b3..71b2d19 100644 --- a/honeypot/ssh/store/cache.go +++ b/honeypot/ssh/store/cache.go @@ -9,15 +9,40 @@ type CachingStore struct { usernameQueryCache map[string][]models.LoginAttempt passwordQueryCache map[string][]models.LoginAttempt ipQueryCache map[string][]models.LoginAttempt + uniqueUsernames map[string]struct{} + uniquePasswords map[string]struct{} + uniqueIPs map[string]struct{} + uniqueCountries map[string]struct{} + totalLoginsCount int } func NewCachingStore(backend LoginAttemptStore) *CachingStore { - return &CachingStore{ + cs := &CachingStore{ backend: backend, usernameQueryCache: make(map[string][]models.LoginAttempt), passwordQueryCache: make(map[string][]models.LoginAttempt), ipQueryCache: make(map[string][]models.LoginAttempt), + uniqueUsernames: make(map[string]struct{}), + uniquePasswords: make(map[string]struct{}), + uniqueIPs: make(map[string]struct{}), + uniqueCountries: make(map[string]struct{}), } + + all, err := backend.All() + if err != nil { + //TODO: Handle better maybe? + panic(err) + } + + cs.totalLoginsCount = len(all) + for _, attempt := range all { + cs.uniqueUsernames[attempt.Username] = struct{}{} + cs.uniquePasswords[attempt.Password] = struct{}{} + cs.uniqueIPs[attempt.RemoteIP.String()] = struct{}{} + cs.uniqueCountries[attempt.Country] = struct{}{} + } + + return cs } func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error { @@ -25,6 +50,11 @@ func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error { delete(s.ipQueryCache, l.RemoteIP.String()) delete(s.passwordQueryCache, l.Password) delete(s.usernameQueryCache, l.Username) + s.totalLoginsCount++ + s.uniqueUsernames[l.Username] = struct{}{} + s.uniquePasswords[l.Password] = struct{}{} + s.uniqueIPs[l.RemoteIP.String()] = struct{}{} + s.uniqueCountries[l.Country] = struct{}{} return s.backend.AddAttempt(l) } @@ -35,20 +65,15 @@ func (s *CachingStore) All() ([]models.LoginAttempt, error) { func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { // Only cache totals for now, as they are the most queried if statType == LoginStatsTotals { - if s.totalStatsCache != nil { - return s.totalStatsCache, nil - } - stats, err := s.backend.Stats(statType, limit) - if err != nil { - return stats, err - } - s.totalStatsCache = stats - - return stats, err + return []StatsResult{ + {Name: "UniquePasswords", Count: len(s.uniquePasswords)}, + {Name: "UniqueUsernames", Count: len(s.uniqueUsernames)}, + {Name: "UniqueIPs", Count: len(s.uniqueIPs)}, + {Name: "UniqueCountries", Count: len(s.uniqueCountries)}, + {Name: "TotalLoginAttempts", Count: s.totalLoginsCount}, + }, nil } - // TODO: Maybe cache the default limits - return s.backend.Stats(statType, limit) } diff --git a/honeypot/ssh/store/cache_test.go b/honeypot/ssh/store/cache_test.go index 2334467..d8405eb 100644 --- a/honeypot/ssh/store/cache_test.go +++ b/honeypot/ssh/store/cache_test.go @@ -11,3 +11,59 @@ func TestCacheStore(t *testing.T) { s := store.NewCachingStore(backend) testLoginAttemptStore(s, t) } + +func TestCacheTotalStats(t *testing.T) { + backend := &store.MemoryStore{} + // Add initial attempts, to ensure that the cache is correcly initialized with existing attempts + attempts := randomAttempts(50000) + for _, attempt := range attempts { + err := backend.AddAttempt(attempt) + if err != nil { + t.Fatalf("Error adding attempts: %s", err) + } + } + s := store.NewCachingStore(backend) + cacheTotals, err := s.Stats(store.LoginStatsTotals, 0) + if err != nil { + t.Fatalf("Error getting cached stats: %s", err) + } + backendTotals, err := backend.Stats(store.LoginStatsTotals, 0) + if err != nil { + t.Fatalf("Error getting cached stats: %s", err) + } + for i := range cacheTotals { + if cacheTotals[i].Count != backendTotals[i].Count || cacheTotals[i].Name != backendTotals[i].Name { + t.Fatalf("Mismatched totals: Cache: %+v Backend: %+v", cacheTotals[i], backendTotals[i]) + } + } + + // Add the same attempts again, to ensure that duplicates are handled correctly + for _, attempt := range attempts { + err := s.AddAttempt(attempt) + if err != nil { + t.Fatalf("Error adding attempts: %s", err) + } + } + + // Add some new attempts + attempts = randomAttempts(10000) + for _, attempt := range attempts { + err := s.AddAttempt(attempt) + if err != nil { + t.Fatalf("Error adding attempts: %s", err) + } + } + cacheTotals, err = s.Stats(store.LoginStatsTotals, 0) + if err != nil { + t.Fatalf("Error getting cached stats: %s", err) + } + backendTotals, err = backend.Stats(store.LoginStatsTotals, 0) + if err != nil { + t.Fatalf("Error getting cached stats: %s", err) + } + for i := range cacheTotals { + if cacheTotals[i].Count != backendTotals[i].Count || cacheTotals[i].Name != backendTotals[i].Name { + t.Fatalf("Mismatched totals: Cache: %+v Backend: %+v", cacheTotals[i], backendTotals[i]) + } + } +} diff --git a/honeypot/ssh/store/memory_test.go b/honeypot/ssh/store/memory_test.go index 604f213..dbb3a50 100644 --- a/honeypot/ssh/store/memory_test.go +++ b/honeypot/ssh/store/memory_test.go @@ -15,3 +15,10 @@ func TestMemoryStoreWithCache(t *testing.T) { s := store.NewCachingStore(backend) testLoginAttemptStore(s, t) } + +func BenchmarkMemoryStore(b *testing.B) { + setupFunc := func() store.LoginAttemptStore { + return &store.MemoryStore{} + } + benchmarkLoginAttemptStore(setupFunc, b) +} diff --git a/honeypot/ssh/store/metrics.go b/honeypot/ssh/store/metrics.go index e2c53b2..ee2531b 100644 --- a/honeypot/ssh/store/metrics.go +++ b/honeypot/ssh/store/metrics.go @@ -1,9 +1,6 @@ package store import ( - "context" - "time" - "github.com/prometheus/client_golang/prometheus" "github.uio.no/torjus/apiary/models" ) @@ -17,7 +14,7 @@ type MetricsCollectingStore struct { totalAttemptsCount prometheus.Gauge } -func NewMetricsCollectingStore(ctx context.Context, store LoginAttemptStore) *MetricsCollectingStore { +func NewMetricsCollectingStore(store LoginAttemptStore) *MetricsCollectingStore { mcs := &MetricsCollectingStore{store: store} mcs.attemptsCounter = prometheus.NewCounterVec( @@ -62,17 +59,6 @@ func NewMetricsCollectingStore(ctx context.Context, store LoginAttemptStore) *Me prometheus.MustRegister(mcs.uniqueIPsCount) prometheus.MustRegister(mcs.totalAttemptsCount) - // Kinda jank, we just fetch the stats every 10seconds, but it should be cached most of the time. - go func(ctx context.Context) { - ticker := time.NewTicker(10 * time.Second) - select { - case <-ctx.Done(): - return - case <-ticker.C: - mcs.Stats(LoginStatsTotals, 0) - } - }(ctx) - return mcs } diff --git a/honeypot/ssh/store/postgres_test.go b/honeypot/ssh/store/postgres_test.go index 41f7250..14880c9 100644 --- a/honeypot/ssh/store/postgres_test.go +++ b/honeypot/ssh/store/postgres_test.go @@ -47,6 +47,26 @@ func TestPostgresStoreWithCache(t *testing.T) { testLoginAttemptStore(s, t) } +func BenchmarkPostgresStore(b *testing.B) { + var dsn string + var found bool + dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN") + if !found { + b.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.") + } + dropPGDatabase(dsn) + setupFunc := func() store.LoginAttemptStore { + dropPGDatabase(dsn) + pgs, err := store.NewPostgresStore(dsn) + if err != nil { + b.Fatalf("Error getting store: %s", err) + } + pgs.InitDB() + return pgs + } + benchmarkLoginAttemptStore(setupFunc, b) + dropPGDatabase(dsn) +} func dropPGDatabase(dsn string) { db, err := sql.Open("pgx", dsn) @@ -54,8 +74,5 @@ func dropPGDatabase(dsn string) { panic(err) } - _, err = db.Exec("DROP TABLE login_attempts") - if err != nil { - panic(err) - } + _, _ = db.Exec("DROP TABLE login_attempts") } diff --git a/honeypot/ssh/store/store_test.go b/honeypot/ssh/store/store_test.go index eb98697..e990c2f 100644 --- a/honeypot/ssh/store/store_test.go +++ b/honeypot/ssh/store/store_test.go @@ -188,6 +188,52 @@ func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) { }) } +func benchmarkLoginAttemptStore(setupFunc func() store.LoginAttemptStore, b *testing.B) { + b.Run("BenchmarkAdd", func(b *testing.B) { + s := setupFunc() + for i := 0; i < b.N; i++ { + attempt := randomAttempts(1) + err := s.AddAttempt(attempt[0]) + if err != nil { + b.Fatalf("Error adding attempt: %s", err) + } + } + }) + b.Run("BenchmarkAdd10k", func(b *testing.B) { + attempts := randomAttempts(10_000) + for i := 0; i < b.N; i++ { + b.StopTimer() + s := setupFunc() + b.StartTimer() + for _, attempt := range attempts { + err := s.AddAttempt(attempt) + if err != nil { + b.Fatalf("Error adding attempt: %s", err) + } + } + } + }) + b.Run("BenchmarkAll10k", func(b *testing.B) { + s := setupFunc() + attempts := randomAttempts(10_000) + for _, attempt := range attempts { + err := s.AddAttempt(attempt) + if err != nil { + b.Fatalf("Error adding attempt: %s", err) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + all, err := s.All() + if err != nil { + b.Fatalf("Error fetchin all: %s", err) + } + _ = len(all) + } + }) +} + func randomAttempts(count int) []*models.LoginAttempt { var attempts []*models.LoginAttempt for i := 0; i < count; i++ { diff --git a/version.go b/version.go index f7a3ccf..cd39919 100644 --- a/version.go +++ b/version.go @@ -5,7 +5,7 @@ import ( "runtime" ) -var Version = "v0.1.17" +var Version = "v0.1.18" var Build string func FullVersion() string {