2021-12-03 22:51:48 +00:00
|
|
|
package store_test
|
|
|
|
|
|
|
|
import (
|
2021-12-05 00:00:32 +00:00
|
|
|
"crypto/ecdsa"
|
|
|
|
"crypto/elliptic"
|
|
|
|
"crypto/rand"
|
|
|
|
"crypto/x509"
|
|
|
|
"crypto/x509/pkix"
|
|
|
|
"math/big"
|
2021-12-03 22:51:48 +00:00
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
2021-12-05 11:39:28 +00:00
|
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"google.golang.org/protobuf/testing/protocmp"
|
|
|
|
|
2021-12-03 22:51:48 +00:00
|
|
|
"gitea.benny.dog/torjus/ezshare/pb"
|
|
|
|
"gitea.benny.dog/torjus/ezshare/store"
|
|
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
)
|
|
|
|
|
|
|
|
func doFileStoreTest(s store.FileStore, t *testing.T) {
|
|
|
|
t.Run("Basics", func(t *testing.T) {
|
|
|
|
// Create
|
|
|
|
file := &pb.File{
|
|
|
|
Data: []byte("testdata lol!"),
|
|
|
|
Metadata: &pb.File_Metadata{
|
|
|
|
UploadedOn: timestamppb.New(time.Now()),
|
|
|
|
ExpiresOn: timestamppb.New(time.Now().Add(24 * time.Hour)),
|
|
|
|
OriginalFilename: "data.txt",
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
id, err := s.StoreFile(file)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Unable to store file: %s", err)
|
|
|
|
}
|
|
|
|
|
2021-12-04 10:30:42 +00:00
|
|
|
// List
|
|
|
|
list, err := s.ListFiles()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("error listing files: %s", err)
|
|
|
|
}
|
|
|
|
if len(list) != 1 {
|
|
|
|
t.Fatalf("List returned unexpected amount. Got %d want %d", len(list), 1)
|
|
|
|
}
|
|
|
|
|
|
|
|
if list[0].FileId != id {
|
|
|
|
t.Fatalf("List contains wrong id")
|
|
|
|
}
|
|
|
|
|
2021-12-03 22:51:48 +00:00
|
|
|
retrieved, err := s.GetFile(id)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Unable to get file: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(file.Data) != len(retrieved.Data) {
|
|
|
|
t.Fatalf("Mismatch in size between stored and retrieved. Got %d want %d", len(retrieved.Data), len(file.Data))
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := range file.Data {
|
|
|
|
if file.Data[i] != retrieved.Data[i] {
|
|
|
|
t.Fatalf("Mismatch at %d", i)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := s.DeleteFile(id); err != nil {
|
|
|
|
t.Fatalf("Unable to delete file: %s", err)
|
|
|
|
}
|
|
|
|
|
2021-12-05 11:42:22 +00:00
|
|
|
if _, err := s.GetFile(id); err != store.ErrNoSuchItem {
|
2021-12-03 22:51:48 +00:00
|
|
|
t.Fatalf("Getting deleted file returned wrong error: %s", err)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
2021-12-05 00:00:32 +00:00
|
|
|
|
|
|
|
func doCertificateStoreTest(s store.CertificateStore, t *testing.T) {
|
|
|
|
t.Run("Basic", func(t *testing.T) {
|
|
|
|
|
|
|
|
// Create cert and key
|
|
|
|
unsigned := &x509.Certificate{
|
2021-12-06 16:28:48 +00:00
|
|
|
SerialNumber: big.NewInt(time.Now().UnixMilli()),
|
2021-12-05 00:00:32 +00:00
|
|
|
Subject: pkix.Name{
|
|
|
|
Organization: []string{"ezshare"},
|
|
|
|
Country: []string{"No"},
|
|
|
|
Locality: []string{"Oslo"},
|
|
|
|
},
|
|
|
|
NotBefore: time.Now(),
|
|
|
|
NotAfter: time.Now().AddDate(10, 0, 0),
|
|
|
|
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
|
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
|
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
|
|
}
|
|
|
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("unable to create private key: %s", err)
|
|
|
|
}
|
|
|
|
certBytes, err := x509.CreateCertificate(rand.Reader, unsigned, unsigned, &privateKey.PublicKey, privateKey)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Error creating cert: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
cert, err := x509.ParseCertificate(certBytes)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Error parsing created certificate: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Store cert
|
2021-12-06 18:14:39 +00:00
|
|
|
if err := s.StoreCertificate(cert); err != nil {
|
2021-12-05 00:00:32 +00:00
|
|
|
t.Fatalf("Error storing cert: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Store key
|
|
|
|
if err := s.StoreKey("key", privateKey); err != nil {
|
|
|
|
t.Fatalf("Error storing key: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// List
|
|
|
|
ids, err := s.ListCertificates()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("List returned error: %s", err)
|
|
|
|
}
|
|
|
|
if len(ids) != 1 {
|
|
|
|
t.Fatalf("List has wrong length: %s", err)
|
|
|
|
}
|
2021-12-06 18:14:39 +00:00
|
|
|
if ids[0] != cert.SerialNumber.String() {
|
2021-12-05 00:00:32 +00:00
|
|
|
t.Fatalf("List has wrong id")
|
|
|
|
}
|
|
|
|
|
2021-12-06 18:14:39 +00:00
|
|
|
retrievedCert, err := s.GetCertificate(cert.SerialNumber.String())
|
2021-12-05 00:00:32 +00:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Unable to get certificate from store: %s", err)
|
|
|
|
}
|
|
|
|
if !retrievedCert.Equal(cert) {
|
|
|
|
t.Errorf("Retrieved certificate does not match stored.")
|
|
|
|
}
|
|
|
|
|
|
|
|
retrievedKey, err := s.GetKey("key")
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Unable to get key from store: %s", err)
|
|
|
|
}
|
|
|
|
if !retrievedKey.Equal(privateKey) {
|
|
|
|
t.Errorf("Retrieved key does not match stored.")
|
|
|
|
}
|
2021-12-06 16:28:48 +00:00
|
|
|
|
|
|
|
// Revoke
|
|
|
|
isRevoked, err := s.IsRevoked(cert.SerialNumber.String())
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Error checking if certificate is revoked: %s", err)
|
|
|
|
}
|
|
|
|
if isRevoked {
|
|
|
|
t.Fatalf("Unrevoked certificate is revoked")
|
|
|
|
}
|
|
|
|
if err := s.Revoke(cert.SerialNumber.String()); err != nil {
|
|
|
|
t.Fatalf("Error revoking certificate: %s", err)
|
|
|
|
}
|
|
|
|
isRevoked, err = s.IsRevoked(cert.SerialNumber.String())
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Error checking if certificate is revoked: %s", err)
|
|
|
|
}
|
|
|
|
if !isRevoked {
|
|
|
|
t.Fatalf("Revoked certificate is not revoked")
|
|
|
|
}
|
|
|
|
|
2021-12-05 00:00:32 +00:00
|
|
|
})
|
|
|
|
}
|
2021-12-05 10:08:09 +00:00
|
|
|
|
|
|
|
func doUserStoreTests(s store.UserStore, t *testing.T) {
|
|
|
|
t.Run("Basics", func(t *testing.T) {
|
|
|
|
// Store user
|
|
|
|
user := &pb.User{
|
|
|
|
Id: uuid.Must(uuid.NewRandom()).String(),
|
|
|
|
Username: "testuser",
|
|
|
|
UserRole: pb.User_USER,
|
|
|
|
Active: true,
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := s.StoreUser(user); err != nil {
|
|
|
|
t.Fatalf("Error storing user: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
retrieved, err := s.GetUser(user.Id)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Retriving user returned error: %s", err)
|
|
|
|
}
|
|
|
|
|
2021-12-05 11:39:28 +00:00
|
|
|
if diff := cmp.Diff(user, retrieved, protocmp.Transform()); diff != "" {
|
|
|
|
t.Errorf("User retrieved by name difference:\n%v", diff)
|
|
|
|
}
|
|
|
|
|
|
|
|
named, err := s.GetUserByUsername(user.Username)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Retrieving user by username returned error: %s", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if diff := cmp.Diff(user, named, protocmp.Transform()); diff != "" {
|
|
|
|
t.Errorf("User retrieved by name difference:\n%v", diff)
|
2021-12-05 10:08:09 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
list, err := s.ListUsers()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Error listing users")
|
|
|
|
}
|
|
|
|
if len(list) != 1 {
|
|
|
|
t.Fatalf("User list wrong length")
|
|
|
|
}
|
|
|
|
if list[0] != user.Id {
|
|
|
|
t.Fatalf("User list has wrong id")
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|