rm certDB helper and build in

This commit is contained in:
6543 2021-12-05 19:00:57 +01:00
parent a0534f1fde
commit 5fe51d8621
No known key found for this signature in database
GPG key ID: C99B82E40B027BAE
5 changed files with 53 additions and 59 deletions

View file

@ -188,8 +188,11 @@ func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
func retrieveCertFromDB(sni, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool, keyDatabase database.CertDB) (tls.Certificate, bool) { func retrieveCertFromDB(sni, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool, keyDatabase database.CertDB) (tls.Certificate, bool) {
// parse certificate from database // parse certificate from database
res := &certificate.Resource{} res, err := keyDatabase.Get(sni)
if !database.PogrebGet(keyDatabase, sni, res) { if err != nil {
panic(err) // TODO: no panic
}
if res == nil {
return tls.Certificate{}, false return tls.Certificate{}, false
} }
@ -294,7 +297,9 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) { if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) {
// avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at // avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at
renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10)) renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10))
database.PogrebPut(keyDatabase, []byte(name), renew) if err := keyDatabase.Put(name, renew); err != nil {
return mockCert(domains[0], err.Error(), string(mainDomainSuffix), keyDatabase), err
}
return tlsCertificate, nil return tlsCertificate, nil
} }
} }
@ -302,7 +307,9 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
} }
log.Printf("Obtained certificate for %v", domains) log.Printf("Obtained certificate for %v", domains)
database.PogrebPut(keyDatabase, []byte(name), res) if err := keyDatabase.Put(name, res); err != nil {
return tls.Certificate{}, err
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil { if err != nil {
return tls.Certificate{}, err return tls.Certificate{}, err
@ -447,12 +454,12 @@ func SetupCertificates(mainDomainSuffix []byte, dnsProvider string, acmeConfig *
} }
} }
func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool, keyDatabase database.CertDB) { func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffix []byte, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) {
for { for {
// clean up expired certs // clean up expired certs
now := time.Now() now := time.Now()
expiredCertCount := 0 expiredCertCount := 0
keyDatabaseIterator := keyDatabase.Items() keyDatabaseIterator := certDB.Items()
key, resBytes, err := keyDatabaseIterator.Next() key, resBytes, err := keyDatabaseIterator.Next()
for err == nil { for err == nil {
if !bytes.Equal(key, mainDomainSuffix) { if !bytes.Equal(key, mainDomainSuffix) {
@ -466,7 +473,7 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
if err != nil || !tlsCertificates[0].NotAfter.After(now) { if err != nil || !tlsCertificates[0].NotAfter.After(now) {
err := keyDatabase.Delete(key) err := certDB.Delete(key)
if err != nil { if err != nil {
log.Printf("[ERROR] Deleting expired certificate for %s failed: %s", string(key), err) log.Printf("[ERROR] Deleting expired certificate for %s failed: %s", string(key), err)
} else { } else {
@ -479,7 +486,7 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
log.Printf("[INFO] Removed %d expired certificates from the database", expiredCertCount) log.Printf("[INFO] Removed %d expired certificates from the database", expiredCertCount)
// compact the database // compact the database
result, err := keyDatabase.Compact() result, err := certDB.Compact()
if err != nil { if err != nil {
log.Printf("[ERROR] Compacting key database failed: %s", err) log.Printf("[ERROR] Compacting key database failed: %s", err)
} else { } else {
@ -487,16 +494,18 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
} }
// update main cert // update main cert
res := &certificate.Resource{} res, err := certDB.Get(mainDomainSuffix)
if !database.PogrebGet(keyDatabase, mainDomainSuffix, res) { if err != nil {
log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted") log.Err(err).Msgf("could not get cert for domain '%s'", mainDomainSuffix)
} else if res == nil {
log.Error().Msgf("Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted")
} else { } else {
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
// renew main certificate 30 days before it expires // renew main certificate 30 days before it expires
if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) { if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) {
go (func() { go (func() {
_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, keyDatabase) _, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
if err != nil { if err != nil {
log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", err) log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", err)
} }

View file

@ -74,7 +74,9 @@ func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB)
if domain == "*"+mainDomainSuffix || domain == mainDomainSuffix[1:] { if domain == "*"+mainDomainSuffix || domain == mainDomainSuffix[1:] {
databaseName = mainDomainSuffix databaseName = mainDomainSuffix
} }
database.PogrebPut(keyDatabase, []byte(databaseName), res) if err := keyDatabase.Put(databaseName, res); err != nil {
panic(err)
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil { if err != nil {

View file

@ -1,37 +0,0 @@
package database
import (
"bytes"
"encoding/gob"
)
func PogrebPut(db CertDB, name []byte, obj interface{}) {
var resGob bytes.Buffer
resEnc := gob.NewEncoder(&resGob)
err := resEnc.Encode(obj)
if err != nil {
panic(err)
}
err = db.Put(name, resGob.Bytes())
if err != nil {
panic(err)
}
}
func PogrebGet(db CertDB, name []byte, obj interface{}) bool {
resBytes, err := db.Get(name)
if err != nil {
panic(err)
}
if resBytes == nil {
return false
}
resGob := bytes.NewBuffer(resBytes)
resDec := gob.NewDecoder(resGob)
err = resDec.Decode(obj)
if err != nil {
panic(err)
}
return true
}

View file

@ -1,11 +1,14 @@
package database package database
import "github.com/akrylysov/pogreb" import (
"github.com/akrylysov/pogreb"
"github.com/go-acme/lego/v4/certificate"
)
type CertDB interface { type CertDB interface {
Close() error Close() error
Put(key []byte, value []byte) error Put(name string, cert *certificate.Resource) error
Get(key []byte) ([]byte, error) Get(name []byte) (*certificate.Resource, error)
Delete(key []byte) error Delete(key []byte) error
Compact() (pogreb.CompactionResult, error) Compact() (pogreb.CompactionResult, error)
Items() *pogreb.ItemIterator Items() *pogreb.ItemIterator

View file

@ -1,14 +1,16 @@
package database package database
import ( import (
"bytes"
"context" "context"
"encoding/gob"
"fmt" "fmt"
"time" "time"
"github.com/rs/zerolog/log"
"github.com/akrylysov/pogreb" "github.com/akrylysov/pogreb"
"github.com/akrylysov/pogreb/fs" "github.com/akrylysov/pogreb/fs"
"github.com/go-acme/lego/v4/certificate"
"github.com/rs/zerolog/log"
) )
type aDB struct { type aDB struct {
@ -23,12 +25,27 @@ func (p aDB) Close() error {
return p.intern.Sync() return p.intern.Sync()
} }
func (p aDB) Put(key []byte, value []byte) error { func (p aDB) Put(name string, cert *certificate.Resource) error {
return p.intern.Put(key, value) var resGob bytes.Buffer
if err := gob.NewEncoder(&resGob).Encode(cert); err != nil {
return err
}
return p.intern.Put([]byte(name), resGob.Bytes())
} }
func (p aDB) Get(key []byte) ([]byte, error) { func (p aDB) Get(name []byte) (*certificate.Resource, error) {
return p.intern.Get(key) cert := &certificate.Resource{}
resBytes, err := p.intern.Get(name)
if err != nil {
return nil, err
}
if resBytes == nil {
return nil, nil
}
if err = gob.NewDecoder(bytes.NewBuffer(resBytes)).Decode(cert); err != nil {
return nil, err
}
return cert, nil
} }
func (p aDB) Delete(key []byte) error { func (p aDB) Delete(key []byte) error {