2022-01-20 02:44:33 +00:00
|
|
|
package api
|
2022-01-19 02:23:54 +00:00
|
|
|
|
|
|
|
import (
|
2022-01-20 00:04:44 +00:00
|
|
|
"context"
|
|
|
|
"fmt"
|
2022-01-19 02:23:54 +00:00
|
|
|
"net/http"
|
2022-01-20 00:04:44 +00:00
|
|
|
"strings"
|
2022-01-19 02:23:54 +00:00
|
|
|
"time"
|
|
|
|
|
2022-01-20 02:44:33 +00:00
|
|
|
"git.t-juice.club/torjus/gpaste"
|
2022-01-20 16:50:56 +00:00
|
|
|
"git.t-juice.club/torjus/gpaste/users"
|
2022-01-19 02:23:54 +00:00
|
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
|
|
)
|
|
|
|
|
2022-01-20 00:04:44 +00:00
|
|
|
type authCtxKey int
|
|
|
|
|
|
|
|
const (
|
|
|
|
authCtxUsername authCtxKey = iota
|
2022-01-20 00:11:40 +00:00
|
|
|
authCtxAuthLevel
|
2022-01-20 12:33:11 +00:00
|
|
|
authCtxClaims
|
2022-01-20 00:04:44 +00:00
|
|
|
)
|
|
|
|
|
2022-01-19 02:23:54 +00:00
|
|
|
func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
|
|
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
|
|
|
t1 := time.Now()
|
|
|
|
|
|
|
|
reqID := middleware.GetReqID(r.Context())
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
s.AccessLogger.Infow(r.Method,
|
|
|
|
"path", r.URL.Path,
|
|
|
|
"status", ww.Status(),
|
|
|
|
"written", ww.BytesWritten(),
|
|
|
|
"remote_addr", r.RemoteAddr,
|
|
|
|
"processing_time_ms", time.Since(t1).Milliseconds(),
|
|
|
|
"req_id", reqID)
|
|
|
|
}()
|
|
|
|
|
|
|
|
next.ServeHTTP(ww, r)
|
|
|
|
}
|
|
|
|
return http.HandlerFunc(fn)
|
|
|
|
}
|
2022-01-20 00:04:44 +00:00
|
|
|
|
|
|
|
func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler {
|
|
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
reqID := middleware.GetReqID(r.Context())
|
|
|
|
header := r.Header.Get("Authorization")
|
|
|
|
if header == "" {
|
|
|
|
s.Logger.Debugw("Request has no auth header.", "req_id", reqID)
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
splitHeader := strings.Split(header, "Bearer ")
|
|
|
|
if len(splitHeader) != 2 {
|
|
|
|
s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
token := splitHeader[1]
|
|
|
|
|
|
|
|
claims, err := s.Auth.ValidateToken(token)
|
|
|
|
if err != nil {
|
|
|
|
s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
|
|
|
|
next.ServeHTTP(w, r)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
ctx := context.WithValue(r.Context(), authCtxUsername, claims.Subject)
|
2022-01-20 12:33:11 +00:00
|
|
|
ctx = context.WithValue(ctx, authCtxAuthLevel, claims.Role)
|
|
|
|
ctx = context.WithValue(ctx, authCtxClaims, claims)
|
2022-01-20 00:04:44 +00:00
|
|
|
withCtx := r.WithContext(ctx)
|
2022-01-20 22:31:09 +00:00
|
|
|
s.Logger.Debugw("Request is authenticated.", "req_id", reqID, "username", claims.Subject, "role", claims.Role)
|
2022-01-20 00:04:44 +00:00
|
|
|
|
|
|
|
next.ServeHTTP(w, withCtx)
|
|
|
|
}
|
|
|
|
|
|
|
|
return http.HandlerFunc(fn)
|
|
|
|
}
|
|
|
|
|
|
|
|
func UsernameFromRequest(r *http.Request) (string, error) {
|
|
|
|
rawUsername := r.Context().Value(authCtxUsername)
|
|
|
|
if rawUsername == nil {
|
|
|
|
return "", fmt.Errorf("no username")
|
|
|
|
}
|
|
|
|
username, ok := rawUsername.(string)
|
|
|
|
if !ok {
|
|
|
|
return "", fmt.Errorf("no username")
|
|
|
|
}
|
|
|
|
return username, nil
|
|
|
|
}
|
2022-01-20 00:11:40 +00:00
|
|
|
|
2022-01-20 16:50:56 +00:00
|
|
|
func RoleFromRequest(r *http.Request) (users.Role, error) {
|
2022-01-20 00:11:40 +00:00
|
|
|
rawLevel := r.Context().Value(authCtxAuthLevel)
|
|
|
|
if rawLevel == nil {
|
2022-01-20 16:50:56 +00:00
|
|
|
return users.RoleUnset, fmt.Errorf("no username")
|
2022-01-20 00:11:40 +00:00
|
|
|
}
|
2022-01-20 16:50:56 +00:00
|
|
|
level, ok := rawLevel.(users.Role)
|
2022-01-20 00:11:40 +00:00
|
|
|
if !ok {
|
2022-01-20 16:50:56 +00:00
|
|
|
return users.RoleUnset, fmt.Errorf("no username")
|
2022-01-20 00:11:40 +00:00
|
|
|
}
|
|
|
|
return level, nil
|
|
|
|
}
|
2022-01-20 12:33:11 +00:00
|
|
|
|
|
|
|
func ClaimsFromRequest(r *http.Request) *gpaste.Claims {
|
|
|
|
rawClaims := r.Context().Value(authCtxAuthLevel)
|
|
|
|
if rawClaims == nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
claims, ok := rawClaims.(*gpaste.Claims)
|
|
|
|
if !ok {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return claims
|
|
|
|
}
|