Reorganize packages

This commit is contained in:
2021-10-21 12:36:01 +02:00
parent a71ff52ab4
commit 94e7faae78
15 changed files with 24 additions and 24 deletions

Binary file not shown.

11
honeypot/ssh/actions.go Normal file
View File

@@ -0,0 +1,11 @@
package ssh
type ActionType int
const (
ActionTypeLogPassword ActionType = iota
ActionTypeLogPasswordSlow
ActionTypeLogCommandAndExit
ActionTypeSendGarbage
)
const ActionTypeDefault ActionType = ActionTypeLogPassword

67
honeypot/ssh/conn.go Normal file
View File

@@ -0,0 +1,67 @@
package ssh
import (
"net"
"time"
"github.com/fujiwara/shapeio"
"github.com/google/uuid"
)
type throttledConn struct {
ID uuid.UUID
conn net.Conn
writer *shapeio.Writer
reader *shapeio.Reader
CloseCallback func(c *throttledConn)
}
func newThrottledConn(conn net.Conn) *throttledConn {
id := uuid.Must(uuid.NewRandom())
return &throttledConn{
ID: id,
conn: conn,
writer: shapeio.NewWriter(conn),
reader: shapeio.NewReader(conn),
}
}
func (sc *throttledConn) SetSpeed(bytesPerSec float64) {
sc.writer.SetRateLimit(bytesPerSec)
sc.reader.SetRateLimit(bytesPerSec)
}
func (sc *throttledConn) Read(b []byte) (n int, err error) {
return sc.reader.Read(b)
}
func (sc *throttledConn) Write(b []byte) (n int, err error) {
return sc.writer.Write(b)
}
func (sc *throttledConn) Close() error {
if sc.CloseCallback != nil {
sc.CloseCallback(sc)
}
return sc.conn.Close()
}
func (sc *throttledConn) LocalAddr() net.Addr {
return sc.conn.LocalAddr()
}
func (sc *throttledConn) RemoteAddr() net.Addr {
return sc.conn.RemoteAddr()
}
func (sc *throttledConn) SetDeadline(t time.Time) error {
return sc.conn.SetDeadline(t)
}
func (sc *throttledConn) SetReadDeadline(t time.Time) error {
return sc.conn.SetReadDeadline(t)
}
func (sc *throttledConn) SetWriteDeadline(t time.Time) error {
return sc.conn.SetWriteDeadline(t)
}

35
honeypot/ssh/geolocate.go Normal file
View File

@@ -0,0 +1,35 @@
package ssh
import (
_ "embed"
"net"
"github.com/oschwald/maxminddb-golang"
)
//go:embed Geoacumen-Country.mmdb
var mmdb []byte
func (s *HoneypotServer) LookupCountry(ip net.IP) string {
db, err := maxminddb.FromBytes(mmdb)
if err != nil {
s.Logger.Warnw("Error opening geoip database", "error", err)
return "??"
}
var record struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
err = db.Lookup(ip, &record)
if err != nil {
s.Logger.Warnw("Error doing geoip lookup", "error", err)
return "??"
}
if record.Country.ISOCode == "None" {
return "??"
}
return record.Country.ISOCode
}

137
honeypot/ssh/server.go Normal file
View File

