Cache total stats in memory

This commit is contained in:
Torjus Håkestad 2021-11-03 21:24:23 +01:00
parent 492c88c5b2
commit bd6eafd9f7
8 changed files with 177 additions and 36 deletions

View File

@ -60,9 +60,11 @@ func ActionServe(c *cli.Context) error {
var s store.LoginAttemptStore var s store.LoginAttemptStore
switch cfg.Store.Type { switch cfg.Store.Type {
case "MEMORY", "memory": case "MEMORY", "memory":
loggers.rootLogger.Debugw("Initializing store", "store_type", "memory") loggers.rootLogger.Infow("Initialized store", "store_type", "memory")
s = &store.MemoryStore{} s = &store.MemoryStore{}
case "POSTGRES", "postgres": case "POSTGRES", "postgres":
pgStartTime := time.Now()
loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres")
pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN) pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN)
if err != nil { if err != nil {
return err return err
@ -70,12 +72,14 @@ func ActionServe(c *cli.Context) error {
if err := pgStore.InitDB(); err != nil { if err := pgStore.InitDB(); err != nil {
return err return err
} }
loggers.rootLogger.Infow("Initialized store", "store_type", "postgres", "init_time", time.Since(pgStartTime))
if cfg.Store.EnableCache { if cfg.Store.EnableCache {
loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres") loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres")
startTime := time.Now()
cachingStore := store.NewCachingStore(pgStore) cachingStore := store.NewCachingStore(pgStore)
s = cachingStore s = cachingStore
loggers.rootLogger.Infow("Initialized store", "store_type", "cache-postgres", "init_time", time.Since(startTime))
} else { } else {
loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres")
s = pgStore s = pgStore
} }
default: default:
@ -92,7 +96,7 @@ func ActionServe(c *cli.Context) error {
defer serversCancel() defer serversCancel()
// Setup metrics collection // Setup metrics collection
s = store.NewMetricsCollectingStore(rootCtx, s) s = store.NewMetricsCollectingStore(s)
// Setup honeypot // Setup honeypot
hs, err := ssh.NewHoneypotServer(cfg.Honeypot, s) hs, err := ssh.NewHoneypotServer(cfg.Honeypot, s)

View File

@ -9,15 +9,40 @@ type CachingStore struct {
usernameQueryCache map[string][]models.LoginAttempt usernameQueryCache map[string][]models.LoginAttempt
passwordQueryCache map[string][]models.LoginAttempt passwordQueryCache map[string][]models.LoginAttempt
ipQueryCache map[string][]models.LoginAttempt ipQueryCache map[string][]models.LoginAttempt
uniqueUsernames map[string]struct{}
uniquePasswords map[string]struct{}
uniqueIPs map[string]struct{}
uniqueCountries map[string]struct{}
totalLoginsCount int
} }
func NewCachingStore(backend LoginAttemptStore) *CachingStore { func NewCachingStore(backend LoginAttemptStore) *CachingStore {
return &CachingStore{ cs := &CachingStore{
backend: backend, backend: backend,
usernameQueryCache: make(map[string][]models.LoginAttempt), usernameQueryCache: make(map[string][]models.LoginAttempt),
passwordQueryCache: make(map[string][]models.LoginAttempt), passwordQueryCache: make(map[string][]models.LoginAttempt),
ipQueryCache: make(map[string][]models.LoginAttempt), ipQueryCache: make(map[string][]models.LoginAttempt),
uniqueUsernames: make(map[string]struct{}),
uniquePasswords: make(map[string]struct{}),
uniqueIPs: make(map[string]struct{}),
uniqueCountries: make(map[string]struct{}),
} }
all, err := backend.All()
if err != nil {
//TODO: Handle better maybe?
panic(err)
}
cs.totalLoginsCount = len(all)
for _, attempt := range all {
cs.uniqueUsernames[attempt.Username] = struct{}{}
cs.uniquePasswords[attempt.Password] = struct{}{}
cs.uniqueIPs[attempt.RemoteIP.String()] = struct{}{}
cs.uniqueCountries[attempt.Country] = struct{}{}
}
return cs
} }
func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error { func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
@ -25,6 +50,11 @@ func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
delete(s.ipQueryCache, l.RemoteIP.String()) delete(s.ipQueryCache, l.RemoteIP.String())
delete(s.passwordQueryCache, l.Password) delete(s.passwordQueryCache, l.Password)
delete(s.usernameQueryCache, l.Username) delete(s.usernameQueryCache, l.Username)
s.totalLoginsCount++
s.uniqueUsernames[l.Username] = struct{}{}
s.uniquePasswords[l.Password] = struct{}{}
s.uniqueIPs[l.RemoteIP.String()] = struct{}{}
s.uniqueCountries[l.Country] = struct{}{}
return s.backend.AddAttempt(l) return s.backend.AddAttempt(l)
} }
@ -35,20 +65,15 @@ func (s *CachingStore) All() ([]models.LoginAttempt, error) {
func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
// Only cache totals for now, as they are the most queried // Only cache totals for now, as they are the most queried
if statType == LoginStatsTotals { if statType == LoginStatsTotals {
if s.totalStatsCache != nil { return []StatsResult{
return s.totalStatsCache, nil {Name: "UniquePasswords", Count: len(s.uniquePasswords)},
} {Name: "UniqueUsernames", Count: len(s.uniqueUsernames)},
stats, err := s.backend.Stats(statType, limit) {Name: "UniqueIPs", Count: len(s.uniqueIPs)},
if err != nil { {Name: "UniqueCountries", Count: len(s.uniqueCountries)},
return stats, err {Name: "TotalLoginAttempts", Count: s.totalLoginsCount},
} }, nil
s.totalStatsCache = stats
return stats, err
} }
// TODO: Maybe cache the default limits
return s.backend.Stats(statType, limit) return s.backend.Stats(statType, limit)
} }

View File

@ -11,3 +11,59 @@ func TestCacheStore(t *testing.T) {
s := store.NewCachingStore(backend) s := store.NewCachingStore(backend)
testLoginAttemptStore(s, t) testLoginAttemptStore(s, t)
} }
func TestCacheTotalStats(t *testing.T) {
backend := &store.MemoryStore{}
// Add initial attempts, to ensure that the cache is correcly initialized with existing attempts
attempts := randomAttempts(50000)
for _, attempt := range attempts {
err := backend.AddAttempt(attempt)
if err != nil {
t.Fatalf("Error adding attempts: %s", err)
}
}
s := store.NewCachingStore(backend)
cacheTotals, err := s.Stats(store.LoginStatsTotals, 0)
if err != nil {
t.Fatalf("Error getting cached stats: %s", err)
}
backendTotals, err := backend.Stats(store.LoginStatsTotals, 0)
if err != nil {
t.Fatalf("Error getting cached stats: %s", err)
}
for i := range cacheTotals {
if cacheTotals[i].Count != backendTotals[i].Count || cacheTotals[i].Name != backendTotals[i].Name {
t.Fatalf("Mismatched totals: Cache: %+v Backend: %+v", cacheTotals[i], backendTotals[i])
}
}
// Add the same attempts again, to ensure that duplicates are handled correctly
for _, attempt := range attempts {
err := s.AddAttempt(attempt)
if err != nil {
t.Fatalf("Error adding attempts: %s", err)
}
}
// Add some new attempts
attempts = randomAttempts(10000)
for _, attempt := range attempts {
err := s.AddAttempt(attempt)
if err != nil {
t.Fatalf("Error adding attempts: %s", err)
}
}
cacheTotals, err = s.Stats(store.LoginStatsTotals, 0)
if err != nil {
t.Fatalf("Error getting cached stats: %s", err)
}
backendTotals, err = backend.Stats(store.LoginStatsTotals, 0)
if err != nil {
t.Fatalf("Error getting cached stats: %s", err)
}
for i := range cacheTotals {
if cacheTotals[i].Count != backendTotals[i].Count || cacheTotals[i].Name != backendTotals[i].Name {
t.Fatalf("Mismatched totals: Cache: %+v Backend: %+v", cacheTotals[i], backendTotals[i])
}
}
}

