Add some caching

This commit is contained in:
Torjus Håkestad 2021-09-17 02:01:43 +02:00
parent 36a487cb0c
commit a13e9c0eba
10 changed files with 224 additions and 123 deletions

View File

@ -4,6 +4,9 @@
# Must be "memory" or "postgres" # Must be "memory" or "postgres"
# Default: "memory" # Default: "memory"
Type = "memory" Type = "memory"
# Enable caching
# Default: false
EnableCache = false
[Store.Postgres] [Store.Postgres]
# Connection string for postgres # Connection string for postgres

View File

@ -53,11 +53,13 @@ func ActionServe(c *cli.Context) error {
// Setup logging // Setup logging
loggers := setupLoggers(cfg) loggers := setupLoggers(cfg)
loggers.rootLogger.Infow("Startin apiary", "version", apiary.FullVersion())
// Setup store // Setup store
var s store.LoginAttemptStore var s store.LoginAttemptStore
switch cfg.Store.Type { switch cfg.Store.Type {
case "MEMORY", "memory": case "MEMORY", "memory":
loggers.rootLogger.Debugw("Initializing store", "store_type", "memory")
s = &store.MemoryStore{} s = &store.MemoryStore{}
case "POSTGRES", "postgres": case "POSTGRES", "postgres":
pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN) pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN)
@ -67,7 +69,14 @@ func ActionServe(c *cli.Context) error {
if err := pgStore.InitDB(); err != nil { if err := pgStore.InitDB(); err != nil {
return err return err
} }
if cfg.Store.EnableCache {
loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres")
cachingStore := store.NewCachingStore(pgStore)
s = cachingStore
} else {
loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres")
s = pgStore s = pgStore
}
default: default:
return fmt.Errorf("Invalid store configured") return fmt.Errorf("Invalid store configured")
} }

View File

@ -16,6 +16,7 @@ type Config struct {
} }
type StoreConfig struct { type StoreConfig struct {
Type string `toml:"Type"` Type string `toml:"Type"`
EnableCache bool `toml:"EnableCache"`
Postgres PostgresStoreConfig `toml:"Postgres"` Postgres PostgresStoreConfig `toml:"Postgres"`
} }

85
honeypot/store/cache.go Normal file
View File

@ -0,0 +1,85 @@
package store
import "github.uio.no/torjus/apiary/models"
type CachingStore struct {
backend LoginAttemptStore
totalStatsCache []StatsResult
usernameQueryCache map[string][]models.LoginAttempt
passwordQueryCache map[string][]models.LoginAttempt
ipQueryCache map[string][]models.LoginAttempt
}
func NewCachingStore(backend LoginAttemptStore) *CachingStore {
return &CachingStore{
backend: backend,
usernameQueryCache: make(map[string][]models.LoginAttempt),
passwordQueryCache: make(map[string][]models.LoginAttempt),
ipQueryCache: make(map[string][]models.LoginAttempt),
}
}
func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
s.totalStatsCache = nil
delete(s.ipQueryCache, l.RemoteIP.String())
delete(s.passwordQueryCache, l.Password)
delete(s.usernameQueryCache, l.Username)
return s.backend.AddAttempt(l)
}
func (s *CachingStore) All() ([]models.LoginAttempt, error) {
return s.backend.All()
}
func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
// Only cache totals for now, as they are the most queried
if statType == LoginStatsTotals {
if s.totalStatsCache != nil {
return s.totalStatsCache, nil
}
stats, err := s.backend.Stats(statType, limit)
if err != nil {
return stats, err
}
s.totalStatsCache = stats
return stats, err
}
// TODO: Maybe cache the default limits
return s.backend.Stats(statType, limit)
}
func (s *CachingStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
switch query.QueryType {
case AttemptQueryTypeIP:
if attempts, ok := s.ipQueryCache[query.Query]; ok {
return attempts, nil
}
case AttemptQueryTypePassword:
if attempts, ok := s.passwordQueryCache[query.Query]; ok {
return attempts, nil
}
case AttemptQueryTypeUsername:
if attempts, ok := s.usernameQueryCache[query.Query]; ok {
return attempts, nil
}
}
attempts, err := s.backend.Query(query)
if err != nil {
return attempts, err
}
switch query.QueryType {
case AttemptQueryTypeIP:
s.ipQueryCache[query.Query] = attempts
case AttemptQueryTypeUsername:
s.ipQueryCache[query.Query] = attempts
case AttemptQueryTypePassword:
s.ipQueryCache[query.Query] = attempts
}
return attempts, err
}

