Completely refactor certificates and implement renewal & cleanup

This commit is contained in:
Moritz Marquardt 2021-11-20 15:30:58 +01:00
parent 33f7a5d0df
commit 2aaac2c52b
No known key found for this signature in database
GPG key ID: D5788327BEE388B6
7 changed files with 242 additions and 247 deletions

View file

@ -5,7 +5,7 @@
<meta name="viewport" content="width=device-width"> <meta name="viewport" content="width=device-width">
<title>%status</title> <title>%status</title>
<link rel="stylesheet" href="https://design.codeberg-test.org/design-kit/codeberg.css" /> <link rel="stylesheet" href="https://design.codeberg.org/design-kit/codeberg.css" />
<link href="https://fonts.codeberg.org/dist/inter/Inter%20Web/inter.css" rel="stylesheet" /> <link href="https://fonts.codeberg.org/dist/inter/Inter%20Web/inter.css" rel="stylesheet" />
<link href="https://fonts.codeberg.org/dist/fontawesome5/css/all.min.css" rel="stylesheet" /> <link href="https://fonts.codeberg.org/dist/fontawesome5/css/all.min.css" rel="stylesheet" />
@ -29,7 +29,7 @@
Sorry, this page doesn't exist or is inaccessible for other reasons (%status) Sorry, this page doesn't exist or is inaccessible for other reasons (%status)
</h5> </h5>
<small class="text-muted"> <small class="text-muted">
<img src="https://design.codeberg-test.org/logo-kit/icon.svg" class="align-top"> <img src="https://design.codeberg.org/logo-kit/icon.svg" class="align-top">
Static pages made easy - <a href="https://codeberg.page">Codeberg Pages</a> Static pages made easy - <a href="https://codeberg.page">Codeberg Pages</a>
</small> </small>
</body> </body>

View file

