package tlsconmon import ( "context" "crypto/tls" "crypto/x509" "encoding/pem" "fmt" "io" "log/slog" "os" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" ) const name = "git.t-juice.com/labmon/tlsconmon" var ( // Prometheus metrics gaugeCertTimeLeft = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "labmon", Subsystem: "tlsconmon", Name: "certificate_seconds_left", Help: "Seconds left until the certificate expires.", }, []string{"address"}) gaugeCertError = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "labmon", Subsystem: "tlsconmon", Name: "certificate_check_error", Help: "Error checking the certificate.", }, []string{"address"}) gaugeCertLifetime = promauto.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "labmon", Subsystem: "tlsconmon", Name: "certificate_lifetime_seconds", Help: "How long the certificate is valid in seconds.", }, []string{"address"}) // OTEL tracing tracer = otel.Tracer(name) ) type TLSConnectionMonitor struct { Address string Verify bool CheckDuration time.Duration extraCAs []*x509.Certificate logger *slog.Logger shutdownCh chan struct{} shutdownComplete chan struct{} cert *x509.Certificate } func init() { } func NewTLSConnectionMonitor(address string, verify bool, extraCAPaths []string, duration time.Duration) (*TLSConnectionMonitor, error) { var extraCAs []*x509.Certificate for _, path := range extraCAPaths { f, err := os.Open(path) if err != nil { return nil, fmt.Errorf("failed to open extra cert file %s: %w", path, err) } defer f.Close() data, err := io.ReadAll(f) if err != nil { return nil, fmt.Errorf("failed to read extra cert file %s: %w", path, err) } pemBlock, _ := pem.Decode(data) if pemBlock.Type != "CERTIFICATE" { return nil, fmt.Errorf("invalid PEM block type in extra ca file %s: %s", path, pemBlock.Type) } cert, err := x509.ParseCertificate(pemBlock.Bytes) if err != nil { return nil, fmt.Errorf("failed to parse extra cert file %s: %w", path, err) } extraCAs = append(extraCAs, cert) } return &TLSConnectionMonitor{ Address: address, Verify: verify, extraCAs: extraCAs, CheckDuration: duration, logger: slog.New(slog.NewTextHandler(io.Discard, nil)), shutdownCh: make(chan struct{}), shutdownComplete: make(chan struct{}, 1), }, nil } func (tm *TLSConnectionMonitor) SetLogger(logger *slog.Logger) { tm.logger = logger.With("component", "tlsconmon", "address", tm.Address) } func (tm *TLSConnectionMonitor) Start(ctx context.Context) { if err := tm.fetchCert(ctx); err != nil { gaugeCertError.WithLabelValues(tm.Address).Set(1) gaugeCertTimeLeft.WithLabelValues(tm.Address).Set(0) } else { gaugeCertError.WithLabelValues(tm.Address).Set(0) timeLeft := time.Until(tm.cert.NotAfter).Seconds() gaugeCertTimeLeft.WithLabelValues(tm.Address).Set(timeLeft) lifetime := tm.cert.NotAfter.Sub(tm.cert.NotBefore).Seconds() gaugeCertLifetime.WithLabelValues(tm.Address).Set(lifetime) } timerCertFetch := time.NewTimer(tm.CheckDuration) defer timerCertFetch.Stop() timerUpdateMonitor := time.NewTimer(1 * time.Second) defer timerUpdateMonitor.Stop() for { select { case <-timerCertFetch.C: if err := tm.fetchCert(ctx); err != nil { gaugeCertError.WithLabelValues(tm.Address).Set(1) } else { gaugeCertError.WithLabelValues(tm.Address).Set(0) lifetime := tm.cert.NotAfter.Sub(tm.cert.NotBefore).Seconds() gaugeCertLifetime.WithLabelValues(tm.Address).Set(lifetime) } timerCertFetch.Reset(tm.CheckDuration) case <-timerUpdateMonitor.C: timeLeft := time.Until(tm.cert.NotAfter).Seconds() gaugeCertTimeLeft.WithLabelValues(tm.Address).Set(timeLeft) timerUpdateMonitor.Reset(1 * time.Second) case <-tm.shutdownCh: tm.shutdownComplete <- struct{}{} return } } } func (tm *TLSConnectionMonitor) Shutdown() { tm.shutdownCh <- struct{}{} close(tm.shutdownCh) <-tm.shutdownComplete close(tm.shutdownComplete) } func (tm *TLSConnectionMonitor) fetchCert(ctx context.Context) error { ctx, span := tracer.Start(ctx, "fetch_cert") defer span.End() span.SetAttributes(attribute.String("cert_address", tm.Address)) span.AddEvent("load_system_cert_pool") pool, err := x509.SystemCertPool() if err != nil { tm.logger.Error("Failed to load system cert pool", "error", err) span.SetStatus(codes.Error, "Failed to fetch certificate") return fmt.Errorf("failed to load system cert pool: %w", err) } for _, cert := range tm.extraCAs { span.AddEvent("add_extra_ca", trace.WithAttributes(attribute.String("ca_cn", cert.Subject.CommonName))) pool.AddCert(cert) } tlsConf := &tls.Config{} if !tm.Verify { tlsConf.InsecureSkipVerify = true } if len(tm.extraCAs) > 0 { tlsConf.RootCAs = pool } _, dialSpan := tracer.Start(ctx, "dial_tls") defer dialSpan.End() conn, err := tls.Dial("tcp", tm.Address, tlsConf) if err != nil { tm.logger.Error("Failed to connect to TLS server", "error", err) dialSpan.SetStatus(codes.Error, "Failed to fetch certificate") return fmt.Errorf("failed to connect to TLS server: %w", err) } defer conn.Close() dialSpan.SetStatus(codes.Ok, "Fetched certificate successfully") dialSpan.End() tm.cert = conn.ConnectionState().PeerCertificates[0] tm.logger.Info("Fetched certificate", "not_after", tm.cert.NotAfter, "subject", tm.cert.Subject) span.SetStatus(codes.Ok, "Certificate fetched successfully") return nil }