package web import ( "context" "encoding/json" "fmt" "io" "net/http" "strconv" "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" "github.uio.no/torjus/apiary/config" "github.uio.no/torjus/apiary/honeypot" "github.uio.no/torjus/apiary/honeypot/store" "github.uio.no/torjus/apiary/models" "go.uber.org/zap" ) type Server struct { http.Server store store.LoginAttemptStore ServerLogger *zap.SugaredLogger AccessLogger *zap.SugaredLogger honeypotServer *honeypot.HoneypotServer attemptListenersLock sync.RWMutex attemptListeners map[string]chan models.LoginAttempt streamContext context.Context } func NewServer(cfg config.FrontendConfig, hs *honeypot.HoneypotServer, store store.LoginAttemptStore) *Server { s := &Server{ ServerLogger: zap.NewNop().Sugar(), AccessLogger: zap.NewNop().Sugar(), store: store, } s.Addr = cfg.ListenAddr r := chi.NewRouter() // Setup middleware r.Use(middleware.RealIP) r.Use(middleware.RequestID) r.Use(s.LoggingMiddleware) r.Route("/", func(r chi.Router) { r.Get("/*", s.IndexHandler("web/vue-frontend/dist")) r.Get("/stream", s.HandlerAttemptStream) r.Route("/api", func(r chi.Router) { r.Use(middleware.SetHeader("Content-Type", "application/json")) r.Get("/stats", s.HandlerStats) }) }) s.Handler = r s.honeypotServer = hs s.attemptListeners = make(map[string]chan models.LoginAttempt) streamCtx, streamCancel := context.WithCancel(context.Background()) s.streamContext = streamCtx s.RegisterOnShutdown(func() { streamCancel() }) hs.AddLoginCallback(func(l models.LoginAttempt) { s.attemptListenersLock.RLock() defer s.attemptListenersLock.RUnlock() for i := range s.attemptListeners { s.attemptListeners[i] <- l } }) return s } func (s *Server) addAttemptListener() (string, chan models.LoginAttempt) { ch := make(chan models.LoginAttempt) s.attemptListenersLock.Lock() defer s.attemptListenersLock.Unlock() id := uuid.Must(uuid.NewRandom()) s.attemptListeners[id.String()] = ch return id.String(), ch } func (s *Server) closeAttemptListener(id string) { s.attemptListenersLock.Lock() defer s.attemptListenersLock.Unlock() ch := s.attemptListeners[id] close(ch) delete(s.attemptListeners, id) } func (s *Server) HandlerAttemptStream(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") id, ch := s.addAttemptListener() defer s.closeAttemptListener(id) flusher := w.(http.Flusher) ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case l := <-ch: data, err := json.Marshal(l) if err != nil { return } _, err = io.WriteString(w, fmt.Sprintf("data: %s\n\n", string(data))) if err != nil { s.ServerLogger.Warnw("Error writing event", "error", err) } flusher.Flush() case <-s.streamContext.Done(): return case <-r.Context().Done(): return case <-ticker.C: if _, err := io.WriteString(w, ": %s\n\n"); err != nil { s.ServerLogger.Warnw("Error writing event", "error", err) } } } } func (s *Server) HandlerStats(w http.ResponseWriter, r *http.Request) { statType := store.LoginStats(r.URL.Query().Get("type")) if statType == store.LoginStatsUndefined { statType = store.LoginStatsPasswords } var limit int limitString := r.URL.Query().Get("limit") limit, err := strconv.Atoi(limitString) if err != nil { limit = 0 } stats, err := s.store.Stats(statType, limit) if err != nil { s.WriteAPIError(w, r, http.StatusInternalServerError, "Error fetching stats") return } encoder := json.NewEncoder(w) if err := encoder.Encode(stats); err != nil { s.ServerLogger.Debugf("Error encoding or writing response", "remote_ip", r.RemoteAddr, "error", err) } } type APIErrorResponse struct { Error string `json:"error"` } func (s *Server) WriteAPIError(w http.ResponseWriter, r *http.Request, status int, message string) { encoder := json.NewEncoder(w) apiErr := APIErrorResponse{Error: message} w.WriteHeader(status) if err := encoder.Encode(&apiErr); err != nil { s.ServerLogger.Debugf("Error encoding or writing error response", "remote_ip", r.RemoteAddr, "error", err) } }