package web

import (
	"net/http"
	"time"

	"github.com/go-chi/chi/v5/middleware"
	"github.com/prometheus/client_golang/prometheus"
)

// LoggingMiddleware is used for logging info about requests to the servers configured accesslogger.
func (s *Server) LoggingMiddleware(next http.Handler) http.Handler {
	fn := func(w http.ResponseWriter, r *http.Request) {
		ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
		t1 := time.Now()

		reqID := middleware.GetReqID(r.Context())

		defer func() {
			// If AccessLogIgnoreMetrics is true, do not log successful requests to metrics endpoint
			if s.cfg.AccessLogIgnoreMetrics && r.URL.Path == "/metrics" && ww.Status() == http.StatusOK {
				return
			}
			s.AccessLogger.Infow(r.Method,
				"path", r.URL.Path,
				"status", ww.Status(),
				"written", ww.BytesWritten(),
				"remote_addr", r.RemoteAddr,
				"processing_time_ms", time.Since(t1).Milliseconds(),
				"req_id", reqID)
		}()

		next.ServeHTTP(ww, r)
	}
	return http.HandlerFunc(fn)
}

type MetricsMiddleware struct {
	requests *prometheus.CounterVec
	latency  *prometheus.HistogramVec
}

func (mm *MetricsMiddleware) handler(next http.Handler) http.Handler {
	fn := func(w http.ResponseWriter, r *http.Request) {
		ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
		start := time.Now()
		defer func() {
			mm.requests.WithLabelValues(http.StatusText(ww.Status()), r.Method, r.URL.Path).Inc()
			mm.latency.WithLabelValues(http.StatusText(ww.Status()), r.Method, r.URL.Path).Observe(float64(time.Since(start).Milliseconds()))
		}()
		next.ServeHTTP(ww, r)
	}

	return http.HandlerFunc(fn)
}

func NewMetricsMiddleware() func(next http.Handler) http.Handler {
	mm := &MetricsMiddleware{}
	mm.requests = prometheus.NewCounterVec(prometheus.CounterOpts{
		Name:        "apiary_http_requests_total",
		Help:        "Total requests processed.",
		ConstLabels: prometheus.Labels{"service": "http"},
	},
		[]string{"code", "method", "path"},
	)

	mm.latency = prometheus.NewHistogramVec(
		prometheus.HistogramOpts{
			Name:        "apiary_http_request_duration_milliseconds",
			Help:        "Request processing time.",
			ConstLabels: prometheus.Labels{"service": "http"},
			Buckets:     []float64{100, 500, 1500},
		},
		[]string{"code", "method", "path"},
	)
	prometheus.MustRegister(mm.requests)
	prometheus.MustRegister(mm.latency)
	return mm.handler
}