@ -1,11 +1,13 @@
## Environment ## Environment
- `HOST` & `PORT` (default: `[::]` & `443`): listen address.
- `PAGES_DOMAIN` (default: `codeberg.page`): main domain for pages. - `PAGES_DOMAIN` (default: `codeberg.page`): main domain for pages.
- `RAW_DOMAIN` (default: `raw.codeberg.org`): domain for raw resources. - `RAW_DOMAIN` (default: `raw.codeberg.org`): domain for raw resources.
- `GITEA_ROOT` (default: `https://codeberg.org`): root of the upstream Gitea instance. - `GITEA_ROOT` (default: `https://codeberg.org`): root of the upstream Gitea instance.
- `REDIRECT_BROKEN_DNS` (default: "https://docs.codeberg.org/pages/custom-domains/"): info page for setting up DNS, shown for invalid DNS setups. - `REDIRECT_BROKEN_DNS` (default: https://docs.codeberg.org/pages/custom-domains/): info page for setting up DNS, shown for invalid DNS setups.
- `REDIRECT_RAW_INFO` (default: https://docs.codeberg.org/pages/raw-content/): info page for raw resources, shown if no resource is provided. - `REDIRECT_RAW_INFO` (default: https://docs.codeberg.org/pages/raw-content/): info page for raw resources, shown if no resource is provided.
- `ACME_API` (default: https://acme-v02.api.letsencrypt.org/directory): Set this to "https://acme-staging-v02.api.letsencrypt.org/directory" to use the staging API of Let's Encrypt instead. - `ACME_API` (default: https://acme.zerossl.com/v2/DV90): set this to https://acme.mock.director to use invalid certificates without any verification (great for debugging). ZeroSSL is used as it doesn't have rate limits and doesn't clash with the official Codeberg certificates (which are using Let's Encrypt).
- `ACME_EMAIL` (default: `noreply@example.email`): Set this to "true" to accept the Terms of Service of your ACME provider.
- `ACME_ACCEPT_TERMS` (default: use self-signed certificate): Set this to "true" to accept the Terms of Service of your ACME provider. - `ACME_ACCEPT_TERMS` (default: use self-signed certificate): Set this to "true" to accept the Terms of Service of your ACME provider.
- `DNS_PROVIDER` (default: use self-signed certificate): Code of the ACME DNS provider for the main domain wildcard. - `DNS_PROVIDER` (default: use self-signed certificate): Code of the ACME DNS provider for the main domain wildcard.
See https://go-acme.github.io/lego/dns/ for available values & additional environment variables. See https://go-acme.github.io/lego/dns/ for available values & additional environment variables.

View file

@ -6,20 +6,17 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors" "errors"
"github.com/OrlovEvgeny/go-mcache" "github.com/OrlovEvgeny/go-mcache"
"github.com/akrylysov/pogreb/fs" "github.com/akrylysov/pogreb/fs"
"github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/challenge/resolver"
"github.com/go-acme/lego/v4/challenge/tlsalpn01" "github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/go-acme/lego/v4/providers/dns" "github.com/go-acme/lego/v4/providers/dns"
"log" "log"
"math/big"
"os" "os"
"strings" "strings"
"time" "time"
@ -36,10 +33,6 @@ import (
var tlsConfig = &tls.Config{ var tlsConfig = &tls.Config{
// check DNS name & get certificate from Let's Encrypt // check DNS name & get certificate from Let's Encrypt
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if os.Getenv("ACME_ACCEPT_TERMS") != "true" {
return FallbackCertificate(), nil
}
sni := strings.ToLower(strings.TrimSpace(info.ServerName)) sni := strings.ToLower(strings.TrimSpace(info.ServerName))
sniBytes := []byte(sni) sniBytes := []byte(sni)
if len(sni) < 1 { if len(sni) < 1 {
@ -63,7 +56,7 @@ var tlsConfig = &tls.Config{
} }
targetOwner := "" targetOwner := ""
if bytes.HasSuffix(sniBytes, MainDomainSuffix) { if bytes.HasSuffix(sniBytes, MainDomainSuffix) || bytes.Equal(sniBytes, MainDomainSuffix[1:]) {
// deliver default certificate for the main domain (*.codeberg.page) // deliver default certificate for the main domain (*.codeberg.page)
sniBytes = MainDomainSuffix sniBytes = MainDomainSuffix
sni = string(sniBytes) sni = string(sniBytes)
@ -71,90 +64,68 @@ var tlsConfig = &tls.Config{
var targetRepo, targetBranch string var targetRepo, targetBranch string
targetOwner, targetRepo, targetBranch = getTargetFromDNS(sni) targetOwner, targetRepo, targetBranch = getTargetFromDNS(sni)
if targetOwner == "" { if targetOwner == "" {
// DNS not set up, return a self-signed certificate to redirect to the docs // DNS not set up, return main certificate to redirect to the docs
return FallbackCertificate(), nil sniBytes = MainDomainSuffix
sni = string(sniBytes)
} else {
_, _ = targetRepo, targetBranch
_, valid := checkCanonicalDomain(targetOwner, targetRepo, targetBranch, sni)
if !valid {
sniBytes = MainDomainSuffix
sni = string(sniBytes)
}
} }
// TODO: use .domains file to list all domains, to keep users from getting rate-limited
_, _ = targetRepo, targetBranch
/*canonicalDomain := checkCanonicalDomain(targetOwner, targetRepo, targetBranch)
if sni != canonicalDomain {
return FallbackCertificate(), nil
}*/
} }
// limit users to 1 certificate per week
var cert, key []byte
if tlsCertificate, ok := keyCache.Get(sni); ok { if tlsCertificate, ok := keyCache.Get(sni); ok {
// we can use an existing certificate object // we can use an existing certificate object
return tlsCertificate.(*tls.Certificate), nil return tlsCertificate.(*tls.Certificate), nil
} else if ok, err := keyDatabase.Has(sniBytes); err != nil { }
var tlsCertificate tls.Certificate
if ok, err := keyDatabase.Has(sniBytes); err != nil {
// key database is not working // key database is not working
panic(err) panic(err)
} else if ok { } else if ok {
// parse certificate from database // parse certificate from database
certPem, err := keyDatabase.Get(sniBytes)
cert, err = keyDatabase.Get(sniBytes)
if err != nil { if err != nil {
// key database is not working // key database is not working
panic(err) panic(err)
} }
key, err = keyDatabase.Get(append(sniBytes, '/', 'k', 'e', 'y')) keyPem, err := keyDatabase.Get(append(sniBytes, '/', 'k', 'e', 'y'))
if err != nil { if err != nil {
// key database is not working or key doesn't exist // key database is not working or key doesn't exist
panic(err) panic(err)
} }
} else {
// request a new certificate
tlsCertificate, err = tls.X509KeyPair(certPem, keyPem)
if err != nil {
panic(err)
}
tlsCertificate.Leaf, err = x509.ParseCertificate(tlsCertificate.Certificate[0])
if err != nil {
panic(err)
}
}
if tlsCertificate.Certificate == nil || !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-24 * time.Hour)) {
// request a new certificate
if bytes.Equal(sniBytes, MainDomainSuffix) { if bytes.Equal(sniBytes, MainDomainSuffix) {
return nil, errors.New("won't request certificate for main domain, something really bad has happened") return nil, errors.New("won't request certificate for main domain, something really bad has happened")
} }
log.Printf("Requesting new certificate for %s", sni) err := CheckUserLimit(targetOwner)
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return nil, err return nil, err
} }
key = x509.MarshalPKCS1PrivateKey(privateKey)
acmeClient, err := acmeClientFromPool(targetOwner) tlsCertificate, err = obtainCert(acmeClient, []string{sni})
if err != nil {
// TODO
}
res, err := acmeClient.Certificate.Obtain(certificate.ObtainRequest{
Domains: []string{sni},
PrivateKey: key,
Bundle: true,
MustStaple: true,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Printf("Obtained certificate for %s", sni)
err = keyDatabase.Put(append(sniBytes, '/', 'k', 'e', 'y'), key)
if err != nil {
return nil, err
}
err = keyDatabase.Put(sniBytes, res.Certificate)
if err != nil {
_ = keyDatabase.Delete(append(sniBytes, '/', 'k', 'e', 'y'))
return nil, err
}
cert = res.Certificate
}
tlsCertificate, err := tls.X509KeyPair(pem.EncodeToMemory(&pem.Block{
Bytes: cert,
Type: "CERTIFICATE",
}), pem.EncodeToMemory(&pem.Block{
Bytes: key,
Type: "RSA PRIVATE KEY",
}))
if err != nil {
panic(err)
} }
err = keyCache.Set(sni, &tlsCertificate, 15 * time.Minute) err := keyCache.Set(sni, &tlsCertificate, 15 * time.Minute)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -178,76 +149,28 @@ var tlsConfig = &tls.Config{
}, },
} }
// GetHSTSHeader returns a HSTS header with includeSubdomains & preload for MainDomainSuffix and RawDomain, or an empty
// string for custom domains.
func GetHSTSHeader(host []byte) string {
if bytes.HasSuffix(host, MainDomainSuffix) || bytes.Equal(host, RawDomain) {
return "max-age=63072000; includeSubdomains; preload"
} else {
return ""
}
}
var challengeCache = mcache.New() var challengeCache = mcache.New()
var keyCache = mcache.New() var keyCache = mcache.New()
var keyDatabase *pogreb.DB var keyDatabase *pogreb.DB
var fallbackCertificate *tls.Certificate func CheckUserLimit(user string) (error) {
// FallbackCertificate generates a new self-signed TLS certificate on demand. userLimit, ok := acmeClientCertificateLimitPerUser[user]
func FallbackCertificate() *tls.Certificate { if !ok {
if fallbackCertificate != nil { // Each Codeberg user can only add 10 new domains per day.
return fallbackCertificate userLimit = equalizer.NewTokenBucket(10, time.Hour * 24)
acmeClientCertificateLimitPerUser[user] = userLimit
} }
if !userLimit.Ask() {
fallbackSerial, err := rand.Int(rand.Reader, (&big.Int{}).Lsh(big.NewInt(1), 159)) return errors.New("rate limit exceeded: 10 certificates per user per 24 hours")
if err != nil {
panic(err)
} }
return nil
fallbackCertKey, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
fallbackCertSpecification := &x509.Certificate{
Subject: pkix.Name{
CommonName: strings.TrimPrefix(string(MainDomainSuffix), "."),
},
SerialNumber: fallbackSerial,
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(100, 0, 0),
}
fallbackCertBytes, err := x509.CreateCertificate(
rand.Reader,
fallbackCertSpecification,
fallbackCertSpecification,
fallbackCertKey.Public(),
fallbackCertKey,
)
if err != nil {
panic(err)
}
fallbackCert, err := tls.X509KeyPair(pem.EncodeToMemory(&pem.Block{
Bytes: fallbackCertBytes,
Type: "CERTIFICATE",
}), pem.EncodeToMemory(&pem.Block{
Bytes: x509.MarshalPKCS1PrivateKey(fallbackCertKey),
Type: "RSA PRIVATE KEY",
}))
if err != nil {
panic(err)
}
fallbackCertificate = &fallbackCert
return fallbackCertificate
} }
type AcmeAccount struct { type AcmeAccount struct {
Email string Email string
Registration *registration.Resource Registration *registration.Resource
key crypto.PrivateKey key crypto.PrivateKey
limit equalizer.Limiter
} }
func (u *AcmeAccount) GetEmail() string { func (u *AcmeAccount) GetEmail() string {
return u.Email return u.Email
@ -259,17 +182,54 @@ func (u *AcmeAccount) GetPrivateKey() crypto.PrivateKey {
return u.key return u.key
} }
// rate-limit certificates per owner, based on LE Rate Limits: func newAcmeClient(configureChallenge func(*resolver.SolverManager) error) *lego.Client {
// - 300 new orders per account per 3 hours privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
// - 20 requests per second if err != nil {
// - 10 Accounts per IP per 3 hours panic(err)
var acmeClientPool []*lego.Client }
var lastAcmeClient = 0 myUser := AcmeAccount{
var acmeClientRequestLimit = equalizer.NewTokenBucket(10, time.Second) // LE allows 20 requests per second, but we want to give other applications a chancem so we want 10 here at most. Email: envOr("ACME_EMAIL", "noreply@example.email"),
var acmeClientRegistrationLimit = equalizer.NewTokenBucket(5, time.Hour * 3) // LE allows 10 registrations in 3 hours per IP, we want at most 5 of them. key: privateKey,
var acmeClientCertificateLimitPerRegistration = []*equalizer.TokenBucket{} }
config := lego.NewConfig(&myUser)
config.CADirURL = envOr("ACME_API", "https://acme.zerossl.com/v2/DV90")
config.Certificate.KeyType = certcrypto.RSA2048
acmeClient, err := lego.NewClient(config)
if err != nil {
panic(err)
}
err = configureChallenge(acmeClient.Challenge)
if err != nil {
panic(err)
}
// accept terms
reg, err := acmeClient.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: os.Getenv("ACME_ACCEPT_TERMS") == "true"})
if err != nil {
panic(err)
}
myUser.Registration = reg
return acmeClient
}
var acmeClient = newAcmeClient(func(challenge *resolver.SolverManager) error {
return challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{})
})
var acmeClientCertificateLimitPerUser = map[string]*equalizer.TokenBucket{} var acmeClientCertificateLimitPerUser = map[string]*equalizer.TokenBucket{}
var mainDomainAcmeClient = newAcmeClient(func(challenge *resolver.SolverManager) error {
if os.Getenv("DNS_PROVIDER") == "" {
// using mock server, don't use wildcard certs
return challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{})
}
provider, err := dns.NewDNSChallengeProviderByName(os.Getenv("DNS_PROVIDER"))
if err != nil {
return err
}
return challenge.SetDNS01Provider(provider)
})
type AcmeTLSChallengeProvider struct{} type AcmeTLSChallengeProvider struct{}
var _ challenge.Provider = AcmeTLSChallengeProvider{} var _ challenge.Provider = AcmeTLSChallengeProvider{}
func (a AcmeTLSChallengeProvider) Present(domain, _, keyAuth string) error { func (a AcmeTLSChallengeProvider) Present(domain, _, keyAuth string) error {
@ -280,67 +240,42 @@ func (a AcmeTLSChallengeProvider) CleanUp(domain, _, _ string) error {
return nil return nil
} }
func acmeClientFromPool(user string) (*lego.Client, error) { func obtainCert(acmeClient *lego.Client, domains []string) (tls.Certificate, error) {
userLimit, ok := acmeClientCertificateLimitPerUser[user] name := domains[0]
if !ok { if os.Getenv("DNS_PROVIDER") == "" && len(domains[0]) > 0 && domains[0][0] == '*' {
// Each Codeberg user can only add 10 new domains per day. domains = domains[1:]
userLimit = equalizer.NewTokenBucket(10, time.Hour * 24)
acmeClientCertificateLimitPerUser[user] = userLimit
}
if !userLimit.Ask() {
return nil, errors.New("rate limit exceeded: 10 certificates per user per 24 hours")
} }
if len(acmeClientPool) < 1 { log.Printf("Requesting new certificate for %v", domains)
acmeClientPool = append(acmeClientPool, newAcmeClient()) res, err := acmeClient.Certificate.Obtain(certificate.ObtainRequest{
acmeClientCertificateLimitPerRegistration = append(acmeClientCertificateLimitPerRegistration, equalizer.NewTokenBucket(290, time.Hour * 3)) Domains: domains,
Bundle: true,
MustStaple: true,
})
if err != nil {
log.Printf("Couldn't obtain certificate for %v: %s", domains, err)
return tls.Certificate{}, err
} }
if !acmeClientCertificateLimitPerRegistration[(lastAcmeClient + 1) % len(acmeClientPool)].Ask() { log.Printf("Obtained certificate for %v", domains)
} err = keyDatabase.Put([]byte(name + "/key"), res.PrivateKey)
equalizer.NewTokenBucket(290, time.Hour * 3) // LE allows 300 certificates per account, to be sure to catch it earlier, we limit that to 290.
// TODO: limit domains by file in repo
}
func newAcmeClient() *lego.Client {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
panic(err) panic(err)
} }
myUser := AcmeAccount{ err = keyDatabase.Put([]byte(name), res.Certificate)
Email: "",
key: privateKey,
}
config := lego.NewConfig(&myUser)
config.CADirURL = envOr("ACME_API", "https://acme-v02.api.letsencrypt.org/directory")
config.Certificate.KeyType = certcrypto.RSA2048
acmeClient, err := lego.NewClient(config)
if err != nil {
panic(err)
}
err = acmeClient.Challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{})
if err != nil { if err != nil {
_ = keyDatabase.Delete([]byte(name + "/key"))
panic(err) panic(err)
} }
// accept terms tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if os.Getenv("ACME_ACCEPT_TERMS") == "true" { if err != nil {
reg, err := acmeClient.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: os.Getenv("ACME_ACCEPT_TERMS") == "true"}) panic(err)
if err != nil {
panic(err)
}
myUser.Registration = reg
} else {
log.Printf("Warning: not using ACME certificates as ACME_ACCEPT_TERMS is false!")
} }
return acmeClient return tlsCertificate, nil
} }
func init() { func init() {
FallbackCertificate()
var err error var err error
keyDatabase, err = pogreb.Open("key-database.pogreb", &pogreb.Options{ keyDatabase, err = pogreb.Open("key-database.pogreb", &pogreb.Options{
BackgroundSyncInterval: 30 * time.Second, BackgroundSyncInterval: 30 * time.Second,
@ -351,50 +286,62 @@ func init() {
panic(err) panic(err)
} }
// generate certificate for main domain if os.Getenv("ACME_ACCEPT_TERMS") != "true" || (os.Getenv("DNS_PROVIDER") == "" && os.Getenv("ACME_API") != "https://acme.mock.directory") {
if os.Getenv("ACME_ACCEPT_TERMS") != "true" || os.Getenv("DNS_PROVIDER") == "" { panic(errors.New("you must set ACME_ACCEPT_TERMS and DNS_PROVIDER, unless ACME_API is set to https://acme.mock.directory"))
err = keyCache.Set(string(MainDomainSuffix), FallbackCertificate(), mcache.TTL_FOREVER)
if err != nil {
panic(err)
}
} else {
log.Printf("Requesting new certificate for *%s", MainDomainSuffix)
dnsAcmeClient, err := lego.NewClient(config)
if err != nil {
panic(err)
}
provider, err := dns.NewDNSChallengeProviderByName(os.Getenv("DNS_PROVIDER"))
if err != nil {
panic(err)
}
err = dnsAcmeClient.Challenge.SetDNS01Provider(provider)
if err != nil {
panic(err)
}
mainPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(err)
}
mainKey := x509.MarshalPKCS1PrivateKey(mainPrivateKey)
res, err := dnsAcmeClient.Certificate.Obtain(certificate.ObtainRequest{
Domains: []string{"*" + string(MainDomainSuffix), string(MainDomainSuffix[1:])},
PrivateKey: mainKey,
Bundle: true,
MustStaple: true,
})
if err != nil {
panic(err)
}
err = keyDatabase.Put(append(MainDomainSuffix, '/', 'k', 'e', 'y'), mainKey)
if err != nil {
panic(err)
}
err = keyDatabase.Put(MainDomainSuffix, res.Certificate)
if err != nil {
_ = keyDatabase.Delete(append(MainDomainSuffix, '/', 'k', 'e', 'y'))
panic(err)
}
} }
}
// TODO: renew & revoke go (func() {
for {
err := keyDatabase.Sync()
if err != nil {
log.Printf("Syncinc key database failed: %s", err)
}
time.Sleep(5 * time.Minute)
}
})()
go (func() {
for {
// clean up expired certs
keySuffix := []byte("/key")
now := time.Now()
expiredCertCount := 0
key, value, err := keyDatabase.Items().Next()
for err == nil {
if !bytes.HasSuffix(key, keySuffix) {
tlsCertificates, err := certcrypto.ParsePEMBundle(value)
if err != nil || !tlsCertificates[0].NotAfter.After(now) {
err := keyDatabase.Delete(key)
if err != nil {
log.Printf("Deleting expired certificate for %s failed: %s", string(key), err)
} else {
expiredCertCount++
}
}
}
key, value, err = keyDatabase.Items().Next()
}
log.Printf("Removed %d expired certificates from the database", expiredCertCount)
// compact the database
result, err := keyDatabase.Compact()
if err != nil {
log.Printf("Compacting key database failed: %s", err)
} else {
log.Printf("Compacted key database (%+v)", result)
}
// update main cert
certPem, err := keyDatabase.Get(MainDomainSuffix)
if err != nil {
// key database is not working
panic(err)
}
tlsCertificates, err := certcrypto.ParsePEMBundle(certPem)
if err != nil || !tlsCertificates[0].NotAfter.After(time.Now().Add(-48 * time.Hour)) {
_, _ = obtainCert(mainDomainAcmeClient, []string{"*" + string(MainDomainSuffix), string(MainDomainSuffix[1:])})
}
time.Sleep(12 * time.Hour)
}
})()
}

