diff --git a/auth.go b/auth.go index cfd753f..e655bfd 100644 --- a/auth.go +++ b/auth.go @@ -45,17 +45,17 @@ func (as *AuthService) Login(username, password string) (string, error) { return signed, nil } -func (as *AuthService) ValidateToken(rawToken string) error { +func (as *AuthService) ValidateToken(rawToken string) (*jwt.StandardClaims, error) { claims := &jwt.StandardClaims{} token, err := jwt.ParseWithClaims(rawToken, claims, func(t *jwt.Token) (interface{}, error) { return as.hmacSecret, nil }) if err != nil { - return err + return nil, err } if !token.Valid { - return fmt.Errorf("invalid token") + return nil, fmt.Errorf("invalid token") } - return nil + return claims, nil } diff --git a/auth_test.go b/auth_test.go index 0c98f7c..7685ce3 100644 --- a/auth_test.go +++ b/auth_test.go @@ -28,11 +28,11 @@ func TestAuth(t *testing.T) { t.Fatalf("Error creating token: %s", err) } - if err := as.ValidateToken(token); err != nil { + if _, err := as.ValidateToken(token); err != nil { t.Fatalf("Error validating token: %s", err) } invalidToken := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2NDMyMjk3NjMsImp0aSI6ImUzNDk5NWI1LThiZmMtNDQyNy1iZDgxLWFmNmQ3OTRiYzM0YiIsImlhdCI6MTY0MjYyNDk2MywibmJmIjoxNjQyNjI0OTYzLCJzdWIiOiJYdE5Hemt5ZSJ9.VM6dkwSLaBv8cStkWRVVv9ADjdUrHGHrlB7GB7Ly7n8` - if err := as.ValidateToken(invalidToken); err == nil { + if _, err := as.ValidateToken(invalidToken); err == nil { t.Fatalf("Invalid token passed validation") } }) diff --git a/gpaste-server.toml b/gpaste-server.toml index 473dc41..79608df 100644 --- a/gpaste-server.toml +++ b/gpaste-server.toml @@ -1,4 +1,4 @@ -LogLevel = "INFO" +LogLevel = "DEBUG" URL = "http://paste.example.org" ListenAddr = ":8080" diff --git a/http.go b/http.go index e4dc9ad..85c64dc 100644 --- a/http.go +++ b/http.go @@ -42,6 +42,7 @@ func NewHTTPServer(cfg *ServerConfig) *HTTPServer { r.Use(middleware.RealIP) r.Use(middleware.RequestID) r.Use(srv.MiddlewareAccessLogger) + r.Use(srv.MiddlewareAuthentication) r.Get("/", srv.HandlerIndex) r.Post("/api/file", srv.HandlerAPIFilePost) r.Get("/api/file/{id}", srv.HandlerAPIFileGet) diff --git a/http_test.go b/http_test.go index b99d4aa..1b2005e 100644 --- a/http_test.go +++ b/http_test.go @@ -137,7 +137,7 @@ func TestHandlers(t *testing.T) { t.Fatalf("Error decoding response: %s", err) } - if err := hs.Auth.ValidateToken(responseData.Token); err != nil { + if _, err := hs.Auth.ValidateToken(responseData.Token); err != nil { t.Fatalf("Unable to validate received token: %s", err) } }) diff --git a/middleware.go b/middleware.go index 585fe2b..f658e12 100644 --- a/middleware.go +++ b/middleware.go @@ -1,12 +1,21 @@ package gpaste import ( + "context" + "fmt" "net/http" + "strings" "time" "github.com/go-chi/chi/v5/middleware" ) +type authCtxKey int + +const ( + authCtxUsername authCtxKey = iota +) + func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) @@ -28,3 +37,51 @@ func (s *HTTPServer) MiddlewareAccessLogger(next http.Handler) http.Handler { } return http.HandlerFunc(fn) } + +func (s *HTTPServer) MiddlewareAuthentication(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + reqID := middleware.GetReqID(r.Context()) + header := r.Header.Get("Authorization") + if header == "" { + s.Logger.Debugw("Request has no auth header.", "req_id", reqID) + next.ServeHTTP(w, r) + return + } + + splitHeader := strings.Split(header, "Bearer ") + if len(splitHeader) != 2 { + s.Logger.Debugw("Request has invalid token.", "req_id", reqID) + next.ServeHTTP(w, r) + return + } + token := splitHeader[1] + + claims, err := s.Auth.ValidateToken(token) + if err != nil { + s.Logger.Debugw("Request has invalid token.", "req_id", reqID) + next.ServeHTTP(w, r) + return + } + + ctx := context.WithValue(r.Context(), authCtxUsername, claims.Subject) + withCtx := r.WithContext(ctx) + s.Logger.Debugw("Request is authenticated.", "req_id", reqID, "username", claims.Subject) + + next.ServeHTTP(w, withCtx) + } + + return http.HandlerFunc(fn) +} + +func UsernameFromRequest(r *http.Request) (string, error) { + rawUsername := r.Context().Value(authCtxUsername) + if rawUsername == nil { + + return "", fmt.Errorf("no username") + } + username, ok := rawUsername.(string) + if !ok { + return "", fmt.Errorf("no username") + } + return username, nil +}