diff --git a/files/filestore_fs_test.go b/files/filestore_fs_test.go index 08cf765..7653ddc 100644 --- a/files/filestore_fs_test.go +++ b/files/filestore_fs_test.go @@ -7,20 +7,23 @@ import ( ) func TestFSFileStore(t *testing.T) { - dir := t.TempDir() - s, err := files.NewFSFileStore(dir) - if err != nil { - t.Fatalf("Error creating store: %s", err) + newFunc := func() files.FileStore { + dir := t.TempDir() + s, err := files.NewFSFileStore(dir) + if err != nil { + t.Fatalf("Error creating store: %s", err) + } + return s } - RunFilestoreTest(s, t) + RunFilestoreTest(newFunc, t) persistentDir := t.TempDir() - newFunc := func() files.FileStore { + persistentFunc := func() files.FileStore { s, err := files.NewFSFileStore(persistentDir) if err != nil { t.Fatalf("Error creating store: %s", err) } return s } - RunPersistentFilestoreTest(newFunc, t) + RunPersistentFilestoreTest(persistentFunc, t) } diff --git a/files/filestore_memory.go b/files/filestore_memory.go index 9788d20..66090f1 100644 --- a/files/filestore_memory.go +++ b/files/filestore_memory.go @@ -57,11 +57,16 @@ func (s *MemoryFileStore) Get(id string) (*File, error) { return nil, fmt.Errorf("no such item") } + body := new(bytes.Buffer) + if _, err := body.Write(fd.Body.Bytes()); err != nil { + return nil, err + } + f := &File{ ID: fd.ID, MaxViews: fd.MaxViews, ExpiresOn: fd.ExpiresOn, - Body: io.NopCloser(&fd.Body), + Body: io.NopCloser(body), FileSize: fd.FileSize, } diff --git a/files/filestore_memory_test.go b/files/filestore_memory_test.go index 7e2b9d8..7790d6d 100644 --- a/files/filestore_memory_test.go +++ b/files/filestore_memory_test.go @@ -7,7 +7,9 @@ import ( ) func TestMemoryFileStore(t *testing.T) { - s := files.NewMemoryFileStore() + newFunc := func() files.FileStore { + return files.NewMemoryFileStore() + } - RunFilestoreTest(s, t) + RunFilestoreTest(newFunc, t) } diff --git a/files/filestore_test.go b/files/filestore_test.go index a3e9c4d..81cbf2c 100644 --- a/files/filestore_test.go +++ b/files/filestore_test.go @@ -12,8 +12,11 @@ import ( "github.com/google/uuid" ) -func RunFilestoreTest(s files.FileStore, t *testing.T) { +var ignoreBody = cmp.FilterPath(func(p cmp.Path) bool { return p.String() == "Body" }, cmp.Ignore()) + +func RunFilestoreTest(newStoreFunc func() files.FileStore, t *testing.T) { t.Run("Basic", func(t *testing.T) { + s := newStoreFunc() // Create dataString := "TEST_LOL_OMG" id := uuid.Must(uuid.NewRandom()).String() @@ -58,7 +61,6 @@ func RunFilestoreTest(s files.FileStore, t *testing.T) { FileSize: int64(len(dataString)), } - ignoreBody := cmp.FilterPath(func(p cmp.Path) bool { return p.String() == "Body" }, cmp.Ignore()) if diff := cmp.Diff(retrieved, expected, ignoreBody); diff != "" { t.Errorf("File comparison failed: %s", diff) } @@ -88,84 +90,128 @@ func RunFilestoreTest(s files.FileStore, t *testing.T) { t.Fatalf("List after delete has wrong length: %d", len(ids)) } }) + t.Run("MultipleGet", func(t *testing.T) { + s := newStoreFunc() + + fileContents := "multiple get test !" + body := io.NopCloser(strings.NewReader(fileContents)) + file := &files.File{ + ID: uuid.NewString(), + OriginalFilename: "multiple.txt", + MaxViews: 999, + ExpiresOn: time.Now().Add(1 * time.Hour), + Body: body, + FileSize: int64(len(fileContents)), + } + + if err := s.Store(file); err != nil { + t.Fatalf("Error storing file: %s", err) + } + + first, err := s.Get(file.ID) + if err != nil { + t.Errorf("Error retrieving first file: %s", err) + } + + firstBody := new(bytes.Buffer) + io.Copy(firstBody, first.Body) + first.Body.Close() + + if diff := cmp.Diff(firstBody.String(), fileContents); diff != "" { + t.Fatalf("File contents mismatch: %s", diff) + } + + second, err := s.Get(file.ID) + if err != nil { + t.Errorf("Error retrieving first file: %s", err) + } + + secondBody := new(bytes.Buffer) + io.Copy(secondBody, second.Body) + first.Body.Close() + + if diff := cmp.Diff(secondBody.String(), fileContents); diff != "" { + t.Fatalf("File contents mismatch: %s", diff) + } + }) } func RunPersistentFilestoreTest(newStoreFunc func() files.FileStore, t *testing.T) { - s := newStoreFunc() + t.Run("Basics", func(t *testing.T) { + s := newStoreFunc() - files := []struct { - File *files.File - ExpectedData string - }{ - { - File: &files.File{ - ID: uuid.NewString(), - OriginalFilename: "testfile.txt", - MaxViews: 5, - ExpiresOn: time.Now().Add(10 * time.Minute), - Body: io.NopCloser(strings.NewReader("cocks!")), - FileSize: 6, + files := []struct { + File *files.File + ExpectedData string + }{ + { + File: &files.File{ + ID: uuid.NewString(), + OriginalFilename: "testfile.txt", + MaxViews: 5, + ExpiresOn: time.Now().Add(10 * time.Minute), + Body: io.NopCloser(strings.NewReader("cocks!")), + FileSize: 6, + }, + ExpectedData: "cocks!", }, - ExpectedData: "cocks!", - }, - { - File: &files.File{ - ID: uuid.NewString(), - OriginalFilename: "testfile2.txt", - MaxViews: 5, - ExpiresOn: time.Now().Add(10 * time.Minute), - Body: io.NopCloser(strings.NewReader("derps!")), - FileSize: 6, + { + File: &files.File{ + ID: uuid.NewString(), + OriginalFilename: "testfile2.txt", + MaxViews: 5, + ExpiresOn: time.Now().Add(10 * time.Minute), + Body: io.NopCloser(strings.NewReader("derps!")), + FileSize: 6, + }, + ExpectedData: "derps!", }, - ExpectedData: "derps!", - }, - } - - for _, f := range files { - err := s.Store(f.File) - if err != nil { - t.Fatalf("Error storing file: %s", err) - } - } - for _, f := range files { - retrieved, err := s.Get(f.File.ID) - if err != nil { - t.Fatalf("Unable to retrieve file: %s", err) } - ignoreBody := cmp.FilterPath(func(p cmp.Path) bool { return p.String() == "Body" }, cmp.Ignore()) - if !cmp.Equal(retrieved, f.File, ignoreBody) { - t.Errorf("Mismatch: %s", cmp.Diff(retrieved, f.File)) + for _, f := range files { + err := s.Store(f.File) + if err != nil { + t.Fatalf("Error storing file: %s", err) + } } - buf := new(strings.Builder) - if _, err := io.Copy(buf, retrieved.Body); err != nil { - t.Fatalf("Error reading from body: %s", err) - } - retrieved.Body.Close() - if buf.String() != f.ExpectedData { - t.Fatalf("Data does not match. %s", cmp.Diff(buf.String(), f.ExpectedData)) - } - } + for _, f := range files { + retrieved, err := s.Get(f.File.ID) + if err != nil { + t.Fatalf("Unable to retrieve file: %s", err) + } - // Reopen store, and fetch again - s = newStoreFunc() - for _, f := range files { - retrieved, err := s.Get(f.File.ID) - if err != nil { - t.Fatalf("Unable to retrieve file: %s", err) + if !cmp.Equal(retrieved, f.File, ignoreBody) { + t.Errorf("Mismatch: %s", cmp.Diff(retrieved, f.File)) + } + buf := new(strings.Builder) + if _, err := io.Copy(buf, retrieved.Body); err != nil { + t.Fatalf("Error reading from body: %s", err) + } + retrieved.Body.Close() + if buf.String() != f.ExpectedData { + t.Fatalf("Data does not match. %s", cmp.Diff(buf.String(), f.ExpectedData)) + } } - ignoreBody := cmp.FilterPath(func(p cmp.Path) bool { return p.String() == "Body" }, cmp.Ignore()) - if !cmp.Equal(retrieved, f.File, ignoreBody) { - t.Errorf("Mismatch: %s", cmp.Diff(retrieved, f.File)) + // Reopen store, and fetch again + s = newStoreFunc() + for _, f := range files { + retrieved, err := s.Get(f.File.ID) + if err != nil { + t.Fatalf("Unable to retrieve file: %s", err) + } + + if !cmp.Equal(retrieved, f.File, ignoreBody) { + t.Errorf("Mismatch: %s", cmp.Diff(retrieved, f.File)) + } + buf := new(strings.Builder) + if _, err := io.Copy(buf, retrieved.Body); err != nil { + t.Fatalf("Error reading from body: %s", err) + } + retrieved.Body.Close() + if buf.String() != f.ExpectedData { + t.Fatalf("Data does not match. %s", cmp.Diff(buf.String(), f.ExpectedData)) + } } - buf := new(strings.Builder) - if _, err := io.Copy(buf, retrieved.Body); err != nil { - t.Fatalf("Error reading from body: %s", err) - } - retrieved.Body.Close() - if buf.String() != f.ExpectedData { - t.Fatalf("Data does not match. %s", cmp.Diff(buf.String(), f.ExpectedData)) - } - } + }) }