package server import ( "context" "fmt" "git.t-juice.club/torjus/ezshare/certs" "git.t-juice.club/torjus/ezshare/pb" "git.t-juice.club/torjus/ezshare/server/interceptors" "git.t-juice.club/torjus/ezshare/store" "github.com/google/uuid" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type GRPCUserServiceServer struct { Logger *zap.SugaredLogger store store.UserStore certService *certs.CertService pb.UnimplementedUserServiceServer } func NewGRPCUserServiceServer(store store.UserStore, certSvc *certs.CertService) *GRPCUserServiceServer { return &GRPCUserServiceServer{store: store, certService: certSvc, Logger: zap.NewNop().Sugar()} } func (s *GRPCUserServiceServer) Register(ctx context.Context, req *pb.RegisterUserRequest) (*pb.RegisterUserResponse, error) { // Check if user already exists if _, err := s.store.GetUserByUsername(req.Username); err != store.ErrNoSuchItem { return nil, status.Error(codes.AlreadyExists, "user already exists") } pw, err := hashPassword(req.Password) if err != nil { return nil, fmt.Errorf("unable to hash password: %w", err) } user := &pb.User{ Id: uuid.Must(uuid.NewRandom()).String(), Username: req.Username, HashedPassword: pw, UserRole: pb.User_USER, Active: true, } if err := s.store.StoreUser(user); err != nil { s.Logger.Warnw("Error storing registered user.", "error", err) return nil, status.Error(codes.Internal, fmt.Sprintf("unable to store user: %s", err)) } s.Logger.Infow("Registered new user.", "username", user.Username) return &pb.RegisterUserResponse{Id: user.Id, Token: ""}, nil } func (s *GRPCUserServiceServer) Login(_ context.Context, req *pb.LoginUserRequest) (*pb.LoginUserResponse, error) { user, err := s.store.GetUserByUsername(req.Username) if err != nil { if err == store.ErrNoSuchItem { return nil, status.Error(codes.NotFound, "no such user") } s.Logger.Warnw("Error retrieving user from store.", "error", err) return nil, status.Error(codes.Internal, "error getting user from store") } if err := bcrypt.CompareHashAndPassword(user.HashedPassword, []byte(req.Password)); err != nil { return nil, status.Error(codes.Unauthenticated, "wrong username and or password") } cert, key, err := s.certService.NewClient(user.Id) if err != nil { s.Logger.Warnw("Error generating client certificate.", "error", err) return nil, status.Error(codes.Internal, "unable to generate client certificate") } resp := &pb.LoginUserResponse{ ClientCert: cert, ClientKey: key, } s.Logger.Infow("Logged in user.", "username", user.Username) return resp, nil } func (s *GRPCUserServiceServer) List(_ context.Context, _ *pb.ListUsersRequest) (*pb.ListUsersResponse, error) { return nil, status.Error(codes.Unimplemented, "not yet implemented") } func (s *GRPCUserServiceServer) Approve(_ context.Context, _ *pb.ApproveUserRequest) (*pb.Empty, error) { return nil, status.Error(codes.Unimplemented, "not yet implemented") } func (s *GRPCUserServiceServer) ChangePassword(ctx context.Context, req *pb.ChangePasswordRequest) (*pb.Empty, error) { // Get ID from ctx userID := interceptors.UserIDFromContext(ctx) if userID == "" { return nil, status.Error(codes.Unauthenticated, "not authenticated") } user, err := s.store.GetUser(userID) if err != nil { return nil, status.Error(codes.Unauthenticated, "user not found") } if err := bcrypt.CompareHashAndPassword(user.HashedPassword, []byte(req.OldPassword)); err != nil { return nil, status.Error(codes.Unauthenticated, "wrong password") } newPasswordHash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) if err != nil { return nil, status.Error(codes.Internal, "unable to hash new password") } user.HashedPassword = newPasswordHash if err := s.store.StoreUser(user); err != nil { s.Logger.Warnw("Error storing user with new password.", "error", err) return nil, status.Error(codes.Internal, "unable to store new password") } s.Logger.Infow("Set new password for user.", "username", user.Username) return &pb.Empty{}, nil } func hashPassword(password string) ([]byte, error) { return bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) }