View File

@ -0,0 +1,13 @@
package store_test
import (
"testing"
"github.uio.no/torjus/apiary/honeypot/store"
)
func TestCacheStore(t *testing.T) {
backend := &store.MemoryStore{}
s := store.NewCachingStore(backend)
testLoginAttemptStore(s, t)
}

View File

@ -57,7 +57,7 @@ func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, err
case LoginStatsUsername: case LoginStatsUsername:
counts[a.Username]++ counts[a.Username]++
default: default:
return nil, fmt.Errorf("Invalid stat type") return nil, fmt.Errorf("invalid stat type")
} }
} }

View File

@ -1,119 +1,12 @@
package store package store_test
/* import (
func TestStatItems(t *testing.T) { "testing"
var tc = []struct {
Input StatItems
ExpectedOutput StatItems
}{
{
Input: StatItems{
{Key: "a", Count: 5},
{Key: "b", Count: 100},
{Key: "c", Count: 99},
{Key: "d", Count: 98},
{Key: "f", Count: 18},
},
ExpectedOutput: StatItems{
{Key: "a", Count: 5},
{Key: "f", Count: 18},
{Key: "d", Count: 98},
{Key: "c", Count: 99},
{Key: "b", Count: 100},
},
},
}
for _, testCase := range tc { "github.uio.no/torjus/apiary/honeypot/store"
sort.Sort(testCase.Input) )
for i := range testCase.Input { func TestMemoryStore(t *testing.T) {
if testCase.Input[i] != testCase.ExpectedOutput[i] { s := &store.MemoryStore{}
t.Fatalf("Not sorted correctly") testLoginAttemptStore(s, t)
} }
}
}
}
func TestStats(t *testing.T) {
var exampleAttempts = []models.LoginAttempt{
{Username: "root", Password: "root", Country: "NO"},
{Username: "root", Password: "root", Country: "US"},
{Username: "user", Password: "passWord", Country: "US"},
{Username: "ibm", Password: "ibm", Country: "US"},
{Username: "ubnt", Password: "ubnt", Country: "GB"},
{Username: "ubnt", Password: "ubnt", Country: "FI"},
{Username: "root", Password: "root", Country: "CH"},
{Username: "ubnt", Password: "12345", Country: "DE"},
{Username: "oracle", Password: "oracle", Country: "FI"},
}
var tc = []struct {
Attempts []models.LoginAttempt
StatType LoginStats
Limit int
ExpectedOutput map[string]int
}{
{
Attempts: exampleAttempts,
StatType: LoginStatsPasswords,
Limit: 2,
ExpectedOutput: map[string]int{
"root": 3,
"ubnt": 2,
},
},
{
Attempts: exampleAttempts,
StatType: LoginStatsPasswords,
Limit: 1,
ExpectedOutput: map[string]int{
"root": 3,
},
},
{
Attempts: exampleAttempts,
StatType: LoginStatsCountry,
Limit: 2,
ExpectedOutput: map[string]int{
"US": 3,
"FI": 2,
},
},
{
Attempts: exampleAttempts,
StatType: LoginStatsCountry,
Limit: 0,
ExpectedOutput: map[string]int{
"US": 3,
"FI": 2,
"NO": 1,
"GB": 1,
"CH": 1,
"DE": 1,
},
},
}
for _, c := range tc {
ms := MemoryStore{}
for _, a := range c.Attempts {
if err := ms.AddAttempt(a); err != nil {
t.Fatalf("Unable to add attempt: %s", err)
}
}
stats, err := ms.Stats(c.StatType, c.Limit)
if err != nil {
t.Fatalf("Error getting stats: %s", err)
}
if len(stats) != len(c.ExpectedOutput) {
t.Fatalf("Stats have wrong length")
}
for key := range stats {
if c.ExpectedOutput[key] != stats[key] {
t.Fatalf("Stats does not match expected output")
}
}
}
}
*/

