diff --git a/api/middleware.go b/api/middleware.go index c5296b6..eb7f835 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -16,6 +16,7 @@ type authCtxKey int const ( authCtxUsername authCtxKey = iota authCtxAuthLevel + authCtxClaims ) func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler { @@ -66,7 +67,8 @@ func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler { } ctx := context.WithValue(r.Context(), authCtxUsername, claims.Subject) - ctx = context.WithValue(ctx, authCtxAuthLevel, gpaste.AuthLevelUser) + ctx = context.WithValue(ctx, authCtxAuthLevel, claims.Role) + ctx = context.WithValue(ctx, authCtxClaims, claims) withCtx := r.WithContext(ctx) s.Logger.Debugw("Request is authenticated.", "req_id", reqID, "username", claims.Subject) @@ -79,7 +81,6 @@ func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler { func UsernameFromRequest(r *http.Request) (string, error) { rawUsername := r.Context().Value(authCtxUsername) if rawUsername == nil { - return "", fmt.Errorf("no username") } username, ok := rawUsername.(string) @@ -100,3 +101,15 @@ func AuthLevelFromRequest(r *http.Request) (gpaste.AuthLevel, error) { } return level, nil } + +func ClaimsFromRequest(r *http.Request) *gpaste.Claims { + rawClaims := r.Context().Value(authCtxAuthLevel) + if rawClaims == nil { + return nil + } + claims, ok := rawClaims.(*gpaste.Claims) + if !ok { + return nil + } + return claims +} diff --git a/auth.go b/auth.go index 9e9748d..04aa080 100644 --- a/auth.go +++ b/auth.go @@ -22,6 +22,12 @@ type AuthService struct { hmacSecret []byte } +type Claims struct { + Role users.Role `json:"role,omitempty"` + + jwt.StandardClaims +} + func NewAuthService(store users.UserStore, signingSecret []byte) *AuthService { return &AuthService{users: store, hmacSecret: signingSecret} } @@ -37,13 +43,13 @@ func (as *AuthService) Login(username, password string) (string, error) { } // TODO: Set iss and aud - claims := jwt.StandardClaims{ - Subject: user.Username, - ExpiresAt: time.Now().Add(7 * 24 * time.Hour).Unix(), - NotBefore: time.Now().Unix(), - IssuedAt: time.Now().Unix(), - Id: uuid.NewString(), - } + claims := new(Claims) + claims.Subject = user.Username + claims.ExpiresAt = time.Now().Add(7 * 24 * time.Hour).Unix() + claims.NotBefore = time.Now().Unix() + claims.IssuedAt = time.Now().Unix() + claims.Id = uuid.NewString() + claims.Role = user.Role token := jwt.NewWithClaims(jwt.GetSigningMethod("HS256"), claims) signed, err := token.SignedString(as.hmacSecret) @@ -54,8 +60,8 @@ func (as *AuthService) Login(username, password string) (string, error) { return signed, nil } -func (as *AuthService) ValidateToken(rawToken string) (*jwt.StandardClaims, error) { - claims := &jwt.StandardClaims{} +func (as *AuthService) ValidateToken(rawToken string) (*Claims, error) { + claims := &Claims{} token, err := jwt.ParseWithClaims(rawToken, claims, func(t *jwt.Token) (interface{}, error) { return as.hmacSecret, nil }) diff --git a/auth_test.go b/auth_test.go index d4251f3..310e96b 100644 --- a/auth_test.go +++ b/auth_test.go @@ -6,6 +6,7 @@ import ( "git.t-juice.club/torjus/gpaste" "git.t-juice.club/torjus/gpaste/users" + "github.com/google/go-cmp/cmp" ) func TestAuth(t *testing.T) { @@ -17,7 +18,7 @@ func TestAuth(t *testing.T) { username := randomString(8) password := randomString(16) - user := &users.User{Username: username} + user := &users.User{Username: username, Role: users.RoleAdmin} if err := user.SetPassword(password); err != nil { t.Fatalf("error setting user password: %s", err) } @@ -30,9 +31,13 @@ func TestAuth(t *testing.T) { t.Fatalf("Error creating token: %s", err) } - if _, err := as.ValidateToken(token); err != nil { + claims, err := as.ValidateToken(token) + if err != nil { t.Fatalf("Error validating token: %s", err) } + if claims.Role != user.Role { + t.Fatalf("Token role is not correct: %s", cmp.Diff(claims.Role, user.Role)) + } invalidToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2NDMyMjk3NjMsImp0aSI6ImUzNDk5NWI1LThiZmMtNDQyNy1iZDgxLWFmNmQ3OTRiYzM0YiIsImlhdCI6MTY0MjYyNDk2MywibmJmIjoxNjQyNjI0OTYzLCJzdWIiOiJYdE5Hemt5ZSJ9.VM6dkwSLaBv8cStkWRVVv9ADjdUrHGHrlB7GB7Ly7n8` if _, err := as.ValidateToken(invalidToken); err == nil { t.Fatalf("Invalid token passed validation") diff --git a/users/user.go b/users/user.go index c25fedd..762264c 100644 --- a/users/user.go +++ b/users/user.go @@ -13,7 +13,7 @@ const ( type User struct { Username string `json:"username"` HashedPassword []byte `json:"hashed_password"` - Roles []Role `json:"roles"` + Role Role `json:"role"` } type UserStore interface { diff --git a/users/userstore_test.go b/users/userstore_test.go index 6146067..7808633 100644 --- a/users/userstore_test.go +++ b/users/userstore_test.go @@ -20,7 +20,7 @@ func RunUserStoreTest(newFunc func() (func(), users.UserStore), t *testing.T) { passwordMap[username] = password user := &users.User{ Username: username, - Roles: []users.Role{users.RoleAdmin}, + Role: users.RoleAdmin, } if err := user.SetPassword(password); err != nil { t.Fatalf("Error setting password: %s", err)