3 Commits

Author SHA1 Message Date
ce5584ba7e Add role to users
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
2022-01-20 01:19:32 +01:00
790cc43949 Add authlevel to middleware
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
2022-01-20 01:11:40 +01:00
a8a64d118c Add auth middleware
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
2022-01-20 01:04:44 +01:00
8 changed files with 116 additions and 17 deletions

16
auth.go
View File

@@ -8,6 +8,14 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
type AuthLevel int
const (
AuthLevelUnset AuthLevel = iota
AuthLevelUser
AuthLevelAdmin
)
type AuthService struct { type AuthService struct {
users UserStore users UserStore
hmacSecret []byte hmacSecret []byte
@@ -45,17 +53,17 @@ func (as *AuthService) Login(username, password string) (string, error) {
return signed, nil return signed, nil
} }
func (as *AuthService) ValidateToken(rawToken string) error { func (as *AuthService) ValidateToken(rawToken string) (*jwt.StandardClaims, error) {
claims := &jwt.StandardClaims{} claims := &jwt.StandardClaims{}
token, err := jwt.ParseWithClaims(rawToken, claims, func(t *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(rawToken, claims, func(t *jwt.Token) (interface{}, error) {
return as.hmacSecret, nil return as.hmacSecret, nil
}) })
if err != nil { if err != nil {
return err return nil, err
} }
if !token.Valid { if !token.Valid {
return fmt.Errorf("invalid token") return nil, fmt.Errorf("invalid token")
} }
return nil return claims, nil
} }

View File

@@ -28,11 +28,11 @@ func TestAuth(t *testing.T) {
t.Fatalf("Error creating token: %s", err) t.Fatalf("Error creating token: %s", err)
} }
if err := as.ValidateToken(token); err != nil { if _, err := as.ValidateToken(token); err != nil {
t.Fatalf("Error validating token: %s", err) t.Fatalf("Error validating token: %s", err)
} }
invalidToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2NDMyMjk3NjMsImp0aSI6ImUzNDk5NWI1LThiZmMtNDQyNy1iZDgxLWFmNmQ3OTRiYzM0YiIsImlhdCI6MTY0MjYyNDk2MywibmJmIjoxNjQyNjI0OTYzLCJzdWIiOiJYdE5Hemt5ZSJ9.VM6dkwSLaBv8cStkWRVVv9ADjdUrHGHrlB7GB7Ly7n8` invalidToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2NDMyMjk3NjMsImp0aSI6ImUzNDk5NWI1LThiZmMtNDQyNy1iZDgxLWFmNmQ3OTRiYzM0YiIsImlhdCI6MTY0MjYyNDk2MywibmJmIjoxNjQyNjI0OTYzLCJzdWIiOiJYdE5Hemt5ZSJ9.VM6dkwSLaBv8cStkWRVVv9ADjdUrHGHrlB7GB7Ly7n8`
if err := as.ValidateToken(invalidToken); err == nil { if _, err := as.ValidateToken(invalidToken); err == nil {
t.Fatalf("Invalid token passed validation") t.Fatalf("Invalid token passed validation")
} }
}) })

View File

@@ -1,4 +1,4 @@
LogLevel = "INFO" LogLevel = "DEBUG"
URL = "http://paste.example.org" URL = "http://paste.example.org"
ListenAddr = ":8080" ListenAddr = ":8080"

View File

@@ -42,6 +42,7 @@ func NewHTTPServer(cfg *ServerConfig) *HTTPServer {
r.Use(middleware.RealIP) r.Use(middleware.RealIP)
r.Use(middleware.RequestID) r.Use(middleware.RequestID)
r.Use(srv.MiddlewareAccessLogger) r.Use(srv.MiddlewareAccessLogger)
r.Use(srv.MiddlewareAuthentication)
r.Get("/", srv.HandlerIndex) r.Get("/", srv.HandlerIndex)
r.Post("/api/file", srv.HandlerAPIFilePost) r.Post("/api/file", srv.HandlerAPIFilePost)
r.Get("/api/file/{id}", srv.HandlerAPIFileGet) r.Get("/api/file/{id}", srv.HandlerAPIFileGet)

View File

@@ -137,7 +137,7 @@ func TestHandlers(t *testing.T) {
t.Fatalf("Error decoding response: %s", err) t.Fatalf("Error decoding response: %s", err)
} }
if err := hs.Auth.ValidateToken(responseData.Token); err != nil { if _, err := hs.Auth.ValidateToken(responseData.Token); err != nil {
t.Fatalf("Unable to validate received token: %s", err) t.Fatalf("Unable to validate received token: %s", err)
} }
}) })

