package config import ( "crypto/tls" "crypto/x509" "errors" "fmt" "io" "io/fs" "io/ioutil" "os" "path/filepath" "github.com/pelletier/go-toml" "google.golang.org/grpc/credentials" ) type Config struct { LogLevel string `toml:"LogLevel"` Server *ServerConfig `toml:"Server"` Client *ClientConfig `toml:"Client"` location string } type CertificatePaths struct { CertificatePath string `toml:"CertificatePath"` CertificateKeyPath string `toml:"CertificateKeyPath"` } type ServerConfig struct { GRPC *ServerGRPCConfig `toml:"GRPC"` HTTP *ServerHTTPConfig `toml:"HTTP"` } type ServerStoreConfig struct { Type string `toml:"Type"` FSStoreConfig *FSStoreConfig `toml:"Filesystem"` } type FSStoreConfig struct { Dir string `toml:"Dir"` } type ServerGRPCConfig struct { ListenAddr string `toml:"ListenAddr"` CACerts *CertificatePaths `toml:"CACerts"` Certs *CertificatePaths `toml:"Certs"` } type ServerHTTPConfig struct { ListenAddr string `toml:"ListenAddr"` } type ClientConfig struct { DefaultServer string `toml:"DefaultServer"` ServerCertPath string `toml:"ServerCertPath"` Certs *CertificatePaths `toml:"Certs"` } func FromDefault() *Config { cfg := &Config{ LogLevel: "INFO", Server: &ServerConfig{ GRPC: &ServerGRPCConfig{ ListenAddr: ":50051", CACerts: &CertificatePaths{}, Certs: &CertificatePaths{}, }, HTTP: &ServerHTTPConfig{ ListenAddr: ":8089", }, }, Client: &ClientConfig{}, } return cfg } func FromReader(r io.Reader) (*Config, error) { decoder := toml.NewDecoder(r) c := FromDefault() if err := decoder.Decode(c); err != nil { return nil, fmt.Errorf("unable to read config: %w", err) } return c, nil } func FromFile(path string) (*Config, error) { f, err := os.Open(path) if err != nil { return nil, fmt.Errorf("unable to open config-file: %w", err) } defer f.Close() cfg, err := FromReader(f) if err == nil { cfg.location = path } return cfg, err } func FromDefaultLocations() (*Config, error) { defaultLocations := []string{ "ezshare.toml", } userConfigDir, err := os.UserConfigDir() if err != nil { defaultLocations = append(defaultLocations, filepath.Join(userConfigDir, "ezshare", "ezshare.toml")) } for _, location := range defaultLocations { if _, err := os.Stat(location); err == nil { return FromFile(location) } } return nil, fmt.Errorf("config not found") } func (c *Config) Location() string { return c.location } func (cp *CertificatePaths) GetCertBytes() ([]byte, error) { f, err := os.Open(cp.CertificatePath) if err != nil { return nil, err } return ioutil.ReadAll(f) } func (cp *CertificatePaths) GetKeyBytes() ([]byte, error) { f, err := os.Open(cp.CertificateKeyPath) if err != nil { return nil, err } return ioutil.ReadAll(f) } func (cc *ClientConfig) ServerCertBytes() ([]byte, error) { f, err := os.Open(cc.ServerCertPath) if err != nil { return nil, fmt.Errorf("unable to open server certificate: %w", err) } defer f.Close() data, err := ioutil.ReadAll(f) if err != nil { return nil, fmt.Errorf("unable to read client server certificate: %w", err) } return data, nil } func (cc *ClientConfig) Creds() (credentials.TransportCredentials, error) { srvCertBytes, err := cc.ServerCertBytes() if err != nil { return nil, err } clientCertBytes, err := cc.Certs.GetCertBytes() if err != nil { return nil, fmt.Errorf("unable to read client cert: %w", err) } clientKeyBytes, err := cc.Certs.GetKeyBytes() if err != nil { return nil, fmt.Errorf("unable to read client cert: %w", err) } certPool := x509.NewCertPool() if !certPool.AppendCertsFromPEM(srvCertBytes) { return nil, fmt.Errorf("unable to load ca cert") } clientCert, err := tls.X509KeyPair(clientCertBytes, clientKeyBytes) if err != nil { return nil, fmt.Errorf("unable to load client cert: %s", err) } config := &tls.Config{ Certificates: []tls.Certificate{clientCert}, RootCAs: certPool, } return credentials.NewTLS(config), nil } func (c *Config) ToDefaultFile() error { userConfigDir, err := os.UserConfigDir() if err != nil { return err } configDirPath := filepath.Join(userConfigDir, "ezshare") info, err := os.Stat(configDirPath) if err != nil { if !errors.Is(err, fs.ErrNotExist) { return err } if err := os.Mkdir(configDirPath, 0755); err != nil { return fmt.Errorf("unable to create config-dir: %w", err) } } else { if !info.IsDir() { return fmt.Errorf("config-directory is not a directory") } } configFilePath := filepath.Join(configDirPath, "ezshare.toml") _, err = os.Stat(configFilePath) if err != nil { if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("error stating config-file: %w", err) } f, err := os.Create(configFilePath) if err != nil { return fmt.Errorf("unable to create config-file: %w", err) } encoder := toml.NewEncoder(f) fmt.Printf("Writing config to '%s'", configFilePath) return encoder.Encode(c) } return fmt.Errorf("config-file already exists") }