View file

@ -68,10 +68,16 @@ var CanonicalDomainCacheTimeout = 15*time.Minute
var canonicalDomainCache = mcache.New() var canonicalDomainCache = mcache.New()
// checkCanonicalDomain returns the canonical domain specified in the repo (using the file `.canonical-domain`). // checkCanonicalDomain returns the canonical domain specified in the repo (using the file `.canonical-domain`).
func checkCanonicalDomain(targetOwner, targetRepo, targetBranch string) (canonicalDomain string) { func checkCanonicalDomain(targetOwner, targetRepo, targetBranch, actualDomain string) (canonicalDomain string, valid bool) {
// Check if the canonical domain matches domains := []string{}
if cachedValue, ok := canonicalDomainCache.Get(targetOwner + "/" + targetRepo + "/" + targetBranch); ok { if cachedValue, ok := canonicalDomainCache.Get(targetOwner + "/" + targetRepo + "/" + targetBranch); ok {
canonicalDomain = cachedValue.(string) domains = cachedValue.([]string)
for _, domain := range domains {
if domain == actualDomain {
valid = true
break
}
}
} else { } else {
req := fasthttp.AcquireRequest() req := fasthttp.AcquireRequest()
req.SetRequestURI(string(GiteaRoot) + "/api/v1/repos/" + targetOwner + "/" + targetRepo + "/raw/" + targetBranch + "/.domains") req.SetRequestURI(string(GiteaRoot) + "/api/v1/repos/" + targetOwner + "/" + targetRepo + "/raw/" + targetBranch + "/.domains")
@ -79,18 +85,28 @@ func checkCanonicalDomain(targetOwner, targetRepo, targetBranch string) (canonic
err := upstreamClient.Do(req, res) err := upstreamClient.Do(req, res)
if err == nil && res.StatusCode() == fasthttp.StatusOK { if err == nil && res.StatusCode() == fasthttp.StatusOK {
canonicalDomain = strings.TrimSpace(string(res.Body())) for _, domain := range strings.Split(string(res.Body()), "\n") {
if strings.Contains(canonicalDomain, "/") { domain = strings.ToLower(domain)
canonicalDomain = "" domain = strings.TrimSpace(domain)
domain = strings.TrimPrefix(domain, "http://")
domain = strings.TrimPrefix(domain, "https://")
if len(domain) > 0 && !strings.HasPrefix(domain, "#") && !strings.ContainsAny(domain, "\t /") && strings.ContainsRune(domain, '.') {
domains = append(domains, domain)
}
if domain == actualDomain {
valid = true
}
} }
} }
if canonicalDomain == "" { domains = append(domains, targetOwner + string(MainDomainSuffix))
canonicalDomain = targetOwner + string(MainDomainSuffix) if domains[len(domains) - 1] == actualDomain {
if targetRepo != "" && targetRepo != "pages" { valid = true
canonicalDomain += "/" + targetRepo
}
} }
_ = canonicalDomainCache.Set(targetOwner + "/" + targetRepo + "/" + targetBranch, canonicalDomain, CanonicalDomainCacheTimeout) if targetRepo != "" && targetRepo != "pages" {
domains[len(domains) - 1] += "/" + targetRepo
}
_ = canonicalDomainCache.Set(targetOwner + "/" + targetRepo + "/" + targetBranch, domains, CanonicalDomainCacheTimeout)
} }
canonicalDomain = domains[0]
return return
} }

View file

@ -28,8 +28,10 @@ func handler(ctx *fasthttp.RequestCtx) {
// Enable caching, but require revalidation to reduce confusion // Enable caching, but require revalidation to reduce confusion
ctx.Response.Header.Set("Cache-Control", "must-revalidate") ctx.Response.Header.Set("Cache-Control", "must-revalidate")
trimmedHost := TrimHostPort(ctx.Request.Host())
// Add HSTS for RawDomain and MainDomainSuffix // Add HSTS for RawDomain and MainDomainSuffix
if hsts := GetHSTSHeader(ctx.Host()); hsts != "" { if hsts := GetHSTSHeader(trimmedHost); hsts != "" {
ctx.Response.Header.Set("Strict-Transport-Security", hsts) ctx.Response.Header.Set("Strict-Transport-Security", hsts)
} }
@ -52,7 +54,7 @@ func handler(ctx *fasthttp.RequestCtx) {
if ctx.IsOptions() { if ctx.IsOptions() {
allowCors := false allowCors := false
for _, allowedCorsDomain := range AllowedCorsDomains { for _, allowedCorsDomain := range AllowedCorsDomains {
if bytes.Equal(ctx.Request.Host(), allowedCorsDomain) { if bytes.Equal(trimmedHost, allowedCorsDomain) {
allowCors = true allowCors = true
break break
} }
@ -109,8 +111,8 @@ func handler(ctx *fasthttp.RequestCtx) {
// tryUpstream forwards the target request to the Gitea API, and shows an error page on failure. // tryUpstream forwards the target request to the Gitea API, and shows an error page on failure.
var tryUpstream = func() { var tryUpstream = func() {
// check if a canonical domain exists on a request on MainDomain // check if a canonical domain exists on a request on MainDomain
if bytes.HasSuffix(ctx.Request.Host(), MainDomainSuffix) { if bytes.HasSuffix(trimmedHost, MainDomainSuffix) {
canonicalDomain := checkCanonicalDomain(targetOwner, targetRepo, targetBranch) canonicalDomain, _ := checkCanonicalDomain(targetOwner, targetRepo, targetBranch, "")
if !strings.HasSuffix(strings.SplitN(canonicalDomain, "/", 2)[0], string(MainDomainSuffix)) { if !strings.HasSuffix(strings.SplitN(canonicalDomain, "/", 2)[0], string(MainDomainSuffix)) {
canonicalPath := string(ctx.RequestURI()) canonicalPath := string(ctx.RequestURI())
if targetRepo != "pages" { if targetRepo != "pages" {
@ -129,7 +131,7 @@ func handler(ctx *fasthttp.RequestCtx) {
s.Step("preparations") s.Step("preparations")
if RawDomain != nil && bytes.Equal(ctx.Request.Host(), RawDomain) { if RawDomain != nil && bytes.Equal(trimmedHost, RawDomain) {
// Serve raw content from RawDomain // Serve raw content from RawDomain
s.Debug("raw domain") s.Debug("raw domain")
@ -169,12 +171,12 @@ func handler(ctx *fasthttp.RequestCtx) {
return return
} }
} else if bytes.HasSuffix(ctx.Request.Host(), MainDomainSuffix) { } else if bytes.HasSuffix(trimmedHost, MainDomainSuffix) {
// Serve pages from subdomains of MainDomainSuffix // Serve pages from subdomains of MainDomainSuffix
s.Debug("main domain suffix") s.Debug("main domain suffix")
pathElements := strings.Split(string(bytes.Trim(ctx.Request.URI().Path(), "/")), "/") pathElements := strings.Split(string(bytes.Trim(ctx.Request.URI().Path(), "/")), "/")
targetOwner = string(bytes.TrimSuffix(ctx.Request.Host(), MainDomainSuffix)) targetOwner = string(bytes.TrimSuffix(trimmedHost, MainDomainSuffix))
targetRepo = pathElements[0] targetRepo = pathElements[0]
targetPath = strings.Trim(strings.Join(pathElements[1:], "/"), "/") targetPath = strings.Trim(strings.Join(pathElements[1:], "/"), "/")
@ -235,8 +237,10 @@ func handler(ctx *fasthttp.RequestCtx) {
returnErrorPage(ctx, fasthttp.StatusFailedDependency) returnErrorPage(ctx, fasthttp.StatusFailedDependency)
return return
} else { } else {
trimmedHostStr := string(trimmedHost)
// Serve pages from external domains // Serve pages from external domains
targetOwner, targetRepo, targetBranch = getTargetFromDNS(string(ctx.Request.Host())) targetOwner, targetRepo, targetBranch = getTargetFromDNS(trimmedHostStr)
if targetOwner == "" { if targetOwner == "" {
ctx.Redirect(BrokenDNSPage, fasthttp.StatusTemporaryRedirect) ctx.Redirect(BrokenDNSPage, fasthttp.StatusTemporaryRedirect)
return return
@ -253,8 +257,11 @@ func handler(ctx *fasthttp.RequestCtx) {
// Try to use the given repo on the given branch or the default branch // Try to use the given repo on the given branch or the default branch
s.Step("custom domain preparations, now trying with details from DNS") s.Step("custom domain preparations, now trying with details from DNS")
if tryBranch(targetRepo, targetBranch, pathElements, canonicalLink) { if tryBranch(targetRepo, targetBranch, pathElements, canonicalLink) {
canonicalDomain := checkCanonicalDomain(targetOwner, targetRepo, targetBranch) canonicalDomain, valid := checkCanonicalDomain(targetOwner, targetRepo, targetBranch, trimmedHostStr)
if canonicalDomain != string(ctx.Request.Host()) { if !valid {
returnErrorPage(ctx, fasthttp.StatusMisdirectedRequest)
return
} else if canonicalDomain != trimmedHostStr {
// only redirect if the target is also a codeberg page! // only redirect if the target is also a codeberg page!
targetOwner, _, _ = getTargetFromDNS(strings.SplitN(canonicalDomain, "/", 2)[0]) targetOwner, _, _ = getTargetFromDNS(strings.SplitN(canonicalDomain, "/", 2)[0])
if targetOwner != "" { if targetOwner != "" {
@ -282,6 +289,9 @@ func returnErrorPage(ctx *fasthttp.RequestCtx, code int) {
ctx.Response.SetStatusCode(code) ctx.Response.SetStatusCode(code)
ctx.Response.Header.SetContentType("text/html; charset=utf-8") ctx.Response.Header.SetContentType("text/html; charset=utf-8")
message := fasthttp.StatusMessage(code) message := fasthttp.StatusMessage(code)
if code == fasthttp.StatusMisdirectedRequest {
message += " - domain not specified in <code>.domains</code> file"
}
if code == fasthttp.StatusFailedDependency { if code == fasthttp.StatusFailedDependency {
message += " - owner, repo or branch doesn't exist" message += " - owner, repo or branch doesn't exist"
} }

21
helpers.go Normal file
View file

@ -0,0 +1,21 @@
package main
import "bytes"
// GetHSTSHeader returns a HSTS header with includeSubdomains & preload for MainDomainSuffix and RawDomain, or an empty
// string for custom domains.
func GetHSTSHeader(host []byte) string {
if bytes.HasSuffix(host, MainDomainSuffix) || bytes.Equal(host, RawDomain) {
return "max-age=63072000; includeSubdomains; preload"
} else {
return ""
}
}
func TrimHostPort(host []byte) []byte {
i := bytes.IndexByte(host, ':')
if i >= 0 {
return host[:i]
}
return host
}

View file

@ -94,7 +94,6 @@ func main() {
Concurrency: 1024 * 32, // TODO: adjust bottlenecks for best performance with Gitea! Concurrency: 1024 * 32, // TODO: adjust bottlenecks for best performance with Gitea!
MaxConnsPerIP: 100, MaxConnsPerIP: 100,
} }
//fasthttp2.ConfigureServerAndConfig(server, tlsConfig)
// Setup listener and TLS // Setup listener and TLS
listener, err := net.Listen("tcp", address) listener, err := net.Listen("tcp", address)