(Ab)use CSR field to store try-again date for renewals (instead of showing a mock cert), must be tested when the first renewals are due

This commit is contained in:
Moritz Marquardt 2021-12-01 22:49:48 +01:00
parent f29ebc57d3
commit 544b3f7321
No known key found for this signature in database
GPG key ID: D5788327BEE388B6
2 changed files with 73 additions and 56 deletions

View file

@ -24,6 +24,7 @@ import (
"log" "log"
"math/big" "math/big"
"os" "os"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -207,21 +208,9 @@ func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
func retrieveCertFromDB(sni []byte) (tls.Certificate, bool) { func retrieveCertFromDB(sni []byte) (tls.Certificate, bool) {
// parse certificate from database // parse certificate from database
resBytes, err := keyDatabase.Get(sni)
if err != nil {
// key database is not working
panic(err)
}
if resBytes == nil {
return tls.Certificate{}, false
}
resGob := bytes.NewBuffer(resBytes)
resDec := gob.NewDecoder(resGob)
res := &certificate.Resource{} res := &certificate.Resource{}
err = resDec.Decode(res) if !PogrebGet(keyDatabase, sni, res) {
if err != nil { return tls.Certificate{}, false
panic(err)
} }
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
@ -237,7 +226,15 @@ func retrieveCertFromDB(sni []byte) (tls.Certificate, bool) {
// renew certificates 7 days before they expire // renew certificates 7 days before they expire
if !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-7 * 24 * time.Hour)) { if !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-7 * 24 * time.Hour)) {
if res.CSR != nil && len(res.CSR) > 0 {
// CSR stores the time when the renewal shall be tried again
nextTryUnix, err := strconv.ParseInt(string(res.CSR), 10, 64)
if err == nil && time.Now().Before(time.Unix(nextTryUnix, 0)) {
return tlsCertificate, true
}
}
go (func() { go (func() {
res.CSR = nil // acme client doesn't like CSR to be set
tlsCertificate, err = obtainCert(acmeClient, []string{string(sni)}, res, "") tlsCertificate, err = obtainCert(acmeClient, []string{string(sni)}, res, "")
if err != nil { if err != nil {
log.Printf("Couldn't renew certificate for %s: %s", sni, err) log.Printf("Couldn't renew certificate for %s: %s", sni, err)
@ -310,18 +307,21 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
} }
if err != nil { if err != nil {
log.Printf("Couldn't obtain certificate for %v: %s", domains, err) log.Printf("Couldn't obtain certificate for %v: %s", domains, err)
return mockCert(domains[0], err.Error()), err if renew != nil && renew.CertURL != "" {
tlsCertificate, err := tls.X509KeyPair(renew.Certificate, renew.PrivateKey)
if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) {
// avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at
renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6 * time.Hour).Unix(), 10))
PogrebPut(keyDatabase, []byte(name), renew)
return tlsCertificate, nil
}
} else {
return mockCert(domains[0], err.Error()), err
}
} }
log.Printf("Obtained certificate for %v", domains) log.Printf("Obtained certificate for %v", domains)
var resGob bytes.Buffer PogrebPut(keyDatabase, []byte(name), res)
resEnc := gob.NewEncoder(&resGob)
err = resEnc.Encode(res)
if err != nil {
panic(err)
}
err = keyDatabase.Put([]byte(name), resGob.Bytes())
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil { if err != nil {
return tls.Certificate{}, err return tls.Certificate{}, err
@ -382,20 +382,11 @@ func mockCert(domain string, msg string) tls.Certificate {
IssuerCertificate: outBytes, IssuerCertificate: outBytes,
Domain: domain, Domain: domain,
} }
var resGob bytes.Buffer
resEnc := gob.NewEncoder(&resGob)
err = resEnc.Encode(res)
if err != nil {
panic(err)
}
databaseName := domain databaseName := domain
if domain == "*" + string(MainDomainSuffix) || domain == string(MainDomainSuffix[1:]) { if domain == "*" + string(MainDomainSuffix) || domain == string(MainDomainSuffix[1:]) {
databaseName = string(MainDomainSuffix) databaseName = string(MainDomainSuffix)
} }
err = keyDatabase.Put([]byte(databaseName), resGob.Bytes()) PogrebPut(keyDatabase, []byte(databaseName), res)
if err != nil {
panic(err)
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey) tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil { if err != nil {
@ -585,30 +576,21 @@ func setupCertificates() {
} }
// update main cert // update main cert
resBytes, err = keyDatabase.Get(MainDomainSuffix)
if err != nil {
// key database is not working
panic(err)
}
resGob := bytes.NewBuffer(resBytes)
resDec := gob.NewDecoder(resGob)
res := &certificate.Resource{} res := &certificate.Resource{}
err = resDec.Decode(res) if !PogrebGet(keyDatabase, MainDomainSuffix, res) {
if err != nil { log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted")
panic(err) } else {
} tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) // renew main certificate 30 days before it expires
if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) {
// renew main certificate 30 days before it expires go (func() {
if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) { _, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(MainDomainSuffix), string(MainDomainSuffix[1:])}, res, "")
go (func() { if err != nil {
_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(MainDomainSuffix), string(MainDomainSuffix[1:])}, res, "") log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", err)
if err != nil { }
log.Printf("Couldn't renew certificate for *%s: %s", MainDomainSuffix, err) })()
} }
})()
} }
time.Sleep(12 * time.Hour) time.Sleep(12 * time.Hour)

View file

@ -1,6 +1,10 @@
package main package main
import "bytes" import (
"bytes"
"encoding/gob"
"github.com/akrylysov/pogreb"
)
// GetHSTSHeader returns a HSTS header with includeSubdomains & preload for MainDomainSuffix and RawDomain, or an empty // GetHSTSHeader returns a HSTS header with includeSubdomains & preload for MainDomainSuffix and RawDomain, or an empty
// string for custom domains. // string for custom domains.
@ -19,3 +23,34 @@ func TrimHostPort(host []byte) []byte {
} }
return host return host
} }
func PogrebPut(db *pogreb.DB, name []byte, obj interface{}) {
var resGob bytes.Buffer
resEnc := gob.NewEncoder(&resGob)
err := resEnc.Encode(obj)
if err != nil {
panic(err)
}
err = db.Put(name, resGob.Bytes())
if err != nil {
panic(err)
}
}
func PogrebGet(db *pogreb.DB, name []byte, obj interface{}) bool {
resBytes, err := db.Get(name)
if err != nil {
panic(err)
}
if resBytes == nil {
return false
}
resGob := bytes.NewBuffer(resBytes)
resDec := gob.NewDecoder(resGob)
err = resDec.Decode(obj)
if err != nil {
panic(err)
}
return true
}