From a1737dd2f69c1283fa7858d4515f52f2127af6b3 Mon Sep 17 00:00:00 2001 From: = Date: Sun, 5 Dec 2021 01:00:32 +0100 Subject: [PATCH] Add certificate store --- config/config.go | 2 +- store/bolt.go | 88 ++++++++++++++++++++++++++++++++++++++++++++ store/bolt_test.go | 12 +++++- store/memory.go | 86 +++++++++++++++++++++++++++++++++++++++---- store/memory_test.go | 8 +++- store/store.go | 10 +++++ store/store_test.go | 77 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 271 insertions(+), 12 deletions(-) diff --git a/config/config.go b/config/config.go index ffe32e6..9928924 100644 --- a/config/config.go +++ b/config/config.go @@ -241,7 +241,7 @@ func (sc *ServerStoreConfig) GetStore() (store.FileStore, func() error, error) { return s, nopCloseFunc, nil } if strings.EqualFold(sc.Type, "memory") { - return store.NewMemoryFileStore(), nopCloseFunc, nil + return store.NewMemoryStore(), nopCloseFunc, nil } return nil, nil, fmt.Errorf("invalid store config") diff --git a/store/bolt.go b/store/bolt.go index e490b96..d1be574 100644 --- a/store/bolt.go +++ b/store/bolt.go @@ -1,6 +1,8 @@ package store import ( + "crypto/ecdsa" + "crypto/x509" "fmt" "gitea.benny.dog/torjus/ezshare/pb" @@ -16,6 +18,8 @@ type BoltStore struct { } var bktKey = []byte("files") +var bktKeyCerts = []byte("certs") +var bktKeyKeys = []byte("keys") func NewBoltStore(path string) (*BoltStore, error) { s := &BoltStore{} @@ -28,6 +32,12 @@ func NewBoltStore(path string) (*BoltStore, error) { if _, err := t.CreateBucketIfNotExists(bktKey); err != nil { return err } + if _, err := t.CreateBucketIfNotExists(bktKeyCerts); err != nil { + return err + } + if _, err := t.CreateBucketIfNotExists(bktKeyKeys); err != nil { + return err + } return nil }) @@ -111,3 +121,81 @@ func (s *BoltStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error) { } return response, nil } + +// Certificate store +var _ CertificateStore = &BoltStore{} + +func (s *BoltStore) GetCertificate(id string) (*x509.Certificate, error) { + var raw []byte + err := s.db.View(func(t *bolt.Tx) error { + bkt := t.Bucket(bktKeyCerts) + + raw = bkt.Get([]byte(id)) + return nil + }) + if err != nil { + return nil, err + } + if raw == nil { + return nil, ErrNoSuchFile + } + + cert, err := x509.ParseCertificate(raw) + if err != nil { + return nil, fmt.Errorf("unable to parse certificate: %w", err) + } + + return cert, nil +} + +func (s *BoltStore) StoreCertificate(id string, cert *x509.Certificate) error { + data := make([]byte, len(cert.Raw)) + copy(data, cert.Raw) + + return s.db.Update(func(t *bolt.Tx) error { + bkt := t.Bucket(bktKeyCerts) + return bkt.Put([]byte(id), cert.Raw) + }) +} + +func (s *BoltStore) GetKey(id string) (*ecdsa.PrivateKey, error) { + var data []byte + err := s.db.View(func(t *bolt.Tx) error { + bkt := t.Bucket(bktKeyKeys) + + data = bkt.Get([]byte(id)) + return nil + }) + if err != nil { + return nil, err + } + + return x509.ParseECPrivateKey(data) +} + +func (s *BoltStore) StoreKey(id string, key *ecdsa.PrivateKey) error { + data, err := x509.MarshalECPrivateKey(key) + if err != nil { + return fmt.Errorf("unable to marshal key: %w", err) + } + return s.db.Update(func(t *bolt.Tx) error { + bkt := t.Bucket(bktKeyKeys) + return bkt.Put([]byte(id), data) + }) +} + +func (s *BoltStore) ListCertificates() ([]string, error) { + var ids []string + err := s.db.View(func(tx *bolt.Tx) error { + bkt := tx.Bucket(bktKeyCerts) + bkt.ForEach(func(k, v []byte) error { + ids = append(ids, string(k)) + return nil + }) + return nil + }) + if err != nil { + return nil, err + } + return ids, nil +} diff --git a/store/bolt_test.go b/store/bolt_test.go index 4502bab..5c32414 100644 --- a/store/bolt_test.go +++ b/store/bolt_test.go @@ -14,5 +14,15 @@ func TestBoltStore(t *testing.T) { t.Fatalf("Error opening store: %s", err) } doFileStoreTest(s, t) - s.Close() + _ = s.Close() +} + +func TestBoltCertificateStore(t *testing.T) { + path := filepath.Join(t.TempDir(), "boltstore.db") + s, err := store.NewBoltStore(path) + if err != nil { + t.Fatalf("Error opening store: %s", err) + } + doCertificateStoreTest(s, t) + _ = s.Close() } diff --git a/store/memory.go b/store/memory.go index b571b33..f544885 100644 --- a/store/memory.go +++ b/store/memory.go @@ -1,24 +1,35 @@ package store import ( + "crypto/ecdsa" + "crypto/x509" "sync" "gitea.benny.dog/torjus/ezshare/pb" "github.com/google/uuid" ) -var _ FileStore = &MemoryFileStore{} +var _ FileStore = &MemoryStore{} +var _ CertificateStore = &MemoryStore{} -type MemoryFileStore struct { +type MemoryStore struct { filesLock sync.RWMutex files map[string]*pb.File + certLock sync.RWMutex + certs map[string][]byte + keyLock sync.RWMutex + keys map[string][]byte } -func NewMemoryFileStore() *MemoryFileStore { - return &MemoryFileStore{files: make(map[string]*pb.File)} +func NewMemoryStore() *MemoryStore { + return &MemoryStore{ + files: make(map[string]*pb.File), + certs: make(map[string][]byte), + keys: make(map[string][]byte), + } } -func (s *MemoryFileStore) GetFile(id string) (*pb.File, error) { +func (s *MemoryStore) GetFile(id string) (*pb.File, error) { s.filesLock.RLock() defer s.filesLock.RUnlock() @@ -29,7 +40,7 @@ func (s *MemoryFileStore) GetFile(id string) (*pb.File, error) { return nil, ErrNoSuchFile } -func (s *MemoryFileStore) StoreFile(file *pb.File) (string, error) { +func (s *MemoryStore) StoreFile(file *pb.File) (string, error) { s.filesLock.Lock() defer s.filesLock.Unlock() @@ -41,7 +52,7 @@ func (s *MemoryFileStore) StoreFile(file *pb.File) (string, error) { return id, nil } -func (s *MemoryFileStore) DeleteFile(id string) error { +func (s *MemoryStore) DeleteFile(id string) error { s.filesLock.Lock() defer s.filesLock.Unlock() if _, ok := s.files[id]; !ok { @@ -52,7 +63,7 @@ func (s *MemoryFileStore) DeleteFile(id string) error { return nil } -func (s *MemoryFileStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error) { +func (s *MemoryStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error) { s.filesLock.RLock() defer s.filesLock.RUnlock() @@ -64,3 +75,62 @@ func (s *MemoryFileStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, err return response, nil } + +func (s *MemoryStore) GetCertificate(id string) (*x509.Certificate, error) { + s.certLock.Lock() + defer s.certLock.Unlock() + + data, ok := s.certs[id] + if !ok { + // TODO: Make separate error, or rename error + return nil, ErrNoSuchFile + } + + return x509.ParseCertificate(data) +} + +func (s *MemoryStore) StoreCertificate(id string, cert *x509.Certificate) error { + s.certLock.Lock() + defer s.certLock.Unlock() + + // Copy cert data + data := make([]byte, len(cert.Raw)) + copy(data, cert.Raw) + + s.certs[id] = data + return nil +} + +func (s *MemoryStore) GetKey(id string) (*ecdsa.PrivateKey, error) { + s.keyLock.RLock() + defer s.keyLock.RUnlock() + data, ok := s.keys[id] + if !ok { + return nil, ErrNoSuchFile + } + + return x509.ParseECPrivateKey(data) +} + +func (s *MemoryStore) StoreKey(id string, key *ecdsa.PrivateKey) error { + s.keyLock.Lock() + defer s.keyLock.Unlock() + + data, err := x509.MarshalECPrivateKey(key) + if err != nil { + return err + } + + s.keys[id] = data + return nil +} + +func (s *MemoryStore) ListCertificates() ([]string, error) { + s.certLock.RLock() + defer s.certLock.RUnlock() + var certIDs []string + for key := range s.certs { + certIDs = append(certIDs, key) + } + return certIDs, nil +} diff --git a/store/memory_test.go b/store/memory_test.go index fcd0091..4e742ae 100644 --- a/store/memory_test.go +++ b/store/memory_test.go @@ -6,7 +6,11 @@ import ( "gitea.benny.dog/torjus/ezshare/store" ) -func TestMemoryStore(t *testing.T) { - s := store.NewMemoryFileStore() +func TestMemoryFileStore(t *testing.T) { + s := store.NewMemoryStore() doFileStoreTest(s, t) } +func TestMemoryCertificateStore(t *testing.T) { + s := store.NewMemoryStore() + doCertificateStoreTest(s, t) +} diff --git a/store/store.go b/store/store.go index 5fab14b..f771ada 100644 --- a/store/store.go +++ b/store/store.go @@ -1,6 +1,8 @@ package store import ( + "crypto/ecdsa" + "crypto/x509" "fmt" "gitea.benny.dog/torjus/ezshare/pb" @@ -14,3 +16,11 @@ type FileStore interface { DeleteFile(id string) error ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error) } + +type CertificateStore interface { + GetCertificate(id string) (*x509.Certificate, error) + StoreCertificate(id string, cert *x509.Certificate) error + GetKey(id string) (*ecdsa.PrivateKey, error) + StoreKey(id string, key *ecdsa.PrivateKey) error + ListCertificates() ([]string, error) +} diff --git a/store/store_test.go b/store/store_test.go index 656ee7f..bf138d2 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -1,6 +1,12 @@ package store_test import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "math/big" "testing" "time" @@ -63,3 +69,74 @@ func doFileStoreTest(s store.FileStore, t *testing.T) { } }) } + +func doCertificateStoreTest(s store.CertificateStore, t *testing.T) { + t.Run("Basic", func(t *testing.T) { + + // Create cert and key + unsigned := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().Unix()), + Subject: pkix.Name{ + Organization: []string{"ezshare"}, + Country: []string{"No"}, + Locality: []string{"Oslo"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("unable to create private key: %s", err) + } + certBytes, err := x509.CreateCertificate(rand.Reader, unsigned, unsigned, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("Error creating cert: %s", err) + } + + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + t.Fatalf("Error parsing created certificate: %s", err) + } + + // Store cert + if err := s.StoreCertificate("cert", cert); err != nil { + t.Fatalf("Error storing cert: %s", err) + } + + // Store key + if err := s.StoreKey("key", privateKey); err != nil { + t.Fatalf("Error storing key: %s", err) + } + + // List + ids, err := s.ListCertificates() + if err != nil { + t.Fatalf("List returned error: %s", err) + } + if len(ids) != 1 { + t.Fatalf("List has wrong length: %s", err) + } + if ids[0] != "cert" { + t.Fatalf("List has wrong id") + } + + retrievedCert, err := s.GetCertificate("cert") + if err != nil { + t.Fatalf("Unable to get certificate from store: %s", err) + } + if !retrievedCert.Equal(cert) { + t.Errorf("Retrieved certificate does not match stored.") + } + + retrievedKey, err := s.GetKey("key") + if err != nil { + t.Fatalf("Unable to get key from store: %s", err) + } + if !retrievedKey.Equal(privateKey) { + t.Errorf("Retrieved key does not match stored.") + } + }) +}