ezshare/cmd/ezshare.go

364 lines
7.8 KiB
Go

package main
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"time"
"gitea.benny.dog/torjus/ezshare/certs"
"gitea.benny.dog/torjus/ezshare/config"
"gitea.benny.dog/torjus/ezshare/pb"
"gitea.benny.dog/torjus/ezshare/server"
"github.com/urfave/cli/v2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
func main() {
app := cli.App{
Name: "ezshare",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "config",
Usage: "Path to config-file.",
},
},
Commands: []*cli.Command{
{
Name: "serve",
Usage: "Start ezshare server",
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "no-grpc",
Usage: "Do not enable grpc.",
},
&cli.BoolFlag{
Name: "no-http",
Usage: "Do not enable http.",
},
&cli.StringFlag{
Name: "grpc-addr",
Usage: "Address to listen for grpc.",
},
&cli.StringFlag{
Name: "http-addr",
Usage: "Address to listen for http.",
},
&cli.StringFlag{
Name: "hostname",
Usage: "Hostname used in links",
},
},
Action: ActionServe,
},
{
Name: "client",
Usage: "Client commands",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "addr",
Usage: "Address of server.",
},
},
Subcommands: []*cli.Command{
{
Name: "get",
Usage: "Get file with id",
ArgsUsage: "ID [ID]..",
Action: ActionClientGet,
},
{
Name: "upload",
Usage: "Upload file(s)",
ArgsUsage: "PATH [PATH]..",
Action: ActionClientUpload,
},
{
Name: "config-init",
Usage: "Initialize default config",
Action: ActionInitConfig,
},
},
},
{
Name: "cert",
Usage: "Certificate commands",
Subcommands: []*cli.Command{
{
Name: "gen-all",
Usage: "Generate CA, Server and Client certificates",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "out-dir",
Usage: "Directory where certificates will be stored.",
},
&cli.StringFlag{
Name: "hostname",
Usage: "Hostname used for server certificate.",
},
},
Action: ActionGencerts,
},
},
},
},
}
err := app.Run(os.Args)
if err != nil {
log.Printf("Error: %s\n", err)
}
}
func ActionServe(c *cli.Context) error {
cfg, err := getConfig(c)
if err != nil {
return err
}
// Read certificates
srvCertBytes, err := cfg.Server.GRPC.Certs.GetCertBytes()
if err != nil {
return err
}
srvKeyBytes, err := cfg.Server.GRPC.Certs.GetKeyBytes()
if err != nil {
return err
}
caCertBytes, err := cfg.Server.GRPC.CACerts.GetCertBytes()
if err != nil {
return err
}
// Setup store
s, closeFunc, err := cfg.Server.StoreConfig.GetStore()
if err != nil {
return fmt.Errorf("unable to initialize store: %w", err)
}
defer closeFunc()
// Setup shutdown-handling
rootCtx, rootCancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer rootCancel()
// Used to initiate grpc shutdown
grpcCtx, grpcCancel := context.WithCancel(rootCtx)
defer grpcCancel()
// Cancelled once grpc is successfully shut down
grpcShutdownCtx, grpcShutdownCancel := context.WithCancel(context.Background())
defer grpcShutdownCancel()
// Start grpc server
go func() {
grpcAddr := cfg.Server.GRPC.ListenAddr
if c.IsSet("grpc-addr") {
grpcAddr = c.String("grpc-addr")
}
grpcFileServer := server.NewGRPCFileServiceServer(s)
grpcFileServer.Hostname = cfg.Server.Hostname
if c.IsSet("hostname") {
grpcFileServer.Hostname = c.String("hostname")
}
lis, err := net.Listen("tcp", grpcAddr)
if err != nil {
log.Printf("Unable to setup grpc listener: %s\n", err)
rootCancel()
}
srvCert, err := tls.X509KeyPair(srvCertBytes, srvKeyBytes)
if err != nil {
log.Printf("Unable load server certs: %s\n", err)
rootCancel()
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caCertBytes) {
log.Println("Unable to load CA cert")
rootCancel()
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{srvCert},
ClientAuth: tls.RequestClientCert,
ClientCAs: certPool,
}
creds := credentials.NewTLS(tlsConfig)
grpcServer := grpc.NewServer(
grpc.Creds(creds),
)
pb.RegisterFileServiceServer(grpcServer, grpcFileServer)
// wait for cancel
go func() {
<-grpcCtx.Done()
grpcServer.GracefulStop()
}()
log.Printf("Starting grpc server")
if err = grpcServer.Serve(lis); err != nil {
log.Printf("GRPC Shutdown with error: %s\n", err)
rootCancel()
}
log.Println("GRPC Shutdown")
grpcShutdownCancel()
}()
httpCtx, httpCancel := context.WithCancel(rootCtx)
defer httpCancel()
httpShutdownCtx, httpShutdownCancel := context.WithCancel(context.Background())
defer httpShutdownCancel()
// Start http server
go func() {
httpAddr := ":8088"
if c.IsSet("http-addr") {
httpAddr = c.String("http-addr")
}
httpServer := server.NewHTTPSever(s)
httpServer.Addr = httpAddr
// wait for cancel
go func() {
<-httpCtx.Done()
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
httpServer.Shutdown(timeoutCtx)
}()
log.Printf("Starting http server")
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Printf("HTTP Server shutdown with error: %s\n", err)
rootCancel()
}
log.Println("HTTP Shutdown")
httpShutdownCancel()
}()
<-grpcShutdownCtx.Done()
<-httpShutdownCtx.Done()
return nil
}
func ActionClientGet(c *cli.Context) error {
addr := c.String("addr")
conn, err := grpc.DialContext(c.Context, addr, grpc.WithInsecure())
if err != nil {
return err
}
defer conn.Close()
client := pb.NewFileServiceClient(conn)
for _, arg := range c.Args().Slice() {
req := &pb.GetFileRequest{Id: arg}
resp, err := client.GetFile(c.Context, req)
if err != nil {
return err
}
filename := resp.File.FileId
if resp.File.Metadata.OriginalFilename != "" {
filename = resp.File.Metadata.OriginalFilename
}
f, err := os.Create(filename)
if err != nil {
return err
}
defer f.Close()
if _, err := f.Write(resp.File.Data); err != nil {
return err
}
fmt.Printf("Wrote file '%s'\n", filename)
}
return nil
}
func ActionClientUpload(c *cli.Context) error {
cfg, err := getConfig(c)
if err != nil {
return err
}
addr := cfg.Client.DefaultServer
if c.IsSet("addr") {
addr = c.String("addr")
}
clientCreds, err := cfg.Client.Creds()
if err != nil {
return err
}
conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds))
if err != nil {
return err
}
defer conn.Close()
client := pb.NewFileServiceClient(conn)
for _, arg := range c.Args().Slice() {
f, err := os.Open(arg)
if err != nil {
return err
}
data, err := io.ReadAll(f)
if err != nil {
return err
}
req := &pb.UploadFileRequest{Data: data, OriginalFilename: filepath.Base(arg)}
resp, err := client.UploadFile(c.Context, req)
if err != nil {
return err
}
fmt.Printf("%s uploaded with id %s. Available at %s\n", arg, resp.Id, resp.FileUrl)
}
return nil
}
func ActionGencerts(c *cli.Context) error {
outDir := "."
if c.IsSet("out-dir") {
outDir = c.String("out-dir")
}
if !c.IsSet("hostname") {
return fmt.Errorf("--hostname required")
}
hostname := c.String("hostname")
return certs.GenAllCerts(outDir, hostname)
}
func ActionInitConfig(c *cli.Context) error {
defaultCfg := config.FromDefault()
return defaultCfg.ToDefaultFile()
}
func getConfig(c *cli.Context) (*config.Config, error) {
if c.IsSet("config") {
cfgPath := c.String("config")
return config.FromFile(cfgPath)
}
cfg, err := config.FromDefaultLocations()
if err == nil {
fmt.Printf("Config loaded from %s\n", cfg.Location())
fmt.Printf("Config: %+v\n", cfg)
}
return cfg, err
}