package store_test import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "math/big" "testing" "time" "gitea.benny.dog/torjus/ezshare/pb" "gitea.benny.dog/torjus/ezshare/store" "google.golang.org/protobuf/types/known/timestamppb" ) func doFileStoreTest(s store.FileStore, t *testing.T) { t.Run("Basics", func(t *testing.T) { // Create file := &pb.File{ Data: []byte("testdata lol!"), Metadata: &pb.File_Metadata{ UploadedOn: timestamppb.New(time.Now()), ExpiresOn: timestamppb.New(time.Now().Add(24 * time.Hour)), OriginalFilename: "data.txt", }, } id, err := s.StoreFile(file) if err != nil { t.Fatalf("Unable to store file: %s", err) } // List list, err := s.ListFiles() if err != nil { t.Fatalf("error listing files: %s", err) } if len(list) != 1 { t.Fatalf("List returned unexpected amount. Got %d want %d", len(list), 1) } if list[0].FileId != id { t.Fatalf("List contains wrong id") } retrieved, err := s.GetFile(id) if err != nil { t.Fatalf("Unable to get file: %s", err) } if len(file.Data) != len(retrieved.Data) { t.Fatalf("Mismatch in size between stored and retrieved. Got %d want %d", len(retrieved.Data), len(file.Data)) } for i := range file.Data { if file.Data[i] != retrieved.Data[i] { t.Fatalf("Mismatch at %d", i) } } if err := s.DeleteFile(id); err != nil { t.Fatalf("Unable to delete file: %s", err) } if _, err := s.GetFile(id); err != store.ErrNoSuchFile { t.Fatalf("Getting deleted file returned wrong error: %s", err) } }) } func doCertificateStoreTest(s store.CertificateStore, t *testing.T) { t.Run("Basic", func(t *testing.T) { // Create cert and key unsigned := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().Unix()), Subject: pkix.Name{ Organization: []string{"ezshare"}, Country: []string{"No"}, Locality: []string{"Oslo"}, }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(10, 0, 0), SubjectKeyId: []byte{1, 2, 3, 4, 6}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, } privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("unable to create private key: %s", err) } certBytes, err := x509.CreateCertificate(rand.Reader, unsigned, unsigned, &privateKey.PublicKey, privateKey) if err != nil { t.Fatalf("Error creating cert: %s", err) } cert, err := x509.ParseCertificate(certBytes) if err != nil { t.Fatalf("Error parsing created certificate: %s", err) } // Store cert if err := s.StoreCertificate("cert", cert); err != nil { t.Fatalf("Error storing cert: %s", err) } // Store key if err := s.StoreKey("key", privateKey); err != nil { t.Fatalf("Error storing key: %s", err) } // List ids, err := s.ListCertificates() if err != nil { t.Fatalf("List returned error: %s", err) } if len(ids) != 1 { t.Fatalf("List has wrong length: %s", err) } if ids[0] != "cert" { t.Fatalf("List has wrong id") } retrievedCert, err := s.GetCertificate("cert") if err != nil { t.Fatalf("Unable to get certificate from store: %s", err) } if !retrievedCert.Equal(cert) { t.Errorf("Retrieved certificate does not match stored.") } retrievedKey, err := s.GetKey("key") if err != nil { t.Fatalf("Unable to get key from store: %s", err) } if !retrievedKey.Equal(privateKey) { t.Errorf("Retrieved key does not match stored.") } }) }