Add some caching
This commit is contained in:
parent
36a487cb0c
commit
a13e9c0eba
@ -4,6 +4,9 @@
|
||||
# Must be "memory" or "postgres"
|
||||
# Default: "memory"
|
||||
Type = "memory"
|
||||
# Enable caching
|
||||
# Default: false
|
||||
EnableCache = false
|
||||
|
||||
[Store.Postgres]
|
||||
# Connection string for postgres
|
||||
|
@ -53,11 +53,13 @@ func ActionServe(c *cli.Context) error {
|
||||
|
||||
// Setup logging
|
||||
loggers := setupLoggers(cfg)
|
||||
loggers.rootLogger.Infow("Startin apiary", "version", apiary.FullVersion())
|
||||
|
||||
// Setup store
|
||||
var s store.LoginAttemptStore
|
||||
switch cfg.Store.Type {
|
||||
case "MEMORY", "memory":
|
||||
loggers.rootLogger.Debugw("Initializing store", "store_type", "memory")
|
||||
s = &store.MemoryStore{}
|
||||
case "POSTGRES", "postgres":
|
||||
pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN)
|
||||
@ -67,7 +69,14 @@ func ActionServe(c *cli.Context) error {
|
||||
if err := pgStore.InitDB(); err != nil {
|
||||
return err
|
||||
}
|
||||
s = pgStore
|
||||
if cfg.Store.EnableCache {
|
||||
loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres")
|
||||
cachingStore := store.NewCachingStore(pgStore)
|
||||
s = cachingStore
|
||||
} else {
|
||||
loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres")
|
||||
s = pgStore
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("Invalid store configured")
|
||||
}
|
||||
|
@ -15,8 +15,9 @@ type Config struct {
|
||||
Frontend FrontendConfig `toml:"Frontend"`
|
||||
}
|
||||
type StoreConfig struct {
|
||||
Type string `toml:"Type"`
|
||||
Postgres PostgresStoreConfig `toml:"Postgres"`
|
||||
Type string `toml:"Type"`
|
||||
EnableCache bool `toml:"EnableCache"`
|
||||
Postgres PostgresStoreConfig `toml:"Postgres"`
|
||||
}
|
||||
|
||||
type PostgresStoreConfig struct {
|
||||
|
85
honeypot/store/cache.go
Normal file
85
honeypot/store/cache.go
Normal file
@ -0,0 +1,85 @@
|
||||
package store
|
||||
|
||||
import "github.uio.no/torjus/apiary/models"
|
||||
|
||||
type CachingStore struct {
|
||||
backend LoginAttemptStore
|
||||
|
||||
totalStatsCache []StatsResult
|
||||
usernameQueryCache map[string][]models.LoginAttempt
|
||||
passwordQueryCache map[string][]models.LoginAttempt
|
||||
ipQueryCache map[string][]models.LoginAttempt
|
||||
}
|
||||
|
||||
func NewCachingStore(backend LoginAttemptStore) *CachingStore {
|
||||
return &CachingStore{
|
||||
backend: backend,
|
||||
usernameQueryCache: make(map[string][]models.LoginAttempt),
|
||||
passwordQueryCache: make(map[string][]models.LoginAttempt),
|
||||
ipQueryCache: make(map[string][]models.LoginAttempt),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
|
||||
s.totalStatsCache = nil
|
||||
delete(s.ipQueryCache, l.RemoteIP.String())
|
||||
delete(s.passwordQueryCache, l.Password)
|
||||
delete(s.usernameQueryCache, l.Username)
|
||||
return s.backend.AddAttempt(l)
|
||||
}
|
||||
|
||||
func (s *CachingStore) All() ([]models.LoginAttempt, error) {
|
||||
return s.backend.All()
|
||||
}
|
||||
|
||||
func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
|
||||
// Only cache totals for now, as they are the most queried
|
||||
if statType == LoginStatsTotals {
|
||||
if s.totalStatsCache != nil {
|
||||
return s.totalStatsCache, nil
|
||||
}
|
||||
stats, err := s.backend.Stats(statType, limit)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
s.totalStatsCache = stats
|
||||
|
||||
return stats, err
|
||||
}
|
||||
|
||||
// TODO: Maybe cache the default limits
|
||||
|
||||
return s.backend.Stats(statType, limit)
|
||||
}
|
||||
|
||||
func (s *CachingStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
|
||||
switch query.QueryType {
|
||||
case AttemptQueryTypeIP:
|
||||
if attempts, ok := s.ipQueryCache[query.Query]; ok {
|
||||
return attempts, nil
|
||||
}
|
||||
case AttemptQueryTypePassword:
|
||||
if attempts, ok := s.passwordQueryCache[query.Query]; ok {
|
||||
return attempts, nil
|
||||
}
|
||||
case AttemptQueryTypeUsername:
|
||||
if attempts, ok := s.usernameQueryCache[query.Query]; ok {
|
||||
return attempts, nil
|
||||
}
|
||||
}
|
||||
attempts, err := s.backend.Query(query)
|
||||
if err != nil {
|
||||
return attempts, err
|
||||
}
|
||||
|
||||
switch query.QueryType {
|
||||
case AttemptQueryTypeIP:
|
||||
s.ipQueryCache[query.Query] = attempts
|
||||
case AttemptQueryTypeUsername:
|
||||
s.ipQueryCache[query.Query] = attempts
|
||||
case AttemptQueryTypePassword:
|
||||
s.ipQueryCache[query.Query] = attempts
|
||||
}
|
||||
|
||||
return attempts, err
|
||||
}
|
13
honeypot/store/cache_test.go
Normal file
13
honeypot/store/cache_test.go
Normal file
@ -0,0 +1,13 @@
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.uio.no/torjus/apiary/honeypot/store"
|
||||
)
|
||||
|
||||
func TestCacheStore(t *testing.T) {
|
||||
backend := &store.MemoryStore{}
|
||||
s := store.NewCachingStore(backend)
|
||||
testLoginAttemptStore(s, t)
|
||||
}
|
@ -57,7 +57,7 @@ func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, err
|
||||
case LoginStatsUsername:
|
||||
counts[a.Username]++
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid stat type")
|
||||
return nil, fmt.Errorf("invalid stat type")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,119 +1,12 @@
|
||||
package store
|
||||
package store_test
|
||||
|
||||
/*
|
||||
func TestStatItems(t *testing.T) {
|
||||
var tc = []struct {
|
||||
Input StatItems
|
||||
ExpectedOutput StatItems
|
||||
}{
|
||||
{
|
||||
Input: StatItems{
|
||||
{Key: "a", Count: 5},
|
||||
{Key: "b", Count: 100},
|
||||
{Key: "c", Count: 99},
|
||||
{Key: "d", Count: 98},
|
||||
{Key: "f", Count: 18},
|
||||
},
|
||||
ExpectedOutput: StatItems{
|
||||
{Key: "a", Count: 5},
|
||||
{Key: "f", Count: 18},
|
||||
{Key: "d", Count: 98},
|
||||
{Key: "c", Count: 99},
|
||||
{Key: "b", Count: 100},
|
||||
},
|
||||
},
|
||||
}
|
||||
import (
|
||||
"testing"
|
||||
|
||||
for _, testCase := range tc {
|
||||
sort.Sort(testCase.Input)
|
||||
"github.uio.no/torjus/apiary/honeypot/store"
|
||||
)
|
||||
|
||||
for i := range testCase.Input {
|
||||
if testCase.Input[i] != testCase.ExpectedOutput[i] {
|
||||
t.Fatalf("Not sorted correctly")
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestMemoryStore(t *testing.T) {
|
||||
s := &store.MemoryStore{}
|
||||
testLoginAttemptStore(s, t)
|
||||
}
|
||||
|
||||
func TestStats(t *testing.T) {
|
||||
var exampleAttempts = []models.LoginAttempt{
|
||||
{Username: "root", Password: "root", Country: "NO"},
|
||||
{Username: "root", Password: "root", Country: "US"},
|
||||
{Username: "user", Password: "passWord", Country: "US"},
|
||||
{Username: "ibm", Password: "ibm", Country: "US"},
|
||||
{Username: "ubnt", Password: "ubnt", Country: "GB"},
|
||||
{Username: "ubnt", Password: "ubnt", Country: "FI"},
|
||||
{Username: "root", Password: "root", Country: "CH"},
|
||||
{Username: "ubnt", Password: "12345", Country: "DE"},
|
||||
{Username: "oracle", Password: "oracle", Country: "FI"},
|
||||
}
|
||||
var tc = []struct {
|
||||
Attempts []models.LoginAttempt
|
||||
StatType LoginStats
|
||||
Limit int
|
||||
ExpectedOutput map[string]int
|
||||
}{
|
||||
{
|
||||
Attempts: exampleAttempts,
|
||||
StatType: LoginStatsPasswords,
|
||||
Limit: 2,
|
||||
ExpectedOutput: map[string]int{
|
||||
"root": 3,
|
||||
"ubnt": 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
Attempts: exampleAttempts,
|
||||
StatType: LoginStatsPasswords,
|
||||
Limit: 1,
|
||||
ExpectedOutput: map[string]int{
|
||||
"root": 3,
|
||||
},
|
||||
},
|
||||
{
|
||||
Attempts: exampleAttempts,
|
||||
StatType: LoginStatsCountry,
|
||||
Limit: 2,
|
||||
ExpectedOutput: map[string]int{
|
||||
"US": 3,
|
||||
"FI": 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
Attempts: exampleAttempts,
|
||||
StatType: LoginStatsCountry,
|
||||
Limit: 0,
|
||||
ExpectedOutput: map[string]int{
|
||||
"US": 3,
|
||||
"FI": 2,
|
||||
"NO": 1,
|
||||
"GB": 1,
|
||||
"CH": 1,
|
||||
"DE": 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range tc {
|
||||
ms := MemoryStore{}
|
||||
for _, a := range c.Attempts {
|
||||
if err := ms.AddAttempt(a); err != nil {
|
||||
t.Fatalf("Unable to add attempt: %s", err)
|
||||
}
|
||||
}
|
||||
stats, err := ms.Stats(c.StatType, c.Limit)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting stats: %s", err)
|
||||
}
|
||||
if len(stats) != len(c.ExpectedOutput) {
|
||||
t.Fatalf("Stats have wrong length")
|
||||
}
|
||||
for key := range stats {
|
||||
if c.ExpectedOutput[key] != stats[key] {
|
||||
t.Fatalf("Stats does not match expected output")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
*/
|
||||
|
@ -177,12 +177,12 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error)
|
||||
FROM login_attempts WHERE username like $1`
|
||||
queryString = fmt.Sprintf("%%%s%%", queryString)
|
||||
default:
|
||||
return nil, fmt.Errorf("Invalid query type")
|
||||
return nil, fmt.Errorf("invalid query type")
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(stmt, queryString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to query database: %w", err)
|
||||
return nil, fmt.Errorf("unable to query database: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
@ -191,7 +191,7 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error)
|
||||
var la models.LoginAttempt
|
||||
var ipString string
|
||||
if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil {
|
||||
return nil, fmt.Errorf("Unable to unmarshal data from database: %w", err)
|
||||
return nil, fmt.Errorf("unable to unmarshal data from database: %w", err)
|
||||
}
|
||||
la.RemoteIP = net.ParseIP(ipString)
|
||||
results = append(results, la)
|
||||
|
97
honeypot/store/store_test.go
Normal file
97
honeypot/store/store_test.go
Normal file
@ -0,0 +1,97 @@
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.uio.no/torjus/apiary/honeypot/store"
|
||||
"github.uio.no/torjus/apiary/models"
|
||||
)
|
||||
|
||||
func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) {
|
||||
t.Run("Simple", func(t *testing.T) {
|
||||
testAttempts := randomAttempts(10)
|
||||
|
||||
for _, attempt := range testAttempts {
|
||||
if err := s.AddAttempt(attempt); err != nil {
|
||||
t.Fatalf("Error adding attempt: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
all, err := s.All()
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting all attempts: %s", err)
|
||||
}
|
||||
if len(all) != len(testAttempts) {
|
||||
t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts))
|
||||
}
|
||||
stats, err := s.Stats(store.LoginStatsTotals, 1)
|
||||
if err != nil {
|
||||
t.Errorf("Stats returned error: %s", err)
|
||||
}
|
||||
for _, stat := range stats {
|
||||
if stat.Name == "TotalLoginAttempts" && stat.Count != len(testAttempts) {
|
||||
t.Errorf("Stats for total attempts is wrong. Got %d want %d", stat.Count, len(testAttempts))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func randomAttempts(count int) []*models.LoginAttempt {
|
||||
var attempts []*models.LoginAttempt
|
||||
for i := 0; i < count; i++ {
|
||||
attempt := &models.LoginAttempt{
|
||||
Date: time.Now(),
|
||||
RemoteIP: randomIP(),
|
||||
Username: randomString(8),
|
||||
Password: randomString(8),
|
||||
Country: randomCountry(),
|
||||
ConnectionUUID: uuid.Must(uuid.NewRandom()),
|
||||
SSHClientVersion: "SSH TEST LOL",
|
||||
}
|
||||
attempts = append(attempts, attempt)
|
||||
}
|
||||
return attempts
|
||||
}
|
||||
|
||||
func randomIP() net.IP {
|
||||
a := byte(rand.Intn(254))
|
||||
b := byte(rand.Intn(254))
|
||||
c := byte(rand.Intn(254))
|
||||
d := byte(rand.Intn(254))
|
||||
return net.IPv4(a, b, c, d)
|
||||
}
|
||||
|
||||
func randomString(n int) string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func randomCountry() string {
|
||||
switch rand.Intn(10) {
|
||||
case 1:
|
||||
return "CN"
|
||||
case 2:
|
||||
return "US"
|
||||
case 3:
|
||||
return "NO"
|
||||
case 4:
|
||||
return "RU"
|
||||
case 5:
|
||||
return "DE"
|
||||
case 6:
|
||||
return "FI"
|
||||
case 7:
|
||||
return "BR"
|
||||
default:
|
||||
return "SE"
|
||||
}
|
||||
}
|
@ -5,7 +5,7 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var Version = "v0.1.7"
|
||||
var Version = "v0.1.8"
|
||||
var Build string
|
||||
|
||||
func FullVersion() string {
|
||||
|
Loading…
Reference in New Issue
Block a user