Add some caching
This commit is contained in:
		| @@ -4,6 +4,9 @@ | ||||
| # Must be "memory" or "postgres" | ||||
| # Default: "memory" | ||||
| Type = "memory" | ||||
| # Enable caching | ||||
| # Default: false | ||||
| EnableCache = false | ||||
|  | ||||
| [Store.Postgres] | ||||
| # Connection string for postgres | ||||
|   | ||||
| @@ -53,11 +53,13 @@ func ActionServe(c *cli.Context) error { | ||||
|  | ||||
| 	// Setup logging | ||||
| 	loggers := setupLoggers(cfg) | ||||
| 	loggers.rootLogger.Infow("Startin apiary", "version", apiary.FullVersion()) | ||||
|  | ||||
| 	// Setup store | ||||
| 	var s store.LoginAttemptStore | ||||
| 	switch cfg.Store.Type { | ||||
| 	case "MEMORY", "memory": | ||||
| 		loggers.rootLogger.Debugw("Initializing store", "store_type", "memory") | ||||
| 		s = &store.MemoryStore{} | ||||
| 	case "POSTGRES", "postgres": | ||||
| 		pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN) | ||||
| @@ -67,7 +69,14 @@ func ActionServe(c *cli.Context) error { | ||||
| 		if err := pgStore.InitDB(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		s = pgStore | ||||
| 		if cfg.Store.EnableCache { | ||||
| 			loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres") | ||||
| 			cachingStore := store.NewCachingStore(pgStore) | ||||
| 			s = cachingStore | ||||
| 		} else { | ||||
| 			loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres") | ||||
| 			s = pgStore | ||||
| 		} | ||||
| 	default: | ||||
| 		return fmt.Errorf("Invalid store configured") | ||||
| 	} | ||||
|   | ||||
| @@ -15,8 +15,9 @@ type Config struct { | ||||
| 	Frontend FrontendConfig `toml:"Frontend"` | ||||
| } | ||||
| type StoreConfig struct { | ||||
| 	Type     string              `toml:"Type"` | ||||
| 	Postgres PostgresStoreConfig `toml:"Postgres"` | ||||
| 	Type        string              `toml:"Type"` | ||||
| 	EnableCache bool                `toml:"EnableCache"` | ||||
| 	Postgres    PostgresStoreConfig `toml:"Postgres"` | ||||
| } | ||||
|  | ||||
| type PostgresStoreConfig struct { | ||||
|   | ||||
							
								
								
									
										85
									
								
								honeypot/store/cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								honeypot/store/cache.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,85 @@ | ||||
| package store | ||||
|  | ||||
| import "github.uio.no/torjus/apiary/models" | ||||
|  | ||||
| type CachingStore struct { | ||||
| 	backend LoginAttemptStore | ||||
|  | ||||
| 	totalStatsCache    []StatsResult | ||||
| 	usernameQueryCache map[string][]models.LoginAttempt | ||||
| 	passwordQueryCache map[string][]models.LoginAttempt | ||||
| 	ipQueryCache       map[string][]models.LoginAttempt | ||||
| } | ||||
|  | ||||
| func NewCachingStore(backend LoginAttemptStore) *CachingStore { | ||||
| 	return &CachingStore{ | ||||
| 		backend:            backend, | ||||
| 		usernameQueryCache: make(map[string][]models.LoginAttempt), | ||||
| 		passwordQueryCache: make(map[string][]models.LoginAttempt), | ||||
| 		ipQueryCache:       make(map[string][]models.LoginAttempt), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error { | ||||
| 	s.totalStatsCache = nil | ||||
| 	delete(s.ipQueryCache, l.RemoteIP.String()) | ||||
| 	delete(s.passwordQueryCache, l.Password) | ||||
| 	delete(s.usernameQueryCache, l.Username) | ||||
| 	return s.backend.AddAttempt(l) | ||||
| } | ||||
|  | ||||
| func (s *CachingStore) All() ([]models.LoginAttempt, error) { | ||||
| 	return s.backend.All() | ||||
| } | ||||
|  | ||||
| func (s *CachingStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { | ||||
| 	// Only cache totals for now, as they are the most queried | ||||
| 	if statType == LoginStatsTotals { | ||||
| 		if s.totalStatsCache != nil { | ||||
| 			return s.totalStatsCache, nil | ||||
| 		} | ||||
| 		stats, err := s.backend.Stats(statType, limit) | ||||
| 		if err != nil { | ||||
| 			return stats, err | ||||
| 		} | ||||
| 		s.totalStatsCache = stats | ||||
|  | ||||
| 		return stats, err | ||||
| 	} | ||||
|  | ||||
| 	// TODO: Maybe cache the default limits | ||||
|  | ||||
| 	return s.backend.Stats(statType, limit) | ||||
| } | ||||
|  | ||||
| func (s *CachingStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) { | ||||
| 	switch query.QueryType { | ||||
| 	case AttemptQueryTypeIP: | ||||
| 		if attempts, ok := s.ipQueryCache[query.Query]; ok { | ||||
| 			return attempts, nil | ||||
| 		} | ||||
| 	case AttemptQueryTypePassword: | ||||
| 		if attempts, ok := s.passwordQueryCache[query.Query]; ok { | ||||
| 			return attempts, nil | ||||
| 		} | ||||
| 	case AttemptQueryTypeUsername: | ||||
| 		if attempts, ok := s.usernameQueryCache[query.Query]; ok { | ||||
| 			return attempts, nil | ||||
| 		} | ||||
| 	} | ||||
| 	attempts, err := s.backend.Query(query) | ||||
| 	if err != nil { | ||||
| 		return attempts, err | ||||
| 	} | ||||
|  | ||||
| 	switch query.QueryType { | ||||
| 	case AttemptQueryTypeIP: | ||||
| 		s.ipQueryCache[query.Query] = attempts | ||||
| 	case AttemptQueryTypeUsername: | ||||
| 		s.ipQueryCache[query.Query] = attempts | ||||
| 	case AttemptQueryTypePassword: | ||||
| 		s.ipQueryCache[query.Query] = attempts | ||||
| 	} | ||||
|  | ||||
| 	return attempts, err | ||||
| } | ||||
							
								
								
									
										13
									
								
								honeypot/store/cache_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								honeypot/store/cache_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| package store_test | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.uio.no/torjus/apiary/honeypot/store" | ||||
| ) | ||||
|  | ||||
| func TestCacheStore(t *testing.T) { | ||||
| 	backend := &store.MemoryStore{} | ||||
| 	s := store.NewCachingStore(backend) | ||||
| 	testLoginAttemptStore(s, t) | ||||
| } | ||||
| @@ -57,7 +57,7 @@ func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, err | ||||
| 		case LoginStatsUsername: | ||||
| 			counts[a.Username]++ | ||||
| 		default: | ||||
| 			return nil, fmt.Errorf("Invalid stat type") | ||||
| 			return nil, fmt.Errorf("invalid stat type") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
|   | ||||
| @@ -1,119 +1,12 @@ | ||||
| package store | ||||
| package store_test | ||||
|  | ||||
| /* | ||||
| func TestStatItems(t *testing.T) { | ||||
| 	var tc = []struct { | ||||
| 		Input          StatItems | ||||
| 		ExpectedOutput StatItems | ||||
| 	}{ | ||||
| 		{ | ||||
| 			Input: StatItems{ | ||||
| 				{Key: "a", Count: 5}, | ||||
| 				{Key: "b", Count: 100}, | ||||
| 				{Key: "c", Count: 99}, | ||||
| 				{Key: "d", Count: 98}, | ||||
| 				{Key: "f", Count: 18}, | ||||
| 			}, | ||||
| 			ExpectedOutput: StatItems{ | ||||
| 				{Key: "a", Count: 5}, | ||||
| 				{Key: "f", Count: 18}, | ||||
| 				{Key: "d", Count: 98}, | ||||
| 				{Key: "c", Count: 99}, | ||||
| 				{Key: "b", Count: 100}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	for _, testCase := range tc { | ||||
| 		sort.Sort(testCase.Input) | ||||
| 	"github.uio.no/torjus/apiary/honeypot/store" | ||||
| ) | ||||
|  | ||||
| 		for i := range testCase.Input { | ||||
| 			if testCase.Input[i] != testCase.ExpectedOutput[i] { | ||||
| 				t.Fatalf("Not sorted correctly") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| func TestMemoryStore(t *testing.T) { | ||||
| 	s := &store.MemoryStore{} | ||||
| 	testLoginAttemptStore(s, t) | ||||
| } | ||||
|  | ||||
| func TestStats(t *testing.T) { | ||||
| 	var exampleAttempts = []models.LoginAttempt{ | ||||
| 		{Username: "root", Password: "root", Country: "NO"}, | ||||
| 		{Username: "root", Password: "root", Country: "US"}, | ||||
| 		{Username: "user", Password: "passWord", Country: "US"}, | ||||
| 		{Username: "ibm", Password: "ibm", Country: "US"}, | ||||
| 		{Username: "ubnt", Password: "ubnt", Country: "GB"}, | ||||
| 		{Username: "ubnt", Password: "ubnt", Country: "FI"}, | ||||
| 		{Username: "root", Password: "root", Country: "CH"}, | ||||
| 		{Username: "ubnt", Password: "12345", Country: "DE"}, | ||||
| 		{Username: "oracle", Password: "oracle", Country: "FI"}, | ||||
| 	} | ||||
| 	var tc = []struct { | ||||
| 		Attempts       []models.LoginAttempt | ||||
| 		StatType       LoginStats | ||||
| 		Limit          int | ||||
| 		ExpectedOutput map[string]int | ||||
| 	}{ | ||||
| 		{ | ||||
| 			Attempts: exampleAttempts, | ||||
| 			StatType: LoginStatsPasswords, | ||||
| 			Limit:    2, | ||||
| 			ExpectedOutput: map[string]int{ | ||||
| 				"root": 3, | ||||
| 				"ubnt": 2, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Attempts: exampleAttempts, | ||||
| 			StatType: LoginStatsPasswords, | ||||
| 			Limit:    1, | ||||
| 			ExpectedOutput: map[string]int{ | ||||
| 				"root": 3, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Attempts: exampleAttempts, | ||||
| 			StatType: LoginStatsCountry, | ||||
| 			Limit:    2, | ||||
| 			ExpectedOutput: map[string]int{ | ||||
| 				"US": 3, | ||||
| 				"FI": 2, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			Attempts: exampleAttempts, | ||||
| 			StatType: LoginStatsCountry, | ||||
| 			Limit:    0, | ||||
| 			ExpectedOutput: map[string]int{ | ||||
| 				"US": 3, | ||||
| 				"FI": 2, | ||||
| 				"NO": 1, | ||||
| 				"GB": 1, | ||||
| 				"CH": 1, | ||||
| 				"DE": 1, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, c := range tc { | ||||
| 		ms := MemoryStore{} | ||||
| 		for _, a := range c.Attempts { | ||||
| 			if err := ms.AddAttempt(a); err != nil { | ||||
| 				t.Fatalf("Unable to add attempt: %s", err) | ||||
| 			} | ||||
| 		} | ||||
| 		stats, err := ms.Stats(c.StatType, c.Limit) | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("Error getting stats: %s", err) | ||||
| 		} | ||||
| 		if len(stats) != len(c.ExpectedOutput) { | ||||
| 			t.Fatalf("Stats have wrong length") | ||||
| 		} | ||||
| 		for key := range stats { | ||||
| 			if c.ExpectedOutput[key] != stats[key] { | ||||
| 				t.Fatalf("Stats does not match expected output") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| } | ||||
| */ | ||||
|   | ||||
| @@ -177,12 +177,12 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) | ||||
| 					FROM login_attempts WHERE username like $1` | ||||
| 		queryString = fmt.Sprintf("%%%s%%", queryString) | ||||
| 	default: | ||||
| 		return nil, fmt.Errorf("Invalid query type") | ||||
| 		return nil, fmt.Errorf("invalid query type") | ||||
| 	} | ||||
|  | ||||
| 	rows, err := s.db.Query(stmt, queryString) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("Unable to query database: %w", err) | ||||
| 		return nil, fmt.Errorf("unable to query database: %w", err) | ||||
| 	} | ||||
| 	defer rows.Close() | ||||
|  | ||||
| @@ -191,7 +191,7 @@ func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) | ||||
| 		var la models.LoginAttempt | ||||
| 		var ipString string | ||||
| 		if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil { | ||||
| 			return nil, fmt.Errorf("Unable to unmarshal data from database: %w", err) | ||||
| 			return nil, fmt.Errorf("unable to unmarshal data from database: %w", err) | ||||
| 		} | ||||
| 		la.RemoteIP = net.ParseIP(ipString) | ||||
| 		results = append(results, la) | ||||
|   | ||||
							
								
								
									
										97
									
								
								honeypot/store/store_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								honeypot/store/store_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,97 @@ | ||||
| package store_test | ||||
|  | ||||
| import ( | ||||
| 	"math/rand" | ||||
| 	"net" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/uuid" | ||||
| 	"github.uio.no/torjus/apiary/honeypot/store" | ||||
| 	"github.uio.no/torjus/apiary/models" | ||||
| ) | ||||
|  | ||||
| func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) { | ||||
| 	t.Run("Simple", func(t *testing.T) { | ||||
| 		testAttempts := randomAttempts(10) | ||||
|  | ||||
| 		for _, attempt := range testAttempts { | ||||
| 			if err := s.AddAttempt(attempt); err != nil { | ||||
| 				t.Fatalf("Error adding attempt: %s", err) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		all, err := s.All() | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("Error getting all attempts: %s", err) | ||||
| 		} | ||||
| 		if len(all) != len(testAttempts) { | ||||
| 			t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts)) | ||||
| 		} | ||||
| 		stats, err := s.Stats(store.LoginStatsTotals, 1) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("Stats returned error: %s", err) | ||||
| 		} | ||||
| 		for _, stat := range stats { | ||||
| 			if stat.Name == "TotalLoginAttempts" && stat.Count != len(testAttempts) { | ||||
| 				t.Errorf("Stats for total attempts is wrong. Got %d want %d", stat.Count, len(testAttempts)) | ||||
| 			} | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func randomAttempts(count int) []*models.LoginAttempt { | ||||
| 	var attempts []*models.LoginAttempt | ||||
| 	for i := 0; i < count; i++ { | ||||
| 		attempt := &models.LoginAttempt{ | ||||
| 			Date:             time.Now(), | ||||
| 			RemoteIP:         randomIP(), | ||||
| 			Username:         randomString(8), | ||||
| 			Password:         randomString(8), | ||||
| 			Country:          randomCountry(), | ||||
| 			ConnectionUUID:   uuid.Must(uuid.NewRandom()), | ||||
| 			SSHClientVersion: "SSH TEST LOL", | ||||
| 		} | ||||
| 		attempts = append(attempts, attempt) | ||||
| 	} | ||||
| 	return attempts | ||||
| } | ||||
|  | ||||
| func randomIP() net.IP { | ||||
| 	a := byte(rand.Intn(254)) | ||||
| 	b := byte(rand.Intn(254)) | ||||
| 	c := byte(rand.Intn(254)) | ||||
| 	d := byte(rand.Intn(254)) | ||||
| 	return net.IPv4(a, b, c, d) | ||||
| } | ||||
|  | ||||
| func randomString(n int) string { | ||||
| 	const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" | ||||
|  | ||||
| 	b := make([]byte, n) | ||||
| 	for i := range b { | ||||
| 		b[i] = letterBytes[rand.Intn(len(letterBytes))] | ||||
| 	} | ||||
| 	return string(b) | ||||
| } | ||||
|  | ||||
| func randomCountry() string { | ||||
| 	switch rand.Intn(10) { | ||||
| 	case 1: | ||||
| 		return "CN" | ||||
| 	case 2: | ||||
| 		return "US" | ||||
| 	case 3: | ||||
| 		return "NO" | ||||
| 	case 4: | ||||
| 		return "RU" | ||||
| 	case 5: | ||||
| 		return "DE" | ||||
| 	case 6: | ||||
| 		return "FI" | ||||
| 	case 7: | ||||
| 		return "BR" | ||||
| 	default: | ||||
| 		return "SE" | ||||
| 	} | ||||
| } | ||||
| @@ -5,7 +5,7 @@ import ( | ||||
| 	"runtime" | ||||
| ) | ||||
|  | ||||
| var Version = "v0.1.7" | ||||
| var Version = "v0.1.8" | ||||
| var Build string | ||||
|  | ||||
| func FullVersion() string { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user