View File

@@ -1,12 +1,22 @@
package gpaste package gpaste
import ( import (
"context"
"fmt"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
) )
type authCtxKey int
const (
authCtxUsername authCtxKey = iota
authCtxAuthLevel
)
func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler { func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) { fn := func(w http.ResponseWriter, r *http.Request) {
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
@@ -28,3 +38,64 @@ func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
} }
return http.HandlerFunc(fn) return http.HandlerFunc(fn)
} }
func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
reqID := middleware.GetReqID(r.Context())
header := r.Header.Get("Authorization")
if header == "" {
s.Logger.Debugw("Request has no auth header.", "req_id", reqID)
next.ServeHTTP(w, r)
return
}
splitHeader := strings.Split(header, "Bearer ")
if len(splitHeader) != 2 {
s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
next.ServeHTTP(w, r)
return
}
token := splitHeader[1]
claims, err := s.Auth.ValidateToken(token)
if err != nil {
s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), authCtxUsername, claims.Subject)
ctx = context.WithValue(ctx, authCtxAuthLevel, AuthLevelUser)
withCtx := r.WithContext(ctx)
s.Logger.Debugw("Request is authenticated.", "req_id", reqID, "username", claims.Subject)
next.ServeHTTP(w, withCtx)
}
return http.HandlerFunc(fn)
}
func UsernameFromRequest(r *http.Request) (string, error) {
rawUsername := r.Context().Value(authCtxUsername)
if rawUsername == nil {
return "", fmt.Errorf("no username")
}
username, ok := rawUsername.(string)
if !ok {
return "", fmt.Errorf("no username")
}
return username, nil
}
func AuthLevelFromRequest(r *http.Request) (AuthLevel, error) {
rawLevel := r.Context().Value(authCtxAuthLevel)
if rawLevel == nil {
return AuthLevelUnset, fmt.Errorf("no username")
}
level, ok := rawLevel.(AuthLevel)
if !ok {
return AuthLevelUnset, fmt.Errorf("no username")
}
return level, nil
}

View File

@@ -2,9 +2,18 @@ package gpaste
import "golang.org/x/crypto/bcrypt" import "golang.org/x/crypto/bcrypt"
type Role string
const (
RoleUnset Role = ""
RoleUser Role = "user"
RoleAdmin Role = "admin"
)
type User struct { type User struct {
Username string `json:"username"` Username string `json:"username"`
HashedPassword []byte `json:"hashed_password"` HashedPassword []byte `json:"hashed_password"`
Roles []Role `json:"roles"`
} }
type UserStore interface { type UserStore interface {

View File

@@ -4,6 +4,7 @@ import (
"testing" "testing"
"git.t-juice.club/torjus/gpaste" "git.t-juice.club/torjus/gpaste"
"github.com/google/go-cmp/cmp"
) )
func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) { func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) {
@@ -11,31 +12,40 @@ func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) {
cleanup, s := newFunc() cleanup, s := newFunc()
t.Cleanup(cleanup) t.Cleanup(cleanup)
userMap := make(map[string]string) userMap := make(map[string]*gpaste.User)
passwordMap := make(map[string]string)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
userMap[randomString(8)] = randomString(16) username := randomString(8)
} password := randomString(16)
passwordMap[username] = password
for k, v := range userMap {
user := &gpaste.User{ user := &gpaste.User{
Username: k, Username: username,
Roles: []gpaste.Role{gpaste.RoleAdmin},
} }
if err := user.SetPassword(v); err != nil { if err := user.SetPassword(password); err != nil {
t.Fatalf("Error setting password: %s", err) t.Fatalf("Error setting password: %s", err)
} }
userMap[username] = user
}
for _, user := range userMap {
if err := s.Store(user); err != nil { if err := s.Store(user); err != nil {
t.Fatalf("Error storing user: %s", err) t.Fatalf("Error storing user: %s", err)
} }
} }
for k, v := range userMap { for k := range userMap {
user, err := s.Get(k) user, err := s.Get(k)
if err != nil { if err != nil {
t.Errorf("Error getting user: %s", err) t.Errorf("Error getting user: %s", err)
} }
if err := user.ValidatePassword(v); err != nil { if err := user.ValidatePassword(passwordMap[user.Username]); err != nil {
t.Errorf("Error verifying password: %s", err) t.Errorf("Error verifying password: %s", err)
} }
if !cmp.Equal(user, userMap[k]) {
t.Errorf("User mismatch: %s", cmp.Diff(user, userMap[k]))
}
} }
}) })
} }