Add user store

This commit is contained in:
Torjus Håkestad 2021-12-05 11:08:09 +01:00
parent fa32f76a61
commit be230233dc
9 changed files with 1488 additions and 164 deletions

File diff suppressed because it is too large Load Diff

View File

@ -207,3 +207,197 @@ var FileService_ServiceDesc = grpc.ServiceDesc{
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "protos/ezshare.proto", Metadata: "protos/ezshare.proto",
} }
// UserServiceClient is the client API for UserService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type UserServiceClient interface {
Register(ctx context.Context, in *RegisterUserRequest, opts ...grpc.CallOption) (*RegisterUserResponse, error)
Login(ctx context.Context, in *LoginUserRequest, opts ...grpc.CallOption) (*LoginUserResponse, error)
List(ctx context.Context, in *ListUsersRequest, opts ...grpc.CallOption) (*ListUsersResponse, error)
Approve(ctx context.Context, in *ApproveUserRequest, opts ...grpc.CallOption) (*Empty, error)
}
type userServiceClient struct {
cc grpc.ClientConnInterface
}
func NewUserServiceClient(cc grpc.ClientConnInterface) UserServiceClient {
return &userServiceClient{cc}
}
func (c *userServiceClient) Register(ctx context.Context, in *RegisterUserRequest, opts ...grpc.CallOption) (*RegisterUserResponse, error) {
out := new(RegisterUserResponse)
err := c.cc.Invoke(ctx, "/ezshare.UserService/Register", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userServiceClient) Login(ctx context.Context, in *LoginUserRequest, opts ...grpc.CallOption) (*LoginUserResponse, error) {
out := new(LoginUserResponse)
err := c.cc.Invoke(ctx, "/ezshare.UserService/Login", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userServiceClient) List(ctx context.Context, in *ListUsersRequest, opts ...grpc.CallOption) (*ListUsersResponse, error) {
out := new(ListUsersResponse)
err := c.cc.Invoke(ctx, "/ezshare.UserService/List", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *userServiceClient) Approve(ctx context.Context, in *ApproveUserRequest, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/ezshare.UserService/Approve", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// UserServiceServer is the server API for UserService service.
// All implementations must embed UnimplementedUserServiceServer
// for forward compatibility
type UserServiceServer interface {
Register(context.Context, *RegisterUserRequest) (*RegisterUserResponse, error)
Login(context.Context, *LoginUserRequest) (*LoginUserResponse, error)
List(context.Context, *ListUsersRequest) (*ListUsersResponse, error)
Approve(context.Context, *ApproveUserRequest) (*Empty, error)
mustEmbedUnimplementedUserServiceServer()
}
// UnimplementedUserServiceServer must be embedded to have forward compatible implementations.
type UnimplementedUserServiceServer struct {
}
func (UnimplementedUserServiceServer) Register(context.Context, *RegisterUserRequest) (*RegisterUserResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Register not implemented")
}
func (UnimplementedUserServiceServer) Login(context.Context, *LoginUserRequest) (*LoginUserResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Login not implemented")
}
func (UnimplementedUserServiceServer) List(context.Context, *ListUsersRequest) (*ListUsersResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method List not implemented")
}
func (UnimplementedUserServiceServer) Approve(context.Context, *ApproveUserRequest) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Approve not implemented")
}
func (UnimplementedUserServiceServer) mustEmbedUnimplementedUserServiceServer() {}
// UnsafeUserServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to UserServiceServer will
// result in compilation errors.
type UnsafeUserServiceServer interface {
mustEmbedUnimplementedUserServiceServer()
}
func RegisterUserServiceServer(s grpc.ServiceRegistrar, srv UserServiceServer) {
s.RegisterService(&UserService_ServiceDesc, srv)
}
func _UserService_Register_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RegisterUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserServiceServer).Register(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/ezshare.UserService/Register",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserServiceServer).Register(ctx, req.(*RegisterUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserService_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(LoginUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserServiceServer).Login(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/ezshare.UserService/Login",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserServiceServer).Login(ctx, req.(*LoginUserRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserService_List_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListUsersRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserServiceServer).List(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/ezshare.UserService/List",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserServiceServer).List(ctx, req.(*ListUsersRequest))
}
return interceptor(ctx, in, info, handler)
}
func _UserService_Approve_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ApproveUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserServiceServer).Approve(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/ezshare.UserService/Approve",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserServiceServer).Approve(ctx, req.(*ApproveUserRequest))
}
return interceptor(ctx, in, info, handler)
}
// UserService_ServiceDesc is the grpc.ServiceDesc for UserService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var UserService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "ezshare.UserService",
HandlerType: (*UserServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Register",
Handler: _UserService_Register_Handler,
},
{
MethodName: "Login",
Handler: _UserService_Login_Handler,
},
{
MethodName: "List",
Handler: _UserService_List_Handler,
},
{
MethodName: "Approve",
Handler: _UserService_Approve_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "protos/ezshare.proto",
}

