Add some caching
This commit is contained in:
parent
36a487cb0c
commit
a13e9c0eba
@ -4,6 +4,9 @@
|
|||||||
# Must be "memory" or "postgres"
|
# Must be "memory" or "postgres"
|
||||||
# Default: "memory"
|
# Default: "memory"
|
||||||
Type = "memory"
|
Type = "memory"
|
||||||
|
# Enable caching
|
||||||
|
# Default: false
|
||||||
|
EnableCache = false
|
||||||
|
|
||||||
[Store.Postgres]
|
[Store.Postgres]
|
||||||
# Connection string for postgres
|
# Connection string for postgres
|
||||||
|
@ -53,11 +53,13 @@ func ActionServe(c *cli.Context) error {
|
|||||||
|
|
||||||
// Setup logging
|
// Setup logging
|
||||||
loggers := setupLoggers(cfg)
|
loggers := setupLoggers(cfg)
|
||||||
|
loggers.rootLogger.Infow("Startin apiary", "version", apiary.FullVersion())
|
||||||
|
|
||||||
// Setup store
|
// Setup store
|
||||||
var s store.LoginAttemptStore
|
var s store.LoginAttemptStore
|
||||||
switch cfg.Store.Type {
|
switch cfg.Store.Type {
|
||||||
case "MEMORY", "memory":
|
case "MEMORY", "memory":
|
||||||
|
loggers.rootLogger.Debugw("Initializing store", "store_type", "memory")
|
||||||
s = &store.MemoryStore{}
|
s = &store.MemoryStore{}
|
||||||
case "POSTGRES", "postgres":
|
case "POSTGRES", "postgres":
|
||||||
pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN)
|
pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN)
|
||||||
@ -67,7 +69,14 @@ func ActionServe(c *cli.Context) error {
|
|||||||
if err := pgStore.InitDB(); err != nil {
|
if err := pgStore.InitDB(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if cfg.Store.EnableCache {
|
||||||
|
loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres")
|
||||||
|
cachingStore := store.NewCachingStore(pgStore)
|
||||||
|
s = cachingStore
|
||||||
|
} else {
|
||||||
|
loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres")
|
||||||
s = pgStore
|
s = pgStore
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Invalid store configured")
|
return fmt.Errorf("Invalid store configured")
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,7 @@ type Config struct {
|
|||||||
}
|
}
|
||||||
type StoreConfig struct {
|
type StoreConfig struct {
|
||||||
Type string `toml:"Type"`
|
Type string `toml:"Type"`
|
||||||
|
EnableCache bool `toml:"EnableCache"`
|
||||||
Postgres PostgresStoreConfig `toml:"Postgres"`
|
Postgres PostgresStoreConfig `toml:"Postgres"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
85
honeypot/store/cache.go
Normal file
85
honeypot/store/cache.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import "github.uio.no/torjus/apiary/models"
|
||||||
|
|
||||||
|
type CachingStore struct {
|
||||||
|
backend LoginAttemptStore
|
||||||
|
|
||||||
|
totalStatsCache []StatsResult
|
||||||
|
usernameQueryCache map[string][]models.LoginAttempt
|
||||||
|
passwordQueryCache map[string][]models.LoginAttempt
|
||||||
|
ipQueryCache map[string][]models.LoginAttempt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCachingStore(backend LoginAttemptStore) *CachingStore {
|
||||||
|
return &CachingStore{
|
||||||
|
backend: backend,
|
||||||
|
usernameQueryCache: make(map[string][]models.LoginAttempt),
|
||||||
|
passwordQueryCache: make(map[string][]models.LoginAttempt),
|
||||||
|
ipQueryCache: make(map[string][]models.LoginAttempt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
|
||||||
|
s.totalStatsCache = nil
|
||||||
|
delete(s.ipQueryCache, l.RemoteIP.String())
|
||||||
|
delete(s.passwordQueryCache, l.Password)
|
||||||
|
delete(s.usernameQueryCache, l.Username)
|
||||||
|
return s.backend.AddAttempt(l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CachingStore) All() ([]models.LoginAttempt, error) {
|
||||||
|
return s.backend.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
|
||||||
|
// Only cache totals for now, as they are the most queried
|
||||||
|
if statType == LoginStatsTotals {
|
||||||
|
if s.totalStatsCache != nil {
|
||||||
|
return s.totalStatsCache, nil
|
||||||
|
}
|
||||||
|
stats, err := s.backend.Stats(statType, limit)
|
||||||
|
if err != nil {
|
||||||
|
return stats, err
|
||||||
|
}
|
||||||
|
s.totalStatsCache = stats
|
||||||
|
|
||||||
|
return stats, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Maybe cache the default limits
|
||||||
|
|
||||||
|
return s.backend.Stats(statType, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *CachingStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
|
||||||
|
switch query.QueryType {
|
||||||
|
case AttemptQueryTypeIP:
|
||||||
|
if attempts, ok := s.ipQueryCache[query.Query]; ok {
|
||||||
|
return attempts, nil
|
||||||
|
}
|
||||||
|
case AttemptQueryTypePassword:
|
||||||
|
if attempts, ok := s.passwordQueryCache[query.Query]; ok {
|
||||||
|
return attempts, nil
|
||||||
|
}
|
||||||
|
case AttemptQueryTypeUsername:
|
||||||
|
if attempts, ok := s.usernameQueryCache[query.Query]; ok {
|
||||||
|
return attempts, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attempts, err := s.backend.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return attempts, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch query.QueryType {
|
||||||
|
case AttemptQueryTypeIP:
|
||||||
|
s.ipQueryCache[query.Query] = attempts
|
||||||
|
case AttemptQueryTypeUsername:
|
||||||
|
s.ipQueryCache[query.Query] = attempts
|
||||||
|
case AttemptQueryTypePassword:
|
||||||
|
s.ipQueryCache[query.Query] = attempts
|
||||||
|
}
|
||||||
|
|
||||||
|
return attempts, err
|
||||||
|
}
|
13
honeypot/store/cache_test.go
Normal file
13
honeypot/store/cache_test.go
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
package store_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.uio.no/torjus/apiary/honeypot/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCacheStore(t *testing.T) {
|
||||||
|
backend := &store.MemoryStore{}
|
||||||
|
s := store.NewCachingStore(backend)
|
||||||
|
testLoginAttemptStore(s, t)
|
||||||
|
}
|
@ -57,7 +57,7 @@ func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, err
|
|||||||
case LoginStatsUsername:
|
case LoginStatsUsername:
|
||||||
counts[a.Username]++
|
counts[a.Username]++
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Invalid stat type")
|
return nil, fmt.Errorf("invalid stat type")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,119 +1,12 @@
|
|||||||
package store
|
package store_test
|
||||||
|
|
||||||
/*
|
import (
|
||||||
func TestStatItems(t *testing.T) {
|
"testing"
|
||||||
var tc = []struct {
|
|
||||||
Input StatItems
|
|
||||||
ExpectedOutput StatItems
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
Input: StatItems{
|
|
||||||
{Key: "a", Count: 5},
|
|
||||||
{Key: "b", Count: 100},
|
|
||||||
{Key: "c", Count: 99},
|
|
||||||
{Key: "d", Count: 98},
|
|
||||||
{Key: "f", Count: 18},
|
|
||||||
},
|
|
||||||
ExpectedOutput: StatItems{
|
|
||||||
{Key: "a", Count: 5},
|
|
||||||
{Key: "f", Count: 18},
|
|
||||||
{Key: "d", Count: 98},
|
|
||||||
{Key: "c", Count: 99},
|
|
||||||
{Key: "b", Count: 100},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, testCase := range tc {
|
"github.uio.no/torjus/apiary/honeypot/store"
|
||||||
sort.Sort(testCase.Input)
|
)
|
||||||
|
|
||||||
for i := range testCase.Input {
|
func TestMemoryStore(t *testing.T) {
|
||||||
if testCase.Input[i] != testCase.ExpectedOutput[i] {
|
s := &store.MemoryStore{}
|
||||||
t.Fatalf("Not sorted correctly")
|
testLoginAttemptStore(s, t)
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStats(t *testing.T) {
|
|
||||||
var exampleAttempts = []models.LoginAttempt{
|
|
||||||
{Username: "root", Password: "root", Country: "NO"},
|
|
||||||
{Username: "root", Password: "root", Country: "US"},
|
|
||||||
{Username: "user", Password: "passWord", Country: "US"},
|
|
||||||
{Username: "ibm", Password: "ibm", Country: "US"},
|
|
||||||
{Username: "ubnt", Password: "ubnt", Country: "GB"},
|
|
||||||
{Username: "ubnt", Password: "ubnt", Country: "FI"},
|
|
||||||
{Username: "root", Password: "root", Country: "CH"},
|
|
||||||
{Username: "ubnt", Password: "12345", Country: "DE"},
|
|
||||||
{Username: "oracle", Password: "oracle", Country: "FI"},
|
|
||||||
}
|
|
||||||
var tc = []struct {
|
|
||||||
Attempts []models.LoginAttempt
|
|
||||||
StatType LoginStats
|
|
||||||
Limit int
|
|
||||||
ExpectedOutput map[string]int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
Attempts: exampleAttempts,
|
|
||||||
StatType: LoginStatsPasswords,
|
|
||||||
Limit: 2,
|
|
||||||
ExpectedOutput: map[string]int{
|
|
||||||
"root": 3,
|
|
||||||
"ubnt": 2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Attempts: exampleAttempts,
|
|
||||||
StatType: LoginStatsPasswords,
|
|
||||||
Limit: 1,
|
|
||||||
ExpectedOutput: map[string]int{
|
|
||||||
"root": 3,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Attempts: exampleAttempts,
|
|
||||||
StatType: LoginStatsCountry,
|
|
||||||
Limit: 2,
|
|
||||||
ExpectedOutput: map[string]int{
|
|
||||||
"US": 3,
|
|
||||||
"FI": 2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Attempts: exampleAttempts,
|
|
||||||
StatType: LoginStatsCountry,
|
|
||||||
Limit: 0,
|
|
||||||
ExpectedOutput: map[string]int{
|
|
||||||
"US": 3,
|
|
||||||
"FI": 2,
|
|
||||||
"NO": 1,
|
|
||||||
"GB": 1,
|
|
||||||
"CH": 1,
|
|
||||||
"DE": 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range tc {
|
|
||||||
ms := MemoryStore{}
|
|
||||||
for _, a := range c.Attempts {
|
|
||||||
if err := ms.AddAttempt(a); err != nil {
|
|
||||||
t.Fatalf("Unable to add attempt: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stats, err := ms.Stats(c.StatType, c.Limit)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Error getting stats: %s", err)
|
|
||||||
}
|
|
||||||
if len(stats) != len(c.ExpectedOutput) {
|
|
||||||
t.Fatalf("Stats have wrong length")
|
|
||||||
}
|
|
||||||
for key := range stats {
|
|
||||||
if c.ExpectedOutput[key] != stats[key] {
|
|
||||||
t.Fatalf("Stats does not match expected output")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
@ -177,12 +177,12 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error)
|
|||||||
FROM login_attempts WHERE username like $1`
|
FROM login_attempts WHERE username like $1`
|
||||||
queryString = fmt.Sprintf("%%%s%%", queryString)
|
queryString = fmt.Sprintf("%%%s%%", queryString)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Invalid query type")
|
return nil, fmt.Errorf("invalid query type")
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := s.db.Query(stmt, queryString)
|
rows, err := s.db.Query(stmt, queryString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to query database: %w", err)
|
return nil, fmt.Errorf("unable to query database: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error)
|
|||||||
var la models.LoginAttempt
|
var la models.LoginAttempt
|
||||||
var ipString string
|
var ipString string
|
||||||
if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil {
|
if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil {
|
||||||
return nil, fmt.Errorf("Unable to unmarshal data from database: %w", err)
|
return nil, fmt.Errorf("unable to unmarshal data from database: %w", err)
|
||||||
}
|
}
|
||||||
la.RemoteIP = net.ParseIP(ipString)
|
la.RemoteIP = net.ParseIP(ipString)
|
||||||
results = append(results, la)
|
results = append(results, la)
|
||||||
|
97
honeypot/store/store_test.go
Normal file
97
honeypot/store/store_test.go
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
package store_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.uio.no/torjus/apiary/honeypot/store"
|
||||||
|
"github.uio.no/torjus/apiary/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) {
|
||||||
|
t.Run("Simple", func(t *testing.T) {
|
||||||
|
testAttempts := randomAttempts(10)
|
||||||
|
|
||||||
|
for _, attempt := range testAttempts {
|
||||||
|
if err := s.AddAttempt(attempt); err != nil {
|
||||||
|
t.Fatalf("Error adding attempt: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
all, err := s.All()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error getting all attempts: %s", err)
|
||||||
|
}
|
||||||
|
if len(all) != len(testAttempts) {
|
||||||
|
t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts))
|
||||||
|
}
|
||||||
|
stats, err := s.Stats(store.LoginStatsTotals, 1)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Stats returned error: %s", err)
|
||||||
|
}
|
||||||
|
for _, stat := range stats {
|
||||||
|
if stat.Name == "TotalLoginAttempts" && stat.Count != len(testAttempts) {
|
||||||
|
t.Errorf("Stats for total attempts is wrong. Got %d want %d", stat.Count, len(testAttempts))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomAttempts(count int) []*models.LoginAttempt {
|
||||||
|
var attempts []*models.LoginAttempt
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
attempt := &models.LoginAttempt{
|
||||||
|
Date: time.Now(),
|
||||||
|
RemoteIP: randomIP(),
|
||||||
|
Username: randomString(8),
|
||||||
|
Password: randomString(8),
|
||||||
|
Country: randomCountry(),
|
||||||
|
ConnectionUUID: uuid.Must(uuid.NewRandom()),
|
||||||
|
SSHClientVersion: "SSH TEST LOL",
|
||||||
|
}
|
||||||
|
attempts = append(attempts, attempt)
|
||||||
|
}
|
||||||
|
return attempts
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomIP() net.IP {
|
||||||
|
a := byte(rand.Intn(254))
|
||||||
|
b := byte(rand.Intn(254))
|
||||||
|
c := byte(rand.Intn(254))
|
||||||
|
d := byte(rand.Intn(254))
|
||||||
|
return net.IPv4(a, b, c, d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomString(n int) string {
|
||||||
|
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
|
|
||||||
|
b := make([]byte, n)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomCountry() string {
|
||||||
|
switch rand.Intn(10) {
|
||||||
|
case 1:
|
||||||
|
return "CN"
|
||||||
|
case 2:
|
||||||
|
return "US"
|
||||||
|
case 3:
|
||||||
|
return "NO"
|
||||||
|
case 4:
|
||||||
|
return "RU"
|
||||||
|
case 5:
|
||||||
|
return "DE"
|
||||||
|
case 6:
|
||||||
|
return "FI"
|
||||||
|
case 7:
|
||||||
|
return "BR"
|
||||||
|
default:
|
||||||
|
return "SE"
|
||||||
|
}
|
||||||
|
}
|
@ -5,7 +5,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Version = "v0.1.7"
|
var Version = "v0.1.8"
|
||||||
var Build string
|
var Build string
|
||||||
|
|
||||||
func FullVersion() string {
|
func FullVersion() string {
|
||||||
|
Loading…
Reference in New Issue
Block a user