Added tests for CreateCA and CreateCert

This commit is contained in:
Martin Montes
2024-12-21 14:19:18 +01:00
parent 1dbccf18cf
commit 2c0f0e5d64
4 changed files with 344 additions and 148 deletions

View File

@ -98,7 +98,7 @@ func (r *CertReconciler) Reconcile(ctx context.Context) (*ReconcileResult, error
return nil, fmt.Errorf("Error reconciling CA KeyPair: %v", err)
}
valid, err := pki.ValidateCACert(result.CAKeyPair, r.caCommonName, r.lookaheadTime())
valid, err := pki.ValidateCA(result.CAKeyPair, r.caCommonName, r.lookaheadTime())
if !valid || err != nil {
result.CAKeyPair, result.RefreshedCA, err = r.reconcileKeyPair(ctx, r.caSecretKey, true, r.createCA)
if err != nil {

View File

@ -390,7 +390,7 @@ wMfXbaIBSyNnT+e9/glHQsUmYVLu5MskmA==
CommonName: "cert",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(defaultCertLifetimeDuration),
NotAfter: time.Now().Add(defaultCertLifetime),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
},
@ -410,7 +410,7 @@ wMfXbaIBSyNnT+e9/glHQsUmYVLu5MskmA==
CommonName: "ca",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(-defaultCertLifetimeDuration), // Invalid NotAfter
NotAfter: time.Now().Add(-defaultCertLifetime), // Invalid NotAfter
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
@ -428,7 +428,7 @@ wMfXbaIBSyNnT+e9/glHQsUmYVLu5MskmA==
CommonName: "cert",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(defaultCertLifetimeDuration),
NotAfter: time.Now().Add(defaultCertLifetime),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
},

View File

