Add cert store option based on sqlite3, mysql & postgres (#173)

Deprecate **pogreb**!

close #169

Reviewed-on: https://codeberg.org/Codeberg/pages-server/pulls/173
This commit is contained in:
6543 2023-02-10 03:00:14 +00:00
parent 7fce7cf68b
commit 7b35a192bf
22 changed files with 1000 additions and 255 deletions

View file

@ -1,14 +1,12 @@
package certificates
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
@ -100,10 +98,9 @@ func TLSConfig(mainDomainSuffix string,
return tlsCertificate.(*tls.Certificate), nil
}
var tlsCertificate tls.Certificate
var tlsCertificate *tls.Certificate
var err error
var ok bool
if tlsCertificate, ok = retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider, acmeUseRateLimits, certDB); !ok {
if tlsCertificate, err = retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider, acmeUseRateLimits, certDB); err != nil {
// request a new certificate
if strings.EqualFold(sni, mainDomainSuffix) {
return nil, errors.New("won't request certificate for main domain, something really bad has happened")
@ -119,12 +116,11 @@ func TLSConfig(mainDomainSuffix string,
}
}
if err := keyCache.Set(sni, &tlsCertificate, 15*time.Minute); err != nil {
if err := keyCache.Set(sni, tlsCertificate, 15*time.Minute); err != nil {
return nil, err
}
return &tlsCertificate, nil
return tlsCertificate, nil
},
PreferServerCipherSuites: true,
NextProtos: []string{
"h2",
"http/1.1",
@ -205,54 +201,53 @@ func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
return nil
}
func retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) (tls.Certificate, bool) {
func retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) (*tls.Certificate, error) {
// parse certificate from database
res, err := certDB.Get(sni)
if err != nil {
panic(err) // TODO: no panic
}
if res == nil {
return tls.Certificate{}, false
return nil, err
} else if res == nil {
return nil, database.ErrNotFound
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
panic(err)
return nil, err
}
// TODO: document & put into own function
if !strings.EqualFold(sni, mainDomainSuffix) {
tlsCertificate.Leaf, err = x509.ParseCertificate(tlsCertificate.Certificate[0])
if err != nil {
panic(err)
return nil, fmt.Errorf("error parsin leaf tlsCert: %w", err)
}
// renew certificates 7 days before they expire
if tlsCertificate.Leaf.NotAfter.Before(time.Now().Add(7 * 24 * time.Hour)) {
// TODO: add ValidUntil to custom res struct
// TODO: use ValidTill of custom cert struct
if res.CSR != nil && len(res.CSR) > 0 {
// CSR stores the time when the renewal shall be tried again
nextTryUnix, err := strconv.ParseInt(string(res.CSR), 10, 64)
if err == nil && time.Now().Before(time.Unix(nextTryUnix, 0)) {
return tlsCertificate, true
return &tlsCertificate, nil
}
}
// TODO: make a queue ?
go (func() {
res.CSR = nil // acme client doesn't like CSR to be set
tlsCertificate, err = obtainCert(acmeClient, []string{sni}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
if err != nil {
if _, err := obtainCert(acmeClient, []string{sni}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB); err != nil {
log.Error().Msgf("Couldn't renew certificate for %s: %v", sni, err)
}
})()
}
}
return tlsCertificate, true
return &tlsCertificate, nil
}
var obtainLocks = sync.Map{}
func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Resource, user, dnsProvider, mainDomainSuffix string, acmeUseRateLimits bool, keyDatabase database.CertDB) (tls.Certificate, error) {
func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Resource, user, dnsProvider, mainDomainSuffix string, acmeUseRateLimits bool, keyDatabase database.CertDB) (*tls.Certificate, error) {
name := strings.TrimPrefix(domains[0], "*")
if dnsProvider == "" && len(domains[0]) > 0 && domains[0][0] == '*' {
domains = domains[1:]
@ -265,16 +260,16 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
time.Sleep(100 * time.Millisecond)
_, working = obtainLocks.Load(name)
}
cert, ok := retrieveCertFromDB(name, mainDomainSuffix, dnsProvider, acmeUseRateLimits, keyDatabase)
if !ok {
return tls.Certificate{}, errors.New("certificate failed in synchronous request")
cert, err := retrieveCertFromDB(name, mainDomainSuffix, dnsProvider, acmeUseRateLimits, keyDatabase)
if err != nil {
return nil, fmt.Errorf("certificate failed in synchronous request: %w", err)
}
return cert, nil
}
defer obtainLocks.Delete(name)
if acmeClient == nil {
return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", mainDomainSuffix, keyDatabase), nil
return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", mainDomainSuffix, keyDatabase)
}
// request actual cert
@ -297,7 +292,7 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
if res == nil {
if user != "" {
if err := checkUserLimit(user); err != nil {
return tls.Certificate{}, err
return nil, err
}
}
@ -320,33 +315,42 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
if renew != nil && renew.CertURL != "" {
tlsCertificate, err := tls.X509KeyPair(renew.Certificate, renew.PrivateKey)
if err != nil {
return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase), err
mockC, err2 := mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase)
if err2 != nil {
return nil, errors.Join(err, err2)
}
return mockC, err
}
leaf, err := leaf(&tlsCertificate)
if err == nil && leaf.NotAfter.After(time.Now()) {
// avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at
renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10))
if err := keyDatabase.Put(name, renew); err != nil {
return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase), err
mockC, err2 := mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase)
if err2 != nil {
return nil, errors.Join(err, err2)
}
return mockC, err
}
return tlsCertificate, nil
return &tlsCertificate, nil
}
}
return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase), err
return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase)
}
log.Debug().Msgf("Obtained certificate for %v", domains)
if err := keyDatabase.Put(name, res); err != nil {
return tls.Certificate{}, err
return nil, err
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
return tls.Certificate{}, err
return nil, err
}
return tlsCertificate, nil
return &tlsCertificate, nil
}
func SetupAcmeConfig(acmeAPI, acmeMail, acmeEabHmac, acmeEabKID string, acmeAcceptTerms bool) (*lego.Config, error) {
// TODO: make it a config flag
const configFile = "acme-account.json"
var myAcmeAccount AcmeAccount
var myAcmeConfig *lego.Config
@ -431,8 +435,8 @@ func SetupAcmeConfig(acmeAPI, acmeMail, acmeEabHmac, acmeEabKID string, acmeAcce
func SetupCertificates(mainDomainSuffix, dnsProvider string, acmeConfig *lego.Config, acmeUseRateLimits, enableHTTPServer bool, challengeCache cache.SetGetKey, certDB database.CertDB) error {
// getting main cert before ACME account so that we can fail here without hitting rate limits
mainCertBytes, err := certDB.Get(mainDomainSuffix)
if err != nil {
return fmt.Errorf("cert database is not working")
if err != nil && !errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("cert database is not working: %w", err)
}
acmeClient, err = lego.NewClient(acmeConfig)
@ -485,41 +489,35 @@ func SetupCertificates(mainDomainSuffix, dnsProvider string, acmeConfig *lego.Co
func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffix, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) {
for {
// clean up expired certs
now := time.Now()
// delete expired certs that will be invalid until next clean up
threshold := time.Now().Add(interval)
expiredCertCount := 0
keyDatabaseIterator := certDB.Items()
key, resBytes, err := keyDatabaseIterator.Next()
for err == nil {
if !strings.EqualFold(string(key), mainDomainSuffix) {
resGob := bytes.NewBuffer(resBytes)
resDec := gob.NewDecoder(resGob)
res := &certificate.Resource{}
err = resDec.Decode(res)
if err != nil {
panic(err)
}
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
if err != nil || tlsCertificates[0].NotAfter.Before(now) {
err := certDB.Delete(string(key))
if err != nil {
log.Error().Err(err).Msgf("Deleting expired certificate for %q failed", string(key))
} else {
expiredCertCount++
certs, err := certDB.Items(0, 0)
if err != nil {
log.Error().Err(err).Msg("could not get certs from list")
} else {
for _, cert := range certs {
if !strings.EqualFold(cert.Domain, strings.TrimPrefix(mainDomainSuffix, ".")) {
if time.Unix(cert.ValidTill, 0).Before(threshold) {
err := certDB.Delete(cert.Domain)
if err != nil {
log.Error().Err(err).Msgf("Deleting expired certificate for %q failed", cert.Domain)
} else {
expiredCertCount++
}
}
}
}
key, resBytes, err = keyDatabaseIterator.Next()
}
log.Debug().Msgf("Removed %d expired certificates from the database", expiredCertCount)
log.Debug().Msgf("Removed %d expired certificates from the database", expiredCertCount)
// compact the database
msg, err := certDB.Compact()
if err != nil {
log.Error().Err(err).Msg("Compacting key database failed")
} else {
log.Debug().Msgf("Compacted key database: %s", msg)
// compact the database
msg, err := certDB.Compact()
if err != nil {
log.Error().Err(err).Msg("Compacting key database failed")
} else {
log.Debug().Msgf("Compacted key database: %s", msg)
}
}
// update main cert
@ -530,9 +528,10 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
log.Error().Msgf("Couldn't renew certificate for main domain %q expected main domain cert to exist, but it's missing - seems like the database is corrupted", mainDomainSuffix)
} else {
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
// renew main certificate 30 days before it expires
if tlsCertificates[0].NotAfter.Before(time.Now().Add(30 * 24 * time.Hour)) {
if err != nil {
log.Error().Err(fmt.Errorf("could not parse cert for mainDomainSuffix: %w", err))
} else if tlsCertificates[0].NotAfter.Before(time.Now().Add(30 * 24 * time.Hour)) {
// renew main certificate 30 days before it expires
go (func() {
_, err = obtainCert(mainDomainAcmeClient, []string{"*" + mainDomainSuffix, mainDomainSuffix[1:]}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
if err != nil {

View file

@ -13,14 +13,15 @@ import (
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
"github.com/rs/zerolog/log"
"codeberg.org/codeberg/pages/server/database"
)
func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB) tls.Certificate {
func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB) (*tls.Certificate, error) {
key, err := certcrypto.GeneratePrivateKey(certcrypto.RSA2048)
if err != nil {
panic(err)
return nil, err
}
template := x509.Certificate{
@ -52,7 +53,7 @@ func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB)
key,
)
if err != nil {
panic(err)
return nil, err
}
out := &bytes.Buffer{}
@ -61,7 +62,7 @@ func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB)
Type: "CERTIFICATE",
})
if err != nil {
panic(err)
return nil, err
}
outBytes := out.Bytes()
res := &certificate.Resource{
@ -75,12 +76,12 @@ func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB)
databaseName = mainDomainSuffix
}
if err := keyDatabase.Put(databaseName, res); err != nil {
panic(err)
log.Error().Err(err)
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
panic(err)
return nil, err
}
return tlsCertificate
return &tlsCertificate, nil
}

View file

@ -10,7 +10,8 @@ import (
func TestMockCert(t *testing.T) {
db, err := database.NewTmpDB()
assert.NoError(t, err)
cert := mockCert("example.com", "some error msg", "codeberg.page", db)
cert, err := mockCert("example.com", "some error msg", "codeberg.page", db)
assert.NoError(t, err)
if assert.NotEmpty(t, cert) {
assert.NotEmpty(t, cert.Certificate)
}