package actions import ( "bufio" "context" "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" "io" "net/http" "os" "path/filepath" "runtime" "syscall" "time" "gitea.benny.dog/torjus/ezshare/certs" "gitea.benny.dog/torjus/ezshare/config" "gitea.benny.dog/torjus/ezshare/ezshare" "gitea.benny.dog/torjus/ezshare/pb" "gitea.benny.dog/torjus/ezshare/server" "github.com/urfave/cli/v2" "golang.org/x/term" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/protobuf/types/known/timestamppb" ) type contextKey string const ( contextKeyConfig contextKey = "config" contextKeyClientConnFunc contextKey = "clientConnFunc" ) func ActionClientGet(c *cli.Context) error { conn, err := connFromContext(c) if err != nil { return err } defer conn.Close() client := pb.NewFileServiceClient(conn) for _, arg := range c.Args().Slice() { req := &pb.GetFileRequest{Id: arg} resp, err := client.GetFile(c.Context, req) if err != nil { return err } filename := resp.File.FileId if resp.File.Metadata.OriginalFilename != "" { filename = resp.File.Metadata.OriginalFilename } f, err := os.Create(filename) if err != nil { return err } defer f.Close() if _, err := f.Write(resp.File.Data); err != nil { return err } fmt.Printf("Wrote file '%s'\n", filename) } return nil } func ActionClientUpload(c *cli.Context) error { conn, err := connFromContext(c) if err != nil { return err } defer conn.Close() client := pb.NewFileServiceClient(conn) for _, arg := range c.Args().Slice() { f, err := os.Open(arg) if err != nil { return err } data, err := io.ReadAll(f) if err != nil { return err } req := &pb.UploadFileRequest{Data: data, OriginalFilename: filepath.Base(arg)} req.WithPasscode = c.Bool("with-passcode") req.MaxViews = c.Int64("max-views") if c.IsSet("ttl") { duration := c.Duration("ttl") expiresOn := time.Now().Add(duration) req.ExpiresOn = timestamppb.New(expiresOn) } resp, err := client.UploadFile(c.Context, req) if err != nil { return err } fmt.Printf("%s uploaded with id %s. Available at:\n%s\n", arg, resp.Id, resp.FileUrl) } return nil } func ActionClientList(c *cli.Context) error { conn, err := connFromContext(c) if err != nil { return err } defer conn.Close() client := pb.NewFileServiceClient(conn) resp, err := client.ListFiles(c.Context, &pb.ListFilesRequest{}) if err != nil { return err } for _, elem := range resp.Files { fmt.Println(elem.FileId) } return nil } func ActionClientDelete(c *cli.Context) error { conn, err := connFromContext(c) if err != nil { return err } defer conn.Close() client := pb.NewFileServiceClient(conn) for _, arg := range c.Args().Slice() { _, err := client.DeleteFile(c.Context, &pb.DeleteFileRequest{Id: arg}) if err != nil { return fmt.Errorf("error deleting file: %w", err) } fmt.Printf("Deleted file %s\n", arg) } return nil } func ActionClientLogin(c *cli.Context) error { // Ensure config-dir exists if err := config.CreateDefaultConfigDir(); err != nil { return err } configFilePath, err := config.DefaultConfigFilePath() if err != nil { return err } // Check if config already exists if _, err := os.Stat(configFilePath); !errors.Is(err, os.ErrNotExist) { if err == nil { if !c.Bool("overwrite") { return cli.Exit("Config already exists. To overwrite run with --overwrite\n", 1) } } else { return err } } configDirPath := filepath.Dir(configFilePath) clientCertPath := filepath.Join(configDirPath, "client.pem") clientKeyPath := filepath.Join(configDirPath, "client.key") serverCertPath := filepath.Join(configDirPath, "server.pem") if c.Args().Len() != 1 { return cli.Exit("Need 1 argument", 1) } // Fetch server certificate ctx, cancel := context.WithTimeout(c.Context, 10*time.Second) defer cancel() certEndpoint := fmt.Sprintf("%s/%s", c.Args().First(), "server.pem") certReq, err := http.NewRequestWithContext(ctx, http.MethodGet, certEndpoint, nil) if err != nil { return cli.Exit(fmt.Sprintf("unable to create http request: %s", err), 1) } verbosePrint(c, fmt.Sprintf("fetching cert from %s", certEndpoint)) certResp, err := http.DefaultClient.Do(certReq) if err != nil { return cli.Exit(fmt.Sprintf("error fetching server cert: %s", err), 1) } defer certResp.Body.Close() serverCert, err := io.ReadAll(certResp.Body) if err != nil { return cli.Exit(fmt.Sprintf("error reading certificate from server: %s", err), 1) } // Fetch metadata mdEndpoint := fmt.Sprintf("%s/%s", c.Args().First(), "metadata") mdReq, err := http.NewRequestWithContext(ctx, http.MethodGet, mdEndpoint, nil) if err != nil { return cli.Exit(fmt.Sprintf("unable to create http request: %s", err), 1) } verbosePrint(c, fmt.Sprintf("fetching metadata from %s", mdEndpoint)) mdResp, err := http.DefaultClient.Do(mdReq) if err != nil { return cli.Exit(fmt.Sprintf("error fetching server cert: %s", err), 1) } defer mdResp.Body.Close() decoder := json.NewDecoder(mdResp.Body) var md server.MetadataResponse if err := decoder.Decode(&md); err != nil { return cli.Exit(fmt.Sprintf("unable to decode metadata response: %s", err), 1) } // Prompt for username and password scanner := bufio.NewScanner(os.Stdin) fmt.Printf("username: ") scanner.Scan() username := scanner.Text() fmt.Printf("password: ") passwordBytes, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1) } password := string(passwordBytes) // Setup certificate credentials certPool := x509.NewCertPool() if !certPool.AppendCertsFromPEM(serverCert) { return cli.Exit(fmt.Sprintf("unable to use server certificate: %s", err), 1) } // Generate temporary self-signed cert keyBytes, certBytes, err := certs.GenCACert() if err != nil { return cli.Exit(fmt.Sprintf("unable to generate self-signed certificate: %s", err), 1) } keyPem, err := certs.ToPEM(keyBytes, "EC PRIVATE KEY") if err != nil { return cli.Exit(fmt.Sprintf("unable to pem-encode key: %s", err), 1) } certPem, err := certs.ToPEM(certBytes, "CERTIFICATE") if err != nil { return cli.Exit(fmt.Sprintf("unable to pem-encode key: %s", err), 1) } clientCert, err := tls.X509KeyPair(certPem, keyPem) if err != nil { return cli.Exit(fmt.Sprintf("unable to use self-signed certificate: %s", err), 1) } creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, Certificates: []tls.Certificate{clientCert}}) // Connect to grpc-endpoint verbosePrint(c, fmt.Sprintf("dialing grpc at %s", md.GRPCEndpoint)) conn, err := grpc.DialContext(c.Context, md.GRPCEndpoint, grpc.WithTransportCredentials(creds)) if err != nil { return err } defer conn.Close() client := pb.NewUserServiceClient(conn) resp, err := client.Login(c.Context, &pb.LoginUserRequest{Username: username, Password: password}) if err != nil { return err } // Write key to file verbosePrint(c, fmt.Sprintf("Writing client certificate key to %s", clientKeyPath)) keyFile, err := os.Create(clientKeyPath) if err != nil { return cli.Exit(fmt.Sprintf("unable create file for client key: %s", err), 1) } defer keyFile.Close() if _, err := keyFile.Write(resp.ClientKey); err != nil { return cli.Exit(fmt.Sprintf("unable write client key to file: %s", err), 1) } // Write client cert to file verbosePrint(c, fmt.Sprintf("Writing client certificate to %s", clientKeyPath)) clientCertFile, err := os.Create(clientCertPath) if err != nil { return cli.Exit(fmt.Sprintf("unable create file for client cert: %s", err), 1) } defer clientCertFile.Close() if _, err := clientCertFile.Write(resp.ClientCert); err != nil { return cli.Exit(fmt.Sprintf("unable write client cert to file: %s", err), 1) } // Write server cer to file verbosePrint(c, fmt.Sprintf("Writing server certificate to %s", serverCertPath)) serverCertFile, err := os.Create(serverCertPath) if err != nil { return cli.Exit(fmt.Sprintf("unable create file for client key: %s", err), 1) } defer serverCertFile.Close() if _, err := serverCertFile.Write(serverCert); err != nil { return cli.Exit(fmt.Sprintf("unable write client cert to file: %s", err), 1) } // Write config cfg := config.FromDefault() cfg.Client.Certs.CertificateKeyPath = clientKeyPath cfg.Client.Certs.CertificatePath = clientCertPath cfg.Client.DefaultServer = md.GRPCEndpoint cfg.Client.ServerCertPath = serverCertPath verbosePrint(c, fmt.Sprintf("Writing config to %s", configFilePath)) f, err := os.Create(configFilePath) if err != nil { return cli.Exit(fmt.Sprintf("Unable to open config-file for writing: %s", err), 1) } defer f.Close() if err := cfg.ToWriter(f); err != nil { return cli.Exit(fmt.Sprintf("Unable write config to file: %s", err), 1) } return nil } func ActionClientPassword(c *cli.Context) error { conn, err := connFromContext(c) if err != nil { return err } defer conn.Close() fmt.Printf("current password: ") oldPasswordBytes, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1) } fmt.Println() oldPassword := string(oldPasswordBytes) fmt.Printf("new password: ") newPasswordBytes, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1) } fmt.Println() newPassword := string(newPasswordBytes) client := pb.NewUserServiceClient(conn) if _, err := client.ChangePassword(c.Context, &pb.ChangePasswordRequest{OldPassword: oldPassword, NewPassword: newPassword}); err != nil { return cli.Exit(fmt.Sprintf("unable to change password: %s", err), 1) } return nil } func ActionClientUpdate(c *cli.Context) error { conn, err := connFromContext(c) if err != nil { return err } defer conn.Close() client := pb.NewBinaryServiceClient(conn) // Check if we have the latest version resp, err := client.GetLatestVersion(c.Context, &pb.Empty{}) if err != nil { return cli.Exit(fmt.Sprintf("Error getting latest version: %s", err), 1) } if resp.Version == ezshare.Version { fmt.Println("Already running latest version.") return nil } // Fetch latest version bin, err := client.GetBinary(c.Context, &pb.GetBinaryRequest{Version: "latest", Arch: runtime.GOARCH, Os: runtime.GOOS}) if err != nil { return cli.Exit(fmt.Sprintf("Error getting binary: %s", err), 1) } outDir := "." if c.IsSet("out-dir") { outDir = c.String("out-dir") } filename := fmt.Sprintf("ezshare-%s-%s-%s", bin.Version[1:], bin.Os, bin.Arch) if runtime.GOOS == "windows" { filename = fmt.Sprintf("%s.exe", filename) } outputPath := filepath.Join(outDir, filename) f, err := os.Create(outputPath) if err != nil { return cli.Exit(fmt.Sprintf("Unable to write latest binary: %s", err), 1) } if _, err := f.Write(bin.Data); err != nil { return cli.Exit(fmt.Sprintf("Unable to write latest binary: %s", err), 1) } fmt.Printf("Wrote latest binary to %s", outputPath) return nil } func BeforeClient(c *cli.Context) error { cfg, err := getConfig(c) if err != nil { return nil } c.Context = context.WithValue(c.Context, contextKeyConfig, cfg) addr := cfg.Client.DefaultServer if c.IsSet("addr") { addr = c.String("addr") } clientCreds, err := cfg.Client.Creds() if err != nil { return nil } connFunc := func() (*grpc.ClientConn, error) { return grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) } c.Context = context.WithValue(c.Context, contextKeyClientConnFunc, connFunc) return nil } func connFromContext(c *cli.Context) (*grpc.ClientConn, error) { connFunc := c.Context.Value(contextKeyClientConnFunc).(func() (*grpc.ClientConn, error)) return connFunc() } func configFromContext(c *cli.Context) (*config.Config, error) { cfg, ok := c.Context.Value(contextKeyConfig).(*config.Config) if !ok { return getConfig(c) } return cfg, nil }