ezshare/actions/client.go
2022-01-13 18:40:15 +01:00

439 lines
12 KiB
Go

package actions
import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"syscall"
"time"
"git.t-juice.club/torjus/ezshare/certs"
"git.t-juice.club/torjus/ezshare/config"
"git.t-juice.club/torjus/ezshare/ezshare"
"git.t-juice.club/torjus/ezshare/pb"
"git.t-juice.club/torjus/ezshare/server"
"github.com/urfave/cli/v2"
"golang.org/x/term"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/protobuf/types/known/timestamppb"
)
type contextKey string
const (
contextKeyConfig contextKey = "config"
contextKeyClientConnFunc contextKey = "clientConnFunc"
)
func ActionClientGet(c *cli.Context) error {
conn, err := connFromContext(c)
if err != nil {
return err
}
defer conn.Close()
client := pb.NewFileServiceClient(conn)
for _, arg := range c.Args().Slice() {
req := &pb.GetFileRequest{Id: arg}
resp, err := client.GetFile(c.Context, req)
if err != nil {
return err
}
filename := resp.File.FileId
if resp.File.Metadata.OriginalFilename != "" {
filename = resp.File.Metadata.OriginalFilename
}
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
if _, err := f.Write(resp.File.Data); err != nil {
return err
}
fmt.Printf("Wrote file '%s'\n", filename)
}
return nil
}
func ActionClientUpload(c *cli.Context) error {
conn, err := connFromContext(c)
if err != nil {
return err
}
defer conn.Close()
client := pb.NewFileServiceClient(conn)
for _, arg := range c.Args().Slice() {
f, err := os.Open(arg)
if err != nil {
return err
}
data, err := io.ReadAll(f)
if err != nil {
return err
}
req := &pb.UploadFileRequest{Data: data, OriginalFilename: filepath.Base(arg)}
req.WithPasscode = c.Bool("with-passcode")
req.MaxViews = c.Int64("max-views")
if c.IsSet("ttl") {
duration := c.Duration("ttl")
expiresOn := time.Now().Add(duration)
req.ExpiresOn = timestamppb.New(expiresOn)
}
resp, err := client.UploadFile(c.Context, req)
if err != nil {
return err
}
fmt.Printf("%s uploaded with id %s. Available at:\n%s\n", arg, resp.Id, resp.FileUrl)
}
return nil
}
func ActionClientList(c *cli.Context) error {
conn, err := connFromContext(c)
if err != nil {
return err
}
defer conn.Close()
client := pb.NewFileServiceClient(conn)
resp, err := client.ListFiles(c.Context, &pb.ListFilesRequest{})
if err != nil {
return err
}
for _, elem := range resp.Files {
fmt.Println(elem.FileId)
}
return nil
}
func ActionClientDelete(c *cli.Context) error {
conn, err := connFromContext(c)
if err != nil {
return err
}
defer conn.Close()
client := pb.NewFileServiceClient(conn)
for _, arg := range c.Args().Slice() {
_, err := client.DeleteFile(c.Context, &pb.DeleteFileRequest{Id: arg})
if err != nil {
return fmt.Errorf("error deleting file: %w", err)
}
fmt.Printf("Deleted file %s\n", arg)
}
return nil
}
func ActionClientLogin(c *cli.Context) error {
// Ensure config-dir exists
if err := config.CreateDefaultConfigDir(); err != nil {
return err
}
configFilePath, err := config.DefaultConfigFilePath()
if err != nil {
return err
}
// Check if config already exists
if _, err := os.Stat(configFilePath); !errors.Is(err, os.ErrNotExist) {
if err == nil {
if !c.Bool("overwrite") {
return cli.Exit("Config already exists. To overwrite run with --overwrite\n", 1)
}
} else {
return err
}
}
configDirPath := filepath.Dir(configFilePath)
clientCertPath := filepath.Join(configDirPath, "client.pem")
clientKeyPath := filepath.Join(configDirPath, "client.key")
serverCertPath := filepath.Join(configDirPath, "server.pem")
if c.Args().Len() != 1 {
return cli.Exit("Need 1 argument", 1)
}
// Fetch server certificate
ctx, cancel := context.WithTimeout(c.Context, 10*time.Second)
defer cancel()
certEndpoint := fmt.Sprintf("%s/%s", c.Args().First(), "server.pem")
certReq, err := http.NewRequestWithContext(ctx, http.MethodGet, certEndpoint, nil)
if err != nil {
return cli.Exit(fmt.Sprintf("unable to create http request: %s", err), 1)
}
verbosePrint(c, fmt.Sprintf("fetching cert from %s", certEndpoint))
certResp, err := http.DefaultClient.Do(certReq)
if err != nil {
return cli.Exit(fmt.Sprintf("error fetching server cert: %s", err), 1)
}
defer certResp.Body.Close()
serverCert, err := io.ReadAll(certResp.Body)
if err != nil {
return cli.Exit(fmt.Sprintf("error reading certificate from server: %s", err), 1)
}
// Fetch metadata
mdEndpoint := fmt.Sprintf("%s/%s", c.Args().First(), "metadata")
mdReq, err := http.NewRequestWithContext(ctx, http.MethodGet, mdEndpoint, nil)
if err != nil {
return cli.Exit(fmt.Sprintf("unable to create http request: %s", err), 1)
}
verbosePrint(c, fmt.Sprintf("fetching metadata from %s", mdEndpoint))
mdResp, err := http.DefaultClient.Do(mdReq)
if err != nil {
return cli.Exit(fmt.Sprintf("error fetching server cert: %s", err), 1)
}
defer mdResp.Body.Close()
decoder := json.NewDecoder(mdResp.Body)
var md server.MetadataResponse
if err := decoder.Decode(&md); err != nil {
return cli.Exit(fmt.Sprintf("unable to decode metadata response: %s", err), 1)
}
// Prompt for username and password
scanner := bufio.NewScanner(os.Stdin)
fmt.Printf("username: ")
scanner.Scan()
username := scanner.Text()
fmt.Printf("password: ")
passwordBytes, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1)
}
password := string(passwordBytes)
// Setup certificate credentials
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(serverCert) {
return cli.Exit(fmt.Sprintf("unable to use server certificate: %s", err), 1)
}
// Generate temporary self-signed cert
keyBytes, certBytes, err := certs.GenCACert()
if err != nil {
return cli.Exit(fmt.Sprintf("unable to generate self-signed certificate: %s", err), 1)
}
keyPem, err := certs.ToPEM(keyBytes, "EC PRIVATE KEY")
if err != nil {
return cli.Exit(fmt.Sprintf("unable to pem-encode key: %s", err), 1)
}
certPem, err := certs.ToPEM(certBytes, "CERTIFICATE")
if err != nil {
return cli.Exit(fmt.Sprintf("unable to pem-encode key: %s", err), 1)
}
clientCert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
return cli.Exit(fmt.Sprintf("unable to use self-signed certificate: %s", err), 1)
}
creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, Certificates: []tls.Certificate{clientCert}})
// Connect to grpc-endpoint
verbosePrint(c, fmt.Sprintf("dialing grpc at %s", md.GRPCEndpoint))
conn, err := grpc.DialContext(c.Context, md.GRPCEndpoint, grpc.WithTransportCredentials(creds))
if err != nil {
return err
}
defer conn.Close()
client := pb.NewUserServiceClient(conn)
resp, err := client.Login(c.Context, &pb.LoginUserRequest{Username: username, Password: password})
if err != nil {
return err
}
// Write key to file
verbosePrint(c, fmt.Sprintf("Writing client certificate key to %s", clientKeyPath))
keyFile, err := os.Create(clientKeyPath)
if err != nil {
return cli.Exit(fmt.Sprintf("unable create file for client key: %s", err), 1)
}
defer keyFile.Close()
if _, err := keyFile.Write(resp.ClientKey); err != nil {
return cli.Exit(fmt.Sprintf("unable write client key to file: %s", err), 1)
}
// Write client cert to file
verbosePrint(c, fmt.Sprintf("Writing client certificate to %s", clientKeyPath))
clientCertFile, err := os.Create(clientCertPath)
if err != nil {
return cli.Exit(fmt.Sprintf("unable create file for client cert: %s", err), 1)
}
defer clientCertFile.Close()
if _, err := clientCertFile.Write(resp.ClientCert); err != nil {
return cli.Exit(fmt.Sprintf("unable write client cert to file: %s", err), 1)
}
// Write server cer to file
verbosePrint(c, fmt.Sprintf("Writing server certificate to %s", serverCertPath))
serverCertFile, err := os.Create(serverCertPath)
if err != nil {
return cli.Exit(fmt.Sprintf("unable create file for client key: %s", err), 1)
}
defer serverCertFile.Close()
if _, err := serverCertFile.Write(serverCert); err != nil {
return cli.Exit(fmt.Sprintf("unable write client cert to file: %s", err), 1)
}
// Write config
cfg := config.FromDefault()
cfg.Client.Certs.CertificateKeyPath = clientKeyPath
cfg.Client.Certs.CertificatePath = clientCertPath
cfg.Client.DefaultServer = md.GRPCEndpoint
cfg.Client.ServerCertPath = serverCertPath
verbosePrint(c, fmt.Sprintf("Writing config to %s", configFilePath))
f, err := os.Create(configFilePath)
if err != nil {
return cli.Exit(fmt.Sprintf("Unable to open config-file for writing: %s", err), 1)
}
defer f.Close()
if err := cfg.ToWriter(f); err != nil {
return cli.Exit(fmt.Sprintf("Unable write config to file: %s", err), 1)
}
return nil
}
func ActionClientPassword(c *cli.Context) error {
conn, err := connFromContext(c)
if err != nil {
return err
}
defer conn.Close()
fmt.Printf("current password: ")
oldPasswordBytes, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1)
}
fmt.Println()
oldPassword := string(oldPasswordBytes)
fmt.Printf("new password: ")
newPasswordBytes, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1)
}
fmt.Println()
newPassword := string(newPasswordBytes)
client := pb.NewUserServiceClient(conn)
if _, err := client.ChangePassword(c.Context, &pb.ChangePasswordRequest{OldPassword: oldPassword, NewPassword: newPassword}); err != nil {
return cli.Exit(fmt.Sprintf("unable to change password: %s", err), 1)
}
return nil
}
func ActionClientUpdate(c *cli.Context) error {
conn, err := connFromContext(c)
if err != nil {
return err
}
defer conn.Close()
client := pb.NewBinaryServiceClient(conn)
// Check if we have the latest version
resp, err := client.GetLatestVersion(c.Context, &pb.Empty{})
if err != nil {
return cli.Exit(fmt.Sprintf("Error getting latest version: %s", err), 1)
}
if resp.Version == ezshare.Version {
fmt.Println("Already running latest version.")
return nil
}
// Fetch latest version
bin, err := client.GetBinary(c.Context, &pb.GetBinaryRequest{Version: "latest", Arch: runtime.GOARCH, Os: runtime.GOOS})
if err != nil {
return cli.Exit(fmt.Sprintf("Error getting binary: %s", err), 1)
}
outDir := "."
if c.IsSet("out-dir") {
outDir = c.String("out-dir")
}
filename := fmt.Sprintf("ezshare-%s-%s-%s", bin.Version[1:], bin.Os, bin.Arch)
if runtime.GOOS == "windows" {
filename = fmt.Sprintf("%s.exe", filename)
}
outputPath := filepath.Join(outDir, filename)
f, err := os.Create(outputPath)
if err != nil {
return cli.Exit(fmt.Sprintf("Unable to write latest binary: %s", err), 1)
}
if _, err := f.Write(bin.Data); err != nil {
return cli.Exit(fmt.Sprintf("Unable to write latest binary: %s", err), 1)
}
fmt.Printf("Wrote latest binary to %s", outputPath)
return nil
}
func BeforeClient(c *cli.Context) error {
cfg, err := getConfig(c)
if err != nil {
return nil
}
c.Context = context.WithValue(c.Context, contextKeyConfig, cfg)
addr := cfg.Client.DefaultServer
if c.IsSet("addr") {
addr = c.String("addr")
}
clientCreds, err := cfg.Client.Creds()
if err != nil {
return nil
}
connFunc := func() (*grpc.ClientConn, error) {
return grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds))
}
c.Context = context.WithValue(c.Context, contextKeyClientConnFunc, connFunc)
return nil
}
func connFromContext(c *cli.Context) (*grpc.ClientConn, error) {
connFunc := c.Context.Value(contextKeyClientConnFunc).(func() (*grpc.ClientConn, error))
return connFunc()
}
func configFromContext(c *cli.Context) (*config.Config, error) {
cfg, ok := c.Context.Value(contextKeyConfig).(*config.Config)
if !ok {
return getConfig(c)
}
return cfg, nil
}