Files
Artyom Babiy 2485c599d5 Fix v25
2025-07-28 15:33:30 +02:00

407 lines
13 KiB
Go

package certificate
import (
"context"
"crypto/x509"
"errors"
"fmt"
"time"
"github.com/go-logr/logr"
mariadbv1alpha1 "github.com/mariadb-operator/mariadb-operator/v25/api/v1alpha1"
"github.com/mariadb-operator/mariadb-operator/v25/pkg/builder"
"github.com/mariadb-operator/mariadb-operator/v25/pkg/discovery"
"github.com/mariadb-operator/mariadb-operator/v25/pkg/metadata"
"github.com/mariadb-operator/mariadb-operator/v25/pkg/pki"
"github.com/mariadb-operator/mariadb-operator/v25/pkg/refresolver"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sigs.k8s.io/controller-runtime/pkg/log"
)
var (
ErrSkipCertRenewal = errors.New("skipping certificate renewal")
)
type CertReconciler struct {
client.Client
scheme *runtime.Scheme
recorder record.EventRecorder
refResolver *refresolver.RefResolver
discovery *discovery.Discovery
builder *builder.Builder
}
func NewCertReconciler(client client.Client, scheme *runtime.Scheme, recorder record.EventRecorder,
discovery *discovery.Discovery, builder *builder.Builder) *CertReconciler {
return &CertReconciler{
Client: client,
scheme: scheme,
recorder: recorder,
refResolver: refresolver.New(client),
discovery: discovery,
builder: builder,
}
}
type ReconcileResult struct {
ctrl.Result
CAKeyPair *pki.KeyPair
CertKeyPair *pki.KeyPair
}
func (r *ReconcileResult) IsZero() bool {
if r == nil {
return true
}
return r.Result.IsZero()
}
func (r *CertReconciler) Reconcile(ctx context.Context, certOpts ...CertReconcilerOpt) (*ReconcileResult, error) {
opts := NewDefaultCertificateOpts()
for _, setOpt := range certOpts {
setOpt(opts)
}
logger := log.FromContext(ctx).WithName("cert")
result := &ReconcileResult{}
var err error
if opts.certIssuerRef != nil {
result.Result, err = r.reconcileCertManagerCert(ctx, opts, logger)
if err != nil {
return nil, fmt.Errorf("error reconciling cert-manager Certificate: %v", err)
}
} else if opts.shouldIssueCA || opts.shouldIssueCert {
result.CAKeyPair, err = r.reconcileCA(ctx, opts, logger)
if err != nil {
return nil, fmt.Errorf("error reconciling CA: %v", err)
}
result.Result, result.CertKeyPair, err = r.reconcileCert(ctx, result.CAKeyPair, opts, logger)
if err != nil {
return nil, fmt.Errorf("error reconciling certificate: %v", err)
}
}
return result, nil
}
func (r *CertReconciler) reconcileCA(ctx context.Context, opts *CertReconcilerOpts, logger logr.Logger) (*pki.KeyPair, error) {
if !opts.shouldIssueCA && !opts.shouldIssueCert {
return nil, nil
}
if !opts.shouldIssueCA && opts.shouldIssueCert {
caKeyPair, err := r.getCAKeyPair(ctx, opts)
if err != nil {
return nil, fmt.Errorf("error getting CA keypair: %v", err)
}
return caKeyPair, nil
}
createCA := r.createCAFn(opts)
caKeyPair, err := r.reconcileKeyPair(ctx, opts.caSecretKey, opts.caSecretType, false, opts, createCA)
if err != nil {
return nil, fmt.Errorf("error reconciling CA keypair: %v", err)
}
caLeafCert, err := caKeyPair.LeafCertificate()
if err != nil {
return nil, fmt.Errorf("error getting CA leaf certificate: %v", err)
}
renewalTime, err := pki.RenewalTime(caLeafCert.NotBefore, caLeafCert.NotAfter, opts.renewBeforePercentage)
if err != nil {
return nil, fmt.Errorf("error getting CA renewal time: %v", err)
}
valid, err := pki.ValidateCA(caKeyPair, opts.caCommonName, time.Now())
afterRenewal := time.Now().After(*renewalTime)
caLogger := logger.WithValues(
"common-name", caLeafCert.Subject.CommonName,
"issuer", caLeafCert.Issuer.CommonName,
"valid", valid,
"err", err,
"renewal-time", renewalTime,
"after-renewal", afterRenewal,
)
caLogger.V(1).Info("CA cert status")
if !valid || err != nil || afterRenewal {
caLogger.Info("starting CA cert renewal")
caKeyPair, err = r.reconcileKeyPair(ctx, opts.caSecretKey, opts.caSecretType, true, opts, createCA)
if err != nil {
return nil, fmt.Errorf("error reconciling CA keypair: %v", err)
}
}
return caKeyPair, nil
}
func (r *CertReconciler) reconcileCert(ctx context.Context, caKeyPair *pki.KeyPair, opts *CertReconcilerOpts,
logger logr.Logger) (ctrl.Result, *pki.KeyPair, error) {
if !opts.shouldIssueCert {
return ctrl.Result{}, nil, nil
}
if caKeyPair == nil {
return ctrl.Result{}, nil, errors.New("unable to issue cert: CA keypair is nil")
}
createCert := r.createCertFn(caKeyPair, opts)
certKeyPair, err := r.reconcileKeyPair(ctx, opts.certSecretKey, SecretTypeTLS, false, opts, createCert)
if err != nil {
return ctrl.Result{}, nil, fmt.Errorf("error reconciling certificate keypair: %v", err)
}
caCerts, err := r.getCABundle(ctx, caKeyPair, opts, logger)
if err != nil {
return ctrl.Result{}, nil, fmt.Errorf("error getting CA bundle: %v", err)
}
leafCert, err := certKeyPair.LeafCertificate()
if err != nil {
return ctrl.Result{}, nil, fmt.Errorf("error getting leaf certificate: %v", err)
}
renewalTime, err := pki.RenewalTime(leafCert.NotBefore, leafCert.NotAfter, opts.renewBeforePercentage)
if err != nil {
return ctrl.Result{}, nil, fmt.Errorf("error getting cert renewal time: %v", err)
}
valid, err := pki.ValidateCert(caCerts, certKeyPair, opts.certCommonName, time.Now())
afterRenewal := time.Now().After(*renewalTime)
certLogger := logger.WithValues(
"common-name", leafCert.Subject.CommonName,
"issuer", leafCert.Issuer.CommonName,
"valid", valid,
"err", err,
"renewal-time", renewalTime,
"after-renewal", afterRenewal,
)
certLogger.V(1).Info("cert status")
if !valid || err != nil {
certLogger.Info("starting cert renewal", "reason", "Invalid cert")
certKeyPair, err = r.reconcileKeyPair(ctx, opts.certSecretKey, SecretTypeTLS, true, opts, createCert)
if err != nil {
return ctrl.Result{}, nil, fmt.Errorf("error reconciling certificate KeyPair: %v", err)
}
if err := opts.certHandler.HandleExpiredCert(ctx); err != nil {
return ctrl.Result{}, certKeyPair, fmt.Errorf("error handling expired certificate: %v", err)
}
return ctrl.Result{}, certKeyPair, nil
}
if !afterRenewal {
return ctrl.Result{}, certKeyPair, nil
}
shouldRenew, reason, err := opts.certHandler.ShouldRenewCert(ctx, caKeyPair)
if err != nil {
if errors.Is(err, ErrSkipCertRenewal) {
certLogger.V(1).Info("skipping cert renewal", "reason", reason)
return ctrl.Result{}, certKeyPair, nil
}
return ctrl.Result{}, nil, fmt.Errorf("error checking whether certificate should be renewed: %v", err)
}
if !shouldRenew {
certLogger.Info("waiting for cert renewal", "reason", reason)
return ctrl.Result{RequeueAfter: 10 * time.Second}, nil, nil
}
if shouldRenew {
certLogger.Info("starting cert renewal", "reason", reason)
certKeyPair, err = r.reconcileKeyPair(ctx, opts.certSecretKey, SecretTypeTLS, true, opts, createCert)
if err != nil {
return ctrl.Result{}, nil, fmt.Errorf("error reconciling certificate KeyPair: %v", err)
}
return ctrl.Result{}, certKeyPair, nil
}
return ctrl.Result{}, certKeyPair, nil
}
func (r *CertReconciler) reconcileKeyPair(ctx context.Context, key types.NamespacedName, secretType SecretType,
shouldRenew bool, opts *CertReconcilerOpts, createKeyPairFn func() (*pki.KeyPair, error)) (keyPair *pki.KeyPair, err error) {
secret := corev1.Secret{}
if err := r.Get(ctx, key, &secret); err != nil {
if !apierrors.IsNotFound(err) {
return nil, err
}
keyPair, err := createKeyPairFn()
if err != nil {
return nil, err
}
if err := r.createSecret(ctx, key, secretType, &secret, keyPair, opts.relatedObject); err != nil {
return nil, err
}
return keyPair, nil
}
if secret.Data == nil || shouldRenew {
keyPair, err := createKeyPairFn()
if err != nil {
return nil, err
}
if err := r.patchSecret(ctx, secretType, &secret, keyPair, opts.relatedObject); err != nil {
return nil, err
}
return keyPair, nil
}
keyPairOpts := opts.KeyPairOpts()
if secretType == SecretTypeCA {
keyPair, err = pki.NewKeyPairFromCASecret(&secret, keyPairOpts...)
if err != nil {
return nil, err
}
} else {
keyPair, err = pki.NewKeyPairFromTLSSecret(&secret, keyPairOpts...)
if err != nil {
return nil, err
}
}
return keyPair, nil
}
func (r *CertReconciler) getCAKeyPair(ctx context.Context, opts *CertReconcilerOpts) (*pki.KeyPair, error) {
var secret corev1.Secret
if err := r.Get(ctx, opts.caSecretKey, &secret); err != nil {
return nil, fmt.Errorf("error getting CA keypair Secret: %w", err)
}
keyPairOpts := opts.KeyPairOpts()
if opts.caSecretType == SecretTypeCA {
keyPair, err := pki.NewKeyPairFromCASecret(&secret, keyPairOpts...)
return r.handleCAKeyPairResult(keyPair, err, opts.caSecretKey.Name, opts)
}
keyPair, err := pki.NewKeyPairFromTLSSecret(&secret, keyPairOpts...)
return r.handleCAKeyPairResult(keyPair, err, opts.caSecretKey.Name, opts)
}
func (r *CertReconciler) handleCAKeyPairResult(keyPair *pki.KeyPair, err error, secretName string,
opts *CertReconcilerOpts) (*pki.KeyPair, error) {
if err != nil {
if errors.Is(err, pki.ErrSecretKeyNotFound) {
msg := fmt.Sprintf("key not found in CA Secret \"%s\": %v", secretName, err)
if relatedObj := opts.relatedObject; relatedObj != nil {
r.recorder.Event(opts.relatedObject, corev1.EventTypeWarning, mariadbv1alpha1.SecretKeyNotFound, msg)
}
return nil, errors.New(msg)
}
return nil, fmt.Errorf("error getting CA Secret \"%s\": %v", secretName, err)
}
return keyPair, nil
}
func (r *CertReconciler) createCAFn(opts *CertReconcilerOpts) func() (*pki.KeyPair, error) {
return func() (*pki.KeyPair, error) {
x509Opts, err := opts.CAx509Opts()
if err != nil {
return nil, fmt.Errorf("error getting CA x509 opts: %v", err)
}
return pki.CreateCA(x509Opts...)
}
}
func (r *CertReconciler) createCertFn(caKeyPair *pki.KeyPair, opts *CertReconcilerOpts) func() (*pki.KeyPair, error) {
return func() (*pki.KeyPair, error) {
x509Opts, err := opts.Certx509Opts()
if err != nil {
return nil, fmt.Errorf("errors getting certificate x509 opts: %v", err)
}
return pki.CreateCert(caKeyPair, x509Opts...)
}
}
func (r *CertReconciler) createSecret(ctx context.Context, key types.NamespacedName, secretType SecretType, secret *corev1.Secret,
keyPair *pki.KeyPair, owner metav1.Object) error {
secret.ObjectMeta = metav1.ObjectMeta{
Name: key.Name,
Namespace: key.Namespace,
}
if secretType == SecretTypeCA {
keyPair.UpdateCASecret(secret)
} else {
secret.Type = corev1.SecretTypeTLS
keyPair.UpdateTLSSecret(secret)
}
if err := r.updateSecretMetadata(secret, owner); err != nil {
return fmt.Errorf("error updating Secret metadata: %v", err)
}
if err := r.Create(ctx, secret); err != nil {
return fmt.Errorf("error creating TLS Secret: %v", err)
}
return nil
}
func (r *CertReconciler) patchSecret(ctx context.Context, secretType SecretType, secret *corev1.Secret,
keyPair *pki.KeyPair, owner metav1.Object) error {
patch := client.MergeFrom(secret.DeepCopy())
if secretType == SecretTypeCA {
keyPair.UpdateCASecret(secret)
} else {
secret.Type = corev1.SecretTypeTLS
keyPair.UpdateTLSSecret(secret)
}
if err := r.updateSecretMetadata(secret, owner); err != nil {
return fmt.Errorf("error updating Secret metadata: %v", err)
}
if err := r.Patch(ctx, secret, patch); err != nil {
return fmt.Errorf("error patching TLS Secret: %v", err)
}
return nil
}
func (r *CertReconciler) updateSecretMetadata(secret *corev1.Secret, owner metav1.Object) error {
if secret.Labels == nil {
secret.Labels = make(map[string]string)
}
secret.Labels[metadata.WatchLabel] = ""
if owner != nil {
if err := controllerutil.SetControllerReference(owner, secret, r.scheme); err != nil {
return fmt.Errorf("error setting controller reference to Secret: %v", err)
}
}
return nil
}
func (r *CertReconciler) getCABundle(ctx context.Context, caKeyPair *pki.KeyPair, opts *CertReconcilerOpts,
logger logr.Logger) ([]*x509.Certificate, error) {
if opts.caBundleSecretKey != nil && opts.caBundleNamespace != nil {
bundle, err := r.refResolver.SecretKeyRef(ctx, *opts.caBundleSecretKey, *opts.caBundleNamespace)
if err == nil {
certs, err := pki.ParseCertificates([]byte(bundle))
if err != nil {
return nil, fmt.Errorf("error parsing bundle certificates: %v", err)
}
return certs, nil
} else {
logger.V(1).Info("error getting CA bundle", "err", err)
}
}
if caKeyPair != nil {
caCerts, err := caKeyPair.Certificates()
if err != nil {
return nil, fmt.Errorf("error getting CA certificates: %v", err)
}
return caCerts, nil
}
return nil, errors.New("unable to get CA bundle")
}