diff --git a/api/http.go b/api/http.go index e85d662..0485261 100644 --- a/api/http.go +++ b/api/http.go @@ -4,7 +4,9 @@ import ( "encoding/json" "io" "net/http" + "strconv" "strings" + "time" "git.t-juice.club/torjus/gpaste" "git.t-juice.club/torjus/gpaste/files" @@ -62,10 +64,6 @@ func (s *HTTPServer) HandlerIndex(w http.ResponseWriter, r *http.Request) { } func (s *HTTPServer) HandlerAPIFilePost(w http.ResponseWriter, r *http.Request) { - f := &files.File{ - ID: uuid.Must(uuid.NewRandom()).String(), - Body: r.Body, - } reqID := middleware.GetReqID(r.Context()) // Check if multipart form @@ -74,6 +72,11 @@ func (s *HTTPServer) HandlerAPIFilePost(w http.ResponseWriter, r *http.Request) s.processMultiPartFormUpload(w, r) return } + + f := fileFromParams(r) + f.ID = uuid.NewString() + f.Body = r.Body + err := s.Files.Store(f) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -149,11 +152,10 @@ func (s *HTTPServer) processMultiPartFormUpload(w http.ResponseWriter, r *http.R s.Logger.Warnw("Error reading file from multipart form.", "req_id", reqID, "error", err) return } - f := &files.File{ - ID: uuid.Must(uuid.NewRandom()).String(), - OriginalFilename: fh.Filename, - Body: ff, - } + f := fileFromParams(r) + f.ID = uuid.NewString() + f.OriginalFilename = fh.Filename + f.Body = ff if err := s.Files.Store(f); err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -251,3 +253,29 @@ func (s *HTTPServer) HandlerAPIUserList(w http.ResponseWriter, r *http.Request) s.Logger.Warnw("Error encoding response.", "req_id", "error", err) } } + +func fileFromParams(r *http.Request) *files.File { + const ( + keyMaxViews = "max_views" + keyExpiresOn = "exp" + ) + var f files.File + + q := r.URL.Query() + + if q.Has(keyMaxViews) { + views, err := strconv.ParseUint(q.Get(keyMaxViews), 10, 64) + if err == nil { + f.MaxViews = uint(views) + } + } + + if q.Has(keyExpiresOn) { + exp, err := time.Parse(time.RFC3339, q.Get(keyExpiresOn)) + if err == nil { + f.ExpiresOn = exp + } + } + + return &f +} diff --git a/api/http_test.go b/api/http_test.go index a826040..e1e277c 100644 --- a/api/http_test.go +++ b/api/http_test.go @@ -64,7 +64,7 @@ func TestHandlers(t *testing.T) { } mw.Close() - req := httptest.NewRequest(http.MethodPost, "/api/file", buf) + req := httptest.NewRequest(http.MethodPost, "/api/file?max_views=99", buf) req.Header.Add("Content-Type", mw.FormDataContentType()) hs.Handler.ServeHTTP(rr, req) @@ -103,6 +103,10 @@ func TestHandlers(t *testing.T) { if diff := cmp.Diff(retBuf.String(), expectedData); diff != "" { t.Errorf("Retrieved file mismatch: %s", diff) } + + if retrieved.MaxViews != 99 { + t.Errorf("Uploaded file has wrong max_views: %d", retrieved.MaxViews) + } }) // GET /api/file/id t.Run("GET", func(t *testing.T) {