gpaste/api/middleware.go

155 lines
3.5 KiB
Go
Raw Normal View History

2022-01-20 02:44:33 +00:00
package api
2022-01-19 02:23:54 +00:00
import (
2022-01-20 00:04:44 +00:00
"context"
"fmt"
2022-01-19 02:23:54 +00:00
"net/http"
2022-01-20 00:04:44 +00:00
"strings"
2022-01-19 02:23:54 +00:00
"time"
2022-01-20 02:44:33 +00:00
"git.t-juice.club/torjus/gpaste"
2022-01-20 16:50:56 +00:00
"git.t-juice.club/torjus/gpaste/users"
2022-01-19 02:23:54 +00:00
"github.com/go-chi/chi/v5/middleware"
2022-01-24 21:53:46 +00:00
"go.uber.org/zap"
2022-01-19 02:23:54 +00:00
)
2022-01-20 00:04:44 +00:00
type authCtxKey int
const (
authCtxUsername authCtxKey = iota
2022-01-20 00:11:40 +00:00
authCtxAuthLevel
2022-01-20 12:33:11 +00:00
authCtxClaims
2022-01-20 00:04:44 +00:00
)
2022-01-19 02:23:54 +00:00
func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
t1 := time.Now()
reqID := middleware.GetReqID(r.Context())
2022-01-24 21:53:46 +00:00
// TODO: Maybe desugar in HTTPServer to avoid doing for all requests
logger := s.AccessLogger.Desugar()
2022-01-19 02:23:54 +00:00
defer func() {
2022-01-24 21:53:46 +00:00
// DEBUG level
if ce := logger.Check(zap.DebugLevel, r.Method); ce != nil {
ct := r.Header.Get("Content-Type")
ce.Write(
zap.String("req_id", reqID),
zap.String("path", r.URL.Path),
zap.Int("status", ww.Status()),
zap.String("remote_addr", r.RemoteAddr),
zap.Int("bytes_written", ww.BytesWritten()),
zap.Duration("processing_time", time.Since(t1)),
zap.String("content_type", ct),
zap.Any("headers", r.Header),
)
} else {
// INFO level
if ce := logger.Check(zap.InfoLevel, r.Method); ce != nil {
ce.Write(
zap.String("req_id", reqID),
zap.String("path", r.URL.Path),
zap.Int("status", ww.Status()),
zap.String("remote_addr", r.RemoteAddr),
zap.Int("bytes_written", ww.BytesWritten()),
zap.Duration("processing_time", time.Since(t1)),
)
}
}
_ = logger.Sync()
2022-01-19 02:23:54 +00:00
}()
next.ServeHTTP(ww, r)
}
2022-01-24 19:25:52 +00:00
2022-01-19 02:23:54 +00:00
return http.HandlerFunc(fn)
}
2022-01-20 00:04:44 +00:00
func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
reqID := middleware.GetReqID(r.Context())
2022-01-24 19:25:52 +00:00
2022-01-20 00:04:44 +00:00
header := r.Header.Get("Authorization")
if header == "" {
s.Logger.Debugw("Request has no auth header.", "req_id", reqID)
next.ServeHTTP(w, r)
2022-01-24 19:25:52 +00:00
2022-01-20 00:04:44 +00:00
return
}
splitHeader := strings.Split(header, "Bearer ")
2022-01-24 19:25:52 +00:00
if len(splitHeader) != 2 { // nolint: gomnd
2022-01-20 00:04:44 +00:00
s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
next.ServeHTTP(w, r)
2022-01-24 19:25:52 +00:00
2022-01-20 00:04:44 +00:00
return
}
2022-01-24 19:25:52 +00:00
2022-01-20 00:04:44 +00:00
token := splitHeader[1]
claims, err := s.Auth.ValidateToken(token)
if err != nil {
s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
next.ServeHTTP(w, r)
2022-01-24 19:25:52 +00:00
2022-01-20 00:04:44 +00:00
return
}
ctx := context.WithValue(r.Context(), authCtxUsername, claims.Subject)
2022-01-20 12:33:11 +00:00
ctx = context.WithValue(ctx, authCtxAuthLevel, claims.Role)
ctx = context.WithValue(ctx, authCtxClaims, claims)
2022-01-20 00:04:44 +00:00
withCtx := r.WithContext(ctx)
2022-01-24 19:25:52 +00:00
2022-01-20 22:31:09 +00:00
s.Logger.Debugw("Request is authenticated.", "req_id", reqID, "username", claims.Subject, "role", claims.Role)
2022-01-20 00:04:44 +00:00
next.ServeHTTP(w, withCtx)
}
return http.HandlerFunc(fn)
}
func UsernameFromRequest(r *http.Request) (string, error) {
rawUsername := r.Context().Value(authCtxUsername)
if rawUsername == nil {
return "", fmt.Errorf("no username")
}
2022-01-24 19:25:52 +00:00
2022-01-20 00:04:44 +00:00
username, ok := rawUsername.(string)
if !ok {
return "", fmt.Errorf("no username")
}
2022-01-24 19:25:52 +00:00
2022-01-20 00:04:44 +00:00
return username, nil
}
2022-01-20 00:11:40 +00:00
2022-01-20 16:50:56 +00:00
func RoleFromRequest(r *http.Request) (users.Role, error) {
2022-01-20 00:11:40 +00:00
rawLevel := r.Context().Value(authCtxAuthLevel)
if rawLevel == nil {
2022-01-20 16:50:56 +00:00
return users.RoleUnset, fmt.Errorf("no username")
2022-01-20 00:11:40 +00:00
}
2022-01-24 19:25:52 +00:00
2022-01-20 16:50:56 +00:00
level, ok := rawLevel.(users.Role)
2022-01-20 00:11:40 +00:00
if !ok {
2022-01-20 16:50:56 +00:00
return users.RoleUnset, fmt.Errorf("no username")
2022-01-20 00:11:40 +00:00
}
2022-01-24 19:25:52 +00:00
2022-01-20 00:11:40 +00:00
return level, nil
}
2022-01-20 12:33:11 +00:00
func ClaimsFromRequest(r *http.Request) *gpaste.Claims {
rawClaims := r.Context().Value(authCtxAuthLevel)
if rawClaims == nil {
return nil
}
2022-01-24 19:25:52 +00:00
2022-01-20 12:33:11 +00:00
claims, ok := rawClaims.(*gpaste.Claims)
if !ok {
return nil
}
2022-01-24 19:25:52 +00:00
2022-01-20 12:33:11 +00:00
return claims
}