package store_test import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "math/big" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/uuid" "google.golang.org/protobuf/testing/protocmp" "gitea.benny.dog/torjus/ezshare/pb" "gitea.benny.dog/torjus/ezshare/store" "google.golang.org/protobuf/types/known/timestamppb" ) func doFileStoreTest(s store.FileStore, t *testing.T) { t.Run("Basics", func(t *testing.T) { // Create file := &pb.File{ Data: []byte("testdata lol!"), Metadata: &pb.File_Metadata{ UploadedOn: timestamppb.New(time.Now()), ExpiresOn: timestamppb.New(time.Now().Add(24 * time.Hour)), OriginalFilename: "data.txt", }, } id, err := s.StoreFile(file) if err != nil { t.Fatalf("Unable to store file: %s", err) } // List list, err := s.ListFiles() if err != nil { t.Fatalf("error listing files: %s", err) } if len(list) != 1 { t.Fatalf("List returned unexpected amount. Got %d want %d", len(list), 1) } if list[0].FileId != id { t.Fatalf("List contains wrong id") } retrieved, err := s.GetFile(id) if err != nil { t.Fatalf("Unable to get file: %s", err) } if len(file.Data) != len(retrieved.Data) { t.Fatalf("Mismatch in size between stored and retrieved. Got %d want %d", len(retrieved.Data), len(file.Data)) } for i := range file.Data { if file.Data[i] != retrieved.Data[i] { t.Fatalf("Mismatch at %d", i) } } if err := s.DeleteFile(id); err != nil { t.Fatalf("Unable to delete file: %s", err) } if _, err := s.GetFile(id); err != store.ErrNoSuchItem { t.Fatalf("Getting deleted file returned wrong error: %s", err) } }) } func doCertificateStoreTest(s store.CertificateStore, t *testing.T) { t.Run("Basic", func(t *testing.T) { // Create cert and key unsigned := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().UnixMilli()), Subject: pkix.Name{ Organization: []string{"ezshare"}, Country: []string{"No"}, Locality: []string{"Oslo"}, }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(10, 0, 0), SubjectKeyId: []byte{1, 2, 3, 4, 6}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, } privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("unable to create private key: %s", err) } certBytes, err := x509.CreateCertificate(rand.Reader, unsigned, unsigned, &privateKey.PublicKey, privateKey) if err != nil { t.Fatalf("Error creating cert: %s", err) } cert, err := x509.ParseCertificate(certBytes) if err != nil { t.Fatalf("Error parsing created certificate: %s", err) } // Store cert if err := s.StoreCertificate(cert); err != nil { t.Fatalf("Error storing cert: %s", err) } // Store key if err := s.StoreKey("key", privateKey); err != nil { t.Fatalf("Error storing key: %s", err) } // List ids, err := s.ListCertificates() if err != nil { t.Fatalf("List returned error: %s", err) } if len(ids) != 1 { t.Fatalf("List has wrong length: %s", err) } if ids[0] != cert.SerialNumber.String() { t.Fatalf("List has wrong id") } retrievedCert, err := s.GetCertificate(cert.SerialNumber.String()) if err != nil { t.Fatalf("Unable to get certificate from store: %s", err) } if !retrievedCert.Equal(cert) { t.Errorf("Retrieved certificate does not match stored.") } retrievedKey, err := s.GetKey("key") if err != nil { t.Fatalf("Unable to get key from store: %s", err) } if !retrievedKey.Equal(privateKey) { t.Errorf("Retrieved key does not match stored.") } // Revoke isRevoked, err := s.IsRevoked(cert.SerialNumber.String()) if err != nil { t.Fatalf("Error checking if certificate is revoked: %s", err) } if isRevoked { t.Fatalf("Unrevoked certificate is revoked") } if err := s.Revoke(cert.SerialNumber.String()); err != nil { t.Fatalf("Error revoking certificate: %s", err) } isRevoked, err = s.IsRevoked(cert.SerialNumber.String()) if err != nil { t.Fatalf("Error checking if certificate is revoked: %s", err) } if !isRevoked { t.Fatalf("Revoked certificate is not revoked") } }) } func doUserStoreTests(s store.UserStore, t *testing.T) { t.Run("Basics", func(t *testing.T) { // Store user user := &pb.User{ Id: uuid.Must(uuid.NewRandom()).String(), Username: "testuser", UserRole: pb.User_USER, Active: true, } if err := s.StoreUser(user); err != nil { t.Fatalf("Error storing user: %s", err) } retrieved, err := s.GetUser(user.Id) if err != nil { t.Fatalf("Retriving user returned error: %s", err) } if diff := cmp.Diff(user, retrieved, protocmp.Transform()); diff != "" { t.Errorf("User retrieved by name difference:\n%v", diff) } named, err := s.GetUserByUsername(user.Username) if err != nil { t.Fatalf("Retrieving user by username returned error: %s", err) } if diff := cmp.Diff(user, named, protocmp.Transform()); diff != "" { t.Errorf("User retrieved by name difference:\n%v", diff) } list, err := s.ListUsers() if err != nil { t.Fatalf("Error listing users") } if len(list) != 1 { t.Fatalf("User list wrong length") } if list[0] != user.Id { t.Fatalf("User list has wrong id") } }) } func doBinaryStoreTests(s store.BinaryStore, t *testing.T) { winBinary := &pb.Binary{ Version: "v1.0.0", Arch: "amd64", Os: "windows", Data: []byte("WINDOWS"), } linuxBinary := &pb.Binary{ Version: "v1.0.0", Arch: "arm", Os: "linux", Data: []byte("LINUXLOL"), } // Store if err := s.StoreBinary(winBinary); err != nil { t.Fatalf("Error storing binary: %s", err) } if err := s.StoreBinary(linuxBinary); err != nil { t.Fatalf("Error storing binary: %s", err) } // Get linux, err := s.GetBinary(linuxBinary.Version, linuxBinary.Os, linuxBinary.Arch) if err != nil { t.Fatalf("Error geting linux binary: %s", err) } windows, err := s.GetBinary(winBinary.Version, winBinary.Os, winBinary.Arch) if err != nil { t.Fatalf("Error geting linux binary: %s", err) } if string(linux.Data) != string(linuxBinary.Data) { t.Fatalf("Data is not the same in linux-binary") } if string(windows.Data) != string(winBinary.Data) { t.Fatalf("Data is not the same in linux-binary") } // List list, err := s.List() if err != nil { t.Fatalf("List returned error: %s", err) } var lFound, wFound bool for _, item := range list { if item == "ezshare-1.0.0-linux-arm" { lFound = true } if item == "ezshare-1.0.0-windows-amd64" { wFound = true } } if !lFound { t.Fatalf("Linux binary not in list. %+v", list) } if !wFound { t.Fatalf("Windows binary not in list. %+v", list) } }