package web import ( "context" "crypto/tls" "encoding/json" "fmt" "io" "net/http" "strconv" "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" "github.uio.no/torjus/apiary/config" "github.uio.no/torjus/apiary/honeypot" "github.uio.no/torjus/apiary/honeypot/store" "github.uio.no/torjus/apiary/models" "go.uber.org/zap" "golang.org/x/crypto/acme/autocert" ) type Server struct { http.Server cfg config.FrontendConfig store store.LoginAttemptStore ServerLogger *zap.SugaredLogger AccessLogger *zap.SugaredLogger honeypotServer *honeypot.HoneypotServer attemptListenersLock sync.RWMutex attemptListeners map[string]chan models.LoginAttempt streamContext context.Context httpRedirectServer http.Server } func NewServer(cfg config.FrontendConfig, hs *honeypot.HoneypotServer, store store.LoginAttemptStore) *Server { s := &Server{ ServerLogger: zap.NewNop().Sugar(), AccessLogger: zap.NewNop().Sugar(), store: store, cfg: cfg, } if cfg.Autocert.Enable { certManager := autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: autocert.HostWhitelist(cfg.Autocert.Domains...), Email: cfg.Autocert.Email, } if cfg.Autocert.CacheDir != "" { certManager.Cache = autocert.DirCache(cfg.Autocert.CacheDir) } tlsConfig := certManager.TLSConfig() tlsConfig.MinVersion = tls.VersionTLS12 s.TLSConfig = tlsConfig s.RegisterOnShutdown(func() { timeoutCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() s.httpRedirectServer.Shutdown(timeoutCtx) }) s.Addr = ":443" if cfg.Autocert.RedirectHTTP { s.httpRedirectServer.Addr = ":80" s.httpRedirectServer.Handler = certManager.HTTPHandler(nil) } } else { s.Addr = cfg.ListenAddr } r := chi.NewRouter() // Setup middleware r.Use(middleware.RealIP) r.Use(middleware.RequestID) r.Use(s.LoggingMiddleware) r.Route("/", func(r chi.Router) { r.Get("/*", s.IndexHandler("web/vue-frontend/dist")) r.Get("/stream", s.HandlerAttemptStream) r.Route("/api", func(r chi.Router) { r.Use(middleware.SetHeader("Content-Type", "application/json")) r.Get("/stats", s.HandlerStats) r.Get("/stream", s.HandlerAttemptStream) }) }) s.Handler = r s.honeypotServer = hs s.attemptListeners = make(map[string]chan models.LoginAttempt) streamCtx, streamCancel := context.WithCancel(context.Background()) s.streamContext = streamCtx s.RegisterOnShutdown(func() { streamCancel() }) hs.AddLoginCallback(func(l models.LoginAttempt) { s.attemptListenersLock.RLock() defer s.attemptListenersLock.RUnlock() for i := range s.attemptListeners { s.attemptListeners[i] <- l } }) return s } func (s *Server) StartServe() error { if s.cfg.Autocert.Enable { if s.cfg.Autocert.RedirectHTTP { s.ServerLogger.Debug("Starting HTTP redirect server") go func() { if err := s.httpRedirectServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { s.ServerLogger.Warnw("HTTP redirect server returned error", "error", err) } }() } return s.ListenAndServeTLS("", "") } else { return s.ListenAndServe() } } func (s *Server) addAttemptListener() (string, chan models.LoginAttempt) { ch := make(chan models.LoginAttempt) s.attemptListenersLock.Lock() defer s.attemptListenersLock.Unlock() id := uuid.Must(uuid.NewRandom()) s.attemptListeners[id.String()] = ch return id.String(), ch } func (s *Server) closeAttemptListener(id string) { s.attemptListenersLock.Lock() defer s.attemptListenersLock.Unlock() ch := s.attemptListeners[id] close(ch) delete(s.attemptListeners, id) } func (s *Server) HandlerAttemptStream(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") id, ch := s.addAttemptListener() defer s.closeAttemptListener(id) flusher := w.(http.Flusher) ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case l := <-ch: data, err := json.Marshal(l) if err != nil { return } _, err = io.WriteString(w, fmt.Sprintf("data: %s\n\n", string(data))) if err != nil { s.ServerLogger.Warnw("Error writing event", "error", err) } flusher.Flush() case <-s.streamContext.Done(): return case <-r.Context().Done(): return case <-ticker.C: if _, err := io.WriteString(w, ": keep-alive\n\n"); err != nil { s.ServerLogger.Warnw("Error writing event", "error", err) } flusher.Flush() } } } func (s *Server) HandlerStats(w http.ResponseWriter, r *http.Request) { statType := store.LoginStats(r.URL.Query().Get("type")) if statType == store.LoginStatsUndefined { statType = store.LoginStatsPasswords } var limit int limitString := r.URL.Query().Get("limit") limit, err := strconv.Atoi(limitString) if err != nil { limit = 0 } stats, err := s.store.Stats(statType, limit) if err != nil { s.ServerLogger.Warnw("Error fetching stats", "error", err) s.WriteAPIError(w, r, http.StatusInternalServerError, "Error fetching stats") return } encoder := json.NewEncoder(w) if err := encoder.Encode(stats); err != nil { s.ServerLogger.Debugf("Error encoding or writing response", "remote_ip", r.RemoteAddr, "error", err) } } type APIErrorResponse struct { Error string `json:"error"` } func (s *Server) WriteAPIError(w http.ResponseWriter, r *http.Request, status int, message string) { encoder := json.NewEncoder(w) apiErr := APIErrorResponse{Error: message} w.WriteHeader(status) if err := encoder.Encode(&apiErr); err != nil { s.ServerLogger.Debugf("Error encoding or writing error response", "remote_ip", r.RemoteAddr, "error", err) } }