Add user store
This commit is contained in:
@@ -20,6 +20,7 @@ type BoltStore struct {
|
||||
var bktKey = []byte("files")
|
||||
var bktKeyCerts = []byte("certs")
|
||||
var bktKeyKeys = []byte("keys")
|
||||
var bktKeyUsers = []byte("users")
|
||||
|
||||
func NewBoltStore(path string) (*BoltStore, error) {
|
||||
s := &BoltStore{}
|
||||
@@ -38,7 +39,9 @@ func NewBoltStore(path string) (*BoltStore, error) {
|
||||
if _, err := t.CreateBucketIfNotExists(bktKeyKeys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := t.CreateBucketIfNotExists(bktKeyUsers); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -188,14 +191,53 @@ func (s *BoltStore) ListCertificates() ([]string, error) {
|
||||
var ids []string
|
||||
err := s.db.View(func(tx *bolt.Tx) error {
|
||||
bkt := tx.Bucket(bktKeyCerts)
|
||||
bkt.ForEach(func(k, v []byte) error {
|
||||
return bkt.ForEach(func(k, v []byte) error {
|
||||
ids = append(ids, string(k))
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (s *BoltStore) StoreUser(user *pb.User) error {
|
||||
return s.db.Update(func(tx *bolt.Tx) error {
|
||||
bkt := tx.Bucket(bktKeyUsers)
|
||||
data, err := proto.Marshal(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return bkt.Put([]byte(user.Id), data)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BoltStore) GetUser(id string) (*pb.User, error) {
|
||||
var data []byte
|
||||
err := s.db.View(func(tx *bolt.Tx) error {
|
||||
bkt := tx.Bucket(bktKeyUsers)
|
||||
data = bkt.Get([]byte(id))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user pb.User
|
||||
err = proto.Unmarshal(data, &user)
|
||||
return &user, err
|
||||
}
|
||||
|
||||
func (s *BoltStore) ListUsers() ([]string, error) {
|
||||
var ids []string
|
||||
err := s.db.View(func(tx *bolt.Tx) error {
|
||||
bkt := tx.Bucket(bktKeyUsers)
|
||||
return bkt.ForEach(func(k, _ []byte) error {
|
||||
ids = append(ids, string(k))
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
return ids, err
|
||||
}
|
||||
|
@@ -10,19 +10,29 @@ import (
|
||||
func TestBoltStore(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "boltstore.db")
|
||||
s, err := store.NewBoltStore(path)
|
||||
defer s.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening store: %s", err)
|
||||
}
|
||||
doFileStoreTest(s, t)
|
||||
_ = s.Close()
|
||||
}
|
||||
|
||||
func TestBoltCertificateStore(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "boltstore.db")
|
||||
s, err := store.NewBoltStore(path)
|
||||
defer s.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening store: %s", err)
|
||||
}
|
||||
doCertificateStoreTest(s, t)
|
||||
_ = s.Close()
|
||||
}
|
||||
|
||||
func TestBoltUserStore(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "boltstore.db")
|
||||
s, err := store.NewBoltStore(path)
|
||||
defer s.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Error opening store: %s", err)
|
||||
}
|
||||
doUserStoreTests(s, t)
|
||||
}
|
||||
|
@@ -9,9 +9,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var _ FileStore = &MemoryStore{}
|
||||
var _ CertificateStore = &MemoryStore{}
|
||||
|
||||
type MemoryStore struct {
|
||||
filesLock sync.RWMutex
|
||||
files map[string]*pb.File
|
||||
@@ -19,6 +16,8 @@ type MemoryStore struct {
|
||||
certs map[string][]byte
|
||||
keyLock sync.RWMutex
|
||||
keys map[string][]byte
|
||||
usersLock sync.RWMutex
|
||||
users map[string]*pb.User
|
||||
}
|
||||
|
||||
func NewMemoryStore() *MemoryStore {
|
||||
@@ -26,9 +25,16 @@ func NewMemoryStore() *MemoryStore {
|
||||
files: make(map[string]*pb.File),
|
||||
certs: make(map[string][]byte),
|
||||
keys: make(map[string][]byte),
|
||||
users: make(map[string]*pb.User),
|
||||
}
|
||||
}
|
||||
|
||||
///////////////
|
||||
// FileStore //
|
||||
///////////////
|
||||
|
||||
var _ FileStore = &MemoryStore{}
|
||||
|
||||
func (s *MemoryStore) GetFile(id string) (*pb.File, error) {
|
||||
s.filesLock.RLock()
|
||||
defer s.filesLock.RUnlock()
|
||||
@@ -76,6 +82,12 @@ func (s *MemoryStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error)
|
||||
return response, nil
|
||||
}
|
||||
|
||||
//////////////////////
|
||||
// CertificateStore //
|
||||
//////////////////////
|
||||
|
||||
var _ CertificateStore = &MemoryStore{}
|
||||
|
||||
func (s *MemoryStore) GetCertificate(id string) (*x509.Certificate, error) {
|
||||
s.certLock.Lock()
|
||||
defer s.certLock.Unlock()
|
||||
@@ -134,3 +146,39 @@ func (s *MemoryStore) ListCertificates() ([]string, error) {
|
||||
}
|
||||
return certIDs, nil
|
||||
}
|
||||
|
||||
///////////////
|
||||
// UserStore //
|
||||
///////////////
|
||||
|
||||
var _ UserStore = &MemoryStore{}
|
||||
|
||||
func (s *MemoryStore) StoreUser(user *pb.User) error {
|
||||
s.usersLock.Lock()
|
||||
defer s.usersLock.Unlock()
|
||||
|
||||
s.users[user.Id] = user
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MemoryStore) GetUser(id string) (*pb.User, error) {
|
||||
s.usersLock.RLock()
|
||||
defer s.usersLock.RUnlock()
|
||||
user, ok := s.users[id]
|
||||
if !ok {
|
||||
// TODO: Update error
|
||||
return nil, ErrNoSuchFile
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *MemoryStore) ListUsers() ([]string, error) {
|
||||
s.usersLock.RLock()
|
||||
defer s.usersLock.RUnlock()
|
||||
|
||||
var ids []string
|
||||
for id := range s.users {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
@@ -14,3 +14,8 @@ func TestMemoryCertificateStore(t *testing.T) {
|
||||
s := store.NewMemoryStore()
|
||||
doCertificateStoreTest(s, t)
|
||||
}
|
||||
|
||||
func TestMemoryUserStore(t *testing.T) {
|
||||
s := store.NewMemoryStore()
|
||||
doUserStoreTests(s, t)
|
||||
}
|
||||
|
@@ -24,3 +24,9 @@ type CertificateStore interface {
|
||||
StoreKey(id string, key *ecdsa.PrivateKey) error
|
||||
ListCertificates() ([]string, error)
|
||||
}
|
||||
|
||||
type UserStore interface {
|
||||
StoreUser(user *pb.User) error
|
||||
GetUser(id string) (*pb.User, error)
|
||||
ListUsers() ([]string, error)
|
||||
}
|
||||
|
@@ -6,6 +6,8 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -140,3 +142,39 @@ func doCertificateStoreTest(s store.CertificateStore, t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func doUserStoreTests(s store.UserStore, t *testing.T) {
|
||||
t.Run("Basics", func(t *testing.T) {
|
||||
// Store user
|
||||
user := &pb.User{
|
||||
Id: uuid.Must(uuid.NewRandom()).String(),
|
||||
Username: "testuser",
|
||||
UserRole: pb.User_USER,
|
||||
Active: true,
|
||||
}
|
||||
|
||||
if err := s.StoreUser(user); err != nil {
|
||||
t.Fatalf("Error storing user: %s", err)
|
||||
}
|
||||
|
||||
retrieved, err := s.GetUser(user.Id)
|
||||
if err != nil {
|
||||
t.Fatalf("Retriving user returned error: %s", err)
|
||||
}
|
||||
|
||||
if !proto.Equal(user, retrieved) {
|
||||
t.Fatalf("Retrieved user does not match original")
|
||||
}
|
||||
|
||||
list, err := s.ListUsers()
|
||||
if err != nil {
|
||||
t.Fatalf("Error listing users")
|
||||
}
|
||||
if len(list) != 1 {
|
||||
t.Fatalf("User list wrong length")
|
||||
}
|
||||
if list[0] != user.Id {
|
||||
t.Fatalf("User list has wrong id")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user