Cache total stats in memory
This commit is contained in:
parent
492c88c5b2
commit
bd6eafd9f7
@ -60,9 +60,11 @@ func ActionServe(c *cli.Context) error {
|
|||||||
var s store.LoginAttemptStore
|
var s store.LoginAttemptStore
|
||||||
switch cfg.Store.Type {
|
switch cfg.Store.Type {
|
||||||
case "MEMORY", "memory":
|
case "MEMORY", "memory":
|
||||||
loggers.rootLogger.Debugw("Initializing store", "store_type", "memory")
|
loggers.rootLogger.Infow("Initialized store", "store_type", "memory")
|
||||||
s = &store.MemoryStore{}
|
s = &store.MemoryStore{}
|
||||||
case "POSTGRES", "postgres":
|
case "POSTGRES", "postgres":
|
||||||
|
pgStartTime := time.Now()
|
||||||
|
loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres")
|
||||||
pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN)
|
pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -70,12 +72,14 @@ func ActionServe(c *cli.Context) error {
|
|||||||
if err := pgStore.InitDB(); err != nil {
|
if err := pgStore.InitDB(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
loggers.rootLogger.Infow("Initialized store", "store_type", "postgres", "init_time", time.Since(pgStartTime))
|
||||||
if cfg.Store.EnableCache {
|
if cfg.Store.EnableCache {
|
||||||
loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres")
|
loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres")
|
||||||
|
startTime := time.Now()
|
||||||
cachingStore := store.NewCachingStore(pgStore)
|
cachingStore := store.NewCachingStore(pgStore)
|
||||||
s = cachingStore
|
s = cachingStore
|
||||||
|
loggers.rootLogger.Infow("Initialized store", "store_type", "cache-postgres", "init_time", time.Since(startTime))
|
||||||
} else {
|
} else {
|
||||||
loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres")
|
|
||||||
s = pgStore
|
s = pgStore
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -92,7 +96,7 @@ func ActionServe(c *cli.Context) error {
|
|||||||
defer serversCancel()
|
defer serversCancel()
|
||||||
|
|
||||||
// Setup metrics collection
|
// Setup metrics collection
|
||||||
s = store.NewMetricsCollectingStore(rootCtx, s)
|
s = store.NewMetricsCollectingStore(s)
|
||||||
|
|
||||||
// Setup honeypot
|
// Setup honeypot
|
||||||
hs, err := ssh.NewHoneypotServer(cfg.Honeypot, s)
|
hs, err := ssh.NewHoneypotServer(cfg.Honeypot, s)
|
||||||
|
@ -9,15 +9,40 @@ type CachingStore struct {
|
|||||||
usernameQueryCache map[string][]models.LoginAttempt
|
usernameQueryCache map[string][]models.LoginAttempt
|
||||||
passwordQueryCache map[string][]models.LoginAttempt
|
passwordQueryCache map[string][]models.LoginAttempt
|
||||||
ipQueryCache map[string][]models.LoginAttempt
|
ipQueryCache map[string][]models.LoginAttempt
|
||||||
|
uniqueUsernames map[string]struct{}
|
||||||
|
uniquePasswords map[string]struct{}
|
||||||
|
uniqueIPs map[string]struct{}
|
||||||
|
uniqueCountries map[string]struct{}
|
||||||
|
totalLoginsCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCachingStore(backend LoginAttemptStore) *CachingStore {
|
func NewCachingStore(backend LoginAttemptStore) *CachingStore {
|
||||||
return &CachingStore{
|
cs := &CachingStore{
|
||||||
backend: backend,
|
backend: backend,
|
||||||
usernameQueryCache: make(map[string][]models.LoginAttempt),
|
usernameQueryCache: make(map[string][]models.LoginAttempt),
|
||||||
passwordQueryCache: make(map[string][]models.LoginAttempt),
|
passwordQueryCache: make(map[string][]models.LoginAttempt),
|
||||||
ipQueryCache: make(map[string][]models.LoginAttempt),
|
ipQueryCache: make(map[string][]models.LoginAttempt),
|
||||||
|
uniqueUsernames: make(map[string]struct{}),
|
||||||
|
uniquePasswords: make(map[string]struct{}),
|
||||||
|
uniqueIPs: make(map[string]struct{}),
|
||||||
|
uniqueCountries: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
all, err := backend.All()
|
||||||
|
if err != nil {
|
||||||
|
//TODO: Handle better maybe?
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cs.totalLoginsCount = len(all)
|
||||||
|
for _, attempt := range all {
|
||||||
|
cs.uniqueUsernames[attempt.Username] = struct{}{}
|
||||||
|
cs.uniquePasswords[attempt.Password] = struct{}{}
|
||||||
|
cs.uniqueIPs[attempt.RemoteIP.String()] = struct{}{}
|
||||||
|
cs.uniqueCountries[attempt.Country] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
|
func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
|
||||||
@ -25,6 +50,11 @@ func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
|
|||||||
delete(s.ipQueryCache, l.RemoteIP.String())
|
delete(s.ipQueryCache, l.RemoteIP.String())
|
||||||
delete(s.passwordQueryCache, l.Password)
|
delete(s.passwordQueryCache, l.Password)
|
||||||
delete(s.usernameQueryCache, l.Username)
|
delete(s.usernameQueryCache, l.Username)
|
||||||
|
s.totalLoginsCount++
|
||||||
|
s.uniqueUsernames[l.Username] = struct{}{}
|
||||||
|
s.uniquePasswords[l.Password] = struct{}{}
|
||||||
|
s.uniqueIPs[l.RemoteIP.String()] = struct{}{}
|
||||||
|
s.uniqueCountries[l.Country] = struct{}{}
|
||||||
return s.backend.AddAttempt(l)
|
return s.backend.AddAttempt(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -35,19 +65,14 @@ func (s *CachingStore) All() ([]models.LoginAttempt, error) {
|
|||||||
func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
|
func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
|
||||||
// Only cache totals for now, as they are the most queried
|
// Only cache totals for now, as they are the most queried
|
||||||
if statType == LoginStatsTotals {
|
if statType == LoginStatsTotals {
|
||||||
if s.totalStatsCache != nil {
|
return []StatsResult{
|
||||||
return s.totalStatsCache, nil
|
{Name: "UniquePasswords", Count: len(s.uniquePasswords)},
|
||||||
|
{Name: "UniqueUsernames", Count: len(s.uniqueUsernames)},
|
||||||
|
{Name: "UniqueIPs", Count: len(s.uniqueIPs)},
|
||||||
|
{Name: "UniqueCountries", Count: len(s.uniqueCountries)},
|
||||||
|
{Name: "TotalLoginAttempts", Count: s.totalLoginsCount},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
stats, err := s.backend.Stats(statType, limit)
|
|
||||||
if err != nil {
|
|
||||||
return stats, err
|
|
||||||
}
|
|
||||||
s.totalStatsCache = stats
|
|
||||||
|
|
||||||
return stats, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Maybe cache the default limits
|
|
||||||
|
|
||||||
return s.backend.Stats(statType, limit)
|
return s.backend.Stats(statType, limit)
|
||||||
}
|
}
|
||||||
|
@ -11,3 +11,59 @@ func TestCacheStore(t *testing.T) {
|
|||||||
s := store.NewCachingStore(backend)
|
s := store.NewCachingStore(backend)
|
||||||
testLoginAttemptStore(s, t)
|
testLoginAttemptStore(s, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCacheTotalStats(t *testing.T) {
|
||||||
|
backend := &store.MemoryStore{}
|
||||||
|
// Add initial attempts, to ensure that the cache is correcly initialized with existing attempts
|
||||||
|
attempts := randomAttempts(50000)
|
||||||
|
for _, attempt := range attempts {
|
||||||
|
err := backend.AddAttempt(attempt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error adding attempts: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s := store.NewCachingStore(backend)
|
||||||
|
cacheTotals, err := s.Stats(store.LoginStatsTotals, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error getting cached stats: %s", err)
|
||||||
|
}
|
||||||
|
backendTotals, err := backend.Stats(store.LoginStatsTotals, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error getting cached stats: %s", err)
|
||||||
|
}
|
||||||
|
for i := range cacheTotals {
|
||||||
|
if cacheTotals[i].Count != backendTotals[i].Count || cacheTotals[i].Name != backendTotals[i].Name {
|
||||||
|
t.Fatalf("Mismatched totals: Cache: %+v Backend: %+v", cacheTotals[i], backendTotals[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the same attempts again, to ensure that duplicates are handled correctly
|
||||||
|
for _, attempt := range attempts {
|
||||||
|
err := s.AddAttempt(attempt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error adding attempts: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some new attempts
|
||||||
|
attempts = randomAttempts(10000)
|
||||||
|
for _, attempt := range attempts {
|
||||||
|
err := s.AddAttempt(attempt)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error adding attempts: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cacheTotals, err = s.Stats(store.LoginStatsTotals, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error getting cached stats: %s", err)
|
||||||
|
}
|
||||||
|
backendTotals, err = backend.Stats(store.LoginStatsTotals, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error getting cached stats: %s", err)
|
||||||
|
}
|
||||||
|
for i := range cacheTotals {
|
||||||
|
if cacheTotals[i].Count != backendTotals[i].Count || cacheTotals[i].Name != backendTotals[i].Name {
|
||||||
|
t.Fatalf("Mismatched totals: Cache: %+v Backend: %+v", cacheTotals[i], backendTotals[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -15,3 +15,10 @@ func TestMemoryStoreWithCache(t *testing.T) {
|
|||||||
s := store.NewCachingStore(backend)
|
s := store.NewCachingStore(backend)
|
||||||
testLoginAttemptStore(s, t)
|
testLoginAttemptStore(s, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkMemoryStore(b *testing.B) {
|
||||||
|
setupFunc := func() store.LoginAttemptStore {
|
||||||
|
return &store.MemoryStore{}
|
||||||
|
}
|
||||||
|
benchmarkLoginAttemptStore(setupFunc, b)
|
||||||
|
}
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
package store
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.uio.no/torjus/apiary/models"
|
"github.uio.no/torjus/apiary/models"
|
||||||
)
|
)
|
||||||
@ -17,7 +14,7 @@ type MetricsCollectingStore struct {
|
|||||||
totalAttemptsCount prometheus.Gauge
|
totalAttemptsCount prometheus.Gauge
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetricsCollectingStore(ctx context.Context, store LoginAttemptStore) *MetricsCollectingStore {
|
func NewMetricsCollectingStore(store LoginAttemptStore) *MetricsCollectingStore {
|
||||||
mcs := &MetricsCollectingStore{store: store}
|
mcs := &MetricsCollectingStore{store: store}
|
||||||
|
|
||||||
mcs.attemptsCounter = prometheus.NewCounterVec(
|
mcs.attemptsCounter = prometheus.NewCounterVec(
|
||||||
@ -62,17 +59,6 @@ func NewMetricsCollectingStore(ctx context.Context, store LoginAttemptStore) *Me
|
|||||||
prometheus.MustRegister(mcs.uniqueIPsCount)
|
prometheus.MustRegister(mcs.uniqueIPsCount)
|
||||||
prometheus.MustRegister(mcs.totalAttemptsCount)
|
prometheus.MustRegister(mcs.totalAttemptsCount)
|
||||||
|
|
||||||
// Kinda jank, we just fetch the stats every 10seconds, but it should be cached most of the time.
|
|
||||||
go func(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(10 * time.Second)
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
mcs.Stats(LoginStatsTotals, 0)
|
|
||||||
}
|
|
||||||
}(ctx)
|
|
||||||
|
|
||||||
return mcs
|
return mcs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,6 +47,26 @@ func TestPostgresStoreWithCache(t *testing.T) {
|
|||||||
|
|
||||||
testLoginAttemptStore(s, t)
|
testLoginAttemptStore(s, t)
|
||||||
}
|
}
|
||||||
|
func BenchmarkPostgresStore(b *testing.B) {
|
||||||
|
var dsn string
|
||||||
|
var found bool
|
||||||
|
dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN")
|
||||||
|
if !found {
|
||||||
|
b.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.")
|
||||||
|
}
|
||||||
|
dropPGDatabase(dsn)
|
||||||
|
setupFunc := func() store.LoginAttemptStore {
|
||||||
|
dropPGDatabase(dsn)
|
||||||
|
pgs, err := store.NewPostgresStore(dsn)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Error getting store: %s", err)
|
||||||
|
}
|
||||||
|
pgs.InitDB()
|
||||||
|
return pgs
|
||||||
|
}
|
||||||
|
benchmarkLoginAttemptStore(setupFunc, b)
|
||||||
|
dropPGDatabase(dsn)
|
||||||
|
}
|
||||||
|
|
||||||
func dropPGDatabase(dsn string) {
|
func dropPGDatabase(dsn string) {
|
||||||
db, err := sql.Open("pgx", dsn)
|
db, err := sql.Open("pgx", dsn)
|
||||||
@ -54,8 +74,5 @@ func dropPGDatabase(dsn string) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = db.Exec("DROP TABLE login_attempts")
|
_, _ = db.Exec("DROP TABLE login_attempts")
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -188,6 +188,52 @@ func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func benchmarkLoginAttemptStore(setupFunc func() store.LoginAttemptStore, b *testing.B) {
|
||||||
|
b.Run("BenchmarkAdd", func(b *testing.B) {
|
||||||
|
s := setupFunc()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
attempt := randomAttempts(1)
|
||||||
|
err := s.AddAttempt(attempt[0])
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Error adding attempt: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.Run("BenchmarkAdd10k", func(b *testing.B) {
|
||||||
|
attempts := randomAttempts(10_000)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
b.StopTimer()
|
||||||
|
s := setupFunc()
|
||||||
|
b.StartTimer()
|
||||||
|
for _, attempt := range attempts {
|
||||||
|
err := s.AddAttempt(attempt)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Error adding attempt: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
b.Run("BenchmarkAll10k", func(b *testing.B) {
|
||||||
|
s := setupFunc()
|
||||||
|
attempts := randomAttempts(10_000)
|
||||||
|
for _, attempt := range attempts {
|
||||||
|
err := s.AddAttempt(attempt)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Error adding attempt: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
all, err := s.All()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Error fetchin all: %s", err)
|
||||||
|
}
|
||||||
|
_ = len(all)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func randomAttempts(count int) []*models.LoginAttempt {
|
func randomAttempts(count int) []*models.LoginAttempt {
|
||||||
var attempts []*models.LoginAttempt
|
var attempts []*models.LoginAttempt
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Version = "v0.1.17"
|
var Version = "v0.1.18"
|
||||||
var Build string
|
var Build string
|
||||||
|
|
||||||
func FullVersion() string {
|
func FullVersion() string {
|
||||||
|
Loading…
Reference in New Issue
Block a user