apiary/honeypot/store/store_test.go
2021-09-17 15:00:58 +02:00

272 lines
6.4 KiB
Go

package store_test
import (
"math/rand"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.uio.no/torjus/apiary/honeypot/store"
"github.uio.no/torjus/apiary/models"
)
func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) {
t.Run("Simple", func(t *testing.T) {
testAttempts := randomAttempts(10)
for _, attempt := range testAttempts {
if err := s.AddAttempt(attempt); err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
}
all, err := s.All()
if err != nil {
t.Fatalf("Error getting all attempts: %s", err)
}
if len(all) != len(testAttempts) {
t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts))
}
stats, err := s.Stats(store.LoginStatsTotals, 1)
if err != nil {
t.Errorf("Stats returned error: %s", err)
}
for _, stat := range stats {
if stat.Name == "TotalLoginAttempts" && stat.Count != len(testAttempts) {
t.Errorf("Stats for total attempts is wrong. Got %d want %d", stat.Count, len(testAttempts))
}
}
})
t.Run("Query", func(t *testing.T) {
testAttempts := []*models.LoginAttempt{
{
Date: time.Now(),
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "corndog",
},
{
Date: time.Now(),
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "c0rnd0g",
},
{
Date: time.Now(),
RemoteIP: net.ParseIP("10.0.0.1"),
Username: "root",
Password: "password",
},
{
Date: time.Now(),
RemoteIP: net.ParseIP("10.0.0.2"),
Username: "ubnt",
Password: "password",
},
}
for _, attempt := range testAttempts {
err := s.AddAttempt(attempt)
if err != nil {
t.Fatalf("Unable to add attempt: %s", err)
}
}
testCases := []struct {
Name string
Query store.AttemptQuery
ExpectedResult []models.LoginAttempt
}{
{
Name: "password one result",
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypePassword, Query: "corndog"},
ExpectedResult: []models.LoginAttempt{
{
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "corndog",
},
},
},
{
Name: "username one result",
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "root"},
ExpectedResult: []models.LoginAttempt{
{
RemoteIP: net.ParseIP("10.0.0.1"),
Username: "root",
Password: "password",
},
},
},
{
Name: "username two results",
Query: store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "corndog"},
ExpectedResult: []models.LoginAttempt{
{
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "c0rnd0g",
},
{
RemoteIP: net.ParseIP("127.0.0.1"),
Username: "corndog",
Password: "corndog",
},
},
},
}
for _, tc := range testCases {
res, err := s.Query(tc.Query)
if err != nil {
t.Errorf("Error performing query: %s", err)
}
if !equalAttempts(res, tc.ExpectedResult) {
t.Errorf("Query did not return expected results")
t.Logf("%+v", res)
t.Logf("%+v", tc.ExpectedResult)
}
}
})
t.Run("QueryCache", func(t *testing.T) {
err := s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "test"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
res, err := s.Query(store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "test"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
if len(res) != 1 {
t.Errorf("Wrong amount of results")
}
err = s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "best"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
res, err = s.Query(store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "test"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
if len(res) != 2 {
t.Errorf("Wrong amount of results")
}
})
t.Run("QueryStats", func(t *testing.T) {
firstStats, err := s.Stats(store.LoginStatsTotals, 1)
if err != nil {
t.Fatalf("Error getting stats: %s", err)
}
err = s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "best"})
if err != nil {
t.Fatalf("Error adding attempt: %s", err)
}
secondStats, err := s.Stats(store.LoginStatsTotals, 1)
if err != nil {
t.Fatalf("Error getting stats: %s", err)
}
var firstCount, secondCount int
for _, stat := range firstStats {
if stat.Name == "TotalLoginAttempts" {
firstCount = stat.Count
}
}
for _, stat := range secondStats {
if stat.Name == "TotalLoginAttempts" {
secondCount = stat.Count
}
}
if secondCount != firstCount+1 {
t.Errorf("TotalLoginAttempts did not increment")
}
})
}
func randomAttempts(count int) []*models.LoginAttempt {
var attempts []*models.LoginAttempt
for i := 0; i < count; i++ {
attempt := &models.LoginAttempt{
Date: time.Now(),
RemoteIP: randomIP(),
Username: randomString(8),
Password: randomString(8),
Country: randomCountry(),
ConnectionUUID: uuid.Must(uuid.NewRandom()),
SSHClientVersion: "SSH TEST LOL",
}
attempts = append(attempts, attempt)
}
return attempts
}
func randomIP() net.IP {
a := byte(rand.Intn(254))
b := byte(rand.Intn(254))
c := byte(rand.Intn(254))
d := byte(rand.Intn(254))
return net.IPv4(a, b, c, d)
}
func randomString(n int) string {
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
func randomCountry() string {
switch rand.Intn(10) {
case 1:
return "CN"
case 2:
return "US"
case 3:
return "NO"
case 4:
return "RU"
case 5:
return "DE"
case 6:
return "FI"
case 7:
return "BR"
default:
return "SE"
}
}
func equalAttempts(a, b []models.LoginAttempt) bool {
if len(a) != len(b) {
return false
}
aFound := make([]bool, len(a))
for i, aAttempt := range a {
for _, bAttempt := range b {
if aAttempt.Username == bAttempt.Username &&
aAttempt.Password == bAttempt.Password &&
aAttempt.RemoteIP.String() == bAttempt.RemoteIP.String() {
aFound[i] = true
}
}
}
for _, found := range aFound {
if !found {
return false
}
}
return true
}