package store_test

import (
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/x509"
	"crypto/x509/pkix"
	"math/big"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	"github.com/google/uuid"
	"google.golang.org/protobuf/testing/protocmp"

	"gitea.benny.dog/torjus/ezshare/pb"
	"gitea.benny.dog/torjus/ezshare/store"
	"google.golang.org/protobuf/types/known/timestamppb"
)

func doFileStoreTest(s store.FileStore, t *testing.T) {
	t.Run("Basics", func(t *testing.T) {
		// Create
		file := &pb.File{
			Data: []byte("testdata lol!"),
			Metadata: &pb.File_Metadata{
				UploadedOn:       timestamppb.New(time.Now()),
				ExpiresOn:        timestamppb.New(time.Now().Add(24 * time.Hour)),
				OriginalFilename: "data.txt",
			},
		}

		id, err := s.StoreFile(file)
		if err != nil {
			t.Fatalf("Unable to store file: %s", err)
		}

		// List
		list, err := s.ListFiles()
		if err != nil {
			t.Fatalf("error listing files: %s", err)
		}
		if len(list) != 1 {
			t.Fatalf("List returned unexpected amount. Got %d want %d", len(list), 1)
		}

		if list[0].FileId != id {
			t.Fatalf("List contains wrong id")
		}

		retrieved, err := s.GetFile(id)
		if err != nil {
			t.Fatalf("Unable to get file: %s", err)
		}

		if len(file.Data) != len(retrieved.Data) {
			t.Fatalf("Mismatch in size between stored and retrieved. Got %d want %d", len(retrieved.Data), len(file.Data))
		}

		for i := range file.Data {
			if file.Data[i] != retrieved.Data[i] {
				t.Fatalf("Mismatch at %d", i)
			}
		}

		if err := s.DeleteFile(id); err != nil {
			t.Fatalf("Unable to delete file: %s", err)
		}

		if _, err := s.GetFile(id); err != store.ErrNoSuchItem {
			t.Fatalf("Getting deleted file returned wrong error: %s", err)
		}
	})
}

func doCertificateStoreTest(s store.CertificateStore, t *testing.T) {
	t.Run("Basic", func(t *testing.T) {

		// Create cert and key
		unsigned := &x509.Certificate{
			SerialNumber: big.NewInt(time.Now().UnixMilli()),
			Subject: pkix.Name{
				Organization: []string{"ezshare"},
				Country:      []string{"No"},
				Locality:     []string{"Oslo"},
			},
			NotBefore:    time.Now(),
			NotAfter:     time.Now().AddDate(10, 0, 0),
			SubjectKeyId: []byte{1, 2, 3, 4, 6},
			ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
			KeyUsage:     x509.KeyUsageDigitalSignature,
		}
		privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
		if err != nil {
			t.Fatalf("unable to create private key: %s", err)
		}
		certBytes, err := x509.CreateCertificate(rand.Reader, unsigned, unsigned, &privateKey.PublicKey, privateKey)
		if err != nil {
			t.Fatalf("Error creating cert: %s", err)
		}

		cert, err := x509.ParseCertificate(certBytes)
		if err != nil {
			t.Fatalf("Error parsing created certificate: %s", err)
		}

		// Store cert
		if err := s.StoreCertificate(cert); err != nil {
			t.Fatalf("Error storing cert: %s", err)
		}

		// Store key
		if err := s.StoreKey("key", privateKey); err != nil {
			t.Fatalf("Error storing key: %s", err)
		}

		// List
		ids, err := s.ListCertificates()
		if err != nil {
			t.Fatalf("List returned error: %s", err)
		}
		if len(ids) != 1 {
			t.Fatalf("List has wrong length: %s", err)
		}
		if ids[0] != cert.SerialNumber.String() {
			t.Fatalf("List has wrong id")
		}

		retrievedCert, err := s.GetCertificate(cert.SerialNumber.String())
		if err != nil {
			t.Fatalf("Unable to get certificate from store: %s", err)
		}
		if !retrievedCert.Equal(cert) {
			t.Errorf("Retrieved certificate does not match stored.")
		}

		retrievedKey, err := s.GetKey("key")
		if err != nil {
			t.Fatalf("Unable to get key from store: %s", err)
		}
		if !retrievedKey.Equal(privateKey) {
			t.Errorf("Retrieved key does not match stored.")
		}

		// Revoke
		isRevoked, err := s.IsRevoked(cert.SerialNumber.String())
		if err != nil {
			t.Fatalf("Error checking if certificate is revoked: %s", err)
		}
		if isRevoked {
			t.Fatalf("Unrevoked certificate is revoked")
		}
		if err := s.Revoke(cert.SerialNumber.String()); err != nil {
			t.Fatalf("Error revoking certificate: %s", err)
		}
		isRevoked, err = s.IsRevoked(cert.SerialNumber.String())
		if err != nil {
			t.Fatalf("Error checking if certificate is revoked: %s", err)
		}
		if !isRevoked {
			t.Fatalf("Revoked certificate is not revoked")
		}

	})
}

func doUserStoreTests(s store.UserStore, t *testing.T) {
	t.Run("Basics", func(t *testing.T) {
		// Store user
		user := &pb.User{
			Id:       uuid.Must(uuid.NewRandom()).String(),
			Username: "testuser",
			UserRole: pb.User_USER,
			Active:   true,
		}

		if err := s.StoreUser(user); err != nil {
			t.Fatalf("Error storing user: %s", err)
		}

		retrieved, err := s.GetUser(user.Id)
		if err != nil {
			t.Fatalf("Retriving user returned error: %s", err)
		}

		if diff := cmp.Diff(user, retrieved, protocmp.Transform()); diff != "" {
			t.Errorf("User retrieved by name difference:\n%v", diff)
		}

		named, err := s.GetUserByUsername(user.Username)
		if err != nil {
			t.Fatalf("Retrieving user by username returned error: %s", err)
		}

		if diff := cmp.Diff(user, named, protocmp.Transform()); diff != "" {
			t.Errorf("User retrieved by name difference:\n%v", diff)
		}

		list, err := s.ListUsers()
		if err != nil {
			t.Fatalf("Error listing users")
		}
		if len(list) != 1 {
			t.Fatalf("User list wrong length")
		}
		if list[0] != user.Id {
			t.Fatalf("User list has wrong id")
		}
	})
}