labmon/tlsconmon/tlsconmon.go

161 lines
4.2 KiB
Go

package tlsconmon
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"log/slog"
"os"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var gaugeCertTimeLeft = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "labmon",
Subsystem: "tlsconmon",
Name: "certificate_seconds_left",
Help: "Seconds left until the certificate expires.",
}, []string{"address"})
var gaugeCertError = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "labmon",
Subsystem: "tlsconmon",
Name: "certificate_check_error",
Help: "Error checking the certificate.",
}, []string{"address"})
type TLSConnectionMonitor struct {
Address string
Verify bool
CheckDuration time.Duration
extraCAs []*x509.Certificate
logger *slog.Logger
shutdownCh chan struct{}
shutdownComplete chan struct{}
cert *x509.Certificate
}
func NewTLSConnectionMonitor(address string, verify bool, extraCAPaths []string, duration time.Duration) (*TLSConnectionMonitor, error) {
var extraCAs []*x509.Certificate
for _, path := range extraCAPaths {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open extra cert file %s: %w", path, err)
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return nil, fmt.Errorf("failed to read extra cert file %s: %w", path, err)
}
pemBlock, _ := pem.Decode(data)
if pemBlock.Type != "CERTIFICATE" {
return nil, fmt.Errorf("invalid PEM block type in extra ca file %s: %s", path, pemBlock.Type)
}
cert, err := x509.ParseCertificate(pemBlock.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse extra cert file %s: %w", path, err)
}
extraCAs = append(extraCAs, cert)
}
return &TLSConnectionMonitor{
Address: address,
Verify: verify,
extraCAs: extraCAs,
CheckDuration: duration,
logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
shutdownCh: make(chan struct{}),
shutdownComplete: make(chan struct{}, 1),
}, nil
}
func (tm *TLSConnectionMonitor) SetLogger(logger *slog.Logger) {
tm.logger = logger.With("component", "tlsconmon", "address", tm.Address)
}
func (tm *TLSConnectionMonitor) Start() {
if err := tm.fetchCert(); err != nil {
gaugeCertError.WithLabelValues(tm.Address).Set(1)
gaugeCertTimeLeft.WithLabelValues(tm.Address).Set(0)
} else {
gaugeCertError.WithLabelValues(tm.Address).Set(0)
timeLeft := time.Until(tm.cert.NotAfter).Seconds()
gaugeCertTimeLeft.WithLabelValues(tm.Address).Set(timeLeft)
}
timerCertFetch := time.NewTimer(tm.CheckDuration)
defer timerCertFetch.Stop()
timerUpdateMonitor := time.NewTimer(1 * time.Second)
defer timerUpdateMonitor.Stop()
for {
select {
case <-timerCertFetch.C:
if err := tm.fetchCert(); err != nil {
gaugeCertError.WithLabelValues(tm.Address).Set(1)
} else {
gaugeCertError.WithLabelValues(tm.Address).Set(0)
}
timerCertFetch.Reset(tm.CheckDuration)
case <-timerUpdateMonitor.C:
timeLeft := time.Until(tm.cert.NotAfter).Seconds()
gaugeCertTimeLeft.WithLabelValues(tm.Address).Set(timeLeft)
timerUpdateMonitor.Reset(1 * time.Second)
case <-tm.shutdownCh:
tm.shutdownComplete <- struct{}{}
return
}
}
}
func (tm *TLSConnectionMonitor) Shutdown() {
tm.shutdownCh <- struct{}{}
close(tm.shutdownCh)
<-tm.shutdownComplete
close(tm.shutdownComplete)
}
func (tm *TLSConnectionMonitor) fetchCert() error {
pool, err := x509.SystemCertPool()
if err != nil {
tm.logger.Error("Failed to load system cert pool", "error", err)
return fmt.Errorf("failed to load system cert pool: %w", err)
}
for _, cert := range tm.extraCAs {
pool.AddCert(cert)
}
tlsConf := &tls.Config{}
if !tm.Verify {
tlsConf.InsecureSkipVerify = true
}
if len(tm.extraCAs) > 0 {
tlsConf.RootCAs = pool
}
conn, err := tls.Dial("tcp", tm.Address, tlsConf)
if err != nil {
tm.logger.Error("Failed to connect to TLS server", "error", err)
return fmt.Errorf("failed to connect to TLS server: %w", err)
}
defer conn.Close()
tm.cert = conn.ConnectionState().PeerCertificates[0]
tm.logger.Info("Fetched certificate", "not_after", tm.cert.NotAfter, "subject", tm.cert.Subject)
return nil
}