@@ -0,0 +1,137 @@
package ssh
import (
"context"
"io"
"net"
"os"
"time"
"unicode/utf8"
gossh "golang.org/x/crypto/ssh"
"github.uio.no/torjus/apiary/config"
sshlib "github.com/gliderlabs/ssh"
"github.com/google/uuid"
"github.uio.no/torjus/apiary/honeypot/ssh/store"
"github.uio.no/torjus/apiary/models"
"go.uber.org/zap"
)
type HoneypotServer struct {
Logger *zap.SugaredLogger
sshServer *sshlib.Server
attemptStore store.LoginAttemptStore
attemptsCallbacks []func(l models.LoginAttempt)
throttleSpeed float64
}
func NewHoneypotServer(cfg config.HoneypotConfig, store store.LoginAttemptStore) (*HoneypotServer, error) {
var hs HoneypotServer
hs.attemptStore = store
hs.Logger = zap.NewNop().Sugar()
hs.sshServer = &sshlib.Server{
Addr: cfg.ListenAddr,
PasswordHandler: hs.passwordHandler,
ConnCallback: hs.connCallback,
Handler: handler,
Version: "OpenSSH_7.4p1 Debian-10+deb9u6",
}
if cfg.HostKeyPath != "" {
f, err := os.Open(cfg.HostKeyPath)
if err != nil {
return nil, err
}
pemBytes, err := io.ReadAll(f)
if err != nil {
return nil, err
}
signer, err := gossh.ParsePrivateKey(pemBytes)
if err != nil {
return nil, err
}
hs.sshServer.AddHostKey(signer)
}
return &hs, nil
}
func (hs *HoneypotServer) ListenAndServe() error {
return hs.sshServer.ListenAndServe()
}
func (hs *HoneypotServer) Shutdown(ctx context.Context) error {
return hs.sshServer.Shutdown(ctx)
}
func (hs *HoneypotServer) AddLoginCallback(c func(l models.LoginAttempt)) {
hs.attemptsCallbacks = append(hs.attemptsCallbacks, c)
}
func (hs *HoneypotServer) passwordHandler(ctx sshlib.Context, password string) bool {
sessUUID, ok := ctx.Value("uuid").(uuid.UUID)
if !ok {
hs.Logger.Warn("Unable to get session UUID")
return false
}
la := models.LoginAttempt{
Date: time.Now(),
RemoteIP: ipFromAddr(ctx.RemoteAddr().String()),
Username: ctx.User(),
Password: password,
SSHClientVersion: ctx.ClientVersion(),
ConnectionUUID: sessUUID,
}
country := hs.LookupCountry(la.RemoteIP)
if utf8.RuneCountInString(country) > 2 {
hs.Logger.Warnw("Too many characters in country", "country", country, "runecount", utf8.RuneCountInString(country))
country = "??"
}
la.Country = country
hs.Logger.Infow("Login attempt",
"remote_ip", la.RemoteIP.String(),
"username", la.Username,
"password", la.Password,
"country", la.Country)
if err := hs.attemptStore.AddAttempt(&la); err != nil {
hs.Logger.Warnw("Error adding attempt to store", "error", err)
}
for _, cFunc := range hs.attemptsCallbacks {
cFunc(la)
}
return false
}
func (s *HoneypotServer) connCallback(ctx sshlib.Context, conn net.Conn) net.Conn {
throttledConn := newThrottledConn(conn)
throttledConn.SetSpeed(s.throttleSpeed)
ctx.SetValue("uuid", throttledConn.ID)
throttledConn.SetSpeed(s.throttleSpeed)
return throttledConn
}
func handler(session sshlib.Session) {
_, _ = io.WriteString(session, "[root@hostname ~]#")
session.Exit(1)
}
func ipFromAddr(addr string) net.IP {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil
}
return net.ParseIP(host)
}

View File

@@ -0,0 +1,85 @@
package store
import "github.uio.no/torjus/apiary/models"
type CachingStore struct {
backend LoginAttemptStore
totalStatsCache []StatsResult
usernameQueryCache map[string][]models.LoginAttempt
passwordQueryCache map[string][]models.LoginAttempt
ipQueryCache map[string][]models.LoginAttempt
}
func NewCachingStore(backend LoginAttemptStore) *CachingStore {
return &CachingStore{
backend: backend,
usernameQueryCache: make(map[string][]models.LoginAttempt),
passwordQueryCache: make(map[string][]models.LoginAttempt),
ipQueryCache: make(map[string][]models.LoginAttempt),
}
}
func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error {
s.totalStatsCache = nil
delete(s.ipQueryCache, l.RemoteIP.String())
delete(s.passwordQueryCache, l.Password)
delete(s.usernameQueryCache, l.Username)
return s.backend.AddAttempt(l)
}
func (s *CachingStore) All() ([]models.LoginAttempt, error) {
return s.backend.All()
}
func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
// Only cache totals for now, as they are the most queried
if statType == LoginStatsTotals {
if s.totalStatsCache != nil {
return s.totalStatsCache, nil
}
stats, err := s.backend.Stats(statType, limit)
if err != nil {
return stats, err
}
s.totalStatsCache = stats
return stats, err
}
// TODO: Maybe cache the default limits
return s.backend.Stats(statType, limit)
}
func (s *CachingStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
switch query.QueryType {
case AttemptQueryTypeIP:
if attempts, ok := s.ipQueryCache[query.Query]; ok {
return attempts, nil
}
case AttemptQueryTypePassword:
if attempts, ok := s.passwordQueryCache[query.Query]; ok {
return attempts, nil
}
case AttemptQueryTypeUsername:
if attempts, ok := s.usernameQueryCache[query.Query]; ok {
return attempts, nil
}
}
attempts, err := s.backend.Query(query)
if err != nil {
return attempts, err
}
switch query.QueryType {
case AttemptQueryTypeIP:
s.ipQueryCache[query.Query] = attempts
case AttemptQueryTypeUsername:
s.ipQueryCache[query.Query] = attempts
case AttemptQueryTypePassword:
s.ipQueryCache[query.Query] = attempts
}
return attempts, err
}