View File

@ -15,3 +15,10 @@ func TestMemoryStoreWithCache(t *testing.T) {
s := store.NewCachingStore(backend) s := store.NewCachingStore(backend)
testLoginAttemptStore(s, t) testLoginAttemptStore(s, t)
} }
func BenchmarkMemoryStore(b *testing.B) {
setupFunc := func() store.LoginAttemptStore {
return &store.MemoryStore{}
}
benchmarkLoginAttemptStore(setupFunc, b)
}

View File

@ -1,9 +1,6 @@
package store package store
import ( import (
"context"
"time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.uio.no/torjus/apiary/models" "github.uio.no/torjus/apiary/models"
) )
@ -17,7 +14,7 @@ type MetricsCollectingStore struct {
totalAttemptsCount prometheus.Gauge totalAttemptsCount prometheus.Gauge
} }
func NewMetricsCollectingStore(ctx context.Context, store LoginAttemptStore) *MetricsCollectingStore { func NewMetricsCollectingStore(store LoginAttemptStore) *MetricsCollectingStore {
mcs := &MetricsCollectingStore{store: store} mcs := &MetricsCollectingStore{store: store}
mcs.attemptsCounter = prometheus.NewCounterVec( mcs.attemptsCounter = prometheus.NewCounterVec(
@ -62,17 +59,6 @@ func NewMetricsCollectingStore(ctx context.Context, store LoginAttemptStore) *Me
prometheus.MustRegister(mcs.uniqueIPsCount) prometheus.MustRegister(mcs.uniqueIPsCount)
prometheus.MustRegister(mcs.totalAttemptsCount) prometheus.MustRegister(mcs.totalAttemptsCount)
// Kinda jank, we just fetch the stats every 10seconds, but it should be cached most of the time.
go func(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
select {
case <-ctx.Done():
return
case <-ticker.C:
mcs.Stats(LoginStatsTotals, 0)
}
}(ctx)
return mcs return mcs
} }

View File

@ -47,6 +47,26 @@ func TestPostgresStoreWithCache(t *testing.T) {
testLoginAttemptStore(s, t) testLoginAttemptStore(s, t)
} }
func BenchmarkPostgresStore(b *testing.B) {
var dsn string
var found bool
dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN")
if !found {
b.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.")
}
dropPGDatabase(dsn)
setupFunc := func() store.LoginAttemptStore {
dropPGDatabase(dsn)
pgs, err := store.NewPostgresStore(dsn)
if err != nil {
b.Fatalf("Error getting store: %s", err)
}
pgs.InitDB()
return pgs
}
benchmarkLoginAttemptStore(setupFunc, b)
dropPGDatabase(dsn)
}
func dropPGDatabase(dsn string) { func dropPGDatabase(dsn string) {
db, err := sql.Open("pgx", dsn) db, err := sql.Open("pgx", dsn)
@ -54,8 +74,5 @@ func dropPGDatabase(dsn string) {
panic(err) panic(err)
} }
_, err = db.Exec("DROP TABLE login_attempts") _, _ = db.Exec("DROP TABLE login_attempts")
if err != nil {
panic(err)
}
} }

