auth/authmw/token.go

93 lines
2.3 KiB
Go
Raw Normal View History

package authmw
import (
"context"
"crypto/x509"
"encoding/json"
"fmt"
"net/http"
"slices"
"strings"
"git.t-juice.club/microfilm/auth"
"github.com/golang-jwt/jwt/v5"
)
func VerifyToken(authURL string, permittedRoles []string) func(http.Handler) http.Handler {
// Fetch current pubkey
url := fmt.Sprintf("%s/key", authURL)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
panic(err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
panic(err)
}
defer resp.Body.Close()
var authResponse auth.PubkeyResponse
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&authResponse); err != nil {
panic(err)
}
// Parse pubkey
pub, err := x509.ParsePKIXPublicKey(authResponse.PubKey)
if err != nil {
panic(err)
}
fn := func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if !strings.Contains(authHeader, "Bearer ") {
// No token, pass if unathorized in permitted
// else reject
if slices.Contains[[]string, string](permittedRoles, auth.RoleUnauthorized) {
next.ServeHTTP(w, r)
return
}
// Reject and write error response
w.WriteHeader(http.StatusUnauthorized)
var errResp auth.ErrorResponse
errResp.Message = fmt.Sprintf("Authorization required: %s", strings.Join(permittedRoles, ","))
errResp.Status = http.StatusUnauthorized
encoder := json.NewEncoder(w)
_ = encoder.Encode(&errResp)
return
}
// Validate token
tokenString := strings.Split(authHeader, " ")[1]
token, err := jwt.ParseWithClaims(tokenString, &auth.MicrofilmClaims{}, func(t *jwt.Token) (interface{}, error) { return pub, nil })
if err != nil {
// Reject and write error response
w.WriteHeader(http.StatusUnauthorized)
var errResp auth.ErrorResponse
errResp.Message = fmt.Sprintf("Token verification failed: %s", err)
errResp.Status = http.StatusUnauthorized
encoder := json.NewEncoder(w)
_ = encoder.Encode(&errResp)
return
}
// Add claims to request context
if claims, ok := token.Claims.(*auth.MicrofilmClaims); ok && token.Valid {
ctx := context.WithValue(r.Context(), "claims", claims)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
return fn
}