diff --git a/actions/serve.go b/actions/serve.go index 8a91599..fe557f8 100644 --- a/actions/serve.go +++ b/actions/serve.go @@ -136,7 +136,7 @@ func ActionServe(c *cli.Context) error { grpcServer := grpc.NewServer( grpc.Creds(creds), - grpc.ChainUnaryInterceptor(interceptors.NewAuthInterceptor(userStore, authLogger)), + grpc.ChainUnaryInterceptor(interceptors.NewAuthInterceptor(userStore, certSvc, authLogger)), ) pb.RegisterFileServiceServer(grpcServer, grpcFileServer) pb.RegisterUserServiceServer(grpcServer, grpcUserServer) diff --git a/server/interceptors/auth.go b/server/interceptors/auth.go index 6270c28..c2d0ff2 100644 --- a/server/interceptors/auth.go +++ b/server/interceptors/auth.go @@ -3,12 +3,15 @@ package interceptors import ( "context" + "gitea.benny.dog/torjus/ezshare/certs" "gitea.benny.dog/torjus/ezshare/pb" "gitea.benny.dog/torjus/ezshare/store" "go.uber.org/zap" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" ) type ContextKey string @@ -16,9 +19,14 @@ type ContextKey string var ContextKeyRole ContextKey = "role" var ContextKeyUserID ContextKey = "userid" -func NewAuthInterceptor(s store.UserStore, logger *zap.SugaredLogger) grpc.UnaryServerInterceptor { +func NewAuthInterceptor(s store.UserStore, certSvc *certs.CertService, logger *zap.SugaredLogger) grpc.UnaryServerInterceptor { // TODO: Verify that cert is signed by our ca return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { + // Login doesn't need valid cert + if info.FullMethod == "/ezshare.UserService/Login" { + return handler(ctx, req) + } + p, ok := peer.FromContext(ctx) if ok { tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo) @@ -26,13 +34,18 @@ func NewAuthInterceptor(s store.UserStore, logger *zap.SugaredLogger) grpc.Unary if len(tlsInfo.State.PeerCertificates) == 1 { cert := tlsInfo.State.PeerCertificates[0] - id := cert.Subject.CommonName + // Check if valid + id, err := certSvc.VerifyClient(cert.Raw) + if err != nil { + logger.Infow("Rejected client due to invalid cert", "error", "err", "remote_addr", p.Addr.String(), "method", info.FullMethod) + return nil, status.Error(codes.Unauthenticated, "invalid client certificate") + } user, err := s.GetUser(id) if err == nil { newCtx := context.WithValue(ctx, ContextKeyRole, user.UserRole) newCtx = context.WithValue(newCtx, ContextKeyUserID, user.Id) - logger.Debugw("Authenticated user.", "username", user.Username, "role", user.UserRole.String()) + logger.Debugw("Authenticated user.", "username", user.Username, "role", user.UserRole.String(), "method", info.FullMethod) return handler(newCtx, req) } }