View File

@ -177,12 +177,12 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error)
FROM login_attempts WHERE username like $1` FROM login_attempts WHERE username like $1`
queryString = fmt.Sprintf("%%%s%%", queryString) queryString = fmt.Sprintf("%%%s%%", queryString)
default: default:
return nil, fmt.Errorf("Invalid query type") return nil, fmt.Errorf("invalid query type")
} }
rows, err := s.db.Query(stmt, queryString) rows, err := s.db.Query(stmt, queryString)
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to query database: %w", err) return nil, fmt.Errorf("unable to query database: %w", err)
} }
defer rows.Close() defer rows.Close()
@ -191,7 +191,7 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error)
var la models.LoginAttempt var la models.LoginAttempt
var ipString string var ipString string
if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil { if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil {
return nil, fmt.Errorf("Unable to unmarshal data from database: %w", err) return nil, fmt.Errorf("unable to unmarshal data from database: %w", err)
} }
la.RemoteIP = net.ParseIP(ipString) la.RemoteIP = net.ParseIP(ipString)
results = append(results, la) results = append(results, la)

View File

@ -0,0 +1,97 @@
package store_test
import (
"math/rand"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.uio.no/torjus/apiary/honeypot/store"
"github.uio.no/torjus/apiary/models"
)
func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) {
t.Run("Simple", func(t *testing.T) {
testAttempts := randomAttempts(10)
for _, attempt := range testAttempts {
if err := s.AddAttempt(attempt); err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
}
all, err := s.All()
if err != nil {
t.Fatalf("Error getting all attempts: %s", err)
}
if len(all) != len(testAttempts) {
t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts))
}
stats, err := s.Stats(store.LoginStatsTotals, 1)
if err != nil {
t.Errorf("Stats returned error: %s", err)
}
for _, stat := range stats {
if stat.Name == "TotalLoginAttempts" && stat.Count != len(testAttempts) {
t.Errorf("Stats for total attempts is wrong. Got %d want %d", stat.Count, len(testAttempts))
}
}
})
}
func randomAttempts(count int) []*models.LoginAttempt {
var attempts []*models.LoginAttempt
for i := 0; i < count; i++ {
attempt := &models.LoginAttempt{
Date: time.Now(),
RemoteIP: randomIP(),
Username: randomString(8),
Password: randomString(8),
Country: randomCountry(),
ConnectionUUID: uuid.Must(uuid.NewRandom()),
SSHClientVersion: "SSH TEST LOL",
}
attempts = append(attempts, attempt)
}
return attempts
}
func randomIP() net.IP {
a := byte(rand.Intn(254))
b := byte(rand.Intn(254))
c := byte(rand.Intn(254))
d := byte(rand.Intn(254))
return net.IPv4(a, b, c, d)
}
func randomString(n int) string {
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
func randomCountry() string {
switch rand.Intn(10) {
case 1:
return "CN"
case 2:
return "US"
case 3:
return "NO"
case 4:
return "RU"
case 5:
return "DE"
case 6:
return "FI"
case 7:
return "BR"
default:
return "SE"
}
}

View File

@ -5,7 +5,7 @@ import (
"runtime" "runtime"
) )
var Version = "v0.1.7" var Version = "v0.1.8"
var Build string var Build string
func FullVersion() string { func FullVersion() string {