package store import ( "database/sql" "fmt" _ "github.com/jackc/pgx/v4/stdlib" "github.uio.no/torjus/apiary/models" ) const DBSchema = ` CREATE TABLE IF NOT EXISTS login_attempts( id serial PRIMARY KEY, date timestamptz, remote_ip inet, username text, password text, client_version text, connection_uuid uuid, country varchar(2) ); ` type PostgresStore struct { db *sql.DB } func NewPostgresStore(dsn string) (*PostgresStore, error) { db, err := sql.Open("pgx", dsn) if err != nil { return nil, err } rs := &PostgresStore{ db: db, } return rs, nil } func (s *PostgresStore) InitDB() error { _, err := s.db.Exec(DBSchema) return err } func (s *PostgresStore) AddAttempt(l *models.LoginAttempt) error { stmt := `INSERT INTO login_attempts(date, remote_ip, username, password, client_version, connection_uuid, country) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id;` tx, err := s.db.Begin() if err != nil { return err } defer tx.Rollback() var id int if err := tx.QueryRow(stmt, l.Date, l.RemoteIP.String(), l.Username, l.Password, l.SSHClientVersion, l.ConnectionUUID, l.Country).Scan(&id); err != nil { return err } l.ID = id return tx.Commit() } func (s *PostgresStore) All() ([]models.LoginAttempt, error) { stmt := `SELECT date, remote_ip, username, password, client_version, connection_uuid, country FROM login_attempts` rows, err := s.db.Query(stmt) if err != nil { return nil, err } defer rows.Close() var attempts []models.LoginAttempt for rows.Next() { var a models.LoginAttempt if err := rows.Scan(&a.Date, &a.RemoteIP, &a.Username, &a.Password, &a.SSHClientVersion, &a.SSHClientVersion, &a.Country); err != nil { return nil, err } attempts = append(attempts, a) } return attempts, nil } func (s *PostgresStore) Stats(statType LoginStats, limit int) ([]StatsResult, error) { var stmt string if statType == LoginStatsTotals { return s.statsTotal(limit) } switch statType { case LoginStatsCountry: stmt = `select country, count(country) from login_attempts order by count desc` case LoginStatsIP: stmt = `select remote_ip, count(remote_ip) from login_attempts order by count desc` case LoginStatsPasswords: stmt = `select password, count(password) from login_attempts order by count desc` case LoginStatsUsername: stmt = `select username, count(username) from login_attempts order by count desc` } if limit > 0 { stmt = fmt.Sprintf("%s limit %d", stmt, limit) } rows, err := s.db.Query(stmt) if err != nil { return nil, err } defer rows.Close() var results []StatsResult for rows.Next() { var r StatsResult if err := rows.Scan(&r.Name, &r.Count); err != nil { return nil, err } results = append(results, r) } return results, nil } func (s *PostgresStore) statsTotal(limit int) ([]StatsResult, error) { uniquePasswordStmt := `select count(*) from (select distinct password from login_attempts) as temp` uniqueUsernameStmt := `select count(*) from (select distinct username from login_attempts) as temp` uniqueIPStmt := `select count(*) from (select distinct remote_ip from login_attempts) as temp` uniqueCountryStmt := `select count(*) from (select distinct country from login_attempts) as temp` attemptsCountStmt := `select count(1) from login_attempts` var uniquePasswordsCount int if err := s.db.QueryRow(uniquePasswordStmt).Scan(&uniquePasswordsCount); err != nil { return nil, err } var uniqueUsernameCount int if err := s.db.QueryRow(uniqueUsernameStmt).Scan(&uniqueUsernameCount); err != nil { return nil, err } var uniqueIPCount int if err := s.db.QueryRow(uniqueIPStmt).Scan(&uniqueIPCount); err != nil { return nil, err } var uniqueCountryCount int if err := s.db.QueryRow(uniqueCountryStmt).Scan(&uniqueCountryCount); err != nil { return nil, err } var attemptsCount int if err := s.db.QueryRow(attemptsCountStmt).Scan(&attemptsCount); err != nil { return nil, err } return []StatsResult{ {Name: "UniquePasswords", Count: uniquePasswordsCount}, {Name: "UniqueUsernames", Count: uniqueUsernameCount}, {Name: "UniqueIPs", Count: uniqueIPCount}, {Name: "UniqueCountries", Count: uniqueCountryCount}, {Name: "TotalLoginAttempts", Count: attemptsCount}, }, nil }