View File

@@ -0,0 +1,13 @@
package store_test
import (
"testing"
"github.uio.no/torjus/apiary/honeypot/ssh/store"
)
func TestCacheStore(t *testing.T) {
backend := &store.MemoryStore{}
s := store.NewCachingStore(backend)
testLoginAttemptStore(s, t)
}

View File

@@ -0,0 +1,153 @@
package store
import (
"fmt"
"sort"
"strings"
"sync"
"github.uio.no/torjus/apiary/models"
)
type MemoryStore struct {
lock sync.RWMutex
attempts []models.LoginAttempt
currentID int
}
type StatItem struct {
Key string
Count int
}
type StatItems []StatItem
func (ms *MemoryStore) AddAttempt(l *models.LoginAttempt) error {
ms.lock.Lock()
defer ms.lock.Unlock()
l.ID = ms.currentID + 1
ms.currentID = ms.currentID + 1
ms.attempts = append(ms.attempts, *l)
return nil
}
func (ms *MemoryStore) All() ([]models.LoginAttempt, error) {
return ms.attempts, nil
}
func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
counts := make(map[string]int)
if statType == LoginStatsTotals {
return ms.statTotals()
}
ms.lock.RLock()
defer ms.lock.RUnlock()
for _, a := range ms.attempts {
switch statType {
case LoginStatsPasswords:
counts[a.Password]++
case LoginStatsCountry:
counts[a.Country]++
case LoginStatsIP:
counts[a.RemoteIP.String()]++
case LoginStatsUsername:
counts[a.Username]++
default:
return nil, fmt.Errorf("invalid stat type")
}
}
if limit < 1 {
return toResults(counts), nil
}
if limit >= len(counts) {
return toResults(counts), nil
}
var si StatItems
for key := range counts {
si = append(si, StatItem{Key: key, Count: counts[key]})
}
sort.Sort(si)
output := make(map[string]int)
for i := len(si) - 1; i > len(si)-limit-1; i-- {
output[si[i].Key] = si[i].Count
}
return toResults(output), nil
}
func (ss StatItems) Len() int {
return len(ss)
}
func (ss StatItems) Less(i, j int) bool {
return ss[i].Count < ss[j].Count
}
func (ss StatItems) Swap(i, j int) {
ss[i], ss[j] = ss[j], ss[i]
}
func (ms *MemoryStore) statTotals() ([]StatsResult, error) {
passwords := make(map[string]int)
usernames := make(map[string]int)
ips := make(map[string]int)
countries := make(map[string]int)
ms.lock.RLock()
defer ms.lock.RUnlock()
for _, val := range ms.attempts {
passwords[val.Password] += 1
usernames[val.Username] += 1
ips[val.RemoteIP.String()] += 1
countries[val.Country] += 1
}
stats := []StatsResult{
{Name: "UniquePasswords", Count: len(passwords)},
{Name: "UniqueUsernames", Count: len(usernames)},
{Name: "UniqueIPs", Count: len(ips)},
{Name: "UniqueCountries", Count: len(countries)},
{Name: "TotalLoginAttempts", Count: len(ms.attempts)},
}
return stats, nil
}
func (ms *MemoryStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
var results []models.LoginAttempt
ms.lock.Lock()
defer ms.lock.Unlock()
for _, la := range ms.attempts {
switch query.QueryType {
case AttemptQueryTypeIP:
if la.RemoteIP.String() == query.Query {
results = append(results, la)
}
case AttemptQueryTypePassword:
if strings.Contains(la.Password, query.Query) {
results = append(results, la)
}
case AttemptQueryTypeUsername:
if strings.Contains(la.Username, query.Query) {
results = append(results, la)
}
}
}
return results, nil
}
func toResults(m map[string]int) []StatsResult {
var results []StatsResult
for key, value := range m {
results = append(results, StatsResult{key, value})
}
return results
}

View File

@@ -0,0 +1,17 @@
package store_test
import (
"testing"
"github.uio.no/torjus/apiary/honeypot/ssh/store"
)
func TestMemoryStore(t *testing.T) {
s := &store.MemoryStore{}
testLoginAttemptStore(s, t)
}
func TestMemoryStoreWithCache(t *testing.T) {
backend := &store.MemoryStore{}
s := store.NewCachingStore(backend)
testLoginAttemptStore(s, t)
}

