ezshare/config/config.go
2021-12-04 04:31:19 +01:00

216 lines
5.0 KiB
Go

package config
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/fs"
"io/ioutil"
"os"
"path/filepath"
"github.com/pelletier/go-toml"
"google.golang.org/grpc/credentials"
)
type Config struct {
Server *ServerConfig `toml:"Server"`
Client *ClientConfig `toml:"Client"`
location string
}
type CertificatePaths struct {
CertificatePath string `toml:"CertificatePath"`
CertificateKeyPath string `toml:"CertificateKeyPath"`
}
type ServerConfig struct {
LogLevel string `toml:"LogLevel"`
Hostname string `toml:"Hostname"`
GRPC *ServerGRPCConfig `toml:"GRPC"`
HTTP *ServerHTTPConfig `toml:"HTTP"`
}
type ServerStoreConfig struct {
Type string `toml:"Type"`
FSStoreConfig *FSStoreConfig `toml:"Filesystem"`
}
type FSStoreConfig struct {
Dir string `toml:"Dir"`
}
type ServerGRPCConfig struct {
ListenAddr string `toml:"ListenAddr"`
CACerts *CertificatePaths `toml:"CACerts"`
Certs *CertificatePaths `toml:"Certs"`
}
type ServerHTTPConfig struct {
ListenAddr string `toml:"ListenAddr"`
}
type ClientConfig struct {
DefaultServer string `toml:"DefaultServer"`
ServerCertPath string `toml:"ServerCertPath"`
Certs *CertificatePaths `toml:"Certs"`
}
func FromDefault() *Config {
cfg := &Config{
Server: &ServerConfig{
LogLevel: "INFO",
GRPC: &ServerGRPCConfig{
ListenAddr: ":50051",
CACerts: &CertificatePaths{},
Certs: &CertificatePaths{},
},
HTTP: &ServerHTTPConfig{
ListenAddr: ":8089",
},
},
Client: &ClientConfig{},
}
return cfg
}
func FromReader(r io.Reader) (*Config, error) {
decoder := toml.NewDecoder(r)
c := FromDefault()
if err := decoder.Decode(c); err != nil {
return nil, fmt.Errorf("unable to read config: %w", err)
}
return c, nil
}
func FromFile(path string) (*Config, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("unable to open config-file: %w", err)
}
defer f.Close()
cfg, err := FromReader(f)
if err == nil {
cfg.location = path
}
return cfg, err
}
func FromDefaultLocations() (*Config, error) {
defaultLocations := []string{
"ezshare.toml",
}
userConfigDir, err := os.UserConfigDir()
if err == nil {
defaultLocations = append(defaultLocations, filepath.Join(userConfigDir, "ezshare", "ezshare.toml"))
}
for _, location := range defaultLocations {
if _, err := os.Stat(location); err == nil {
return FromFile(location)
}
}
return nil, fmt.Errorf("config not found")
}
func (c *Config) Location() string {
return c.location
}
func (cp *CertificatePaths) GetCertBytes() ([]byte, error) {
f, err := os.Open(cp.CertificatePath)
if err != nil {
return nil, err
}
return ioutil.ReadAll(f)
}
func (cp *CertificatePaths) GetKeyBytes() ([]byte, error) {
f, err := os.Open(cp.CertificateKeyPath)
if err != nil {
return nil, err
}
return ioutil.ReadAll(f)
}
func (cc *ClientConfig) ServerCertBytes() ([]byte, error) {
f, err := os.Open(cc.ServerCertPath)
if err != nil {
return nil, fmt.Errorf("unable to open server certificate: %w", err)
}
defer f.Close()
data, err := ioutil.ReadAll(f)
if err != nil {
return nil, fmt.Errorf("unable to read client server certificate: %w", err)
}
return data, nil
}
func (cc *ClientConfig) Creds() (credentials.TransportCredentials, error) {
srvCertBytes, err := cc.ServerCertBytes()
if err != nil {
return nil, err
}
clientCertBytes, err := cc.Certs.GetCertBytes()
if err != nil {
return nil, fmt.Errorf("unable to read client cert: %w", err)
}
clientKeyBytes, err := cc.Certs.GetKeyBytes()
if err != nil {
return nil, fmt.Errorf("unable to read client cert: %w", err)
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(srvCertBytes) {
return nil, fmt.Errorf("unable to load ca cert")
}
clientCert, err := tls.X509KeyPair(clientCertBytes, clientKeyBytes)
if err != nil {
return nil, fmt.Errorf("unable to load client cert: %s", err)
}
config := &tls.Config{
Certificates: []tls.Certificate{clientCert},
RootCAs: certPool,
}
return credentials.NewTLS(config), nil
}
func (c *Config) ToDefaultFile() error {
userConfigDir, err := os.UserConfigDir()
if err != nil {
return err
}
configDirPath := filepath.Join(userConfigDir, "ezshare")
info, err := os.Stat(configDirPath)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return err
}
if err := os.Mkdir(configDirPath, 0755); err != nil {
return fmt.Errorf("unable to create config-dir: %w", err)
}
} else {
if !info.IsDir() {
return fmt.Errorf("config-directory is not a directory")
}
}
configFilePath := filepath.Join(configDirPath, "ezshare.toml")
_, err = os.Stat(configFilePath)
if err != nil {
if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("error stating config-file: %w", err)
}
f, err := os.Create(configFilePath)
if err != nil {
return fmt.Errorf("unable to create config-file: %w", err)
}
encoder := toml.NewEncoder(f)
fmt.Printf("Writing config to '%s'", configFilePath)
return encoder.Encode(c)
}
return fmt.Errorf("config-file already exists")
}