package main import ( "context" "fmt" "net/http" "os" "os/signal" "time" "github.com/coreos/go-systemd/daemon" sshlib "github.com/gliderlabs/ssh" "github.com/urfave/cli/v2" "github.uio.no/torjus/apiary" "github.uio.no/torjus/apiary/config" "github.uio.no/torjus/apiary/honeypot/ports" "github.uio.no/torjus/apiary/honeypot/ssh" "github.uio.no/torjus/apiary/honeypot/ssh/store" "github.uio.no/torjus/apiary/web" "go.uber.org/zap" "go.uber.org/zap/zapcore" "golang.org/x/crypto/acme/autocert" ) func main() { app := &cli.App{ Name: "apiary", Version: apiary.FullVersion(), Authors: []*cli.Author{ { Name: "Torjus HÃ¥kestad", Email: "torjus@usit.uio.no", }, }, Commands: []*cli.Command{ { Name: "serve", Action: ActionServe, Usage: "Start Apiary server", }, }, } if err := app.Run(os.Args); err != nil { fmt.Printf("Error: %s\n", err) os.Exit(1) } } func ActionServe(c *cli.Context) error { cfg, err := getConfig() if err != nil { return err } // Setup logging loggers := setupLoggers(cfg) loggers.rootLogger.Infow("Starting apiary", "version", apiary.FullVersion()) // Setup store var s store.LoginAttemptStore switch cfg.Store.Type { case "MEMORY", "memory": loggers.rootLogger.Infow("Initialized store", "store_type", "memory") s = &store.MemoryStore{} case "POSTGRES", "postgres": pgStartTime := time.Now() loggers.rootLogger.Debugw("Initializing store", "store_type", "postgres") pgStore, err := store.NewPostgresStore(cfg.Store.Postgres.DSN) if err != nil { return err } if err := pgStore.InitDB(); err != nil { return err } loggers.rootLogger.Infow("Initialized store", "store_type", "postgres", "init_time", time.Since(pgStartTime)) if cfg.Store.EnableCache { loggers.rootLogger.Debugw("Initializing store", "store_type", "cache-postgres") startTime := time.Now() cachingStore := store.NewCachingStore(pgStore) s = cachingStore loggers.rootLogger.Infow("Initialized store", "store_type", "cache-postgres", "init_time", time.Since(startTime)) } else { s = pgStore } default: return fmt.Errorf("Invalid store configured") } // Setup interrupt handling interruptChan := make(chan os.Signal, 1) signal.Notify(interruptChan, os.Interrupt) rootCtx, rootCancel := context.WithCancel(c.Context) defer rootCancel() serversCtx, serversCancel := context.WithCancel(rootCtx) defer serversCancel() // Setup metrics collection s = store.NewMetricsCollectingStore(rootCtx, s) // Setup honeypot hs, err := ssh.NewHoneypotServer(cfg.Honeypot, s) if err != nil { return err } hs.Logger = loggers.honeypotLogger // Setup webserver web := web.NewServer(cfg.Frontend, hs, s) web.AccessLogger = loggers.webAccessLogger web.ServerLogger = loggers.webServerLogger if cfg.Frontend.Autocert.Enable { certManager := autocert.Manager{ Prompt: autocert.AcceptTOS, HostPolicy: autocert.HostWhitelist(cfg.Frontend.Autocert.Domains...), Email: cfg.Frontend.Autocert.Email, } if cfg.Frontend.Autocert.CacheDir != "" { certManager.Cache = autocert.DirCache(cfg.Frontend.Autocert.CacheDir) } tlsConfig := certManager.TLSConfig() web.TLSConfig = tlsConfig } // Setup portlistener, if configured if cfg.Ports.Enable { portsCtx, cancel := context.WithCancel(rootCtx) defer cancel() // TODO: Add more stores store := &ports.MemoryStore{} portsServer := ports.New(store) portsServer.Logger = loggers.portsLogger for _, port := range cfg.Ports.TCPPorts { portsServer.AddTCPPort(port) } go func() { loggers.rootLogger.Info("Starting ports server") portsServer.Start(portsCtx) }() } // Handle interrupt go func() { <-interruptChan loggers.rootLogger.Info("Interrupt received, shutting down") serversCancel() }() // Start ssh server go func() { loggers.rootLogger.Info("Starting SSH server") if err := hs.ListenAndServe(); err != nil && err != sshlib.ErrServerClosed { loggers.rootLogger.Warnw("SSH server returned error", "error", err) } }() // Start web server go func() { loggers.rootLogger.Info("Starting web server") if err := web.StartServe(); err != nil && err != http.ErrServerClosed { loggers.rootLogger.Warnw("Web server returned error", "error", err) } }() // If run by systemd, enable watchdog and notify ready go func() { notifyCtx, cancel := context.WithCancel(rootCtx) defer cancel() _, ok := os.LookupEnv("NOTIFY_SOCKET") if !ok { return } loggers.rootLogger.Info("Systemd notify socket detected. Sending ready and enabling watchdog.") ok, err := daemon.SdNotify(false, daemon.SdNotifyReady) if !ok { loggers.rootLogger.Info("Systemd notify not enabled.") return } if err != nil { loggers.rootLogger.Warnw("Unable to connect to NOTIFY_SOCKET.", "error", err) return } loggers.rootLogger.Debug("Sent READY=1 to NOTIFY_SOCKET.") if _, err := daemon.SdNotify(false, "WATCHDOG_USEC=10000000"); err != nil { loggers.rootLogger.Warnw("Unable to connect to NOTIFY_SOCKET to set watchdog timeout.", "error", err) return } loggers.rootLogger.Debug("Sent WATCHDOG_USEC=10000000 to NOTIFY_SOCKET.") if _, err := daemon.SdNotify(false, daemon.SdNotifyWatchdog); err != nil { loggers.rootLogger.Warnw("Unable to connect to NOTIFY_SOCKET to notify watchdog.", "error", err) return } // Setup timer timeout, err := daemon.SdWatchdogEnabled(false) if err != nil { loggers.rootLogger.Warnw("Unable to connect to NOTIFY_SOCKET to get watchdog timeout.", "error", err) return } if timeout < 0 { loggers.rootLogger.Debugw("Got invalid watchdog timeout", "timeout", timeout) return } ticker := time.NewTicker(timeout / 2) for { healthy := s.IsHealthy() select { case <-ticker.C: if healthy == nil { daemon.SdNotify(false, daemon.SdNotifyWatchdog) } case <-notifyCtx.Done(): loggers.rootLogger.Debugw("Notify context cancelled.") return } } }() go func() { <-serversCtx.Done() // Stop SSH server sshShutdownCtx, sshShutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) defer sshShutdownCancel() loggers.rootLogger.Info("SSH server shutdown started") if err := hs.Shutdown(sshShutdownCtx); err != nil { loggers.rootLogger.Infow("Error shutting down SSH server", "error", err) } loggers.rootLogger.Info("SSH server shutdown complete") // Stop Web server webShutdownCtx, webShutdownCancel := context.WithTimeout(context.Background(), 10*time.Second) defer webShutdownCancel() loggers.rootLogger.Info("Web server shutdown started") if err := web.Shutdown(webShutdownCtx); err != nil { loggers.rootLogger.Infow("Error shutting down web server", "error", err) } loggers.rootLogger.Info("Web server shutdown complete") rootCancel() }() <-rootCtx.Done() return nil } type loggerCollection struct { rootLogger *zap.SugaredLogger honeypotLogger *zap.SugaredLogger webAccessLogger *zap.SugaredLogger webServerLogger *zap.SugaredLogger portsLogger *zap.SugaredLogger } func setupLoggers(cfg config.Config) *loggerCollection { logEncoderCfg := zap.NewProductionEncoderConfig() logEncoderCfg.EncodeCaller = func(caller zapcore.EntryCaller, enc zapcore.PrimitiveArrayEncoder) {} level := zap.NewAtomicLevelAt(zap.InfoLevel) logEncoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder logEncoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder logEncoderCfg.EncodeDuration = zapcore.StringDurationEncoder rootLoggerCfg := &zap.Config{ Level: level, OutputPaths: []string{"stdout"}, ErrorOutputPaths: []string{"stderr"}, Encoding: "console", EncoderConfig: logEncoderCfg, } rootLogger, err := rootLoggerCfg.Build() if err != nil { panic(err) } return &loggerCollection{ rootLogger: rootLogger.Named("APP").Sugar(), honeypotLogger: rootLogger.Named("HON").Sugar(), webAccessLogger: rootLogger.Named("ACC").Sugar(), webServerLogger: rootLogger.Named("WEB").Sugar(), portsLogger: rootLogger.Named("PRT").Sugar(), } } func getConfig() (config.Config, error) { defaultLocations := []string{ "apiary.toml", "/etc/apiary.toml", "/etc/apiary/apiary.toml", } for _, fname := range defaultLocations { if _, err := os.Stat(fname); os.IsNotExist(err) { continue } cfg, err := config.FromFile(fname) if err != nil { return config.Config{}, err } return cfg, nil } return config.Config{}, fmt.Errorf("Could not find config file") }