package actions import ( "bufio" "context" "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" "io" "net/http" "os" "path/filepath" "syscall" "time" "gitea.benny.dog/torjus/ezshare/certs" "gitea.benny.dog/torjus/ezshare/config" "gitea.benny.dog/torjus/ezshare/pb" "gitea.benny.dog/torjus/ezshare/server" "github.com/urfave/cli/v2" "golang.org/x/term" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) func ActionClientGet(c *cli.Context) error { addr := c.String("addr") conn, err := grpc.DialContext(c.Context, addr, grpc.WithInsecure()) if err != nil { return err } defer conn.Close() client := pb.NewFileServiceClient(conn) for _, arg := range c.Args().Slice() { req := &pb.GetFileRequest{Id: arg} resp, err := client.GetFile(c.Context, req) if err != nil { return err } filename := resp.File.FileId if resp.File.Metadata.OriginalFilename != "" { filename = resp.File.Metadata.OriginalFilename } f, err := os.Create(filename) if err != nil { return err } defer f.Close() if _, err := f.Write(resp.File.Data); err != nil { return err } fmt.Printf("Wrote file '%s'\n", filename) } return nil } func ActionClientUpload(c *cli.Context) error { cfg, err := getConfig(c) if err != nil { return err } addr := cfg.Client.DefaultServer if c.IsSet("addr") { addr = c.String("addr") } clientCreds, err := cfg.Client.Creds() if err != nil { return err } conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) if err != nil { return err } defer conn.Close() client := pb.NewFileServiceClient(conn) for _, arg := range c.Args().Slice() { f, err := os.Open(arg) if err != nil { return err } data, err := io.ReadAll(f) if err != nil { return err } req := &pb.UploadFileRequest{Data: data, OriginalFilename: filepath.Base(arg)} resp, err := client.UploadFile(c.Context, req) if err != nil { return err } fmt.Printf("%s uploaded with id %s. Available at %s\n", arg, resp.Id, resp.FileUrl) } return nil } func ActionClientList(c *cli.Context) error { cfg, err := getConfig(c) if err != nil { return err } addr := cfg.Client.DefaultServer if c.IsSet("addr") { addr = c.String("addr") } clientCreds, err := cfg.Client.Creds() if err != nil { return err } conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) if err != nil { return err } defer conn.Close() client := pb.NewFileServiceClient(conn) resp, err := client.ListFiles(c.Context, &pb.ListFilesRequest{}) if err != nil { return err } for _, elem := range resp.Files { fmt.Println(elem.FileId) } return nil } func ActionClientDelete(c *cli.Context) error { cfg, err := getConfig(c) if err != nil { return err } addr := cfg.Client.DefaultServer if c.IsSet("addr") { addr = c.String("addr") } clientCreds, err := cfg.Client.Creds() if err != nil { return err } conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) if err != nil { return err } defer conn.Close() client := pb.NewFileServiceClient(conn) for _, arg := range c.Args().Slice() { _, err := client.DeleteFile(c.Context, &pb.DeleteFileRequest{Id: arg}) if err != nil { return fmt.Errorf("error deleting file: %w", err) } fmt.Printf("Deleted file %s\n", arg) } return nil } func ActionClientLogin(c *cli.Context) error { // Ensure config-dir exists if err := config.CreateDefaultConfigDir(); err != nil { return err } configFilePath, err := config.DefaultConfigFilePath() if err != nil { return err } // Check if config already exists if _, err := os.Stat(configFilePath); !errors.Is(err, os.ErrNotExist) { if err == nil { if !c.Bool("overwrite") { return cli.Exit("Config already exists. To overwrite run with --overwrite\n", 1) } } else { return err } } configDirPath := filepath.Dir(configFilePath) clientCertPath := filepath.Join(configDirPath, "client.pem") clientKeyPath := filepath.Join(configDirPath, "client.key") serverCertPath := filepath.Join(configDirPath, "server.pem") if c.Args().Len() != 1 { return cli.Exit("Need 1 argument", 1) } // Fetch server certificate ctx, cancel := context.WithTimeout(c.Context, 10*time.Second) defer cancel() certEndpoint := fmt.Sprintf("%s/%s", c.Args().First(), "server.pem") certReq, err := http.NewRequestWithContext(ctx, http.MethodGet, certEndpoint, nil) if err != nil { return cli.Exit(fmt.Sprintf("unable to create http request: %s", err), 1) } verbosePrint(c, fmt.Sprintf("fetching cert from %s", certEndpoint)) certResp, err := http.DefaultClient.Do(certReq) if err != nil { return cli.Exit(fmt.Sprintf("error fetching server cert: %s", err), 1) } defer certResp.Body.Close() serverCert, err := io.ReadAll(certResp.Body) if err != nil { return cli.Exit(fmt.Sprintf("error reading certificate from server: %s", err), 1) } // Fetch metadata mdEndpoint := fmt.Sprintf("%s/%s", c.Args().First(), "metadata") mdReq, err := http.NewRequestWithContext(ctx, http.MethodGet, mdEndpoint, nil) if err != nil { return cli.Exit(fmt.Sprintf("unable to create http request: %s", err), 1) } verbosePrint(c, fmt.Sprintf("fetching metadata from %s", mdEndpoint)) mdResp, err := http.DefaultClient.Do(mdReq) if err != nil { return cli.Exit(fmt.Sprintf("error fetching server cert: %s", err), 1) } defer mdResp.Body.Close() decoder := json.NewDecoder(mdResp.Body) var md server.MetadataResponse if err := decoder.Decode(&md); err != nil { return cli.Exit(fmt.Sprintf("unable to decode metadata response: %s", err), 1) } // Prompt for username and password scanner := bufio.NewScanner(os.Stdin) fmt.Printf("username: ") scanner.Scan() username := scanner.Text() fmt.Printf("password: ") passwordBytes, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1) } password := string(passwordBytes) // Setup certificate credentials certPool := x509.NewCertPool() if !certPool.AppendCertsFromPEM(serverCert) { return cli.Exit(fmt.Sprintf("unable to use server certificate: %s", err), 1) } // Generate temporary self-signed cert keyBytes, certBytes, err := certs.GenCACert() if err != nil { return cli.Exit(fmt.Sprintf("unable to generate self-signed certificate: %s", err), 1) } keyPem, err := certs.ToPEM(keyBytes, "EC PRIVATE KEY") if err != nil { return cli.Exit(fmt.Sprintf("unable to pem-encode key: %s", err), 1) } certPem, err := certs.ToPEM(certBytes, "CERTIFICATE") if err != nil { return cli.Exit(fmt.Sprintf("unable to pem-encode key: %s", err), 1) } clientCert, err := tls.X509KeyPair(certPem, keyPem) if err != nil { return cli.Exit(fmt.Sprintf("unable to use self-signed certificate: %s", err), 1) } creds := credentials.NewTLS(&tls.Config{RootCAs: certPool, Certificates: []tls.Certificate{clientCert}}) // Connect to grpc-endpoint verbosePrint(c, fmt.Sprintf("dialing grpc at %s", md.GRPCEndpoint)) conn, err := grpc.DialContext(c.Context, md.GRPCEndpoint, grpc.WithTransportCredentials(creds)) if err != nil { return err } defer conn.Close() client := pb.NewUserServiceClient(conn) resp, err := client.Login(c.Context, &pb.LoginUserRequest{Username: username, Password: password}) if err != nil { return err } // Write key to file verbosePrint(c, fmt.Sprintf("Writing client certificate key to %s", clientKeyPath)) keyFile, err := os.Create(clientKeyPath) if err != nil { return cli.Exit(fmt.Sprintf("unable create file for client key: %s", err), 1) } defer keyFile.Close() if _, err := keyFile.Write(resp.ClientKey); err != nil { return cli.Exit(fmt.Sprintf("unable write client key to file: %s", err), 1) } // Write client cert to file verbosePrint(c, fmt.Sprintf("Writing client certificate to %s", clientKeyPath)) clientCertFile, err := os.Create(clientCertPath) if err != nil { return cli.Exit(fmt.Sprintf("unable create file for client cert: %s", err), 1) } defer clientCertFile.Close() if _, err := clientCertFile.Write(resp.ClientCert); err != nil { return cli.Exit(fmt.Sprintf("unable write client cert to file: %s", err), 1) } // Write server cer to file verbosePrint(c, fmt.Sprintf("Writing server certificate to %s", serverCertPath)) serverCertFile, err := os.Create(serverCertPath) if err != nil { return cli.Exit(fmt.Sprintf("unable create file for client key: %s", err), 1) } defer serverCertFile.Close() if _, err := serverCertFile.Write(serverCert); err != nil { return cli.Exit(fmt.Sprintf("unable write client cert to file: %s", err), 1) } // Write config cfg := config.FromDefault() cfg.Client.Certs.CertificateKeyPath = clientKeyPath cfg.Client.Certs.CertificatePath = clientCertPath cfg.Client.DefaultServer = md.GRPCEndpoint cfg.Client.ServerCertPath = serverCertPath verbosePrint(c, fmt.Sprintf("Writing config to %s", configFilePath)) f, err := os.Create(configFilePath) if err != nil { return cli.Exit(fmt.Sprintf("Unable to open config-file for writing: %s", err), 1) } defer f.Close() if err := cfg.ToWriter(f); err != nil { return cli.Exit(fmt.Sprintf("Unable write config to file: %s", err), 1) } return nil } func ActionClientChangePassword(c *cli.Context) error { cfg, err := getConfig(c) if err != nil { return err } addr := cfg.Client.DefaultServer if c.IsSet("addr") { addr = c.String("addr") } clientCreds, err := cfg.Client.Creds() if err != nil { return err } conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) if err != nil { return err } defer conn.Close() fmt.Printf("current password: ") oldPasswordBytes, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1) } fmt.Println() oldPassword := string(oldPasswordBytes) fmt.Printf("new password: ") newPasswordBytes, err := term.ReadPassword(int(syscall.Stdin)) if err != nil { return cli.Exit(fmt.Sprintf("unable to read password: %s", err), 1) } fmt.Println() newPassword := string(newPasswordBytes) client := pb.NewUserServiceClient(conn) if _, err := client.ChangePassword(c.Context, &pb.ChangePasswordRequest{OldPassword: oldPassword, NewPassword: newPassword}); err != nil { return cli.Exit(fmt.Sprintf("unable to change password: %s", err), 1) } return nil } func ActionClientCertList(c *cli.Context) error { cfg, err := getConfig(c) if err != nil { return err } addr := cfg.Client.DefaultServer if c.IsSet("addr") { addr = c.String("addr") } clientCreds, err := cfg.Client.Creds() if err != nil { return err } conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) if err != nil { return err } defer conn.Close() client := pb.NewCertificateServiceClient(conn) resp, err := client.ListCertificates(c.Context, &pb.Empty{}) if err != nil { return cli.Exit(fmt.Sprintf("unable to list certificates: %s", err), 1) } for _, info := range resp.Certificates { fmt.Printf("%s - %s", info.Serial, info.OwnerUsername) } return nil } func ActionClientCertRevoke(c *cli.Context) error { if c.Args().Len() < 1 { return cli.Exit("need at least 1 argument", 1) } cfg, err := getConfig(c) if err != nil { return err } addr := cfg.Client.DefaultServer if c.IsSet("addr") { addr = c.String("addr") } clientCreds, err := cfg.Client.Creds() if err != nil { return err } conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) if err != nil { return err } defer conn.Close() client := pb.NewCertificateServiceClient(conn) for _, serial := range c.Args().Slice() { if _, err := client.RevokeCertificate(c.Context, &pb.RevokeCertificateRequest{Serial: serial}); err != nil { fmt.Printf("Revoked %s\n", serial) } } return nil }