View File

@@ -0,0 +1,204 @@
package store
import (
"database/sql"
"fmt"
"net"
_ "github.com/jackc/pgx/v4/stdlib"
"github.uio.no/torjus/apiary/models"
)
const DBSchema = `
CREATE TABLE IF NOT EXISTS login_attempts(
id serial PRIMARY KEY,
date timestamptz,
remote_ip inet,
username text,
password text,
client_version text,
connection_uuid uuid,
country varchar(2)
);
`
type PostgresStore struct {
db *sql.DB
}
func NewPostgresStore(dsn string) (*PostgresStore, error) {
db, err := sql.Open("pgx", dsn)
if err != nil {
return nil, err
}
rs := &PostgresStore{
db: db,
}
return rs, nil
}
func (s *PostgresStore) InitDB() error {
_, err := s.db.Exec(DBSchema)
return err
}
func (s *PostgresStore) AddAttempt(l *models.LoginAttempt) error {
stmt := `INSERT INTO
login_attempts(date, remote_ip, username, password, client_version, connection_uuid, country)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id;`
tx, err := s.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
var id int
if err := tx.QueryRow(stmt, l.Date, l.RemoteIP.String(), l.Username, l.Password, l.SSHClientVersion, l.ConnectionUUID, l.Country).Scan(&id); err != nil {
return err
}
l.ID = id
return tx.Commit()
}
func (s *PostgresStore) All() ([]models.LoginAttempt, error) {
stmt := `SELECT date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts`
rows, err := s.db.Query(stmt)
if err != nil {
return nil, err
}
defer rows.Close()
var attempts []models.LoginAttempt
for rows.Next() {
var a models.LoginAttempt
var ip string
if err := rows.Scan(&a.Date, &ip, &a.Username, &a.Password, &a.SSHClientVersion, &a.SSHClientVersion, &a.Country); err != nil {
return nil, err
}
a.RemoteIP = net.ParseIP(ip)
attempts = append(attempts, a)
}
return attempts, nil
}
func (s *PostgresStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) {
var stmt string
if statType == LoginStatsTotals {
return s.statsTotal(limit)
}
switch statType {
case LoginStatsCountry:
stmt = `select country, count(country) from login_attempts group by country order by count desc`
case LoginStatsIP:
stmt = `select remote_ip, count(remote_ip) from login_attempts group by remote_ip order by count desc`
case LoginStatsPasswords:
stmt = `select password, count(password) from login_attempts group by password order by count desc`
case LoginStatsUsername:
stmt = `select username, count(username) from login_attempts group by username order by count desc`
default:
return nil, fmt.Errorf("invalid stat type")
}
if limit > 0 {
stmt = fmt.Sprintf("%s limit %d", stmt, limit)
}
rows, err := s.db.Query(stmt)
if err != nil {
return nil, err
}
defer rows.Close()
var results []StatsResult
for rows.Next() {
var r StatsResult
if err := rows.Scan(&r.Name, &r.Count); err != nil {
return nil, err
}
results = append(results, r)
}
return results, nil
}
func (s *PostgresStore) statsTotal(limit int) ([]StatsResult, error) {
uniquePasswordStmt := `select count(*) from (select distinct password from login_attempts) as temp`
uniqueUsernameStmt := `select count(*) from (select distinct username from login_attempts) as temp`
uniqueIPStmt := `select count(*) from (select distinct remote_ip from login_attempts) as temp`
uniqueCountryStmt := `select count(*) from (select distinct country from login_attempts) as temp`
attemptsCountStmt := `select count(1) from login_attempts`
var uniquePasswordsCount int
if err := s.db.QueryRow(uniquePasswordStmt).Scan(&uniquePasswordsCount); err != nil {
return nil, err
}
var uniqueUsernameCount int
if err := s.db.QueryRow(uniqueUsernameStmt).Scan(&uniqueUsernameCount); err != nil {
return nil, err
}
var uniqueIPCount int
if err := s.db.QueryRow(uniqueIPStmt).Scan(&uniqueIPCount); err != nil {
return nil, err
}
var uniqueCountryCount int
if err := s.db.QueryRow(uniqueCountryStmt).Scan(&uniqueCountryCount); err != nil {
return nil, err
}
var attemptsCount int
if err := s.db.QueryRow(attemptsCountStmt).Scan(&attemptsCount); err != nil {
return nil, err
}
return []StatsResult{
{Name: "UniquePasswords", Count: uniquePasswordsCount},
{Name: "UniqueUsernames", Count: uniqueUsernameCount},
{Name: "UniqueIPs", Count: uniqueIPCount},
{Name: "UniqueCountries", Count: uniqueCountryCount},
{Name: "TotalLoginAttempts", Count: attemptsCount},
}, nil
}
func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) {
var stmt string
queryString := query.Query
switch query.QueryType {
case AttemptQueryTypeIP:
stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country
FROM login_attempts WHERE remote_ip = $1`
case AttemptQueryTypePassword:
stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country
FROM login_attempts WHERE password like $1`
queryString = fmt.Sprintf("%%%s%%", queryString)
case AttemptQueryTypeUsername:
stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country
FROM login_attempts WHERE username like $1`
queryString = fmt.Sprintf("%%%s%%", queryString)
default:
return nil, fmt.Errorf("invalid query type")
}
rows, err := s.db.Query(stmt, queryString)
if err != nil {
return nil, fmt.Errorf("unable to query database: %w", err)
}
defer rows.Close()
var results []models.LoginAttempt
for rows.Next() {
var la models.LoginAttempt
var ipString string
if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil {
return nil, fmt.Errorf("unable to unmarshal data from database: %w", err)
}
la.RemoteIP = net.ParseIP(ipString)
results = append(results, la)
}
return results, nil
}

View File

@@ -0,0 +1,61 @@
package store_test
import (
"database/sql"
"os"
"testing"
"github.uio.no/torjus/apiary/honeypot/ssh/store"
)
func TestPostgresStore(t *testing.T) {
var dsn string
var found bool
dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN")
if !found {
t.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.")
}
dropPGDatabase(dsn)
s, err := store.NewPostgresStore(dsn)
if err != nil {
t.Fatalf("Error getting store: %s", err)
}
s.InitDB()
testLoginAttemptStore(s, t)
}
func TestPostgresStoreWithCache(t *testing.T) {
var dsn string
var found bool
dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN")
if !found {
t.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.")
}
dropPGDatabase(dsn)
pgs, err := store.NewPostgresStore(dsn)
if err != nil {
t.Fatalf("Error getting store: %s", err)
}
pgs.InitDB()
s := store.NewCachingStore(pgs)
testLoginAttemptStore(s, t)
}
func dropPGDatabase(dsn string) {
db, err := sql.Open("pgx", dsn)
if err != nil {
panic(err)
}
_, err = db.Exec("DROP TABLE login_attempts")
if err != nil {
panic(err)
}
}

View File

@@ -0,0 +1,38 @@
package store
import "github.uio.no/torjus/apiary/models"
type LoginStats string
const (
LoginStatsUndefined LoginStats = ""
LoginStatsPasswords LoginStats = "password"
LoginStatsCountry LoginStats = "country"
LoginStatsIP LoginStats = "ip"
LoginStatsUsername LoginStats = "username"
LoginStatsTotals LoginStats = "total"
)
type AttemptQueryType string
const (
AttemptQueryTypeUsername AttemptQueryType = "username"
AttemptQueryTypePassword AttemptQueryType = "password"
AttemptQueryTypeIP AttemptQueryType = "ip"
)
type StatsResult struct {
Name string `json:"name"`
Count int `json:"count"`
}
type AttemptQuery struct {
QueryType AttemptQueryType
Query string
}
type LoginAttemptStore interface {
AddAttempt(l *models.LoginAttempt) error
All() ([]models.LoginAttempt, error)
Stats(statType LoginStats, limit int) ([]StatsResult, error)
Query(query AttemptQuery) ([]models.LoginAttempt, error)
}

View File

@@ -0,0 +1,271 @@
package store_test
import (
"math/rand"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.uio.no/torjus/apiary/honeypot/ssh/store"
"github.uio.no/torjus/apiary/models"
)
func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) {
t.Run("Simple", func(t *testing.T) {
testAttempts := randomAttempts(10)
for _, attempt := range testAttempts {
if err := s.AddAttempt(attempt); err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
}
all, err := s.All()
if err != nil {
t.Fatalf("Error getting all attempts: %s", err)
}
if len(all) != len(testAttempts) {
t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts))
}
stats, err := s.Stats(store.LoginStatsTotals, 1)
if err != nil {
t.Errorf("Stats returned error: %s", err)
}
for _, stat := range stats {
if stat.Name == "TotalLoginAttempts" && stat.Count != len(testAttempts) {
t.Errorf("Stats for total attempts is wrong. Got %d want %d", stat.Count, len(testAttempts))
}
}
})
t.Run("Query", func(t *testing.T) {
testAttempts := []*models.LoginAttempt{
{
Date: time.Now(),
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "corndog",
},
{
Date: time.Now(),
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "c0rnd0g",
},
{
Date: time.Now(),
RemoteIP: net.ParseIP("10.0.0.1"),
Username: "root",
Password: "password",
},
{
Date: time.Now(),
RemoteIP: net.ParseIP("10.0.0.2"),
Username: "ubnt",
Password: "password",
},
}
for _, attempt := range testAttempts {
err := s.AddAttempt(attempt)
if err != nil {
t.Fatalf("Unable to add attempt: %s", err)
}
}
testCases := []struct {
Name string
Query store.AttemptQuery
ExpectedResult []models.LoginAttempt
}{
{
Name: "password one result",
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypePassword, Query: "corndog"},
ExpectedResult: []models.LoginAttempt{
{
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "corndog",
},
},
},
{
Name: "username one result",
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "root"},
ExpectedResult: []models.LoginAttempt{
{
RemoteIP: net.ParseIP("10.0.0.1"),
Username: "root",
Password: "password",
},
},
},
{
Name: "username two results",
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "corndog"},
ExpectedResult: []models.LoginAttempt{
{
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "c0rnd0g",
},
{
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "corndog",
},
},
},
}
for _, tc := range testCases {
res, err := s.Query(tc.Query)
if err != nil {
t.Errorf("Error performing query: %s", err)
}
if !equalAttempts(res, tc.ExpectedResult) {
t.Errorf("Query did not return expected results")
t.Logf("%+v", res)
t.Logf("%+v", tc.ExpectedResult)
}
}
})
t.Run("QueryCache", func(t *testing.T) {
err := s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "test"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
res, err := s.Query(store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "test"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
if len(res) != 1 {
t.Errorf("Wrong amount of results")
}
err = s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "best"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
res, err = s.Query(store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "test"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
if len(res) != 2 {
t.Errorf("Wrong amount of results")
}
})
t.Run("QueryStats", func(t *testing.T) {
firstStats, err := s.Stats(store.LoginStatsTotals, 1)
if err != nil {
t.Fatalf("Error getting stats: %s", err)
}
err = s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "best"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
secondStats, err := s.Stats(store.LoginStatsTotals, 1)
if err != nil {
t.Fatalf("Error getting stats: %s", err)
}
var firstCount, secondCount int
for _, stat := range firstStats {
if stat.Name == "TotalLoginAttempts" {
firstCount = stat.Count
}
}
for _, stat := range secondStats {
if stat.Name == "TotalLoginAttempts" {
secondCount = stat.Count
}
}
if secondCount != firstCount+1 {
t.Errorf("TotalLoginAttempts did not increment")
}
})
}
func randomAttempts(count int) []*models.LoginAttempt {
var attempts []*models.LoginAttempt
for i := 0; i < count; i++ {
attempt := &models.LoginAttempt{
Date: time.Now(),
RemoteIP: randomIP(),
Username: randomString(8),
Password: randomString(8),
Country: randomCountry(),
ConnectionUUID: uuid.Must(uuid.NewRandom()),
SSHClientVersion: "SSH TEST LOL",
}
attempts = append(attempts, attempt)
}
return attempts
}
func randomIP() net.IP {
a := byte(rand.Intn(254))
b := byte(rand.Intn(254))
c := byte(rand.Intn(254))
d := byte(rand.Intn(254))
return net.IPv4(a, b, c, d)
}
func randomString(n int) string {
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
func randomCountry() string {
switch rand.Intn(10) {
case 1:
return "CN"
case 2:
return "US"
case 3:
return "NO"
case 4:
return "RU"
case 5:
return "DE"
case 6:
return "FI"
case 7:
return "BR"
default:
return "SE"
}
}
func equalAttempts(a, b []models.LoginAttempt) bool {
if len(a) != len(b) {
return false
}
aFound := make([]bool, len(a))
for i, aAttempt := range a {
for _, bAttempt := range b {
if aAttempt.Username == bAttempt.Username &&
aAttempt.Password == bAttempt.Password &&
aAttempt.RemoteIP.String() == bAttempt.RemoteIP.String() {
aFound[i] = true
}
}
}
for _, found := range aFound {
if !found {
return false
}
}
return true
}