package store import ( "database/sql" "fmt" "net" "git.t-juice.club/torjus/apiary/models" _ "github.com/jackc/pgx/v4/stdlib" ) var _ LoginAttemptStore = &PostgresStore{} const DBSchema = ` CREATE TABLE IF NOT EXISTS login_attempts( id serial PRIMARY KEY, date timestamptz, remote_ip inet, username text, password text, client_version text, connection_uuid uuid, country varchar(2) ); ` type PostgresStore struct { db *sql.DB } func NewPostgresStore(dsn string) (*PostgresStore, error) { db, err := sql.Open("pgx", dsn) if err != nil { return nil, err } rs := &PostgresStore{ db: db, } return rs, nil } func (s *PostgresStore) InitDB() error { _, err := s.db.Exec(DBSchema) return err } func (s *PostgresStore) AddAttempt(l *models.LoginAttempt) error { stmt := `INSERT INTO login_attempts(date, remote_ip, username, password, client_version, connection_uuid, country) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id;` tx, err := s.db.Begin() if err != nil { return err } defer tx.Rollback() // nolint: errcheck var id int if err := tx.QueryRow(stmt, l.Date, l.RemoteIP.String(), l.Username, l.Password, l.SSHClientVersion, l.ConnectionUUID, l.Country).Scan(&id); err != nil { return err } l.ID = id return tx.Commit() } func (s *PostgresStore) All() (<-chan models.LoginAttempt, error) { stmt := `SELECT date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts` rows, err := s.db.Query(stmt) if err != nil { return nil, err } ch := make(chan models.LoginAttempt) go func() { defer rows.Close() for rows.Next() { var a models.LoginAttempt var ip string if err := rows.Scan(&a.Date, &ip, &a.Username, &a.Password, &a.SSHClientVersion, &a.SSHClientVersion, &a.Country); err != nil { panic(err) } a.RemoteIP = net.ParseIP(ip) ch <- a } close(ch) }() return ch, nil } func (s *PostgresStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { var stmt string if statType == LoginStatsTotals { return s.statsTotal(limit) } switch statType { case LoginStatsCountry: stmt = `select country, count(country) from login_attempts group by country order by count desc` case LoginStatsIP: stmt = `select remote_ip, count(remote_ip) from login_attempts group by remote_ip order by count desc` case LoginStatsPasswords: stmt = `select password, count(password) from login_attempts group by password order by count desc` case LoginStatsUsername: stmt = `select username, count(username) from login_attempts group by username order by count desc` default: return nil, fmt.Errorf("invalid stat type") } if limit > 0 { stmt = fmt.Sprintf("%s limit %d", stmt, limit) } rows, err := s.db.Query(stmt) if err != nil { return nil, err } defer rows.Close() var results []StatsResult for rows.Next() { var r StatsResult if err := rows.Scan(&r.Name, &r.Count); err != nil { return nil, err } results = append(results, r) } return results, nil } func (s *PostgresStore) IsHealthy() error { if err := s.db.Ping(); err != nil { return ErrStoreUnhealthy } return nil } func (s *PostgresStore) statsTotal(limit int) ([]StatsResult, error) { uniquePasswordStmt := `select count(*) from (select distinct password from login_attempts) as temp` uniqueUsernameStmt := `select count(*) from (select distinct username from login_attempts) as temp` uniqueIPStmt := `select count(*) from (select distinct remote_ip from login_attempts) as temp` uniqueCountryStmt := `select count(*) from (select distinct country from login_attempts) as temp` attemptsCountStmt := `select count(1) from login_attempts` var uniquePasswordsCount int if err := s.db.QueryRow(uniquePasswordStmt).Scan(&uniquePasswordsCount); err != nil { return nil, err } var uniqueUsernameCount int if err := s.db.QueryRow(uniqueUsernameStmt).Scan(&uniqueUsernameCount); err != nil { return nil, err } var uniqueIPCount int if err := s.db.QueryRow(uniqueIPStmt).Scan(&uniqueIPCount); err != nil { return nil, err } var uniqueCountryCount int if err := s.db.QueryRow(uniqueCountryStmt).Scan(&uniqueCountryCount); err != nil { return nil, err } var attemptsCount int if err := s.db.QueryRow(attemptsCountStmt).Scan(&attemptsCount); err != nil { return nil, err } return []StatsResult{ {Name: "UniquePasswords", Count: uniquePasswordsCount}, {Name: "UniqueUsernames", Count: uniqueUsernameCount}, {Name: "UniqueIPs", Count: uniqueIPCount}, {Name: "UniqueCountries", Count: uniqueCountryCount}, {Name: "TotalLoginAttempts", Count: attemptsCount}, }, nil } func (s *PostgresStore) Query(query AttemptQuery) ([]models.LoginAttempt, error) { var stmt string queryString := query.Query const limit = 10000 switch query.QueryType { case AttemptQueryTypeIP: stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts WHERE remote_ip = $1 order by date desc limit $2` case AttemptQueryTypePassword: stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts WHERE password = $1 order by date desc limit $2` case AttemptQueryTypeUsername: stmt = `SELECT id, date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts WHERE username = $1 order by date desc limit $2` default: return nil, fmt.Errorf("invalid query type") } rows, err := s.db.Query(stmt, queryString, limit) if err != nil { return nil, fmt.Errorf("unable to query database: %w", err) } defer rows.Close() var results []models.LoginAttempt for rows.Next() { var la models.LoginAttempt var ipString string if err := rows.Scan(&la.ID, &la.Date, &ipString, &la.Username, &la.Password, &la.SSHClientVersion, &la.ConnectionUUID, &la.Country); err != nil { return nil, fmt.Errorf("unable to unmarshal data from database: %w", err) } la.RemoteIP = net.ParseIP(ipString) results = append(results, la) } return results, nil }