ezshare/store/store_test.go

192 lines
4.7 KiB
Go

package store_test
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"google.golang.org/protobuf/testing/protocmp"
"gitea.benny.dog/torjus/ezshare/pb"
"gitea.benny.dog/torjus/ezshare/store"
"google.golang.org/protobuf/types/known/timestamppb"
)
func doFileStoreTest(s store.FileStore, t *testing.T) {
t.Run("Basics", func(t *testing.T) {
// Create
file := &pb.File{
Data: []byte("testdata lol!"),
Metadata: &pb.File_Metadata{
UploadedOn: timestamppb.New(time.Now()),
ExpiresOn: timestamppb.New(time.Now().Add(24 * time.Hour)),
OriginalFilename: "data.txt",
},
}
id, err := s.StoreFile(file)
if err != nil {
t.Fatalf("Unable to store file: %s", err)
}
// List
list, err := s.ListFiles()
if err != nil {
t.Fatalf("error listing files: %s", err)
}
if len(list) != 1 {
t.Fatalf("List returned unexpected amount. Got %d want %d", len(list), 1)
}
if list[0].FileId != id {
t.Fatalf("List contains wrong id")
}
retrieved, err := s.GetFile(id)
if err != nil {
t.Fatalf("Unable to get file: %s", err)
}
if len(file.Data) != len(retrieved.Data) {
t.Fatalf("Mismatch in size between stored and retrieved. Got %d want %d", len(retrieved.Data), len(file.Data))
}
for i := range file.Data {
if file.Data[i] != retrieved.Data[i] {
t.Fatalf("Mismatch at %d", i)
}
}
if err := s.DeleteFile(id); err != nil {
t.Fatalf("Unable to delete file: %s", err)
}
if _, err := s.GetFile(id); err != store.ErrNoSuchItem {
t.Fatalf("Getting deleted file returned wrong error: %s", err)
}
})
}
func doCertificateStoreTest(s store.CertificateStore, t *testing.T) {
t.Run("Basic", func(t *testing.T) {
// Create cert and key
unsigned := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().Unix()),
Subject: pkix.Name{
Organization: []string{"ezshare"},
Country: []string{"No"},
Locality: []string{"Oslo"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
}
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("unable to create private key: %s", err)
}
certBytes, err := x509.CreateCertificate(rand.Reader, unsigned, unsigned, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatalf("Error creating cert: %s", err)
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
t.Fatalf("Error parsing created certificate: %s", err)
}
// Store cert
if err := s.StoreCertificate("cert", cert); err != nil {
t.Fatalf("Error storing cert: %s", err)
}
// Store key
if err := s.StoreKey("key", privateKey); err != nil {
t.Fatalf("Error storing key: %s", err)
}
// List
ids, err := s.ListCertificates()
if err != nil {
t.Fatalf("List returned error: %s", err)
}
if len(ids) != 1 {
t.Fatalf("List has wrong length: %s", err)
}
if ids[0] != "cert" {
t.Fatalf("List has wrong id")
}
retrievedCert, err := s.GetCertificate("cert")
if err != nil {
t.Fatalf("Unable to get certificate from store: %s", err)
}
if !retrievedCert.Equal(cert) {
t.Errorf("Retrieved certificate does not match stored.")
}
retrievedKey, err := s.GetKey("key")
if err != nil {
t.Fatalf("Unable to get key from store: %s", err)
}
if !retrievedKey.Equal(privateKey) {
t.Errorf("Retrieved key does not match stored.")
}
})
}
func doUserStoreTests(s store.UserStore, t *testing.T) {
t.Run("Basics", func(t *testing.T) {
// Store user
user := &pb.User{
Id: uuid.Must(uuid.NewRandom()).String(),
Username: "testuser",
UserRole: pb.User_USER,
Active: true,
}
if err := s.StoreUser(user); err != nil {
t.Fatalf("Error storing user: %s", err)
}
retrieved, err := s.GetUser(user.Id)
if err != nil {
t.Fatalf("Retriving user returned error: %s", err)
}
if diff := cmp.Diff(user, retrieved, protocmp.Transform()); diff != "" {
t.Errorf("User retrieved by name difference:\n%v", diff)
}
named, err := s.GetUserByUsername(user.Username)
if err != nil {
t.Fatalf("Retrieving user by username returned error: %s", err)
}
if diff := cmp.Diff(user, named, protocmp.Transform()); diff != "" {
t.Errorf("User retrieved by name difference:\n%v", diff)
}
list, err := s.ListUsers()
if err != nil {
t.Fatalf("Error listing users")
}
if len(list) != 1 {
t.Fatalf("User list wrong length")
}
if list[0] != user.Id {
t.Fatalf("User list has wrong id")
}
})
}