Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
ce5584ba7e | |||
790cc43949 | |||
a8a64d118c |
16
auth.go
16
auth.go
@@ -8,6 +8,14 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type AuthLevel int
|
||||||
|
|
||||||
|
const (
|
||||||
|
AuthLevelUnset AuthLevel = iota
|
||||||
|
AuthLevelUser
|
||||||
|
AuthLevelAdmin
|
||||||
|
)
|
||||||
|
|
||||||
type AuthService struct {
|
type AuthService struct {
|
||||||
users UserStore
|
users UserStore
|
||||||
hmacSecret []byte
|
hmacSecret []byte
|
||||||
@@ -45,17 +53,17 @@ func (as *AuthService) Login(username, password string) (string, error) {
|
|||||||
return signed, nil
|
return signed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (as *AuthService) ValidateToken(rawToken string) error {
|
func (as *AuthService) ValidateToken(rawToken string) (*jwt.StandardClaims, error) {
|
||||||
claims := &jwt.StandardClaims{}
|
claims := &jwt.StandardClaims{}
|
||||||
token, err := jwt.ParseWithClaims(rawToken, claims, func(t *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(rawToken, claims, func(t *jwt.Token) (interface{}, error) {
|
||||||
return as.hmacSecret, nil
|
return as.hmacSecret, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !token.Valid {
|
if !token.Valid {
|
||||||
return fmt.Errorf("invalid token")
|
return nil, fmt.Errorf("invalid token")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
@@ -28,11 +28,11 @@ func TestAuth(t *testing.T) {
|
|||||||
t.Fatalf("Error creating token: %s", err)
|
t.Fatalf("Error creating token: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := as.ValidateToken(token); err != nil {
|
if _, err := as.ValidateToken(token); err != nil {
|
||||||
t.Fatalf("Error validating token: %s", err)
|
t.Fatalf("Error validating token: %s", err)
|
||||||
}
|
}
|
||||||
invalidToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2NDMyMjk3NjMsImp0aSI6ImUzNDk5NWI1LThiZmMtNDQyNy1iZDgxLWFmNmQ3OTRiYzM0YiIsImlhdCI6MTY0MjYyNDk2MywibmJmIjoxNjQyNjI0OTYzLCJzdWIiOiJYdE5Hemt5ZSJ9.VM6dkwSLaBv8cStkWRVVv9ADjdUrHGHrlB7GB7Ly7n8`
|
invalidToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2NDMyMjk3NjMsImp0aSI6ImUzNDk5NWI1LThiZmMtNDQyNy1iZDgxLWFmNmQ3OTRiYzM0YiIsImlhdCI6MTY0MjYyNDk2MywibmJmIjoxNjQyNjI0OTYzLCJzdWIiOiJYdE5Hemt5ZSJ9.VM6dkwSLaBv8cStkWRVVv9ADjdUrHGHrlB7GB7Ly7n8`
|
||||||
if err := as.ValidateToken(invalidToken); err == nil {
|
if _, err := as.ValidateToken(invalidToken); err == nil {
|
||||||
t.Fatalf("Invalid token passed validation")
|
t.Fatalf("Invalid token passed validation")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
LogLevel = "INFO"
|
LogLevel = "DEBUG"
|
||||||
URL = "http://paste.example.org"
|
URL = "http://paste.example.org"
|
||||||
ListenAddr = ":8080"
|
ListenAddr = ":8080"
|
||||||
|
|
||||||
|
1
http.go
1
http.go
@@ -42,6 +42,7 @@ func NewHTTPServer(cfg *ServerConfig) *HTTPServer {
|
|||||||
r.Use(middleware.RealIP)
|
r.Use(middleware.RealIP)
|
||||||
r.Use(middleware.RequestID)
|
r.Use(middleware.RequestID)
|
||||||
r.Use(srv.MiddlewareAccessLogger)
|
r.Use(srv.MiddlewareAccessLogger)
|
||||||
|
r.Use(srv.MiddlewareAuthentication)
|
||||||
r.Get("/", srv.HandlerIndex)
|
r.Get("/", srv.HandlerIndex)
|
||||||
r.Post("/api/file", srv.HandlerAPIFilePost)
|
r.Post("/api/file", srv.HandlerAPIFilePost)
|
||||||
r.Get("/api/file/{id}", srv.HandlerAPIFileGet)
|
r.Get("/api/file/{id}", srv.HandlerAPIFileGet)
|
||||||
|
@@ -137,7 +137,7 @@ func TestHandlers(t *testing.T) {
|
|||||||
t.Fatalf("Error decoding response: %s", err)
|
t.Fatalf("Error decoding response: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := hs.Auth.ValidateToken(responseData.Token); err != nil {
|
if _, err := hs.Auth.ValidateToken(responseData.Token); err != nil {
|
||||||
t.Fatalf("Unable to validate received token: %s", err)
|
t.Fatalf("Unable to validate received token: %s", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@@ -1,12 +1,22 @@
|
|||||||
package gpaste
|
package gpaste
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type authCtxKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
authCtxUsername authCtxKey = iota
|
||||||
|
authCtxAuthLevel
|
||||||
|
)
|
||||||
|
|
||||||
func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
|
func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
|
||||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
||||||
@@ -28,3 +38,64 @@ func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
return http.HandlerFunc(fn)
|
return http.HandlerFunc(fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler {
|
||||||
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
reqID := middleware.GetReqID(r.Context())
|
||||||
|
header := r.Header.Get("Authorization")
|
||||||
|
if header == "" {
|
||||||
|
s.Logger.Debugw("Request has no auth header.", "req_id", reqID)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
splitHeader := strings.Split(header, "Bearer ")
|
||||||
|
if len(splitHeader) != 2 {
|
||||||
|
s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := splitHeader[1]
|
||||||
|
|
||||||
|
claims, err := s.Auth.ValidateToken(token)
|
||||||
|
if err != nil {
|
||||||
|
s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), authCtxUsername, claims.Subject)
|
||||||
|
ctx = context.WithValue(ctx, authCtxAuthLevel, AuthLevelUser)
|
||||||
|
withCtx := r.WithContext(ctx)
|
||||||
|
s.Logger.Debugw("Request is authenticated.", "req_id", reqID, "username", claims.Subject)
|
||||||
|
|
||||||
|
next.ServeHTTP(w, withCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UsernameFromRequest(r *http.Request) (string, error) {
|
||||||
|
rawUsername := r.Context().Value(authCtxUsername)
|
||||||
|
if rawUsername == nil {
|
||||||
|
|
||||||
|
return "", fmt.Errorf("no username")
|
||||||
|
}
|
||||||
|
username, ok := rawUsername.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("no username")
|
||||||
|
}
|
||||||
|
return username, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuthLevelFromRequest(r *http.Request) (AuthLevel, error) {
|
||||||
|
rawLevel := r.Context().Value(authCtxAuthLevel)
|
||||||
|
if rawLevel == nil {
|
||||||
|
return AuthLevelUnset, fmt.Errorf("no username")
|
||||||
|
}
|
||||||
|
level, ok := rawLevel.(AuthLevel)
|
||||||
|
if !ok {
|
||||||
|
return AuthLevelUnset, fmt.Errorf("no username")
|
||||||
|
}
|
||||||
|
return level, nil
|
||||||
|
}
|
||||||
|
9
user.go
9
user.go
@@ -2,9 +2,18 @@ package gpaste
|
|||||||
|
|
||||||
import "golang.org/x/crypto/bcrypt"
|
import "golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
type Role string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RoleUnset Role = ""
|
||||||
|
RoleUser Role = "user"
|
||||||
|
RoleAdmin Role = "admin"
|
||||||
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
HashedPassword []byte `json:"hashed_password"`
|
HashedPassword []byte `json:"hashed_password"`
|
||||||
|
Roles []Role `json:"roles"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserStore interface {
|
type UserStore interface {
|
||||||
|
@@ -4,6 +4,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.t-juice.club/torjus/gpaste"
|
"git.t-juice.club/torjus/gpaste"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) {
|
func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) {
|
||||||
@@ -11,31 +12,40 @@ func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) {
|
|||||||
cleanup, s := newFunc()
|
cleanup, s := newFunc()
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
userMap := make(map[string]string)
|
userMap := make(map[string]*gpaste.User)
|
||||||
|
passwordMap := make(map[string]string)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
userMap[randomString(8)] = randomString(16)
|
username := randomString(8)
|
||||||
}
|
password := randomString(16)
|
||||||
|
passwordMap[username] = password
|
||||||
for k, v := range userMap {
|
|
||||||
user := &gpaste.User{
|
user := &gpaste.User{
|
||||||
Username: k,
|
Username: username,
|
||||||
|
Roles: []gpaste.Role{gpaste.RoleAdmin},
|
||||||
}
|
}
|
||||||
if err := user.SetPassword(v); err != nil {
|
if err := user.SetPassword(password); err != nil {
|
||||||
t.Fatalf("Error setting password: %s", err)
|
t.Fatalf("Error setting password: %s", err)
|
||||||
}
|
}
|
||||||
|
userMap[username] = user
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, user := range userMap {
|
||||||
if err := s.Store(user); err != nil {
|
if err := s.Store(user); err != nil {
|
||||||
t.Fatalf("Error storing user: %s", err)
|
t.Fatalf("Error storing user: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range userMap {
|
for k := range userMap {
|
||||||
user, err := s.Get(k)
|
user, err := s.Get(k)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error getting user: %s", err)
|
t.Errorf("Error getting user: %s", err)
|
||||||
}
|
}
|
||||||
if err := user.ValidatePassword(v); err != nil {
|
if err := user.ValidatePassword(passwordMap[user.Username]); err != nil {
|
||||||
t.Errorf("Error verifying password: %s", err)
|
t.Errorf("Error verifying password: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !cmp.Equal(user, userMap[k]) {
|
||||||
|
t.Errorf("User mismatch: %s", cmp.Diff(user, userMap[k]))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user