apiary/web/server.go

285 lines
7.6 KiB
Go

package web
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"sync"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"github.uio.no/torjus/apiary"
"github.uio.no/torjus/apiary/config"
"github.uio.no/torjus/apiary/honeypot"
"github.uio.no/torjus/apiary/honeypot/store"
"github.uio.no/torjus/apiary/models"
"go.uber.org/zap"
"golang.org/x/crypto/acme/autocert"
)
const streamKeepAliveDuration = 30 * time.Second
type Server struct {
http.Server
httpRedirectServer http.Server
cfg config.FrontendConfig
honeypotServer *honeypot.HoneypotServer
store store.LoginAttemptStore
ServerLogger *zap.SugaredLogger
AccessLogger *zap.SugaredLogger
attemptListenersLock sync.RWMutex
attemptListeners map[string]chan models.LoginAttempt
streamContext context.Context
}
func NewServer(cfg config.FrontendConfig, hs *honeypot.HoneypotServer, store store.LoginAttemptStore) *Server {
s := &Server{
ServerLogger: zap.NewNop().Sugar(),
AccessLogger: zap.NewNop().Sugar(),
store: store,
cfg: cfg,
}
if cfg.Autocert.Enable {
certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(cfg.Autocert.Domains...),
Email: cfg.Autocert.Email,
}
if cfg.Autocert.CacheDir != "" {
certManager.Cache = autocert.DirCache(cfg.Autocert.CacheDir)
}
tlsConfig := certManager.TLSConfig()
tlsConfig.MinVersion = tls.VersionTLS12
s.TLSConfig = tlsConfig
s.RegisterOnShutdown(func() {
timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
s.httpRedirectServer.Shutdown(timeoutCtx)
})
s.Addr = ":443"
if cfg.Autocert.RedirectHTTP {
s.httpRedirectServer.Addr = ":80"
s.httpRedirectServer.Handler = certManager.HTTPHandler(nil)
}
} else {
s.Addr = cfg.ListenAddr
}
r := chi.NewRouter()
// Setup middleware
r.Use(middleware.RealIP)
r.Use(middleware.RequestID)
r.Use(s.LoggingMiddleware)
r.Use(middleware.SetHeader("Server", fmt.Sprintf("apiary/%s", apiary.FullVersion())))
r.Route("/", func(r chi.Router) {
r.Get("/*", s.IndexHandler("web/vue-frontend/dist"))
r.Get("/stream", s.HandlerAttemptStream)
r.Route("/api", func(r chi.Router) {
r.Get("/stats", s.HandlerStats)
r.Get("/stream", s.HandlerAttemptStream)
r.Get("/query", s.HandlerQuery)
})
})
s.Handler = r
s.honeypotServer = hs
s.attemptListeners = make(map[string]chan models.LoginAttempt)
streamCtx, streamCancel := context.WithCancel(context.Background())
s.streamContext = streamCtx
s.RegisterOnShutdown(func() {
streamCancel()
})
hs.AddLoginCallback(func(l models.LoginAttempt) {
s.attemptListenersLock.RLock()
defer s.attemptListenersLock.RUnlock()
for i := range s.attemptListeners {
s.attemptListeners[i] <- l
}
})
return s
}
func (s *Server) StartServe() error {
if s.cfg.Autocert.Enable {
if s.cfg.Autocert.RedirectHTTP {
s.ServerLogger.Debug("Starting HTTP redirect server")
go func() {
if err := s.httpRedirectServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
s.ServerLogger.Warnw("HTTP redirect server returned error", "error", err)
}
}()
}
return s.ListenAndServeTLS("", "")
} else {
return s.ListenAndServe()
}
}
func (s *Server) addAttemptListener() (string, chan models.LoginAttempt) {
ch := make(chan models.LoginAttempt)
s.attemptListenersLock.Lock()
defer s.attemptListenersLock.Unlock()
id := uuid.Must(uuid.NewRandom())
s.attemptListeners[id.String()] = ch
return id.String(), ch
}
func (s *Server) closeAttemptListener(id string) {
s.attemptListenersLock.Lock()
defer s.attemptListenersLock.Unlock()
ch := s.attemptListeners[id]
close(ch)
delete(s.attemptListeners, id)
}
func (s *Server) HandlerAttemptStream(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
id, ch := s.addAttemptListener()
defer s.closeAttemptListener(id)
w.WriteHeader(http.StatusOK)
flusher := w.(http.Flusher)
ticker := time.NewTicker(streamKeepAliveDuration)
defer ticker.Stop()
for {
select {
case l := <-ch:
data, err := json.Marshal(l)
if err != nil {
return
}
_, err = io.WriteString(w, fmt.Sprintf("data: %s\n\n", string(data)))
if err != nil {
s.ServerLogger.Warnw("Error writing event", "error", err)
}
flusher.Flush()
ticker.Reset(streamKeepAliveDuration)
case <-s.streamContext.Done():
return
case <-r.Context().Done():
return
case <-ticker.C:
if _, err := io.WriteString(w, ": keep-alive\n\n"); err != nil {
s.ServerLogger.Warnw("Error writing event", "error", err)
}
flusher.Flush()
}
}
}
func (s *Server) HandlerStats(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
statType := store.LoginStats(r.URL.Query().Get("type"))
if statType == store.LoginStatsUndefined {
statType = store.LoginStatsPasswords
}
var limit int
limitString := r.URL.Query().Get("limit")
limit, err := strconv.Atoi(limitString)
if err != nil {
limit = 0
}
stats, err := s.store.Stats(statType, limit)
if err != nil {
s.ServerLogger.Warnw("Error fetching stats", "error", err)
s.WriteAPIError(w, r, http.StatusInternalServerError, "Error fetching stats")
return
}
encoder := json.NewEncoder(w)
if err := encoder.Encode(stats); err != nil {
s.ServerLogger.Debugf("Error encoding or writing response", "remote_ip", r.RemoteAddr, "error", err)
}
}
func (s *Server) HandlerQuery(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
queryType := r.URL.Query().Get("type")
query := r.URL.Query().Get("query")
if query == "" {
s.WriteAPIError(w, r, http.StatusBadRequest, "Invalid query or query type")
return
}
results := []models.LoginAttempt{}
if queryType == "" {
uq := store.AttemptQuery{
QueryType: store.AttemptQueryType(store.AttemptQueryTypePassword),
Query: query,
}
pq := store.AttemptQuery{
QueryType: store.AttemptQueryType(store.AttemptQueryTypeUsername),
Query: query,
}
userResults, err := s.store.Query(uq)
if err != nil {
s.WriteAPIError(w, r, http.StatusInternalServerError, "Unable to perform query")
s.ServerLogger.Warnw("Error performing query", "error", err)
return
}
passwordResults, err := s.store.Query(pq)
if err != nil {
s.WriteAPIError(w, r, http.StatusInternalServerError, "Unable to perform query")
s.ServerLogger.Warnw("Error performing query", "error", err)
return
}
results = append(results, userResults...)
results = append(results, passwordResults...)
} else {
aq := store.AttemptQuery{
QueryType: store.AttemptQueryType(queryType),
Query: query,
}
queryResults, err := s.store.Query(aq)
if err != nil {
s.WriteAPIError(w, r, http.StatusInternalServerError, "Unable to perform query")
s.ServerLogger.Warnw("Error performing query", "error", err)
return
}
results = append(results, queryResults...)
}
encoder := json.NewEncoder(w)
if err := encoder.Encode(&results); err != nil {
s.ServerLogger.Warnw("Error writing query results", "error", err)
}
}
type APIErrorResponse struct {
Error string `json:"error"`
}
func (s *Server) WriteAPIError(w http.ResponseWriter, r *http.Request, status int, message string) {
encoder := json.NewEncoder(w)
apiErr := APIErrorResponse{Error: message}
w.WriteHeader(status)
if err := encoder.Encode(&apiErr); err != nil {
s.ServerLogger.Debugf("Error encoding or writing error response", "remote_ip", r.RemoteAddr, "error", err)
}
}