package api import ( "context" "fmt" "net/http" "strings" "time" "git.t-juice.club/torjus/gpaste" "git.t-juice.club/torjus/gpaste/users" "github.com/go-chi/chi/v5/middleware" "go.uber.org/zap" ) type authCtxKey int const ( authCtxUsername authCtxKey = iota authCtxAuthLevel authCtxClaims ) func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) t1 := time.Now() reqID := middleware.GetReqID(r.Context()) // TODO: Maybe desugar in HTTPServer to avoid doing for all requests logger := s.AccessLogger.Desugar() defer func() { // DEBUG level if ce := logger.Check(zap.DebugLevel, r.Method); ce != nil { ct := r.Header.Get("Content-Type") ce.Write( zap.String("req_id", reqID), zap.String("path", r.URL.Path), zap.Int("status", ww.Status()), zap.String("remote_addr", r.RemoteAddr), zap.Int("bytes_written", ww.BytesWritten()), zap.Duration("processing_time", time.Since(t1)), zap.String("content_type", ct), zap.Any("headers", r.Header), ) } else { // INFO level if ce := logger.Check(zap.InfoLevel, r.Method); ce != nil { ce.Write( zap.String("req_id", reqID), zap.String("path", r.URL.Path), zap.Int("status", ww.Status()), zap.String("remote_addr", r.RemoteAddr), zap.Int("bytes_written", ww.BytesWritten()), zap.Duration("processing_time", time.Since(t1)), ) } } _ = logger.Sync() }() next.ServeHTTP(ww, r) } return http.HandlerFunc(fn) } func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { reqID := middleware.GetReqID(r.Context()) header := r.Header.Get("Authorization") if header == "" { s.Logger.Debugw("Request has no auth header.", "req_id", reqID) next.ServeHTTP(w, r) return } splitHeader := strings.Split(header, "Bearer ") if len(splitHeader) != 2 { // nolint: gomnd s.Logger.Debugw("Request has invalid token.", "req_id", reqID) next.ServeHTTP(w, r) return } token := splitHeader[1] claims, err := s.Auth.ValidateToken(token) if err != nil { s.Logger.Debugw("Request has invalid token.", "req_id", reqID) next.ServeHTTP(w, r) return } ctx := context.WithValue(r.Context(), authCtxUsername, claims.Subject) ctx = context.WithValue(ctx, authCtxAuthLevel, claims.Role) ctx = context.WithValue(ctx, authCtxClaims, claims) withCtx := r.WithContext(ctx) s.Logger.Debugw("Request is authenticated.", "req_id", reqID, "username", claims.Subject, "role", claims.Role) next.ServeHTTP(w, withCtx) } return http.HandlerFunc(fn) } func UsernameFromRequest(r *http.Request) (string, error) { rawUsername := r.Context().Value(authCtxUsername) if rawUsername == nil { return "", fmt.Errorf("no username") } username, ok := rawUsername.(string) if !ok { return "", fmt.Errorf("no username") } return username, nil } func RoleFromRequest(r *http.Request) (users.Role, error) { rawLevel := r.Context().Value(authCtxAuthLevel) if rawLevel == nil { return users.RoleUnset, fmt.Errorf("no username") } level, ok := rawLevel.(users.Role) if !ok { return users.RoleUnset, fmt.Errorf("no username") } return level, nil } func ClaimsFromRequest(r *http.Request) *gpaste.Claims { rawClaims := r.Context().Value(authCtxAuthLevel) if rawClaims == nil { return nil } claims, ok := rawClaims.(*gpaste.Claims) if !ok { return nil } return claims }