Compare commits
	
		
			4 Commits
		
	
	
		
			88b5b941df
			...
			v0.3.3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| ce5584ba7e | |||
| 790cc43949 | |||
| a8a64d118c | |||
| fdf374d541 | 
							
								
								
									
										16
									
								
								auth.go
									
									
									
									
									
								
							
							
						
						
									
										16
									
								
								auth.go
									
									
									
									
									
								
							@@ -8,6 +8,14 @@ import (
 | 
				
			|||||||
	"github.com/google/uuid"
 | 
						"github.com/google/uuid"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type AuthLevel int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						AuthLevelUnset AuthLevel = iota
 | 
				
			||||||
 | 
						AuthLevelUser
 | 
				
			||||||
 | 
						AuthLevelAdmin
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type AuthService struct {
 | 
					type AuthService struct {
 | 
				
			||||||
	users      UserStore
 | 
						users      UserStore
 | 
				
			||||||
	hmacSecret []byte
 | 
						hmacSecret []byte
 | 
				
			||||||
@@ -45,17 +53,17 @@ func (as *AuthService) Login(username, password string) (string, error) {
 | 
				
			|||||||
	return signed, nil
 | 
						return signed, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (as *AuthService) ValidateToken(rawToken string) error {
 | 
					func (as *AuthService) ValidateToken(rawToken string) (*jwt.StandardClaims, error) {
 | 
				
			||||||
	claims := &jwt.StandardClaims{}
 | 
						claims := &jwt.StandardClaims{}
 | 
				
			||||||
	token, err := jwt.ParseWithClaims(rawToken, claims, func(t *jwt.Token) (interface{}, error) {
 | 
						token, err := jwt.ParseWithClaims(rawToken, claims, func(t *jwt.Token) (interface{}, error) {
 | 
				
			||||||
		return as.hmacSecret, nil
 | 
							return as.hmacSecret, nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if !token.Valid {
 | 
						if !token.Valid {
 | 
				
			||||||
		return fmt.Errorf("invalid token")
 | 
							return nil, fmt.Errorf("invalid token")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return claims, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,11 +28,11 @@ func TestAuth(t *testing.T) {
 | 
				
			|||||||
			t.Fatalf("Error creating token: %s", err)
 | 
								t.Fatalf("Error creating token: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if err := as.ValidateToken(token); err != nil {
 | 
							if _, err := as.ValidateToken(token); err != nil {
 | 
				
			||||||
			t.Fatalf("Error validating token: %s", err)
 | 
								t.Fatalf("Error validating token: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		invalidToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2NDMyMjk3NjMsImp0aSI6ImUzNDk5NWI1LThiZmMtNDQyNy1iZDgxLWFmNmQ3OTRiYzM0YiIsImlhdCI6MTY0MjYyNDk2MywibmJmIjoxNjQyNjI0OTYzLCJzdWIiOiJYdE5Hemt5ZSJ9.VM6dkwSLaBv8cStkWRVVv9ADjdUrHGHrlB7GB7Ly7n8`
 | 
							invalidToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2NDMyMjk3NjMsImp0aSI6ImUzNDk5NWI1LThiZmMtNDQyNy1iZDgxLWFmNmQ3OTRiYzM0YiIsImlhdCI6MTY0MjYyNDk2MywibmJmIjoxNjQyNjI0OTYzLCJzdWIiOiJYdE5Hemt5ZSJ9.VM6dkwSLaBv8cStkWRVVv9ADjdUrHGHrlB7GB7Ly7n8`
 | 
				
			||||||
		if err := as.ValidateToken(invalidToken); err == nil {
 | 
							if _, err := as.ValidateToken(invalidToken); err == nil {
 | 
				
			||||||
			t.Fatalf("Invalid token passed validation")
 | 
								t.Fatalf("Invalid token passed validation")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
LogLevel = "INFO"
 | 
					LogLevel = "DEBUG"
 | 
				
			||||||
URL = "http://paste.example.org"
 | 
					URL = "http://paste.example.org"
 | 
				
			||||||
ListenAddr = ":8080"
 | 
					ListenAddr = ":8080"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										1
									
								
								http.go
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								http.go
									
									
									
									
									
								
							@@ -42,6 +42,7 @@ func NewHTTPServer(cfg *ServerConfig) *HTTPServer {
 | 
				
			|||||||
	r.Use(middleware.RealIP)
 | 
						r.Use(middleware.RealIP)
 | 
				
			||||||
	r.Use(middleware.RequestID)
 | 
						r.Use(middleware.RequestID)
 | 
				
			||||||
	r.Use(srv.MiddlewareAccessLogger)
 | 
						r.Use(srv.MiddlewareAccessLogger)
 | 
				
			||||||
 | 
						r.Use(srv.MiddlewareAuthentication)
 | 
				
			||||||
	r.Get("/", srv.HandlerIndex)
 | 
						r.Get("/", srv.HandlerIndex)
 | 
				
			||||||
	r.Post("/api/file", srv.HandlerAPIFilePost)
 | 
						r.Post("/api/file", srv.HandlerAPIFilePost)
 | 
				
			||||||
	r.Get("/api/file/{id}", srv.HandlerAPIFileGet)
 | 
						r.Get("/api/file/{id}", srv.HandlerAPIFileGet)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -137,7 +137,7 @@ func TestHandlers(t *testing.T) {
 | 
				
			|||||||
			t.Fatalf("Error decoding response: %s", err)
 | 
								t.Fatalf("Error decoding response: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if err := hs.Auth.ValidateToken(responseData.Token); err != nil {
 | 
							if _, err := hs.Auth.ValidateToken(responseData.Token); err != nil {
 | 
				
			||||||
			t.Fatalf("Unable to validate received token: %s", err)
 | 
								t.Fatalf("Unable to validate received token: %s", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,22 @@
 | 
				
			|||||||
package gpaste
 | 
					package gpaste
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/go-chi/chi/v5/middleware"
 | 
						"github.com/go-chi/chi/v5/middleware"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type authCtxKey int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						authCtxUsername authCtxKey = iota
 | 
				
			||||||
 | 
						authCtxAuthLevel
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
 | 
					func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
 | 
				
			||||||
	fn := func(w http.ResponseWriter, r *http.Request) {
 | 
						fn := func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
		ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
 | 
							ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
 | 
				
			||||||
@@ -28,3 +38,64 @@ func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return http.HandlerFunc(fn)
 | 
						return http.HandlerFunc(fn)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler {
 | 
				
			||||||
 | 
						fn := func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							reqID := middleware.GetReqID(r.Context())
 | 
				
			||||||
 | 
							header := r.Header.Get("Authorization")
 | 
				
			||||||
 | 
							if header == "" {
 | 
				
			||||||
 | 
								s.Logger.Debugw("Request has no auth header.", "req_id", reqID)
 | 
				
			||||||
 | 
								next.ServeHTTP(w, r)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							splitHeader := strings.Split(header, "Bearer ")
 | 
				
			||||||
 | 
							if len(splitHeader) != 2 {
 | 
				
			||||||
 | 
								s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
 | 
				
			||||||
 | 
								next.ServeHTTP(w, r)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							token := splitHeader[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							claims, err := s.Auth.ValidateToken(token)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								s.Logger.Debugw("Request has invalid token.", "req_id", reqID)
 | 
				
			||||||
 | 
								next.ServeHTTP(w, r)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							ctx := context.WithValue(r.Context(), authCtxUsername, claims.Subject)
 | 
				
			||||||
 | 
							ctx = context.WithValue(ctx, authCtxAuthLevel, AuthLevelUser)
 | 
				
			||||||
 | 
							withCtx := r.WithContext(ctx)
 | 
				
			||||||
 | 
							s.Logger.Debugw("Request is authenticated.", "req_id", reqID, "username", claims.Subject)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							next.ServeHTTP(w, withCtx)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return http.HandlerFunc(fn)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func UsernameFromRequest(r *http.Request) (string, error) {
 | 
				
			||||||
 | 
						rawUsername := r.Context().Value(authCtxUsername)
 | 
				
			||||||
 | 
						if rawUsername == nil {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return "", fmt.Errorf("no username")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						username, ok := rawUsername.(string)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("no username")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return username, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func AuthLevelFromRequest(r *http.Request) (AuthLevel, error) {
 | 
				
			||||||
 | 
						rawLevel := r.Context().Value(authCtxAuthLevel)
 | 
				
			||||||
 | 
						if rawLevel == nil {
 | 
				
			||||||
 | 
							return AuthLevelUnset, fmt.Errorf("no username")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						level, ok := rawLevel.(AuthLevel)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return AuthLevelUnset, fmt.Errorf("no username")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return level, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										9
									
								
								user.go
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								user.go
									
									
									
									
									
								
							@@ -2,9 +2,18 @@ package gpaste
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import "golang.org/x/crypto/bcrypt"
 | 
					import "golang.org/x/crypto/bcrypt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Role string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						RoleUnset Role = ""
 | 
				
			||||||
 | 
						RoleUser  Role = "user"
 | 
				
			||||||
 | 
						RoleAdmin Role = "admin"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type User struct {
 | 
					type User struct {
 | 
				
			||||||
	Username       string `json:"username"`
 | 
						Username       string `json:"username"`
 | 
				
			||||||
	HashedPassword []byte `json:"hashed_password"`
 | 
						HashedPassword []byte `json:"hashed_password"`
 | 
				
			||||||
 | 
						Roles          []Role `json:"roles"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type UserStore interface {
 | 
					type UserStore interface {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@ import (
 | 
				
			|||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.t-juice.club/torjus/gpaste"
 | 
						"git.t-juice.club/torjus/gpaste"
 | 
				
			||||||
 | 
						"github.com/google/go-cmp/cmp"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) {
 | 
					func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) {
 | 
				
			||||||
@@ -11,31 +12,40 @@ func RunUserStoreTest(newFunc func() (func(), gpaste.UserStore), t *testing.T) {
 | 
				
			|||||||
		cleanup, s := newFunc()
 | 
							cleanup, s := newFunc()
 | 
				
			||||||
		t.Cleanup(cleanup)
 | 
							t.Cleanup(cleanup)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		userMap := make(map[string]string)
 | 
							userMap := make(map[string]*gpaste.User)
 | 
				
			||||||
 | 
							passwordMap := make(map[string]string)
 | 
				
			||||||
		for i := 0; i < 10; i++ {
 | 
							for i := 0; i < 10; i++ {
 | 
				
			||||||
			userMap[randomString(8)] = randomString(16)
 | 
								username := randomString(8)
 | 
				
			||||||
		}
 | 
								password := randomString(16)
 | 
				
			||||||
 | 
								passwordMap[username] = password
 | 
				
			||||||
		for k, v := range userMap {
 | 
					 | 
				
			||||||
			user := &gpaste.User{
 | 
								user := &gpaste.User{
 | 
				
			||||||
				Username: k,
 | 
									Username: username,
 | 
				
			||||||
 | 
									Roles:    []gpaste.Role{gpaste.RoleAdmin},
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if err := user.SetPassword(v); err != nil {
 | 
								if err := user.SetPassword(password); err != nil {
 | 
				
			||||||
				t.Fatalf("Error setting password: %s", err)
 | 
									t.Fatalf("Error setting password: %s", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								userMap[username] = user
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							for _, user := range userMap {
 | 
				
			||||||
			if err := s.Store(user); err != nil {
 | 
								if err := s.Store(user); err != nil {
 | 
				
			||||||
				t.Fatalf("Error storing user: %s", err)
 | 
									t.Fatalf("Error storing user: %s", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for k, v := range userMap {
 | 
							for k := range userMap {
 | 
				
			||||||
			user, err := s.Get(k)
 | 
								user, err := s.Get(k)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				t.Errorf("Error getting user: %s", err)
 | 
									t.Errorf("Error getting user: %s", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if err := user.ValidatePassword(v); err != nil {
 | 
								if err := user.ValidatePassword(passwordMap[user.Username]); err != nil {
 | 
				
			||||||
				t.Errorf("Error verifying password: %s", err)
 | 
									t.Errorf("Error verifying password: %s", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if !cmp.Equal(user, userMap[k]) {
 | 
				
			||||||
 | 
									t.Errorf("User mismatch: %s", cmp.Diff(user, userMap[k]))
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user