Improve authmw
This commit is contained in:
@@ -16,6 +16,12 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
)
|
||||
|
||||
type ctxType string
|
||||
|
||||
var ctxKeyClaims ctxType = "claims"
|
||||
|
||||
var ErrNoClaimsInRequest = fmt.Errorf("no claims in request")
|
||||
|
||||
func VerifyToken(authURL string, permittedRoles []string) func(http.Handler) http.Handler {
|
||||
fn := func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -128,7 +134,7 @@ func VerifyToken(authURL string, permittedRoles []string) func(http.Handler) htt
|
||||
|
||||
// Add claims to request context
|
||||
if claims, ok := token.Claims.(*auth.MicrofilmClaims); ok && token.Valid {
|
||||
ctx := context.WithValue(r.Context(), "claims", claims)
|
||||
ctx := context.WithValue(r.Context(), ctxKeyClaims, claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
@@ -140,3 +146,13 @@ func VerifyToken(authURL string, permittedRoles []string) func(http.Handler) htt
|
||||
|
||||
return fn
|
||||
}
|
||||
|
||||
func ClaimsFromCtx(ctx context.Context) (*auth.MicrofilmClaims, error) {
|
||||
rawValue := ctx.Value(ctxKeyClaims)
|
||||
value, ok := rawValue.(*auth.MicrofilmClaims)
|
||||
if ok {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return nil, ErrNoClaimsInRequest
|
||||
}
|
||||
|
31
authmw/token_test.go
Normal file
31
authmw/token_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package authmw
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"git.t-juice.club/microfilm/auth"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestClaimsFromContext(t *testing.T) {
|
||||
claims := &auth.MicrofilmClaims{
|
||||
Role: "admin",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "test",
|
||||
Subject: "subject",
|
||||
},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), ctxKeyClaims, claims)
|
||||
|
||||
retrieved, err := ClaimsFromCtx(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to retrieve claims")
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(claims, retrieved); diff != "" {
|
||||
t.Fatalf("Claims diff: %s", diff)
|
||||
}
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user