View File

@ -188,6 +188,52 @@ func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) {
}) })
} }
func benchmarkLoginAttemptStore(setupFunc func() store.LoginAttemptStore, b *testing.B) {
b.Run("BenchmarkAdd", func(b *testing.B) {
s := setupFunc()
for i := 0; i < b.N; i++ {
attempt := randomAttempts(1)
err := s.AddAttempt(attempt[0])
if err != nil {
b.Fatalf("Error adding attempt: %s", err)
}
}
})
b.Run("BenchmarkAdd10k", func(b *testing.B) {
attempts := randomAttempts(10_000)
for i := 0; i < b.N; i++ {
b.StopTimer()
s := setupFunc()
b.StartTimer()
for _, attempt := range attempts {
err := s.AddAttempt(attempt)
if err != nil {
b.Fatalf("Error adding attempt: %s", err)
}
}
}
})
b.Run("BenchmarkAll10k", func(b *testing.B) {
s := setupFunc()
attempts := randomAttempts(10_000)
for _, attempt := range attempts {
err := s.AddAttempt(attempt)
if err != nil {
b.Fatalf("Error adding attempt: %s", err)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
all, err := s.All()
if err != nil {
b.Fatalf("Error fetchin all: %s", err)
}
_ = len(all)
}
})
}
func randomAttempts(count int) []*models.LoginAttempt { func randomAttempts(count int) []*models.LoginAttempt {
var attempts []*models.LoginAttempt var attempts []*models.LoginAttempt
for i := 0; i < count; i++ { for i := 0; i < count; i++ {

View File

@ -5,7 +5,7 @@ import (
"runtime" "runtime"
) )
var Version = "v0.1.17" var Version = "v0.1.18"
var Build string var Build string
func FullVersion() string { func FullVersion() string {