package interceptors

import (
	"context"

	"gitea.benny.dog/torjus/ezshare/certs"
	"gitea.benny.dog/torjus/ezshare/pb"
	"gitea.benny.dog/torjus/ezshare/store"
	"go.uber.org/zap"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/peer"
	"google.golang.org/grpc/status"
)

type ContextKey string

var ContextKeyRole ContextKey = "role"
var ContextKeyUserID ContextKey = "userid"

func NewAuthInterceptor(s store.UserStore, certSvc *certs.CertService, logger *zap.SugaredLogger) grpc.UnaryServerInterceptor {
	// TODO: Verify that cert is signed by our ca
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
		// Login doesn't need valid cert
		if info.FullMethod == "/ezshare.UserService/Login" {
			return handler(ctx, req)
		}

		p, ok := peer.FromContext(ctx)
		if ok {
			tlsInfo, ok := p.AuthInfo.(credentials.TLSInfo)
			if ok {
				if len(tlsInfo.State.PeerCertificates) == 1 {
					cert := tlsInfo.State.PeerCertificates[0]

					// Check if valid
					id, err := certSvc.VerifyClient(cert.Raw)
					if err != nil {
						logger.Infow("Rejected client due to invalid cert", "error", "err", "remote_addr", p.Addr.String(), "method", info.FullMethod)
						return nil, status.Error(codes.Unauthenticated, "invalid client certificate")
					}

					user, err := s.GetUser(id)
					if err == nil {
						newCtx := context.WithValue(ctx, ContextKeyRole, user.UserRole)
						newCtx = context.WithValue(newCtx, ContextKeyUserID, user.Id)
						logger.Debugw("Authenticated user.", "username", user.Username, "role", user.UserRole.String(), "method", info.FullMethod)
						return handler(newCtx, req)
					}
				}
			}
		}

		newCtx := context.WithValue(ctx, ContextKeyRole, pb.User_UNKNOWN)

		return handler(newCtx, req)
	}
}

func RoleFromContext(ctx context.Context) pb.User_Role {
	value := ctx.Value(ContextKeyRole)
	if value == nil {
		return pb.User_UNKNOWN
	}
	role, ok := value.(pb.User_Role)
	if ok {
		return role
	}
	return pb.User_UNKNOWN
}

func UserIDFromContext(ctx context.Context) string {
	value := ctx.Value(ContextKeyUserID)
	if value == nil {
		return ""
	}
	id, ok := value.(string)
	if ok {
		return id
	}
	return ""
}

func RoleAtLeast(ctx context.Context, role pb.User_Role) bool {
	ctxRole := RoleFromContext(ctx)

	return ctxRole >= role
}