View File

@ -1,11 +1,18 @@
syntax = "proto3"; syntax = "proto3";
option go_package = "gitea.benny.dog/torjus/ezshare/pb";
package ezshare; package ezshare;
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
option go_package = "gitea.benny.dog/torjus/ezshare/pb";
/////////////////////
// Common messages //
/////////////////////
message Empty {}
////////////////////////
// FILE RELATED STUFF //
////////////////////////
message File { message File {
string file_id = 1; string file_id = 1;
bytes data = 2; bytes data = 2;
@ -62,4 +69,71 @@ service FileService {
rpc GetFile(GetFileRequest) returns (GetFileResponse) {} rpc GetFile(GetFileRequest) returns (GetFileResponse) {}
rpc DeleteFile(DeleteFileRequest) returns (DeleteFileResponse) {} rpc DeleteFile(DeleteFileRequest) returns (DeleteFileResponse) {}
rpc ListFiles(ListFilesRequest) returns (ListFilesResponse) {} rpc ListFiles(ListFilesRequest) returns (ListFilesResponse) {}
} }
////////////////////////
// USER RELATED STUFF //
////////////////////////
message User {
string id = 1;
string username = 2;
bytes hashed_password = 3;
enum Role {
UNAPPROVED = 0;
VIEWONLY = 1;
USER = 2;
ADMIN = 3;
}
Role user_role = 4;
bool active = 5;
}
// Register
message RegisterUserRequest {
string username = 1;
string password = 2;
}
message RegisterUserResponse {
string id = 1;
string token = 2;
}
// Login
message LoginUserRequest {
message UserPasswordLogin {
string username = 1;
string password = 2;
}
message TokenLogin {
string token = 1;
}
oneof requested_login {
TokenLogin with_token = 1;
UserPasswordLogin with_password = 2;
}
}
message LoginUserResponse {
bytes server_cert = 1;
bytes client_cert = 2;
bytes client_key = 3;
}
// List
message ListUsersRequest {
}
message ListUsersResponse {
repeated User users = 1;
}
// Approve
message ApproveUserRequest {
string user_id = 1;
}
service UserService {
rpc Register(RegisterUserRequest) returns (RegisterUserResponse) {}
rpc Login(LoginUserRequest) returns (LoginUserResponse) {}
rpc List(ListUsersRequest) returns (ListUsersResponse) {}
rpc Approve(ApproveUserRequest) returns (Empty) {}
}

View File

@ -20,6 +20,7 @@ type BoltStore struct {
var bktKey = []byte("files") var bktKey = []byte("files")
var bktKeyCerts = []byte("certs") var bktKeyCerts = []byte("certs")
var bktKeyKeys = []byte("keys") var bktKeyKeys = []byte("keys")
var bktKeyUsers = []byte("users")
func NewBoltStore(path string) (*BoltStore, error) { func NewBoltStore(path string) (*BoltStore, error) {
s := &BoltStore{} s := &BoltStore{}
@ -38,7 +39,9 @@ func NewBoltStore(path string) (*BoltStore, error) {
if _, err := t.CreateBucketIfNotExists(bktKeyKeys); err != nil { if _, err := t.CreateBucketIfNotExists(bktKeyKeys); err != nil {
return err return err
} }
if _, err := t.CreateBucketIfNotExists(bktKeyUsers); err != nil {
return err
}
return nil return nil
}) })
if err != nil { if err != nil {
@ -188,14 +191,53 @@ func (s *BoltStore) ListCertificates() ([]string, error) {
var ids []string var ids []string
err := s.db.View(func(tx *bolt.Tx) error { err := s.db.View(func(tx *bolt.Tx) error {
bkt := tx.Bucket(bktKeyCerts) bkt := tx.Bucket(bktKeyCerts)
bkt.ForEach(func(k, v []byte) error { return bkt.ForEach(func(k, v []byte) error {
ids = append(ids, string(k)) ids = append(ids, string(k))
return nil return nil
}) })
return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ids, nil return ids, nil
} }
func (s *BoltStore) StoreUser(user *pb.User) error {
return s.db.Update(func(tx *bolt.Tx) error {
bkt := tx.Bucket(bktKeyUsers)
data, err := proto.Marshal(user)
if err != nil {
return err
}
return bkt.Put([]byte(user.Id), data)
})
}
func (s *BoltStore) GetUser(id string) (*pb.User, error) {
var data []byte
err := s.db.View(func(tx *bolt.Tx) error {
bkt := tx.Bucket(bktKeyUsers)
data = bkt.Get([]byte(id))
return nil
})
if err != nil {
return nil, err
}
var user pb.User
err = proto.Unmarshal(data, &user)
return &user, err
}
func (s *BoltStore) ListUsers() ([]string, error) {
var ids []string
err := s.db.View(func(tx *bolt.Tx) error {
bkt := tx.Bucket(bktKeyUsers)
return bkt.ForEach(func(k, _ []byte) error {
ids = append(ids, string(k))
return nil
})
})
return ids, err
}

View File

@ -10,19 +10,29 @@ import (
func TestBoltStore(t *testing.T) { func TestBoltStore(t *testing.T) {
path := filepath.Join(t.TempDir(), "boltstore.db") path := filepath.Join(t.TempDir(), "boltstore.db")
s, err := store.NewBoltStore(path) s, err := store.NewBoltStore(path)
defer s.Close()
if err != nil { if err != nil {
t.Fatalf("Error opening store: %s", err) t.Fatalf("Error opening store: %s", err)
} }
doFileStoreTest(s, t) doFileStoreTest(s, t)
_ = s.Close()
} }
func TestBoltCertificateStore(t *testing.T) { func TestBoltCertificateStore(t *testing.T) {
path := filepath.Join(t.TempDir(), "boltstore.db") path := filepath.Join(t.TempDir(), "boltstore.db")
s, err := store.NewBoltStore(path) s, err := store.NewBoltStore(path)
defer s.Close()
if err != nil { if err != nil {
t.Fatalf("Error opening store: %s", err) t.Fatalf("Error opening store: %s", err)
} }
doCertificateStoreTest(s, t) doCertificateStoreTest(s, t)
_ = s.Close() }
func TestBoltUserStore(t *testing.T) {
path := filepath.Join(t.TempDir(), "boltstore.db")
s, err := store.NewBoltStore(path)
defer s.Close()
if err != nil {
t.Fatalf("Error opening store: %s", err)
}
doUserStoreTests(s, t)
} }

View File

@ -9,9 +9,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
var _ FileStore = &MemoryStore{}
var _ CertificateStore = &MemoryStore{}
type MemoryStore struct { type MemoryStore struct {
filesLock sync.RWMutex filesLock sync.RWMutex
files map[string]*pb.File files map[string]*pb.File
@ -19,6 +16,8 @@ type MemoryStore struct {
certs map[string][]byte certs map[string][]byte
keyLock sync.RWMutex keyLock sync.RWMutex
keys map[string][]byte keys map[string][]byte
usersLock sync.RWMutex
users map[string]*pb.User
} }
func NewMemoryStore() *MemoryStore { func NewMemoryStore() *MemoryStore {
@ -26,9 +25,16 @@ func NewMemoryStore() *MemoryStore {
files: make(map[string]*pb.File), files: make(map[string]*pb.File),
certs: make(map[string][]byte), certs: make(map[string][]byte),
keys: make(map[string][]byte), keys: make(map[string][]byte),
users: make(map[string]*pb.User),
} }
} }
///////////////
// FileStore //
///////////////
var _ FileStore = &MemoryStore{}
func (s *MemoryStore) GetFile(id string) (*pb.File, error) { func (s *MemoryStore) GetFile(id string) (*pb.File, error) {
s.filesLock.RLock() s.filesLock.RLock()
defer s.filesLock.RUnlock() defer s.filesLock.RUnlock()
@ -76,6 +82,12 @@ func (s *MemoryStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error)
return response, nil return response, nil
} }
//////////////////////
// CertificateStore //
//////////////////////
var _ CertificateStore = &MemoryStore{}
func (s *MemoryStore) GetCertificate(id string) (*x509.Certificate, error) { func (s *MemoryStore) GetCertificate(id string) (*x509.Certificate, error) {
s.certLock.Lock() s.certLock.Lock()
defer s.certLock.Unlock() defer s.certLock.Unlock()
@ -134,3 +146,39 @@ func (s *MemoryStore) ListCertificates() ([]string, error) {
} }
return certIDs, nil return certIDs, nil
} }
///////////////
// UserStore //
///////////////
var _ UserStore = &MemoryStore{}
func (s *MemoryStore) StoreUser(user *pb.User) error {
s.usersLock.Lock()
defer s.usersLock.Unlock()
s.users[user.Id] = user
return nil
}
func (s *MemoryStore) GetUser(id string) (*pb.User, error) {
s.usersLock.RLock()
defer s.usersLock.RUnlock()
user, ok := s.users[id]
if !ok {
// TODO: Update error
return nil, ErrNoSuchFile
}
return user, nil
}
func (s *MemoryStore) ListUsers() ([]string, error) {
s.usersLock.RLock()
defer s.usersLock.RUnlock()
var ids []string
for id := range s.users {
ids = append(ids, id)
}
return ids, nil
}

View File

@ -14,3 +14,8 @@ func TestMemoryCertificateStore(t *testing.T) {
s := store.NewMemoryStore() s := store.NewMemoryStore()
doCertificateStoreTest(s, t) doCertificateStoreTest(s, t)
} }
func TestMemoryUserStore(t *testing.T) {
s := store.NewMemoryStore()
doUserStoreTests(s, t)
}

View File

@ -24,3 +24,9 @@ type CertificateStore interface {
StoreKey(id string, key *ecdsa.PrivateKey) error StoreKey(id string, key *ecdsa.PrivateKey) error
ListCertificates() ([]string, error) ListCertificates() ([]string, error)
} }
type UserStore interface {
StoreUser(user *pb.User) error
GetUser(id string) (*pb.User, error)
ListUsers() ([]string, error)
}

View File

@ -6,6 +6,8 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"github.com/google/uuid"
"google.golang.org/protobuf/proto"
"math/big" "math/big"
"testing" "testing"
"time" "time"
@ -140,3 +142,39 @@ func doCertificateStoreTest(s store.CertificateStore, t *testing.T) {
} }
}) })
} }
func doUserStoreTests(s store.UserStore, t *testing.T) {
t.Run("Basics", func(t *testing.T) {
// Store user
user := &pb.User{
Id: uuid.Must(uuid.NewRandom()).String(),
Username: "testuser",
UserRole: pb.User_USER,
Active: true,
}
if err := s.StoreUser(user); err != nil {
t.Fatalf("Error storing user: %s", err)
}
retrieved, err := s.GetUser(user.Id)
if err != nil {
t.Fatalf("Retriving user returned error: %s", err)
}
if !proto.Equal(user, retrieved) {
t.Fatalf("Retrieved user does not match original")
}
list, err := s.ListUsers()
if err != nil {
t.Fatalf("Error listing users")
}
if len(list) != 1 {
t.Fatalf("User list wrong length")
}
if list[0] != user.Id {
t.Fatalf("User list has wrong id")
}
})
}