diff --git a/apiary.toml b/apiary.toml index bb9fb84..27197be 100644 --- a/apiary.toml +++ b/apiary.toml @@ -4,6 +4,9 @@ # Must be "memory" or "postgres" # Default: "memory" Type = "memory" +# Enable caching +# Default: false +EnableCache = false [Store.Postgres] # Connection string for postgres diff --git a/cmd/apiary.go b/cmd/apiary.go index 3e347d9..fba9ff9 100644 --- a/cmd/apiary.go +++ b/cmd/apiary.go @@ -53,11 +53,13 @@ func ActionServe(c *cli.Context) error { // Setup logging loggers := setupLoggers(cfg) + loggers.rootLogger.Infow("Startin apiary", "version", apiary.FullVersion()) // Setup store var s store.LoginAttemptStore switch cfg.Store.Type { case "MEMORY", "memory": + loggers.rootLogger.Debugw("Initializing store", "store_type", "memory") s = &store.MemoryStore{} case "POSTGRES", "postgres": pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN) @@ -67,7 +69,14 @@ func ActionServe(c *cli.Context) error { if err := pgStore.InitDB(); err != nil { return err } - s = pgStore + if cfg.Store.EnableCache { + loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres") + cachingStore := store.NewCachingStore(pgStore) + s = cachingStore + } else { + loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres") + s = pgStore + } default: return fmt.Errorf("Invalid store configured") } diff --git a/config/config.go b/config/config.go index 085a763..fe652dc 100644 --- a/config/config.go +++ b/config/config.go @@ -15,8 +15,9 @@ type Config struct { Frontend FrontendConfig `toml:"Frontend"` } type StoreConfig struct { - Type string `toml:"Type"` - Postgres PostgresStoreConfig `toml:"Postgres"` + Type string `toml:"Type"` + EnableCache bool `toml:"EnableCache"` + Postgres PostgresStoreConfig `toml:"Postgres"` } type PostgresStoreConfig struct { diff --git a/honeypot/store/cache.go b/honeypot/store/cache.go new file mode 100644 index 0000000..16799b3 --- /dev/null +++ b/honeypot/store/cache.go @@ -0,0 +1,85 @@ +package store + +import "github.uio.no/torjus/apiary/models" + +type CachingStore struct { + backend LoginAttemptStore + + totalStatsCache []StatsResult + usernameQueryCache map[string][]models.LoginAttempt + passwordQueryCache map[string][]models.LoginAttempt + ipQueryCache map[string][]models.LoginAttempt +} + +func NewCachingStore(backend LoginAttemptStore) *CachingStore { + return &CachingStore{ + backend: backend, + usernameQueryCache: make(map[string][]models.LoginAttempt), + passwordQueryCache: make(map[string][]models.LoginAttempt), + ipQueryCache: make(map[string][]models.LoginAttempt), + } +} + +func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error { + s.totalStatsCache = nil + delete(s.ipQueryCache, l.RemoteIP.String()) + delete(s.passwordQueryCache, l.Password) + delete(s.usernameQueryCache, l.Username) + return s.backend.AddAttempt(l) +} + +func (s *CachingStore) All() ([]models.LoginAttempt, error) { + return s.backend.All() +} + +func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { + // Only cache totals for now, as they are the most queried + if statType == LoginStatsTotals { + if s.totalStatsCache != nil { + return s.totalStatsCache, nil + } + stats, err := s.backend.Stats(statType, limit) + if err != nil { + return stats, err + } + s.totalStatsCache = stats + + return stats, err + } + + // TODO: Maybe cache the default limits + + return s.backend.Stats(statType, limit) +} + +func (s *CachingStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) { + switch query.QueryType { + case AttemptQueryTypeIP: + if attempts, ok := s.ipQueryCache[query.Query]; ok { + return attempts, nil + } + case AttemptQueryTypePassword: + if attempts, ok := s.passwordQueryCache[query.Query]; ok { + return attempts, nil + } + case AttemptQueryTypeUsername: + if attempts, ok := s.usernameQueryCache[query.Query]; ok { + return attempts, nil + } + } + attempts, err := s.backend.Query(query) + if err != nil { + return attempts, err + } + + switch query.QueryType { + case AttemptQueryTypeIP: + s.ipQueryCache[query.Query] = attempts + case AttemptQueryTypeUsername: + s.ipQueryCache[query.Query] = attempts + case AttemptQueryTypePassword: + s.ipQueryCache[query.Query] = attempts + } + + return attempts, err +} diff --git a/honeypot/store/cache_test.go b/honeypot/store/cache_test.go new file mode 100644 index 0000000..460ac68 --- /dev/null +++ b/honeypot/store/cache_test.go @@ -0,0 +1,13 @@ +package store_test + +import ( + "testing" + + "github.uio.no/torjus/apiary/honeypot/store" +) + +func TestCacheStore(t *testing.T) { + backend := &store.MemoryStore{} + s := store.NewCachingStore(backend) + testLoginAttemptStore(s, t) +} diff --git a/honeypot/store/memory.go b/honeypot/store/memory.go index b68fbc2..9440b4a 100644 --- a/honeypot/store/memory.go +++ b/honeypot/store/memory.go @@ -57,7 +57,7 @@ func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, err case LoginStatsUsername: counts[a.Username]++ default: - return nil, fmt.Errorf("Invalid stat type") + return nil, fmt.Errorf("invalid stat type") } } diff --git a/honeypot/store/memory_test.go b/honeypot/store/memory_test.go index 45d6585..ed13710 100644 --- a/honeypot/store/memory_test.go +++ b/honeypot/store/memory_test.go @@ -1,119 +1,12 @@ -package store +package store_test -/* -func TestStatItems(t *testing.T) { - var tc = []struct { - Input StatItems - ExpectedOutput StatItems - }{ - { - Input: StatItems{ - {Key: "a", Count: 5}, - {Key: "b", Count: 100}, - {Key: "c", Count: 99}, - {Key: "d", Count: 98}, - {Key: "f", Count: 18}, - }, - ExpectedOutput: StatItems{ - {Key: "a", Count: 5}, - {Key: "f", Count: 18}, - {Key: "d", Count: 98}, - {Key: "c", Count: 99}, - {Key: "b", Count: 100}, - }, - }, - } +import ( + "testing" - for _, testCase := range tc { - sort.Sort(testCase.Input) + "github.uio.no/torjus/apiary/honeypot/store" +) - for i := range testCase.Input { - if testCase.Input[i] != testCase.ExpectedOutput[i] { - t.Fatalf("Not sorted correctly") - } - } - } +func TestMemoryStore(t *testing.T) { + s := &store.MemoryStore{} + testLoginAttemptStore(s, t) } - -func TestStats(t *testing.T) { - var exampleAttempts = []models.LoginAttempt{ - {Username: "root", Password: "root", Country: "NO"}, - {Username: "root", Password: "root", Country: "US"}, - {Username: "user", Password: "passWord", Country: "US"}, - {Username: "ibm", Password: "ibm", Country: "US"}, - {Username: "ubnt", Password: "ubnt", Country: "GB"}, - {Username: "ubnt", Password: "ubnt", Country: "FI"}, - {Username: "root", Password: "root", Country: "CH"}, - {Username: "ubnt", Password: "12345", Country: "DE"}, - {Username: "oracle", Password: "oracle", Country: "FI"}, - } - var tc = []struct { - Attempts []models.LoginAttempt - StatType LoginStats - Limit int - ExpectedOutput map[string]int - }{ - { - Attempts: exampleAttempts, - StatType: LoginStatsPasswords, - Limit: 2, - ExpectedOutput: map[string]int{ - "root": 3, - "ubnt": 2, - }, - }, - { - Attempts: exampleAttempts, - StatType: LoginStatsPasswords, - Limit: 1, - ExpectedOutput: map[string]int{ - "root": 3, - }, - }, - { - Attempts: exampleAttempts, - StatType: LoginStatsCountry, - Limit: 2, - ExpectedOutput: map[string]int{ - "US": 3, - "FI": 2, - }, - }, - { - Attempts: exampleAttempts, - StatType: LoginStatsCountry, - Limit: 0, - ExpectedOutput: map[string]int{ - "US": 3, - "FI": 2, - "NO": 1, - "GB": 1, - "CH": 1, - "DE": 1, - }, - }, - } - - for _, c := range tc { - ms := MemoryStore{} - for _, a := range c.Attempts { - if err := ms.AddAttempt(a); err != nil { - t.Fatalf("Unable to add attempt: %s", err) - } - } - stats, err := ms.Stats(c.StatType, c.Limit) - if err != nil { - t.Fatalf("Error getting stats: %s", err) - } - if len(stats) != len(c.ExpectedOutput) { - t.Fatalf("Stats have wrong length") - } - for key := range stats { - if c.ExpectedOutput[key] != stats[key] { - t.Fatalf("Stats does not match expected output") - } - } - } - -} -*/ diff --git a/honeypot/store/postgres.go b/honeypot/store/postgres.go index 83bc571..5876f26 100644 --- a/honeypot/store/postgres.go +++ b/honeypot/store/postgres.go @@ -177,12 +177,12 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) FROM login_attempts WHERE username like $1` queryString = fmt.Sprintf("%%%s%%", queryString) default: - return nil, fmt.Errorf("Invalid query type") + return nil, fmt.Errorf("invalid query type") } rows, err := s.db.Query(stmt, queryString) if err != nil { - return nil, fmt.Errorf("Unable to query database: %w", err) + return nil, fmt.Errorf("unable to query database: %w", err) } defer rows.Close() @@ -191,7 +191,7 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) var la models.LoginAttempt var ipString string if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil { - return nil, fmt.Errorf("Unable to unmarshal data from database: %w", err) + return nil, fmt.Errorf("unable to unmarshal data from database: %w", err) } la.RemoteIP = net.ParseIP(ipString) results = append(results, la) diff --git a/honeypot/store/store_test.go b/honeypot/store/store_test.go new file mode 100644 index 0000000..fef91e8 --- /dev/null +++ b/honeypot/store/store_test.go @@ -0,0 +1,97 @@ +package store_test + +import ( + "math/rand" + "net" + "testing" + "time" + + "github.com/google/uuid" + "github.uio.no/torjus/apiary/honeypot/store" + "github.uio.no/torjus/apiary/models" +) + +func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) { + t.Run("Simple", func(t *testing.T) { + testAttempts := randomAttempts(10) + + for _, attempt := range testAttempts { + if err := s.AddAttempt(attempt); err != nil { + t.Fatalf("Error adding attempt: %s", err) + } + } + + all, err := s.All() + if err != nil { + t.Fatalf("Error getting all attempts: %s", err) + } + if len(all) != len(testAttempts) { + t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts)) + } + stats, err := s.Stats(store.LoginStatsTotals, 1) + if err != nil { + t.Errorf("Stats returned error: %s", err) + } + for _, stat := range stats { + if stat.Name == "TotalLoginAttempts" && stat.Count != len(testAttempts) { + t.Errorf("Stats for total attempts is wrong. Got %d want %d", stat.Count, len(testAttempts)) + } + } + }) +} + +func randomAttempts(count int) []*models.LoginAttempt { + var attempts []*models.LoginAttempt + for i := 0; i < count; i++ { + attempt := &models.LoginAttempt{ + Date: time.Now(), + RemoteIP: randomIP(), + Username: randomString(8), + Password: randomString(8), + Country: randomCountry(), + ConnectionUUID: uuid.Must(uuid.NewRandom()), + SSHClientVersion: "SSH TEST LOL", + } + attempts = append(attempts, attempt) + } + return attempts +} + +func randomIP() net.IP { + a := byte(rand.Intn(254)) + b := byte(rand.Intn(254)) + c := byte(rand.Intn(254)) + d := byte(rand.Intn(254)) + return net.IPv4(a, b, c, d) +} + +func randomString(n int) string { + const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + b := make([]byte, n) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return string(b) +} + +func randomCountry() string { + switch rand.Intn(10) { + case 1: + return "CN" + case 2: + return "US" + case 3: + return "NO" + case 4: + return "RU" + case 5: + return "DE" + case 6: + return "FI" + case 7: + return "BR" + default: + return "SE" + } +} diff --git a/version.go b/version.go index dbb73cf..56a1314 100644 --- a/version.go +++ b/version.go @@ -5,7 +5,7 @@ import ( "runtime" ) -var Version = "v0.1.7" +var Version = "v0.1.8" var Build string func FullVersion() string {