diff --git a/honeypot/store/memory_test.go b/honeypot/store/memory_test.go index ed13710..810a962 100644 --- a/honeypot/store/memory_test.go +++ b/honeypot/store/memory_test.go @@ -10,3 +10,8 @@ func TestMemoryStore(t *testing.T) { s := &store.MemoryStore{} testLoginAttemptStore(s, t) } +func TestMemoryStoreWithCache(t *testing.T) { + backend := &store.MemoryStore{} + s := store.NewCachingStore(backend) + testLoginAttemptStore(s, t) +} diff --git a/honeypot/store/postgres_test.go b/honeypot/store/postgres_test.go new file mode 100644 index 0000000..c05ce1a --- /dev/null +++ b/honeypot/store/postgres_test.go @@ -0,0 +1,61 @@ +package store_test + +import ( + "database/sql" + "os" + "testing" + + "github.uio.no/torjus/apiary/honeypot/store" +) + +func TestPostgresStore(t *testing.T) { + var dsn string + var found bool + dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN") + if !found { + t.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.") + } + + dropPGDatabase(dsn) + + s, err := store.NewPostgresStore(dsn) + if err != nil { + t.Fatalf("Error getting store: %s", err) + } + + s.InitDB() + + testLoginAttemptStore(s, t) +} +func TestPostgresStoreWithCache(t *testing.T) { + var dsn string + var found bool + dsn, found = os.LookupEnv("APIARY_TEST_POSTGRES_DSN") + if !found { + t.Skipf("APIARY_TEST_POSTGRES_DSN not set. Skipping.") + } + + dropPGDatabase(dsn) + + pgs, err := store.NewPostgresStore(dsn) + if err != nil { + t.Fatalf("Error getting store: %s", err) + } + + pgs.InitDB() + s := store.NewCachingStore(pgs) + + testLoginAttemptStore(s, t) +} + +func dropPGDatabase(dsn string) { + db, err := sql.Open("pgx", dsn) + if err != nil { + panic(err) + } + + _, err = db.Exec("DROP TABLE login_attempts") + if err != nil { + panic(err) + } +} diff --git a/honeypot/store/store_test.go b/honeypot/store/store_test.go index fef91e8..fd99a16 100644 --- a/honeypot/store/store_test.go +++ b/honeypot/store/store_test.go @@ -38,6 +38,154 @@ func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) { } } }) + t.Run("Query", func(t *testing.T) { + testAttempts := []*models.LoginAttempt{ + { + Date: time.Now(), + RemoteIP: net.ParseIP("127.0.0.1"), + Username: "corndog", + Password: "corndog", + }, + { + Date: time.Now(), + RemoteIP: net.ParseIP("127.0.0.1"), + Username: "corndog", + Password: "c0rnd0g", + }, + { + Date: time.Now(), + RemoteIP: net.ParseIP("10.0.0.1"), + Username: "root", + Password: "password", + }, + { + Date: time.Now(), + RemoteIP: net.ParseIP("10.0.0.2"), + Username: "ubnt", + Password: "password", + }, + } + + for _, attempt := range testAttempts { + err := s.AddAttempt(attempt) + if err != nil { + t.Fatalf("Unable to add attempt: %s", err) + } + } + testCases := []struct { + Name string + Query store.AttemptQuery + ExpectedResult []models.LoginAttempt + }{ + { + Name: "password one result", + Query: store.AttemptQuery{QueryType: store.AttemptQueryTypePassword, Query: "corndog"}, + ExpectedResult: []models.LoginAttempt{ + { + RemoteIP: net.ParseIP("127.0.0.1"), + Username: "corndog", + Password: "corndog", + }, + }, + }, + { + Name: "username one result", + Query: store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "root"}, + ExpectedResult: []models.LoginAttempt{ + { + RemoteIP: net.ParseIP("10.0.0.1"), + Username: "root", + Password: "password", + }, + }, + }, + { + Name: "username two results", + Query: store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "corndog"}, + ExpectedResult: []models.LoginAttempt{ + { + RemoteIP: net.ParseIP("127.0.0.1"), + Username: "corndog", + Password: "c0rnd0g", + }, + { + RemoteIP: net.ParseIP("127.0.0.1"), + Username: "corndog", + Password: "corndog", + }, + }, + }, + } + + for _, tc := range testCases { + res, err := s.Query(tc.Query) + if err != nil { + t.Errorf("Error performing query: %s", err) + } + if !equalAttempts(res, tc.ExpectedResult) { + t.Errorf("Query did not return expected results") + t.Logf("%+v", res) + t.Logf("%+v", tc.ExpectedResult) + + } + } + }) + + t.Run("QueryCache", func(t *testing.T) { + err := s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "test"}) + if err != nil { + t.Fatalf("Error adding attempt: %s", err) + } + + res, err := s.Query(store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "test"}) + if err != nil { + t.Fatalf("Error adding attempt: %s", err) + } + if len(res) != 1 { + t.Errorf("Wrong amount of results") + } + + err = s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "best"}) + if err != nil { + t.Fatalf("Error adding attempt: %s", err) + } + res, err = s.Query(store.AttemptQuery{QueryType: store.AttemptQueryTypeUsername, Query: "test"}) + if err != nil { + t.Fatalf("Error adding attempt: %s", err) + } + if len(res) != 2 { + t.Errorf("Wrong amount of results") + } + }) + t.Run("QueryStats", func(t *testing.T) { + firstStats, err := s.Stats(store.LoginStatsTotals, 1) + if err != nil { + t.Fatalf("Error getting stats: %s", err) + } + err = s.AddAttempt(&models.LoginAttempt{RemoteIP: net.ParseIP("127.0.0.1"), Username: "test", Password: "best"}) + if err != nil { + t.Fatalf("Error adding attempt: %s", err) + } + secondStats, err := s.Stats(store.LoginStatsTotals, 1) + if err != nil { + t.Fatalf("Error getting stats: %s", err) + } + var firstCount, secondCount int + for _, stat := range firstStats { + if stat.Name == "TotalLoginAttempts" { + firstCount = stat.Count + } + } + for _, stat := range secondStats { + if stat.Name == "TotalLoginAttempts" { + secondCount = stat.Count + } + } + + if secondCount != firstCount+1 { + t.Errorf("TotalLoginAttempts did not increment") + } + }) } func randomAttempts(count int) []*models.LoginAttempt { @@ -95,3 +243,29 @@ func randomCountry() string { return "SE" } } + +func equalAttempts(a, b []models.LoginAttempt) bool { + if len(a) != len(b) { + return false + } + + aFound := make([]bool, len(a)) + + for i, aAttempt := range a { + for _, bAttempt := range b { + if aAttempt.Username == bAttempt.Username && + aAttempt.Password == bAttempt.Password && + aAttempt.RemoteIP.String() == bAttempt.RemoteIP.String() { + aFound[i] = true + } + } + } + + for _, found := range aFound { + if !found { + return false + } + } + + return true +} diff --git a/version.go b/version.go index e84afe6..e309def 100644 --- a/version.go +++ b/version.go @@ -5,7 +5,7 @@ import ( "runtime" ) -var Version = "v0.1.9" +var Version = "v0.1.10" var Build string func FullVersion() string {