package interceptors import ( "context" "gitea.benny.dog/torjus/ezshare/certs" "gitea.benny.dog/torjus/ezshare/pb" "gitea.benny.dog/torjus/ezshare/store" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) type ContextKey string var ContextKeyRole ContextKey = "role" var ContextKeyUserID ContextKey = "userid" func NewAuthInterceptor(s store.UserStore, certSvc *certs.CertService, logger *zap.SugaredLogger) grpc.UnaryServerInterceptor { // TODO: Verify that cert is signed by our ca return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { // Login doesn't need valid cert if info.FullMethod == "/ezshare.UserService/Login" { return handler(ctx, req) } p, ok := peer.FromContext(ctx) if ok { tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo) if ok { if len(tlsInfo.State.PeerCertificates) == 1 { cert := tlsInfo.State.PeerCertificates[0] // Check if valid id, err := certSvc.VerifyClient(cert.Raw) if err != nil { logger.Infow("Rejected client due to invalid cert", "error", "err", "remote_addr", p.Addr.String(), "method", info.FullMethod) return nil, status.Error(codes.Unauthenticated, "invalid client certificate") } user, err := s.GetUser(id) if err == nil { newCtx := context.WithValue(ctx, ContextKeyRole, user.UserRole) newCtx = context.WithValue(newCtx, ContextKeyUserID, user.Id) logger.Debugw("Authenticated user.", "username", user.Username, "role", user.UserRole.String(), "method", info.FullMethod) return handler(newCtx, req) } } } } newCtx := context.WithValue(ctx, ContextKeyRole, pb.User_UNKNOWN) return handler(newCtx, req) } } func RoleFromContext(ctx context.Context) pb.User_Role { value := ctx.Value(ContextKeyRole) if value == nil { return pb.User_UNKNOWN } role, ok := value.(pb.User_Role) if ok { return role } return pb.User_UNKNOWN } func UserIDFromContext(ctx context.Context) string { value := ctx.Value(ContextKeyUserID) if value == nil { return "" } id, ok := value.(string) if ok { return id } return "" } func RoleAtLeast(ctx context.Context, role pb.User_Role) bool { ctxRole := RoleFromContext(ctx) return ctxRole > role }