Reorganize packages
This commit is contained in:
85
honeypot/ssh/store/cache.go
Normal file
85
honeypot/ssh/store/cache.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package store
|
||||
|
||||
import "github.uio.no/torjus/apiary/models"
|
||||
|
||||
type CachingStore struct {
|
||||
backend LoginAttemptStore
|
||||
|
||||
totalStatsCache []StatsResult
|
||||
usernameQueryCache map[string][]models.LoginAttempt
|
||||
passwordQueryCache map[string][]models.LoginAttempt
|
||||
ipQueryCache map[string][]models.LoginAttempt
|
||||
}
|
||||
|
||||
func NewCachingStore(backend LoginAttemptStore) *CachingStore {
|
||||
return &CachingStore{
|
||||
backend: backend,
|
||||
usernameQueryCache: make(map[string][]models.LoginAttempt),
|
||||
passwordQueryCache: make(map[string][]models.LoginAttempt),
|
||||
ipQueryCache: make(map[string][]models.LoginAttempt),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
|
||||
s.totalStatsCache = nil
|
||||
delete(s.ipQueryCache, l.RemoteIP.String())
|
||||
delete(s.passwordQueryCache, l.Password)
|
||||
delete(s.usernameQueryCache, l.Username)
|
||||
return s.backend.AddAttempt(l)
|
||||
}
|
||||
|
||||
func (s *CachingStore) All() ([]models.LoginAttempt, error) {
|
||||
return s.backend.All()
|
||||
}
|
||||
|
||||
func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
|
||||
// Only cache totals for now, as they are the most queried
|
||||
if statType == LoginStatsTotals {
|
||||
if s.totalStatsCache != nil {
|
||||
return s.totalStatsCache, nil
|
||||
}
|
||||
stats, err := s.backend.Stats(statType, limit)
|
||||
if err != nil {
|
||||
return stats, err
|
||||
}
|
||||
s.totalStatsCache = stats
|
||||
|
||||
return stats, err
|
||||
}
|
||||
|
||||
// TODO: Maybe cache the default limits
|
||||
|
||||
return s.backend.Stats(statType, limit)
|
||||
}
|
||||
|
||||
func (s *CachingStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
|
||||
switch query.QueryType {
|
||||
case AttemptQueryTypeIP:
|
||||
if attempts, ok := s.ipQueryCache[query.Query]; ok {
|
||||
return attempts, nil
|
||||
}
|
||||
case AttemptQueryTypePassword:
|
||||
if attempts, ok := s.passwordQueryCache[query.Query]; ok {
|
||||
return attempts, nil
|
||||
}
|
||||
case AttemptQueryTypeUsername:
|
||||
if attempts, ok := s.usernameQueryCache[query.Query]; ok {
|
||||
return attempts, nil
|
||||
}
|
||||
}
|
||||
attempts, err := s.backend.Query(query)
|
||||
if err != nil {
|
||||
return attempts, err
|
||||
}
|
||||
|
||||
switch query.QueryType {
|
||||
case AttemptQueryTypeIP:
|
||||
s.ipQueryCache[query.Query] = attempts
|
||||
case AttemptQueryTypeUsername:
|
||||
s.ipQueryCache[query.Query] = attempts
|
||||
case AttemptQueryTypePassword:
|
||||
s.ipQueryCache[query.Query] = attempts
|
||||
}
|
||||
|
||||
return attempts, err
|
||||
}
|
13
honeypot/ssh/store/cache_test.go
Normal file
13
honeypot/ssh/store/cache_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.uio.no/torjus/apiary/honeypot/ssh/store"
|
||||
)
|
||||
|
||||
func TestCacheStore(t *testing.T) {
|
||||
backend := &store.MemoryStore{}
|
||||
s := store.NewCachingStore(backend)
|
||||
testLoginAttemptStore(s, t)
|
||||
}
|
153
honeypot/ssh/store/memory.go
Normal file
153
honeypot/ssh/store/memory.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.uio.no/torjus/apiary/models"
|
||||
)
|
||||
|
||||
type MemoryStore struct {
|
||||
lock sync.RWMutex
|
||||
attempts []models.LoginAttempt
|
||||
currentID int
|
||||
}
|
||||
|
||||
type StatItem struct {
|
||||
Key string
|
||||
Count int
|
||||
}
|
||||
|
||||
type StatItems []StatItem
|
||||
|
||||
func (ms *MemoryStore) AddAttempt(l *models.LoginAttempt) error {
|
||||
ms.lock.Lock()
|
||||
defer ms.lock.Unlock()
|
||||
l.ID = ms.currentID + 1
|
||||
ms.currentID = ms.currentID + 1
|
||||
|
||||
ms.attempts = append(ms.attempts, *l)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *MemoryStore) All() ([]models.LoginAttempt, error) {
|
||||
return ms.attempts, nil
|
||||
}
|
||||
|
||||
func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
|
||||
counts := make(map[string]int)
|
||||
|
||||
if statType == LoginStatsTotals {
|
||||
return ms.statTotals()
|
||||
}
|
||||
|
||||
ms.lock.RLock()
|
||||
defer ms.lock.RUnlock()
|
||||
|
||||
for _, a := range ms.attempts {
|
||||
switch statType {
|
||||
case LoginStatsPasswords:
|
||||
counts[a.Password]++
|
||||
case LoginStatsCountry:
|
||||
counts[a.Country]++
|
||||
case LoginStatsIP:
|
||||
counts[a.RemoteIP.String()]++
|
||||
case LoginStatsUsername:
|
||||
counts[a.Username]++
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid stat type")
|
||||
}
|
||||
}
|
||||
|
||||
if limit < 1 {
|
||||
return toResults(counts), nil
|
||||
}
|
||||
if limit >= len(counts) {
|
||||
return toResults(counts), nil
|
||||
}
|
||||
|
||||
var si StatItems
|
||||
for key := range counts {
|
||||
si = append(si, StatItem{Key: key, Count: counts[key]})
|
||||
}
|
||||
sort.Sort(si)
|
||||
|
||||
output := make(map[string]int)
|
||||
for i := len(si) - 1; i > len(si)-limit-1; i-- {
|
||||
output[si[i].Key] = si[i].Count
|
||||
}
|
||||
return toResults(output), nil
|
||||
}
|
||||
|
||||
func (ss StatItems) Len() int {
|
||||
return len(ss)
|
||||
}
|
||||
func (ss StatItems) Less(i, j int) bool {
|
||||
return ss[i].Count < ss[j].Count
|
||||
}
|
||||
func (ss StatItems) Swap(i, j int) {
|
||||
ss[i], ss[j] = ss[j], ss[i]
|
||||
}
|
||||
|
||||
func (ms *MemoryStore) statTotals() ([]StatsResult, error) {
|
||||
passwords := make(map[string]int)
|
||||
usernames := make(map[string]int)
|
||||
ips := make(map[string]int)
|
||||
countries := make(map[string]int)
|
||||
|
||||
ms.lock.RLock()
|
||||
defer ms.lock.RUnlock()
|
||||
|
||||
for _, val := range ms.attempts {
|
||||
passwords[val.Password] += 1
|
||||
usernames[val.Username] += 1
|
||||
ips[val.RemoteIP.String()] += 1
|
||||
countries[val.Country] += 1
|
||||
}
|
||||
|
||||
stats := []StatsResult{
|
||||
{Name: "UniquePasswords", Count: len(passwords)},
|
||||
{Name: "UniqueUsernames", Count: len(usernames)},
|
||||
{Name: "UniqueIPs", Count: len(ips)},
|
||||
{Name: "UniqueCountries", Count: len(countries)},
|
||||
{Name: "TotalLoginAttempts", Count: len(ms.attempts)},
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (ms *MemoryStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
|
||||
var results []models.LoginAttempt
|
||||
ms.lock.Lock()
|
||||
defer ms.lock.Unlock()
|
||||
|
||||
for _, la := range ms.attempts {
|
||||
switch query.QueryType {
|
||||
case AttemptQueryTypeIP:
|
||||
if la.RemoteIP.String() == query.Query {
|
||||
results = append(results, la)
|
||||
}
|
||||
case AttemptQueryTypePassword:
|
||||
if strings.Contains(la.Password, query.Query) {
|
||||
results = append(results, la)
|
||||
}
|
||||
case AttemptQueryTypeUsername:
|
||||
if strings.Contains(la.Username, query.Query) {
|
||||
results = append(results, la)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func toResults(m map[string]int) []StatsResult {
|
||||
var results []StatsResult
|
||||
|
||||
for key, value := range m {
|
||||
results = append(results, StatsResult{key, value})
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
17
honeypot/ssh/store/memory_test.go
Normal file
17
honeypot/ssh/store/memory_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.uio.no/torjus/apiary/honeypot/ssh/store"
|
||||
)
|
||||
|
||||
func TestMemoryStore(t *testing.T) {
|
||||
s := &store.MemoryStore{}
|
||||
testLoginAttemptStore(s, t)
|
||||
}
|
||||
func TestMemoryStoreWithCache(t *testing.T) {
|
||||
backend := &store.MemoryStore{}
|
||||
s := store.NewCachingStore(backend)
|
||||
testLoginAttemptStore(s, t)
|
||||
}
|
204
honeypot/ssh/store/postgres.go
Normal file
204
honeypot/ssh/store/postgres.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.uio.no/torjus/apiary/models"
|
||||
)
|
||||
|
||||
const DBSchema = `
|
||||
CREATE TABLE IF NOT EXISTS login_attempts(
|
||||
id serial PRIMARY KEY,
|
||||
date timestamptz,
|
||||
remote_ip inet,
|
||||
username text,
|
||||
password text,
|
||||
client_version text,
|
||||
connection_uuid uuid,
|
||||
country varchar(2)
|
||||
);
|
||||
`
|
||||
|
||||
type PostgresStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewPostgresStore(dsn string) (*PostgresStore, error) {
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rs := &PostgresStore{
|
||||
db: db,
|
||||
}
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
func (s *PostgresStore) InitDB() error {
|
||||
_, err := s.db.Exec(DBSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *PostgresStore) AddAttempt(l *models.LoginAttempt) error {
|
||||
stmt := `INSERT INTO
|
||||
login_attempts(date, remote_ip, username, password, client_version, connection_uuid, country)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
RETURNING id;`
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
var id int
|
||||
if err := tx.QueryRow(stmt, l.Date, l.RemoteIP.String(), l.Username, l.Password, l.SSHClientVersion, l.ConnectionUUID, l.Country).Scan(&id); err != nil {
|
||||
return err
|
||||
}
|
||||
l.ID = id
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *PostgresStore) All() ([]models.LoginAttempt, error) {
|
||||
stmt := `SELECT date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts`
|
||||
|
||||
rows, err := s.db.Query(stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var attempts []models.LoginAttempt
|
||||
for rows.Next() {
|
||||
var a models.LoginAttempt
|
||||
var ip string
|
||||
if err := rows.Scan(&a.Date, &ip, &a.Username, &a.Password, &a.SSHClientVersion, &a.SSHClientVersion, &a.Country); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.RemoteIP = net.ParseIP(ip)
|
||||
attempts = append(attempts, a)
|
||||
}
|
||||
|
||||
return attempts, nil
|
||||
}
|
||||
|
||||
func (s *PostgresStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
|
||||
var stmt string
|
||||
|
||||
if statType == LoginStatsTotals {
|
||||
return s.statsTotal(limit)
|
||||
}
|
||||
|
||||
switch statType {
|
||||
case LoginStatsCountry:
|
||||
stmt = `select country, count(country) from login_attempts group by country order by count desc`
|
||||
case LoginStatsIP:
|
||||
stmt = `select remote_ip, count(remote_ip) from login_attempts group by remote_ip order by count desc`
|
||||
case LoginStatsPasswords:
|
||||
stmt = `select password, count(password) from login_attempts group by password order by count desc`
|
||||
case LoginStatsUsername:
|
||||
stmt = `select username, count(username) from login_attempts group by username order by count desc`
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid stat type")
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
stmt = fmt.Sprintf("%s limit %d", stmt, limit)
|
||||
}
|
||||
rows, err := s.db.Query(stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []StatsResult
|
||||
for rows.Next() {
|
||||
var r StatsResult
|
||||
|
||||
if err := rows.Scan(&r.Name, &r.Count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, r)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *PostgresStore) statsTotal(limit int) ([]StatsResult, error) {
|
||||
uniquePasswordStmt := `select count(*) from (select distinct password from login_attempts) as temp`
|
||||
uniqueUsernameStmt := `select count(*) from (select distinct username from login_attempts) as temp`
|
||||
uniqueIPStmt := `select count(*) from (select distinct remote_ip from login_attempts) as temp`
|
||||
uniqueCountryStmt := `select count(*) from (select distinct country from login_attempts) as temp`
|
||||
attemptsCountStmt := `select count(1) from login_attempts`
|
||||
|
||||
var uniquePasswordsCount int
|
||||
if err := s.db.QueryRow(uniquePasswordStmt).Scan(&uniquePasswordsCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var uniqueUsernameCount int
|
||||
if err := s.db.QueryRow(uniqueUsernameStmt).Scan(&uniqueUsernameCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var uniqueIPCount int
|
||||
if err := s.db.QueryRow(uniqueIPStmt).Scan(&uniqueIPCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var uniqueCountryCount int
|
||||
if err := s.db.QueryRow(uniqueCountryStmt).Scan(&uniqueCountryCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var attemptsCount int
|
||||
if err := s.db.QueryRow(attemptsCountStmt).Scan(&attemptsCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []StatsResult{
|
||||
{Name: "UniquePasswords", Count: uniquePasswordsCount},
|
||||
{Name: "UniqueUsernames", Count: uniqueUsernameCount},
|
||||
{Name: "UniqueIPs", Count: uniqueIPCount},
|
||||
{Name: "UniqueCountries", Count: uniqueCountryCount},
|
||||
{Name: "TotalLoginAttempts", Count: attemptsCount},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
|
||||
var stmt string
|
||||
queryString := query.Query
|
||||
|
||||
switch query.QueryType {
|
||||
case AttemptQueryTypeIP:
|
||||
stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country
|
||||
FROM login_attempts WHERE remote_ip = $1`
|
||||
case AttemptQueryTypePassword:
|
||||
stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country
|
||||
FROM login_attempts WHERE password like $1`
|
||||
queryString = fmt.Sprintf("%%%s%%", queryString)
|
||||
case AttemptQueryTypeUsername:
|
||||
stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country
|
||||
FROM login_attempts WHERE username like $1`
|
||||
queryString = fmt.Sprintf("%%%s%%", queryString)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid query type")
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(stmt, queryString)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to query database: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []models.LoginAttempt
|
||||
for rows.Next() {
|
||||
var la models.LoginAttempt
|
||||
var ipString string
|
||||
if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal data from database: %w", err)
|
||||
}
|
||||
la.RemoteIP = net.ParseIP(ipString)
|
||||
results = append(results, la)
|
||||
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
61
honeypot/ssh/store/postgres_test.go
Normal file
61
honeypot/ssh/store/postgres_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.uio.no/torjus/apiary/honeypot/ssh/store"
|
||||
)
|
||||
|
||||
func TestPostgresStore(t *testing.T) {
|
||||
var dsn string
|
||||
var found bool
|
||||
dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN")
|
||||
if !found {
|
||||
t.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.")
|
||||
}
|
||||
|
||||
dropPGDatabase(dsn)
|
||||
|
||||
s, err := store.NewPostgresStore(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting store: %s", err)
|
||||
}
|
||||
|
||||
s.InitDB()
|
||||
|
||||
testLoginAttemptStore(s, t)
|
||||
}
|
||||
func TestPostgresStoreWithCache(t *testing.T) {
|
||||
var dsn string
|
||||
var found bool
|
||||
dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN")
|
||||
if !found {
|
||||
t.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.")
|
||||
}
|
||||
|
||||
dropPGDatabase(dsn)
|
||||
|
||||
pgs, err := store.NewPostgresStore(dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting store: %s", err)
|
||||
}
|
||||
|
||||
pgs.InitDB()
|
||||
s := store.NewCachingStore(pgs)
|
||||
|
||||
testLoginAttemptStore(s, t)
|
||||
}
|
||||
|
||||
func dropPGDatabase(dsn string) {
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("DROP TABLE login_attempts")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
38
honeypot/ssh/store/store.go
Normal file
38
honeypot/ssh/store/store.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package store
|
||||
|
||||
import "github.uio.no/torjus/apiary/models"
|
||||
|
||||
type LoginStats string
|
||||
|
||||
const (
|
||||
LoginStatsUndefined LoginStats = ""
|
||||
LoginStatsPasswords LoginStats = "password"
|
||||
LoginStatsCountry LoginStats = "country"
|
||||
LoginStatsIP LoginStats = "ip"
|
||||
LoginStatsUsername LoginStats = "username"
|
||||
LoginStatsTotals LoginStats = "total"
|
||||
)
|
||||
|
||||
type AttemptQueryType string
|
||||
|
||||
const (
|
||||
AttemptQueryTypeUsername AttemptQueryType = "username"
|
||||
AttemptQueryTypePassword AttemptQueryType = "password"
|
||||
AttemptQueryTypeIP AttemptQueryType = "ip"
|
||||
)
|
||||
|
||||
type StatsResult struct {
|
||||
Name string `json:"name"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
type AttemptQuery struct {
|
||||
QueryType AttemptQueryType
|
||||
Query string
|
||||
}
|
||||
type LoginAttemptStore interface {
|
||||
AddAttempt(l *models.LoginAttempt) error
|
||||
All() ([]models.LoginAttempt, error)
|
||||
Stats(statType LoginStats, limit int) ([]StatsResult, error)
|
||||
Query(query AttemptQuery) ([]models.LoginAttempt, error)
|
||||
}
|
271
honeypot/ssh/store/store_test.go
Normal file
271
honeypot/ssh/store/store_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package store_test
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.uio.no/torjus/apiary/honeypot/ssh/store"
|
||||
"github.uio.no/torjus/apiary/models"
|
||||
)
|
||||
|
||||
func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) {
|
||||
t.Run("Simple", func(t *testing.T) {
|
||||
testAttempts := randomAttempts(10)
|
||||
|
||||
for _, attempt := range testAttempts {
|
||||
if err := s.AddAttempt(attempt); err != nil {
|
||||
t.Fatalf("Error adding attempt: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
all, err := s.All()
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting all attempts: %s", err)
|
||||
}
|
||||
if len(all) != len(testAttempts) {
|
||||
t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts))
|
||||
}
|
||||
stats, err := s.Stats(store.LoginStatsTotals, 1)
|
||||
if err != nil {
|
||||
t.Errorf("Stats returned error: %s", err)
|
||||
}
|
||||
for _, stat := range stats {
|
||||
if stat.Name == "TotalLoginAttempts" && stat.Count != len(testAttempts) {
|
||||
t.Errorf("Stats for total attempts is wrong. Got %d want %d", stat.Count, len(testAttempts))
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("Query", func(t *testing.T) {
|
||||
testAttempts := []*models.LoginAttempt{
|
||||
{
|
||||
Date: time.Now(),
|
||||
RemoteIP: net.ParseIP("127.0.0.1"),
|
||||
Username: "corndog",
|
||||
Password: "corndog",
|
||||
},
|
||||
{
|
||||
Date: time.Now(),
|
||||
RemoteIP: net.ParseIP("127.0.0.1"),
|
||||
Username: "corndog",
|
||||
Password: "c0rnd0g",
|
||||
},
|
||||
{
|
||||
Date: time.Now(),
|
||||
RemoteIP: net.ParseIP("10.0.0.1"),
|
||||
Username: "root",
|
||||
Password: "password",
|
||||
},
|
||||
{
|
||||
Date: time.Now(),
|
||||
RemoteIP: net.ParseIP("10.0.0.2"),
|
||||
Username: "ubnt",
|
||||
Password: "password",
|
||||
},
|
||||
}
|
||||
|
||||
for _, attempt := range testAttempts {
|
||||
err := s.AddAttempt(attempt)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to add attempt: %s", err)
|
||||
}
|
||||
}
|
||||
testCases := []struct {
|
||||
Name string
|
||||
Query store.AttemptQuery
|
||||
ExpectedResult []models.LoginAttempt
|
||||
}{
|
||||
{
|
||||
Name: "password one result",
|
||||
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypePassword, Query: "corndog"},
|
||||
ExpectedResult: []models.LoginAttempt{
|
||||
{
|
||||
RemoteIP: net.ParseIP("127.0.0.1"),
|
||||
Username: "corndog",
|
||||
Password: "corndog",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "username one result",
|
||||
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "root"},
|
||||
ExpectedResult: []models.LoginAttempt{
|
||||
{
|
||||
RemoteIP: net.ParseIP("10.0.0.1"),
|
||||
Username: "root",
|
||||
Password: "password",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "username two results",
|
||||
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "corndog"},
|
||||
ExpectedResult: []models.LoginAttempt{
|
||||
{
|
||||
RemoteIP: net.ParseIP("127.0.0.1"),
|
||||
Username: "corndog",
|
||||
Password: "c0rnd0g",
|
||||
},
|
||||
{
|
||||
RemoteIP: net.ParseIP("127.0.0.1"),
|
||||
Username: "corndog",
|
||||
Password: "corndog",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
res, err := s.Query(tc.Query)
|
||||
if err != nil {
|
||||
t.Errorf("Error performing query: %s", err)
|
||||
}
|
||||
if !equalAttempts(res, tc.ExpectedResult) {
|
||||
t.Errorf("Query did not return expected results")
|
||||
t.Logf("%+v", res)
|
||||
t.Logf("%+v", tc.ExpectedResult)
|
||||
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("QueryCache", func(t *testing.T) {
|
||||
err := s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("Error adding attempt: %s", err)
|
||||
}
|
||||
|
||||
res, err := s.Query(store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("Error adding attempt: %s", err)
|
||||
}
|
||||
if len(res) != 1 {
|
||||
t.Errorf("Wrong amount of results")
|
||||
}
|
||||
|
||||
err = s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "best"})
|
||||
if err != nil {
|
||||
t.Fatalf("Error adding attempt: %s", err)
|
||||
}
|
||||
res, err = s.Query(store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("Error adding attempt: %s", err)
|
||||
}
|
||||
if len(res) != 2 {
|
||||
t.Errorf("Wrong amount of results")
|
||||
}
|
||||
})
|
||||
t.Run("QueryStats", func(t *testing.T) {
|
||||
firstStats, err := s.Stats(store.LoginStatsTotals, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting stats: %s", err)
|
||||
}
|
||||
err = s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "best"})
|
||||
if err != nil {
|
||||
t.Fatalf("Error adding attempt: %s", err)
|
||||
}
|
||||
secondStats, err := s.Stats(store.LoginStatsTotals, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting stats: %s", err)
|
||||
}
|
||||
var firstCount, secondCount int
|
||||
for _, stat := range firstStats {
|
||||
if stat.Name == "TotalLoginAttempts" {
|
||||
firstCount = stat.Count
|
||||
}
|
||||
}
|
||||
for _, stat := range secondStats {
|
||||
if stat.Name == "TotalLoginAttempts" {
|
||||
secondCount = stat.Count
|
||||
}
|
||||
}
|
||||
|
||||
if secondCount != firstCount+1 {
|
||||
t.Errorf("TotalLoginAttempts did not increment")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func randomAttempts(count int) []*models.LoginAttempt {
|
||||
var attempts []*models.LoginAttempt
|
||||
for i := 0; i < count; i++ {
|
||||
attempt := &models.LoginAttempt{
|
||||
Date: time.Now(),
|
||||
RemoteIP: randomIP(),
|
||||
Username: randomString(8),
|
||||
Password: randomString(8),
|
||||
Country: randomCountry(),
|
||||
ConnectionUUID: uuid.Must(uuid.NewRandom()),
|
||||
SSHClientVersion: "SSH TEST LOL",
|
||||
}
|
||||
attempts = append(attempts, attempt)
|
||||
}
|
||||
return attempts
|
||||
}
|
||||
|
||||
func randomIP() net.IP {
|
||||
a := byte(rand.Intn(254))
|
||||
b := byte(rand.Intn(254))
|
||||
c := byte(rand.Intn(254))
|
||||
d := byte(rand.Intn(254))
|
||||
return net.IPv4(a, b, c, d)
|
||||
}
|
||||
|
||||
func randomString(n int) string {
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
b := make([]byte, n)
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func randomCountry() string {
|
||||
switch rand.Intn(10) {
|
||||
case 1:
|
||||
return "CN"
|
||||
case 2:
|
||||
return "US"
|
||||
case 3:
|
||||
return "NO"
|
||||
case 4:
|
||||
return "RU"
|
||||
case 5:
|
||||
return "DE"
|
||||
case 6:
|
||||
return "FI"
|
||||
case 7:
|
||||
return "BR"
|
||||
default:
|
||||
return "SE"
|
||||
}
|
||||
}
|
||||
|
||||
func equalAttempts(a, b []models.LoginAttempt) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
aFound := make([]bool, len(a))
|
||||
|
||||
for i, aAttempt := range a {
|
||||
for _, bAttempt := range b {
|
||||
if aAttempt.Username == bAttempt.Username &&
|
||||
aAttempt.Password == bAttempt.Password &&
|
||||
aAttempt.RemoteIP.String() == bAttempt.RemoteIP.String() {
|
||||
aFound[i] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, found := range aFound {
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
Reference in New Issue
Block a user