feature/use-channel-for-all #2
| @@ -18,11 +18,12 @@ func TestServer(t *testing.T) { | |||||||
| 	server := ports.New(store) | 	server := ports.New(store) | ||||||
| 	server.IP = "127.0.0.1" | 	server.IP = "127.0.0.1" | ||||||
|  |  | ||||||
| 	server.AddTCPPort("25") | 	server.AddTCPPort("2555") | ||||||
|  |  | ||||||
| 	go server.Start(ctx) | 	go server.Start(ctx) | ||||||
|  | 	time.Sleep(1 * time.Second) | ||||||
|  |  | ||||||
| 	rAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(server.IP, "25")) | 	rAddr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(server.IP, "2555")) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("Error resolving remote address: %s", err) | 		t.Fatalf("Error resolving remote address: %s", err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -2,6 +2,8 @@ package store | |||||||
|  |  | ||||||
| import "git.t-juice.club/torjus/apiary/models" | import "git.t-juice.club/torjus/apiary/models" | ||||||
|  |  | ||||||
|  | var _ LoginAttemptStore = &CachingStore{} | ||||||
|  |  | ||||||
| type CachingStore struct { | type CachingStore struct { | ||||||
| 	backend LoginAttemptStore | 	backend LoginAttemptStore | ||||||
|  |  | ||||||
| @@ -33,14 +35,16 @@ func NewCachingStore(backend LoginAttemptStore) *CachingStore { | |||||||
| 		//TODO: Handle better maybe? | 		//TODO: Handle better maybe? | ||||||
| 		panic(err) | 		panic(err) | ||||||
| 	} | 	} | ||||||
|  | 	var loginCount int | ||||||
|  |  | ||||||
| 	cs.totalLoginsCount = len(all) | 	for attempt := range all { | ||||||
| 	for _, attempt := range all { |  | ||||||
| 		cs.uniqueUsernames[attempt.Username] = struct{}{} | 		cs.uniqueUsernames[attempt.Username] = struct{}{} | ||||||
| 		cs.uniquePasswords[attempt.Password] = struct{}{} | 		cs.uniquePasswords[attempt.Password] = struct{}{} | ||||||
| 		cs.uniqueIPs[attempt.RemoteIP.String()] = struct{}{} | 		cs.uniqueIPs[attempt.RemoteIP.String()] = struct{}{} | ||||||
| 		cs.uniqueCountries[attempt.Country] = struct{}{} | 		cs.uniqueCountries[attempt.Country] = struct{}{} | ||||||
|  | 		loginCount++ | ||||||
| 	} | 	} | ||||||
|  | 	cs.totalLoginsCount = loginCount | ||||||
|  |  | ||||||
| 	return cs | 	return cs | ||||||
| } | } | ||||||
| @@ -58,7 +62,7 @@ func (s *CachingStore) AddAttempt(l *models.LoginAttempt) error { | |||||||
| 	return s.backend.AddAttempt(l) | 	return s.backend.AddAttempt(l) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *CachingStore) All() ([]models.LoginAttempt, error) { | func (s *CachingStore) All() (<-chan models.LoginAttempt, error) { | ||||||
| 	return s.backend.All() | 	return s.backend.All() | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -9,6 +9,8 @@ import ( | |||||||
| 	"git.t-juice.club/torjus/apiary/models" | 	"git.t-juice.club/torjus/apiary/models" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var _ LoginAttemptStore = &MemoryStore{} | ||||||
|  |  | ||||||
| type MemoryStore struct { | type MemoryStore struct { | ||||||
| 	lock      sync.RWMutex | 	lock      sync.RWMutex | ||||||
| 	attempts  []models.LoginAttempt | 	attempts  []models.LoginAttempt | ||||||
| @@ -32,8 +34,17 @@ func (ms *MemoryStore) AddAttempt(l *models.LoginAttempt) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ms *MemoryStore) All() ([]models.LoginAttempt, error) { | func (ms *MemoryStore) All() (<-chan models.LoginAttempt, error) { | ||||||
| 	return ms.attempts, nil | 	ch := make(chan models.LoginAttempt) | ||||||
|  | 	go func() { | ||||||
|  | 		ms.lock.RLock() | ||||||
|  | 		defer ms.lock.RUnlock() | ||||||
|  | 		for _, attempt := range ms.attempts { | ||||||
|  | 			ch <- attempt | ||||||
|  | 		} | ||||||
|  | 		close(ch) | ||||||
|  | 	}() | ||||||
|  | 	return ch, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { | func (ms *MemoryStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { | ||||||
|   | |||||||
| @@ -8,6 +8,8 @@ import ( | |||||||
| 	"github.com/prometheus/client_golang/prometheus" | 	"github.com/prometheus/client_golang/prometheus" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var _ LoginAttemptStore = &MetricsCollectingStore{} | ||||||
|  |  | ||||||
| const tickDuration = 5 * time.Second | const tickDuration = 5 * time.Second | ||||||
|  |  | ||||||
| type MetricsCollectingStore struct { | type MetricsCollectingStore struct { | ||||||
| @@ -89,7 +91,7 @@ func (s *MetricsCollectingStore) AddAttempt(l *models.LoginAttempt) error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *MetricsCollectingStore) All() ([]models.LoginAttempt, error) { | func (s *MetricsCollectingStore) All() (<-chan models.LoginAttempt, error) { | ||||||
| 	return s.store.All() | 	return s.store.All() | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -9,6 +9,8 @@ import ( | |||||||
| 	_ "github.com/jackc/pgx/v4/stdlib" | 	_ "github.com/jackc/pgx/v4/stdlib" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var _ LoginAttemptStore = &PostgresStore{} | ||||||
|  |  | ||||||
| const DBSchema = ` | const DBSchema = ` | ||||||
| CREATE TABLE IF NOT EXISTS login_attempts( | CREATE TABLE IF NOT EXISTS login_attempts( | ||||||
| 	id serial PRIMARY KEY, | 	id serial PRIMARY KEY, | ||||||
| @@ -61,27 +63,30 @@ RETURNING id;` | |||||||
| 	return tx.Commit() | 	return tx.Commit() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *PostgresStore) All() ([]models.LoginAttempt, error) { | func (s *PostgresStore) All() (<-chan models.LoginAttempt, error) { | ||||||
| 	stmt := `SELECT date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts` | 	stmt := `SELECT date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts` | ||||||
|  |  | ||||||
| 	rows, err := s.db.Query(stmt) | 	rows, err := s.db.Query(stmt) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	defer rows.Close() |  | ||||||
|  |  | ||||||
| 	var attempts []models.LoginAttempt | 	ch := make(chan models.LoginAttempt) | ||||||
|  | 	go func() { | ||||||
|  | 		defer rows.Close() | ||||||
| 		for rows.Next() { | 		for rows.Next() { | ||||||
| 			var a models.LoginAttempt | 			var a models.LoginAttempt | ||||||
| 			var ip string | 			var ip string | ||||||
| 			if err := rows.Scan(&a.Date, &ip, &a.Username, &a.Password, &a.SSHClientVersion, &a.SSHClientVersion, &a.Country); err != nil { | 			if err := rows.Scan(&a.Date, &ip, &a.Username, &a.Password, &a.SSHClientVersion, &a.SSHClientVersion, &a.Country); err != nil { | ||||||
| 			return nil, err | 				panic(err) | ||||||
| 			} | 			} | ||||||
| 			a.RemoteIP = net.ParseIP(ip) | 			a.RemoteIP = net.ParseIP(ip) | ||||||
| 		attempts = append(attempts, a) | 			ch <- a | ||||||
| 		} | 		} | ||||||
|  | 		close(ch) | ||||||
|  | 	}() | ||||||
|  |  | ||||||
| 	return attempts, nil | 	return ch, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *PostgresStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { | func (s *PostgresStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { | ||||||
|   | |||||||
| @@ -38,7 +38,7 @@ type AttemptQuery struct { | |||||||
| } | } | ||||||
| type LoginAttemptStore interface { | type LoginAttemptStore interface { | ||||||
| 	AddAttempt(l *models.LoginAttempt) error | 	AddAttempt(l *models.LoginAttempt) error | ||||||
| 	All() ([]models.LoginAttempt, error) | 	All() (<-chan models.LoginAttempt, error) | ||||||
| 	Stats(statType LoginStats, limit int) ([]StatsResult, error) | 	Stats(statType LoginStats, limit int) ([]StatsResult, error) | ||||||
| 	Query(query AttemptQuery) ([]models.LoginAttempt, error) | 	Query(query AttemptQuery) ([]models.LoginAttempt, error) | ||||||
| 	HealthCheker | 	HealthCheker | ||||||
|   | |||||||
| @@ -25,8 +25,12 @@ func testLoginAttemptStore(s store.LoginAttemptStore, t *testing.T) { | |||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Fatalf("Error getting all attempts: %s", err) | 			t.Fatalf("Error getting all attempts: %s", err) | ||||||
| 		} | 		} | ||||||
| 		if len(all) != len(testAttempts) { | 		var count int | ||||||
| 			t.Errorf("All returned wrong amount. Got %d want %d", len(all), len(testAttempts)) | 		for range all { | ||||||
|  | 			count++ | ||||||
|  | 		} | ||||||
|  | 		if count != len(testAttempts) { | ||||||
|  | 			t.Errorf("All returned wrong amount. Got %d want %d", count, len(testAttempts)) | ||||||
| 		} | 		} | ||||||
| 		stats, err := s.Stats(store.LoginStatsTotals, 1) | 		stats, err := s.Stats(store.LoginStatsTotals, 1) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -229,7 +233,11 @@ func benchmarkLoginAttemptStore(setupFunc func() store.LoginAttemptStore, b *tes | |||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				b.Fatalf("Error fetchin all: %s", err) | 				b.Fatalf("Error fetchin all: %s", err) | ||||||
| 			} | 			} | ||||||
| 			_ = len(all) | 			var count int | ||||||
|  | 			for range all { | ||||||
|  | 				count++ | ||||||
|  | 			} | ||||||
|  | 			_ = count | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user