@ -12,9 +12,15 @@ import (
)
var (
defaultCACommonName = "mariadb-operator"
defaultCALifetimeDuration = 3 * 365 * 24 * time.Hour // 3 years
defaultCertLifetimeDuration = 3 * 30 * 24 * time.Hour // 3 months
defaultCACommonName = "mariadb-operator"
defaultCALifetime = 3 * 365 * 24 * time.Hour // 3 years
defaultCertLifetime = 3 * 30 * 24 * time.Hour // 3 months
caMinLifetime = 1 * time.Hour
caMaxLifetime = 10 * 365 * 24 * time.Hour // 10 years
certMinLifetime = 1 * time.Hour
certMaxLifetime = 1 * 365 * 24 * time.Hour // 3 years
)
type X509Opts struct {
@ -68,11 +74,14 @@ func CreateCA(x509Opts ...X509Opt) (*KeyPair, error) {
opts := X509Opts{
CommonName: defaultCACommonName,
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(defaultCALifetimeDuration),
NotAfter: time.Now().Add(defaultCALifetime),
}
for _, setOpt := range x509Opts {
setOpt(&opts)
}
if err := validateLifetime(opts.NotBefore, opts.NotAfter, caMinLifetime, caMaxLifetime); err != nil {
return nil, fmt.Errorf("invalid CA lifetime: %v", err)
}
serialNumber, err := getSerialNumber()
if err != nil {
@ -99,7 +108,7 @@ func CreateCA(x509Opts ...X509Opt) (*KeyPair, error) {
func CreateCert(caKeyPair *KeyPair, x509Opts ...X509Opt) (*KeyPair, error) {
opts := X509Opts{
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(defaultCertLifetimeDuration),
NotAfter: time.Now().Add(defaultCertLifetime),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
}
for _, setOpt := range x509Opts {
@ -108,6 +117,9 @@ func CreateCert(caKeyPair *KeyPair, x509Opts ...X509Opt) (*KeyPair, error) {
if opts.CommonName == "" || opts.DNSNames == nil {
return nil, errors.New("CommonName and DNSNames are mandatory")
}
if err := validateLifetime(opts.NotBefore, opts.NotAfter, certMinLifetime, certMaxLifetime); err != nil {
return nil, fmt.Errorf("invalid certificate lifetime: %v", err)
}
serialNumber, err := getSerialNumber()
if err != nil {
@ -130,41 +142,12 @@ func CreateCert(caKeyPair *KeyPair, x509Opts ...X509Opt) (*KeyPair, error) {
return NewKeyPairFromTemplate(tpl, caKeyPair)
}
func ParseCertificate(bytes []byte) (*x509.Certificate, error) {
certs, err := ParseCertificates(bytes)
func ValidateCA(keyPair *KeyPair, dnsName string, at time.Time) (bool, error) {
certs, err := keyPair.Certificates()
if err != nil {
return nil, err
return false, fmt.Errorf("error getting certificates: %v", err)
}
return certs[0], nil
}
func ParseCertificates(bytes []byte) ([]*x509.Certificate, error) {
var (
certs []*x509.Certificate
block *pem.Block
)
pemBytes := bytes
for len(pemBytes) > 0 {
block, pemBytes = pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("invalid PEM block")
}
if block.Type != pemBlockCertificate {
return nil, fmt.Errorf("invalid PEM certificate block, got block type: %v", block.Type)
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
certs = append(certs, cert)
}
if len(certs) == 0 {
return nil, errors.New("no valid certificates found")
}
return certs, nil
return ValidateCert(certs, keyPair, dnsName, at)
}
func ValidateCert(caCerts []*x509.Certificate, certKeyPair *KeyPair, dnsName string, at time.Time) (bool, error) {
@ -206,12 +189,58 @@ func ValidateCert(caCerts []*x509.Certificate, certKeyPair *KeyPair, dnsName str
return true, nil
}
func ValidateCACert(keyPair *KeyPair, dnsName string, at time.Time) (bool, error) {
certs, err := keyPair.Certificates()
func ParseCertificate(bytes []byte) (*x509.Certificate, error) {
certs, err := ParseCertificates(bytes)
if err != nil {
return false, fmt.Errorf("error getting certificates: %v", err)
return nil, err
}
return ValidateCert(certs, keyPair, dnsName, at)
return certs[0], nil
}
func ParseCertificates(bytes []byte) ([]*x509.Certificate, error) {
var (
certs []*x509.Certificate
block *pem.Block
)
pemBytes := bytes
for len(pemBytes) > 0 {
block, pemBytes = pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("invalid PEM block")
}
if block.Type != pemBlockCertificate {
return nil, fmt.Errorf("invalid PEM certificate block, got block type: %v", block.Type)
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
certs = append(certs, cert)
}
if len(certs) == 0 {
return nil, errors.New("no valid certificates found")
}
return certs, nil
}
func validateLifetime(notBefore, notAfter time.Time, minDuration, maxDuration time.Duration) error {
if notBefore.After(notAfter) {
return fmt.Errorf("NotBefore (%v) cannot be after NotAfter (%v)", notBefore, notAfter)
}
duration := notAfter.Sub(notBefore)
if duration < minDuration {
return fmt.Errorf("lifetime duration (%v) is less than the minimum allowed duration (%v)", duration, minDuration)
}
if duration > maxDuration {
return fmt.Errorf("lifetime duration (%v) exceeds the maximum allowed duration (%v)", duration, maxDuration)
}
return nil
}
var serialNumberLimit = new(big.Int).Lsh(big.NewInt(1), 128)

View File

@ -2,61 +2,62 @@ package pki
import (
"crypto/x509"
"reflect"
"testing"
"time"
)
func TestValidateCACert(t *testing.T) {
caName := "test-mariadb-operator"
x509Opts := []X509Opt{
WithCommonName(caName),
WithNotBefore(time.Now()),
WithNotAfter(time.Now().Add(24 * time.Hour)),
}
caKeyPair, err := CreateCA(x509Opts...)
if err != nil {
t.Fatalf("CA cert creation should succeed. Got error: %v", err)
}
valid, err := ValidateCACert(caKeyPair, caName, time.Now())
if err != nil {
t.Fatalf("CA cert validation should succeed. Got error: %v", err)
}
if !valid {
t.Fatal("Expected CA cert to be valid")
}
valid, err = ValidateCACert(caKeyPair, caName, time.Now().Add(-1*time.Hour))
if err == nil {
t.Fatalf("CA cert validation should return an error. Got nil")
}
if valid {
t.Fatal("Expected CA cert to be invalid")
}
valid, err = ValidateCACert(caKeyPair, "foo", time.Now())
if err == nil {
t.Fatalf("CA cert validation should return an error. Got nil")
}
if valid {
t.Fatal("Expected CA cert to be invalid")
}
caKeyPair, err = CreateCA(x509Opts...)
if err != nil {
t.Fatalf("CA cert renewal should succeed. Got error: %v", err)
}
valid, err = ValidateCACert(caKeyPair, caName, time.Now())
if err != nil {
t.Fatalf("CA cert validation should succeed after renewal. Got error: %v", err)
}
if !valid {
t.Fatal("Expected CA cert to be valid after renewal")
}
func TestCreateCA(t *testing.T) {
testCreateCert(
t,
[]testCaseCreateCert{
{
name: "Invalid Lifetime",
x509Opts: []X509Opt{
WithNotBefore(time.Now().Add(2 * time.Hour)),
WithNotAfter(time.Now().Add(1 * time.Hour)),
},
wantErr: true,
},
{
name: "Default CA",
x509Opts: []X509Opt{},
wantErr: false,
wantCommonName: defaultCACommonName,
wantIssuer: defaultCACommonName,
wantDNSNames: []string{defaultCACommonName},
wantKeyUsage: x509.KeyUsageCertSign,
},
{
name: "Custom CommonName",
x509Opts: []X509Opt{
WithCommonName("custom-ca"),
},
wantErr: false,
wantCommonName: "custom-ca",
wantIssuer: "custom-ca",
wantDNSNames: []string{"custom-ca"},
wantKeyUsage: x509.KeyUsageCertSign,
},
{
name: "Custom Lifetime",
x509Opts: []X509Opt{
WithNotBefore(time.Now().Add(-2 * time.Hour)),
WithNotAfter(time.Now().Add(5 * 365 * 24 * time.Hour)),
},
wantErr: false,
wantCommonName: defaultCACommonName,
wantIssuer: defaultCACommonName,
wantDNSNames: []string{defaultCACommonName},
wantKeyUsage: x509.KeyUsageCertSign,
},
},
CreateCA,
ValidateCA,
)
}
func TestValidateCert(t *testing.T) {
func TestCreateCert(t *testing.T) {
caKeyPair, err := CreateCA()
if err != nil {
t.Fatalf("CA cert creation should succeed. Got error: %v", err)
@ -66,59 +67,80 @@ func TestValidateCert(t *testing.T) {
t.Fatalf("Unable to get CA certificates: %v", err)
}
commonName := "mariadb-operator.default.svc"
x509Opts := []X509Opt{
WithCommonName(commonName),
WithDNSNames([]string{
"mariadb-operator",
"mariadb-operator.default",
commonName,
}),
WithNotBefore(time.Now()),
WithNotAfter(time.Now().Add(24 * time.Hour)),
WithExtKeyUsage(x509.ExtKeyUsageServerAuth),
}
keyPairPEM, err := CreateCert(caKeyPair, x509Opts...)
if err != nil {
t.Fatalf("Certificate creation should succeed. Got error: %v", err)
}
valid, err := ValidateCert(caCerts, keyPairPEM, commonName, time.Now())
if err != nil {
t.Fatalf("Cert validation should succeed. Got error: %v", err)
}
if !valid {
t.Fatal("Expected cert to be valid")
}
valid, err = ValidateCert(caCerts, keyPairPEM, commonName, time.Now().Add(-1*time.Hour))
if err == nil {
t.Fatalf("Cert validation should return an error. Got nil")
}
if valid {
t.Fatal("Expected cert to be invalid")
}
valid, err = ValidateCert(caCerts, keyPairPEM, "foo", time.Now())
if err == nil {
t.Fatalf("Cert validation should return an error. Got nil")
}
if valid {
t.Fatal("Expected cert to be invalid")
}
keyPairPEM, err = CreateCert(caKeyPair, x509Opts...)
if err != nil {
t.Fatalf("Certificate renewal should succeed. Got error: %v", err)
}
valid, err = ValidateCert(caCerts, keyPairPEM, commonName, time.Now())
if err != nil {
t.Fatalf("Cert validation should succeed after renewal. Got error: %v", err)
}
if !valid {
t.Fatal("Expected cert to be valid")
}
testCreateCert(
t,
[]testCaseCreateCert{
{
name: "Missing CommonName",
x509Opts: []X509Opt{
WithDNSNames([]string{"missing-common-name"}),
},
wantErr: true,
},
{
name: "Missing DNSNames",
x509Opts: []X509Opt{
WithCommonName("missing-dns-names"),
},
wantErr: true,
},
{
name: "Invalid Lifetime",
x509Opts: []X509Opt{
WithCommonName("invalid-lifetime"),
WithDNSNames([]string{"invalid-lifetime"}),
WithNotBefore(time.Now().Add(2 * time.Hour)),
WithNotAfter(time.Now().Add(1 * time.Hour)),
},
wantErr: true,
},
{
name: "Default Cert",
x509Opts: []X509Opt{
WithCommonName("default-cert"),
WithDNSNames([]string{"default-cert"}),
},
wantErr: false,
wantCommonName: "default-cert",
wantIssuer: defaultCACommonName,
wantDNSNames: []string{"default-cert"},
wantKeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
},
{
name: "Custom Key Usage",
x509Opts: []X509Opt{
WithCommonName("custom-key-usage"),
WithDNSNames([]string{"custom-key-usage"}),
WithKeyUsage(x509.KeyUsageKeyEncipherment),
},
wantErr: false,
wantCommonName: "custom-key-usage",
wantIssuer: defaultCACommonName,
wantDNSNames: []string{"custom-key-usage"},
wantKeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment,
},
{
name: "Custom Ext Key Usage",
x509Opts: []X509Opt{
WithCommonName("custom-ext-key-usage"),
WithDNSNames([]string{"custom-ext-key-usage"}),
WithExtKeyUsage(x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth),
},
wantErr: false,
wantCommonName: "custom-ext-key-usage",
wantIssuer: defaultCACommonName,
wantDNSNames: []string{"custom-ext-key-usage"},
wantKeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement,
wantExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
},
},
func(opts ...X509Opt) (*KeyPair, error) {
return CreateCert(caKeyPair, opts...)
},
func(kp *KeyPair, dnsName string, at time.Time) (bool, error) {
return ValidateCert(caCerts, kp, dnsName, at)
},
)
}
func TestParseCertificates(t *testing.T) {
@ -261,3 +283,148 @@ invalid
})
}
}
func TestValidateLifetime(t *testing.T) {
minLifetime := 1 * time.Hour
maxLifetime := 5 * 365 * 24 * time.Hour // 5 years
tests := []struct {
name string
notBefore time.Time
notAfter time.Time
minDuration time.Duration
maxDuration time.Duration
wantErr bool
}{
{
name: "Valid lifetime",
notBefore: time.Now(),
notAfter: time.Now().Add(2 * time.Hour),
minDuration: minLifetime,
maxDuration: maxLifetime,
wantErr: false,
},
{
name: "NotBefore after NotAfter",
notBefore: time.Now().Add(2 * time.Hour),
notAfter: time.Now(),
minDuration: minLifetime,
maxDuration: maxLifetime,
wantErr: true,
},
{
name: "Duration less than minimum",
notBefore: time.Now(),
notAfter: time.Now().Add(30 * time.Minute),
minDuration: minLifetime,
maxDuration: maxLifetime,
wantErr: true,
},
{
name: "Duration exceeds maximum",
notBefore: time.Now(),
notAfter: time.Now().Add(6 * 365 * 24 * time.Hour), // 6 years
minDuration: minLifetime,
maxDuration: maxLifetime,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateLifetime(tt.notBefore, tt.notAfter, tt.minDuration, tt.maxDuration)
if tt.wantErr && err == nil {
t.Fatalf("Expecting error to be non nil for test '%s'", tt.name)
}
if !tt.wantErr && err != nil {
t.Fatalf("Expecting error to be nil for test '%s'. Got: %v", tt.name, err)
}
})
}
}
type testCaseCreateCert struct {
name string
x509Opts []X509Opt
wantErr bool
wantCommonName string
wantIssuer string
wantDNSNames []string
wantKeyUsage x509.KeyUsage
wantExtKeyUsage []x509.ExtKeyUsage
}
func testCreateCert(
t *testing.T,
tests []testCaseCreateCert,
createCertFn func(...X509Opt) (*KeyPair, error),
validateCertFn func(*KeyPair, string, time.Time) (bool, error),
) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyPair, err := createCertFn(tt.x509Opts...)
if tt.wantErr && err == nil {
t.Fatalf("Expecting error to be non nil when creating cert '%s'", tt.name)
}
if !tt.wantErr && err != nil {
t.Fatalf("Expecting error to be nil when creating cert '%s'. Got: %v", tt.name, err)
}
if tt.wantErr {
return
}
certs, err := keyPair.Certificates()
if err != nil {
t.Fatalf("error getting certificates: %v", err)
}
cert := certs[0] // we are only creating certificates with a single PEM block
commonName := cert.Subject.CommonName
notBefore := cert.NotBefore
if commonName != tt.wantCommonName {
t.Fatalf("unexpected common name, got: %v, want: %v", commonName, tt.wantCommonName)
}
if cert.Issuer.CommonName != tt.wantIssuer {
t.Fatalf("unexpected issuer, got: %v, want: %v", cert.Issuer.CommonName, tt.wantIssuer)
}
if !reflect.DeepEqual(cert.DNSNames, tt.wantDNSNames) {
t.Fatalf("unexpected DNS names, got: %v, want: %v", cert.DNSNames, tt.wantDNSNames)
}
if !reflect.DeepEqual(cert.KeyUsage, tt.wantKeyUsage) {
t.Fatalf("unexpected key usage, got: %v, want: %v", cert.KeyUsage, tt.wantKeyUsage)
}
if !reflect.DeepEqual(cert.ExtKeyUsage, tt.wantExtKeyUsage) {
t.Fatalf("unexpected extended key usage, got: %v, want: %v", cert.ExtKeyUsage, tt.wantExtKeyUsage)
}
valid, err := validateCertFn(keyPair, commonName, notBefore.Add(-1*time.Hour))
if err == nil {
t.Fatalf("Cert validation should return an error. Got nil")
}
if valid {
t.Fatal("Expected cert to be invalid")
}
valid, err = validateCertFn(keyPair, "foo", time.Now())
if err == nil {
t.Fatalf("Cert validation should return an error. Got nil")
}
if valid {
t.Fatal("Expected cert to be invalid")
}
keyPair, err = createCertFn(tt.x509Opts...)
if err != nil {
t.Fatalf("Certificate renewal should succeed. Got error: %v", err)
}
valid, err = validateCertFn(keyPair, commonName, time.Now())
if err != nil {
t.Fatalf("Cert validation should succeed after renewal. Got error: %v", err)
}
if !valid {
t.Fatal("Expected cert to be valid")
}
})
}
}