diff --git a/actions/client.go b/actions/client.go index 7b811b3..a5ed700 100644 --- a/actions/client.go +++ b/actions/client.go @@ -28,9 +28,15 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +type contextKey string + +const ( + contextKeyConfig contextKey = "config" + contextKeyClientConnFunc contextKey = "clientConnFunc" +) + func ActionClientGet(c *cli.Context) error { - addr := c.String("addr") - conn, err := grpc.DialContext(c.Context, addr, grpc.WithInsecure()) + conn, err := connFromContext(c) if err != nil { return err } @@ -65,22 +71,7 @@ func ActionClientGet(c *cli.Context) error { } func ActionClientUpload(c *cli.Context) error { - cfg, err := getConfig(c) - if err != nil { - return err - } - - addr := cfg.Client.DefaultServer - if c.IsSet("addr") { - addr = c.String("addr") - } - - clientCreds, err := cfg.Client.Creds() - if err != nil { - return err - } - conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) - + conn, err := connFromContext(c) if err != nil { return err } @@ -118,22 +109,7 @@ func ActionClientUpload(c *cli.Context) error { } func ActionClientList(c *cli.Context) error { - cfg, err := getConfig(c) - if err != nil { - return err - } - - addr := cfg.Client.DefaultServer - if c.IsSet("addr") { - addr = c.String("addr") - } - - clientCreds, err := cfg.Client.Creds() - if err != nil { - return err - } - conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) - + conn, err := connFromContext(c) if err != nil { return err } @@ -153,22 +129,7 @@ func ActionClientList(c *cli.Context) error { } func ActionClientDelete(c *cli.Context) error { - cfg, err := getConfig(c) - if err != nil { - return err - } - - addr := cfg.Client.DefaultServer - if c.IsSet("addr") { - addr = c.String("addr") - } - - clientCreds, err := cfg.Client.Creds() - if err != nil { - return err - } - conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) - + conn, err := connFromContext(c) if err != nil { return err } @@ -362,22 +323,7 @@ func ActionClientLogin(c *cli.Context) error { } func ActionClientPassword(c *cli.Context) error { - cfg, err := getConfig(c) - if err != nil { - return err - } - - addr := cfg.Client.DefaultServer - if c.IsSet("addr") { - addr = c.String("addr") - } - - clientCreds, err := cfg.Client.Creds() - if err != nil { - return err - } - conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) - + conn, err := connFromContext(c) if err != nil { return err } @@ -406,21 +352,7 @@ func ActionClientPassword(c *cli.Context) error { } func ActionClientUpdate(c *cli.Context) error { - cfg, err := getConfig(c) - if err != nil { - return err - } - - addr := cfg.Client.DefaultServer - if c.IsSet("addr") { - addr = c.String("addr") - } - - clientCreds, err := cfg.Client.Creds() - if err != nil { - return err - } - conn, err := grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) + conn, err := connFromContext(c) if err != nil { return err @@ -467,3 +399,40 @@ func ActionClientUpdate(c *cli.Context) error { fmt.Printf("Wrote latest binary to %s", outputPath) return nil } + +func BeforeClient(c *cli.Context) error { + cfg, err := getConfig(c) + if err != nil { + return nil + } + + c.Context = context.WithValue(c.Context, contextKeyConfig, cfg) + + addr := cfg.Client.DefaultServer + if c.IsSet("addr") { + addr = c.String("addr") + } + + clientCreds, err := cfg.Client.Creds() + if err != nil { + return nil + } + connFunc := func() (*grpc.ClientConn, error) { + return grpc.DialContext(c.Context, addr, grpc.WithTransportCredentials(clientCreds)) + } + c.Context = context.WithValue(c.Context, contextKeyClientConnFunc, connFunc) + return nil +} + +func connFromContext(c *cli.Context) (*grpc.ClientConn, error) { + connFunc := c.Context.Value(contextKeyClientConnFunc).(func() (*grpc.ClientConn, error)) + return connFunc() +} + +func configFromContext(c *cli.Context) (*config.Config, error) { + cfg, ok := c.Context.Value(contextKeyConfig).(*config.Config) + if !ok { + return getConfig(c) + } + return cfg, nil +} diff --git a/main.go b/main.go index 7b554b2..2735cdb 100644 --- a/main.go +++ b/main.go @@ -62,6 +62,7 @@ func main() { Usage: "Address of server.", }, }, + Before: actions.BeforeClient, Subcommands: []*cli.Command{ { Name: "get",