diff --git a/certs/certservice.go b/certs/certservice.go new file mode 100644 index 0000000..5a2e1fa --- /dev/null +++ b/certs/certservice.go @@ -0,0 +1,133 @@ +package certs + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "gitea.benny.dog/torjus/ezshare/store" + "math/big" + "time" +) + +type CertService struct { + caCert *x509.Certificate + caKey crypto.Signer + store store.CertificateStore +} + +func NewCertService(s store.CertificateStore, certBytes, keyBytes []byte) (*CertService, error) { + // Try to decode key as PEM + keyBlock, _ := pem.Decode(keyBytes) + if keyBlock != nil { + if keyBlock.Type != "EC PRIVATE KEY" { + return nil, fmt.Errorf("private key is not of type EC PRIVATE KEY: %s", keyBlock.Type) + } + keyBytes = keyBlock.Bytes + } + // Try to decode cert as PEM + certBlock, _ := pem.Decode(certBytes) + if certBlock != nil { + if certBlock.Type != "CERTIFICATE" { + return nil, fmt.Errorf("certificate is not of type CERTIFICATE: %s", certBlock.Type) + } + certBytes = certBlock.Bytes + } + + caCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, fmt.Errorf("unable to parse certificate: %w", err) + } + + if !caCert.IsCA { + return nil, fmt.Errorf("certificate is not CA") + } + + caKey, err := x509.ParseECPrivateKey(keyBytes) + if err != nil { + return nil, fmt.Errorf("unable to parse private key: %w", err) + } + + return &CertService{caCert: caCert, caKey: caKey, store: s}, nil +} + +func (cs *CertService) NewClient(id string) ([]byte, []byte, error) { + cert := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().Unix()), + Subject: pkix.Name{ + CommonName: id, + Organization: []string{"ezshare"}, + Country: []string{"No"}, + Locality: []string{"Oslo"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + + certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + certPrivKeyBytes, err := x509.MarshalECPrivateKey(certPrivKey) + if err != nil { + return nil, nil, err + } + + certBytes, err := x509.CreateCertificate(rand.Reader, cert, cs.caCert, &certPrivKey.PublicKey, cs.caKey) + if err != nil { + return nil, nil, err + } + + keyPEM := new(bytes.Buffer) + if err := pem.Encode(keyPEM, &pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: certPrivKeyBytes, + }); err != nil { + return nil, nil, fmt.Errorf("unable to encode client private key: %w", err) + } + certPEM := new(bytes.Buffer) + if err := pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }); err != nil { + return nil, nil, fmt.Errorf("unable to encode client private key: %w", err) + } + + return certPEM.Bytes(), keyPEM.Bytes(), nil +} + +func (cs *CertService) VerifyClient(certBytes []byte) (string, error) { + + // Try to decode cert as PEM + certBlock, _ := pem.Decode(certBytes) + if certBlock != nil { + if certBlock.Type != "CERTIFICATE" { + return "", fmt.Errorf("certificate is not of type CERTIFICATE: %s", certBlock.Type) + } + certBytes = certBlock.Bytes + } + + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return "", fmt.Errorf("unable to parse certificate: %w", err) + } + + rootPool := x509.NewCertPool() + rootPool.AddCert(cs.caCert) + + if _, err := cert.Verify(x509.VerifyOptions{ + Roots: rootPool, + }); err != nil { + return "", fmt.Errorf("unable to verify: %w", err) + } + + return cert.Subject.CommonName, nil +} diff --git a/certs/certservice_test.go b/certs/certservice_test.go new file mode 100644 index 0000000..51f8fd2 --- /dev/null +++ b/certs/certservice_test.go @@ -0,0 +1,84 @@ +package certs_test + +import ( + "crypto/x509" + "encoding/pem" + "gitea.benny.dog/torjus/ezshare/certs" + "gitea.benny.dog/torjus/ezshare/store" + "github.com/google/uuid" + "testing" +) + +func TestCertService(t *testing.T) { + t.Run("TestManualVerifyClientCertificate", func(t *testing.T) { + + s := store.NewMemoryStore() + + caKeyBytes, caCertBytes, err := certs.GenCACert() + if err != nil { + t.Fatalf("Error generating ca cert: %s", err) + } + + svc, err := certs.NewCertService(s, caCertBytes, caKeyBytes) + if err != nil { + t.Fatalf("Unable to create service: %s", err) + } + + clientCertPEM, _, err := svc.NewClient("test") + if err != nil { + t.Fatalf("Unable to create client certificate: %s", err) + } + + caCert, err := x509.ParseCertificate(caCertBytes) + if err != nil { + t.Fatalf("Unable to parse CA certificate: %s", err) + } + certPool := x509.NewCertPool() + certPool.AddCert(caCert) + + clientCertPEMBlock, _ := pem.Decode(clientCertPEM) + if clientCertPEMBlock == nil { + t.Fatalf("Client does not contain PEM-encoded data") + } + if clientCertPEMBlock.Type != "CERTIFICATE" { + t.Fatal("Client cert is not certificate") + } + + clientCert, err := x509.ParseCertificate(clientCertPEMBlock.Bytes) + if err != nil { + t.Fatalf("Could not parse client certificate: %s", err) + } + + if _, err := clientCert.Verify(x509.VerifyOptions{Roots: certPool}); err != nil { + t.Fatalf("Could not verify client certificate: %s", err) + } + }) + t.Run("TestVerifyClientCertificate", func(t *testing.T) { + + s := store.NewMemoryStore() + + caKeyBytes, caCertBytes, err := certs.GenCACert() + if err != nil { + t.Fatalf("Error generating ca cert: %s", err) + } + + svc, err := certs.NewCertService(s, caCertBytes, caKeyBytes) + if err != nil { + t.Fatalf("Unable to create service: %s", err) + } + + clientID := uuid.Must(uuid.NewRandom()).String() + clientCertPEM, _, err := svc.NewClient(clientID) + if err != nil { + t.Fatalf("Unable to create client certificate: %s", err) + } + + id, err := svc.VerifyClient(clientCertPEM) + if err != nil { + t.Fatalf("Failed to verify certificate: %s", err) + } + if id != clientID { + t.Fatalf("Verify returned wrong id. Got %s want %s", id, clientID) + } + }) +} diff --git a/certs/generate.go b/certs/generate.go index 8d698ac..05ae11b 100644 --- a/certs/generate.go +++ b/certs/generate.go @@ -142,7 +142,7 @@ func GenCert(caPub, caPrivKey []byte, dnsNames []string) (priv, pub []byte, err } cert := &x509.Certificate{ - SerialNumber: big.NewInt(1658), + SerialNumber: big.NewInt(time.Now().Unix()), Subject: pkix.Name{ Organization: []string{"ezshare"}, Country: []string{"No"}, diff --git a/config/config.go b/config/config.go index ffe32e6..9928924 100644 --- a/config/config.go +++ b/config/config.go @@ -241,7 +241,7 @@ func (sc *ServerStoreConfig) GetStore() (store.FileStore, func() error, error) { return s, nopCloseFunc, nil } if strings.EqualFold(sc.Type, "memory") { - return store.NewMemoryFileStore(), nopCloseFunc, nil + return store.NewMemoryStore(), nopCloseFunc, nil } return nil, nil, fmt.Errorf("invalid store config") diff --git a/store/bolt.go b/store/bolt.go index e490b96..d1be574 100644 --- a/store/bolt.go +++ b/store/bolt.go @@ -1,6 +1,8 @@ package store import ( + "crypto/ecdsa" + "crypto/x509" "fmt" "gitea.benny.dog/torjus/ezshare/pb" @@ -16,6 +18,8 @@ type BoltStore struct { } var bktKey = []byte("files") +var bktKeyCerts = []byte("certs") +var bktKeyKeys = []byte("keys") func NewBoltStore(path string) (*BoltStore, error) { s := &BoltStore{} @@ -28,6 +32,12 @@ func NewBoltStore(path string) (*BoltStore, error) { if _, err := t.CreateBucketIfNotExists(bktKey); err != nil { return err } + if _, err := t.CreateBucketIfNotExists(bktKeyCerts); err != nil { + return err + } + if _, err := t.CreateBucketIfNotExists(bktKeyKeys); err != nil { + return err + } return nil }) @@ -111,3 +121,81 @@ func (s *BoltStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error) { } return response, nil } + +// Certificate store +var _ CertificateStore = &BoltStore{} + +func (s *BoltStore) GetCertificate(id string) (*x509.Certificate, error) { + var raw []byte + err := s.db.View(func(t *bolt.Tx) error { + bkt := t.Bucket(bktKeyCerts) + + raw = bkt.Get([]byte(id)) + return nil + }) + if err != nil { + return nil, err + } + if raw == nil { + return nil, ErrNoSuchFile + } + + cert, err := x509.ParseCertificate(raw) + if err != nil { + return nil, fmt.Errorf("unable to parse certificate: %w", err) + } + + return cert, nil +} + +func (s *BoltStore) StoreCertificate(id string, cert *x509.Certificate) error { + data := make([]byte, len(cert.Raw)) + copy(data, cert.Raw) + + return s.db.Update(func(t *bolt.Tx) error { + bkt := t.Bucket(bktKeyCerts) + return bkt.Put([]byte(id), cert.Raw) + }) +} + +func (s *BoltStore) GetKey(id string) (*ecdsa.PrivateKey, error) { + var data []byte + err := s.db.View(func(t *bolt.Tx) error { + bkt := t.Bucket(bktKeyKeys) + + data = bkt.Get([]byte(id)) + return nil + }) + if err != nil { + return nil, err + } + + return x509.ParseECPrivateKey(data) +} + +func (s *BoltStore) StoreKey(id string, key *ecdsa.PrivateKey) error { + data, err := x509.MarshalECPrivateKey(key) + if err != nil { + return fmt.Errorf("unable to marshal key: %w", err) + } + return s.db.Update(func(t *bolt.Tx) error { + bkt := t.Bucket(bktKeyKeys) + return bkt.Put([]byte(id), data) + }) +} + +func (s *BoltStore) ListCertificates() ([]string, error) { + var ids []string + err := s.db.View(func(tx *bolt.Tx) error { + bkt := tx.Bucket(bktKeyCerts) + bkt.ForEach(func(k, v []byte) error { + ids = append(ids, string(k)) + return nil + }) + return nil + }) + if err != nil { + return nil, err + } + return ids, nil +} diff --git a/store/bolt_test.go b/store/bolt_test.go index 4502bab..5c32414 100644 --- a/store/bolt_test.go +++ b/store/bolt_test.go @@ -14,5 +14,15 @@ func TestBoltStore(t *testing.T) { t.Fatalf("Error opening store: %s", err) } doFileStoreTest(s, t) - s.Close() + _ = s.Close() +} + +func TestBoltCertificateStore(t *testing.T) { + path := filepath.Join(t.TempDir(), "boltstore.db") + s, err := store.NewBoltStore(path) + if err != nil { + t.Fatalf("Error opening store: %s", err) + } + doCertificateStoreTest(s, t) + _ = s.Close() } diff --git a/store/memory.go b/store/memory.go index b571b33..f544885 100644 --- a/store/memory.go +++ b/store/memory.go @@ -1,24 +1,35 @@ package store import ( + "crypto/ecdsa" + "crypto/x509" "sync" "gitea.benny.dog/torjus/ezshare/pb" "github.com/google/uuid" ) -var _ FileStore = &MemoryFileStore{} +var _ FileStore = &MemoryStore{} +var _ CertificateStore = &MemoryStore{} -type MemoryFileStore struct { +type MemoryStore struct { filesLock sync.RWMutex files map[string]*pb.File + certLock sync.RWMutex + certs map[string][]byte + keyLock sync.RWMutex + keys map[string][]byte } -func NewMemoryFileStore() *MemoryFileStore { - return &MemoryFileStore{files: make(map[string]*pb.File)} +func NewMemoryStore() *MemoryStore { + return &MemoryStore{ + files: make(map[string]*pb.File), + certs: make(map[string][]byte), + keys: make(map[string][]byte), + } } -func (s *MemoryFileStore) GetFile(id string) (*pb.File, error) { +func (s *MemoryStore) GetFile(id string) (*pb.File, error) { s.filesLock.RLock() defer s.filesLock.RUnlock() @@ -29,7 +40,7 @@ func (s *MemoryFileStore) GetFile(id string) (*pb.File, error) { return nil, ErrNoSuchFile } -func (s *MemoryFileStore) StoreFile(file *pb.File) (string, error) { +func (s *MemoryStore) StoreFile(file *pb.File) (string, error) { s.filesLock.Lock() defer s.filesLock.Unlock() @@ -41,7 +52,7 @@ func (s *MemoryFileStore) StoreFile(file *pb.File) (string, error) { return id, nil } -func (s *MemoryFileStore) DeleteFile(id string) error { +func (s *MemoryStore) DeleteFile(id string) error { s.filesLock.Lock() defer s.filesLock.Unlock() if _, ok := s.files[id]; !ok { @@ -52,7 +63,7 @@ func (s *MemoryFileStore) DeleteFile(id string) error { return nil } -func (s *MemoryFileStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error) { +func (s *MemoryStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error) { s.filesLock.RLock() defer s.filesLock.RUnlock() @@ -64,3 +75,62 @@ func (s *MemoryFileStore) ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, err return response, nil } + +func (s *MemoryStore) GetCertificate(id string) (*x509.Certificate, error) { + s.certLock.Lock() + defer s.certLock.Unlock() + + data, ok := s.certs[id] + if !ok { + // TODO: Make separate error, or rename error + return nil, ErrNoSuchFile + } + + return x509.ParseCertificate(data) +} + +func (s *MemoryStore) StoreCertificate(id string, cert *x509.Certificate) error { + s.certLock.Lock() + defer s.certLock.Unlock() + + // Copy cert data + data := make([]byte, len(cert.Raw)) + copy(data, cert.Raw) + + s.certs[id] = data + return nil +} + +func (s *MemoryStore) GetKey(id string) (*ecdsa.PrivateKey, error) { + s.keyLock.RLock() + defer s.keyLock.RUnlock() + data, ok := s.keys[id] + if !ok { + return nil, ErrNoSuchFile + } + + return x509.ParseECPrivateKey(data) +} + +func (s *MemoryStore) StoreKey(id string, key *ecdsa.PrivateKey) error { + s.keyLock.Lock() + defer s.keyLock.Unlock() + + data, err := x509.MarshalECPrivateKey(key) + if err != nil { + return err + } + + s.keys[id] = data + return nil +} + +func (s *MemoryStore) ListCertificates() ([]string, error) { + s.certLock.RLock() + defer s.certLock.RUnlock() + var certIDs []string + for key := range s.certs { + certIDs = append(certIDs, key) + } + return certIDs, nil +} diff --git a/store/memory_test.go b/store/memory_test.go index fcd0091..4e742ae 100644 --- a/store/memory_test.go +++ b/store/memory_test.go @@ -6,7 +6,11 @@ import ( "gitea.benny.dog/torjus/ezshare/store" ) -func TestMemoryStore(t *testing.T) { - s := store.NewMemoryFileStore() +func TestMemoryFileStore(t *testing.T) { + s := store.NewMemoryStore() doFileStoreTest(s, t) } +func TestMemoryCertificateStore(t *testing.T) { + s := store.NewMemoryStore() + doCertificateStoreTest(s, t) +} diff --git a/store/store.go b/store/store.go index 5fab14b..f771ada 100644 --- a/store/store.go +++ b/store/store.go @@ -1,6 +1,8 @@ package store import ( + "crypto/ecdsa" + "crypto/x509" "fmt" "gitea.benny.dog/torjus/ezshare/pb" @@ -14,3 +16,11 @@ type FileStore interface { DeleteFile(id string) error ListFiles() ([]*pb.ListFilesResponse_ListFileInfo, error) } + +type CertificateStore interface { + GetCertificate(id string) (*x509.Certificate, error) + StoreCertificate(id string, cert *x509.Certificate) error + GetKey(id string) (*ecdsa.PrivateKey, error) + StoreKey(id string, key *ecdsa.PrivateKey) error + ListCertificates() ([]string, error) +} diff --git a/store/store_test.go b/store/store_test.go index 656ee7f..bf138d2 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -1,6 +1,12 @@ package store_test import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "math/big" "testing" "time" @@ -63,3 +69,74 @@ func doFileStoreTest(s store.FileStore, t *testing.T) { } }) } + +func doCertificateStoreTest(s store.CertificateStore, t *testing.T) { + t.Run("Basic", func(t *testing.T) { + + // Create cert and key + unsigned := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().Unix()), + Subject: pkix.Name{ + Organization: []string{"ezshare"}, + Country: []string{"No"}, + Locality: []string{"Oslo"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + } + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("unable to create private key: %s", err) + } + certBytes, err := x509.CreateCertificate(rand.Reader, unsigned, unsigned, &privateKey.PublicKey, privateKey) + if err != nil { + t.Fatalf("Error creating cert: %s", err) + } + + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + t.Fatalf("Error parsing created certificate: %s", err) + } + + // Store cert + if err := s.StoreCertificate("cert", cert); err != nil { + t.Fatalf("Error storing cert: %s", err) + } + + // Store key + if err := s.StoreKey("key", privateKey); err != nil { + t.Fatalf("Error storing key: %s", err) + } + + // List + ids, err := s.ListCertificates() + if err != nil { + t.Fatalf("List returned error: %s", err) + } + if len(ids) != 1 { + t.Fatalf("List has wrong length: %s", err) + } + if ids[0] != "cert" { + t.Fatalf("List has wrong id") + } + + retrievedCert, err := s.GetCertificate("cert") + if err != nil { + t.Fatalf("Unable to get certificate from store: %s", err) + } + if !retrievedCert.Equal(cert) { + t.Errorf("Retrieved certificate does not match stored.") + } + + retrievedKey, err := s.GetKey("key") + if err != nil { + t.Fatalf("Unable to get key from store: %s", err) + } + if !retrievedKey.Equal(privateKey) { + t.Errorf("Retrieved key does not match stored.") + } + }) +}