dont access global vars inject them

This commit is contained in:
6543 2021-12-05 14:45:17 +01:00
parent fb5726bd20
commit bdc2d0c259
No known key found for this signature in database
GPG key ID: C99B82E40B027BAE
11 changed files with 730 additions and 706 deletions

604
server/certificates.go Normal file
View file

@ -0,0 +1,604 @@
package server
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/gob"
"encoding/json"
"encoding/pem"
"errors"
"github.com/OrlovEvgeny/go-mcache"
"github.com/akrylysov/pogreb/fs"
"github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/go-acme/lego/v4/providers/dns"
"io/ioutil"
"log"
"math/big"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/akrylysov/pogreb"
"github.com/reugn/equalizer"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration"
)
// TlsConfig returns the configuration for generating, serving and cleaning up Let's Encrypt certificates.
func TlsConfig(mainDomainSuffix []byte, giteaRoot, giteaApiToken string) *tls.Config {
return &tls.Config{
// check DNS name & get certificate from Let's Encrypt
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
sni := strings.ToLower(strings.TrimSpace(info.ServerName))
sniBytes := []byte(sni)
if len(sni) < 1 {
return nil, errors.New("missing sni")
}
if info.SupportedProtos != nil {
for _, proto := range info.SupportedProtos {
if proto == tlsalpn01.ACMETLS1Protocol {
challenge, ok := ChallengeCache.Get(sni)
if !ok {
return nil, errors.New("no challenge for this domain")
}
cert, err := tlsalpn01.ChallengeCert(sni, challenge.(string))
if err != nil {
return nil, err
}
return cert, nil
}
}
}
targetOwner := ""
if bytes.HasSuffix(sniBytes, mainDomainSuffix) || bytes.Equal(sniBytes, mainDomainSuffix[1:]) {
// deliver default certificate for the main domain (*.codeberg.page)
sniBytes = mainDomainSuffix
sni = string(sniBytes)
} else {
var targetRepo, targetBranch string
targetOwner, targetRepo, targetBranch = getTargetFromDNS(sni, string(mainDomainSuffix))
if targetOwner == "" {
// DNS not set up, return main certificate to redirect to the docs
sniBytes = mainDomainSuffix
sni = string(sniBytes)
} else {
_, _ = targetRepo, targetBranch
_, valid := checkCanonicalDomain(targetOwner, targetRepo, targetBranch, sni, string(mainDomainSuffix), giteaRoot, giteaApiToken)
if !valid {
sniBytes = mainDomainSuffix
sni = string(sniBytes)
}
}
}
if tlsCertificate, ok := keyCache.Get(sni); ok {
// we can use an existing certificate object
return tlsCertificate.(*tls.Certificate), nil
}
var tlsCertificate tls.Certificate
var err error
var ok bool
if tlsCertificate, ok = retrieveCertFromDB(sniBytes, mainDomainSuffix); !ok {
// request a new certificate
if bytes.Equal(sniBytes, mainDomainSuffix) {
return nil, errors.New("won't request certificate for main domain, something really bad has happened")
}
tlsCertificate, err = obtainCert(acmeClient, []string{sni}, nil, targetOwner, mainDomainSuffix)
if err != nil {
return nil, err
}
}
err = keyCache.Set(sni, &tlsCertificate, 15*time.Minute)
if err != nil {
panic(err)
}
return &tlsCertificate, nil
},
PreferServerCipherSuites: true,
NextProtos: []string{
"http/1.1",
tlsalpn01.ACMETLS1Protocol,
},
// generated 2021-07-13, Mozilla Guideline v5.6, Go 1.14.4, intermediate configuration
// https://ssl-config.mozilla.org/#server=go&version=1.14.4&config=intermediate&guideline=5.6
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
}
}
// TODO: clean up & move to init
var keyCache = mcache.New()
var KeyDatabase, KeyDatabaseErr = pogreb.Open("key-database.pogreb", &pogreb.Options{
BackgroundSyncInterval: 30 * time.Second,
BackgroundCompactionInterval: 6 * time.Hour,
FileSystem: fs.OSMMap,
})
func CheckUserLimit(user string) error {
userLimit, ok := acmeClientCertificateLimitPerUser[user]
if !ok {
// Each Codeberg user can only add 10 new domains per day.
userLimit = equalizer.NewTokenBucket(10, time.Hour*24)
acmeClientCertificateLimitPerUser[user] = userLimit
}
if !userLimit.Ask() {
return errors.New("rate limit exceeded: 10 certificates per user per 24 hours")
}
return nil
}
var myAcmeAccount AcmeAccount
var myAcmeConfig *lego.Config
type AcmeAccount struct {
Email string
Registration *registration.Resource
Key crypto.PrivateKey `json:"-"`
KeyPEM string `json:"Key"`
}
func (u *AcmeAccount) GetEmail() string {
return u.Email
}
func (u AcmeAccount) GetRegistration() *registration.Resource {
return u.Registration
}
func (u *AcmeAccount) GetPrivateKey() crypto.PrivateKey {
return u.Key
}
var acmeClient, mainDomainAcmeClient *lego.Client
var acmeClientCertificateLimitPerUser = map[string]*equalizer.TokenBucket{}
// rate limit is 300 / 3 hours, we want 200 / 2 hours but to refill more often, so that's 25 new domains every 15 minutes
// TODO: when this is used a lot, we probably have to think of a somewhat better solution?
var acmeClientOrderLimit = equalizer.NewTokenBucket(25, 15*time.Minute)
// rate limit is 20 / second, we want 5 / second (especially as one cert takes at least two requests)
var acmeClientRequestLimit = equalizer.NewTokenBucket(5, 1*time.Second)
var ChallengeCache = mcache.New()
type AcmeTLSChallengeProvider struct{}
var _ challenge.Provider = AcmeTLSChallengeProvider{}
func (a AcmeTLSChallengeProvider) Present(domain, _, keyAuth string) error {
return ChallengeCache.Set(domain, keyAuth, 1*time.Hour)
}
func (a AcmeTLSChallengeProvider) CleanUp(domain, _, _ string) error {
ChallengeCache.Remove(domain)
return nil
}
type AcmeHTTPChallengeProvider struct{}
var _ challenge.Provider = AcmeHTTPChallengeProvider{}
func (a AcmeHTTPChallengeProvider) Present(domain, token, keyAuth string) error {
return ChallengeCache.Set(domain+"/"+token, keyAuth, 1*time.Hour)
}
func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
ChallengeCache.Remove(domain + "/" + token)
return nil
}
func retrieveCertFromDB(sni, mainDomainSuffix []byte) (tls.Certificate, bool) {
// parse certificate from database
res := &certificate.Resource{}
if !PogrebGet(KeyDatabase, sni, res) {
return tls.Certificate{}, false
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
panic(err)
}
// TODO: document & put into own function
if !bytes.Equal(sni, mainDomainSuffix) {
tlsCertificate.Leaf, err = x509.ParseCertificate(tlsCertificate.Certificate[0])
if err != nil {
panic(err)
}
// renew certificates 7 days before they expire
if !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(-7 * 24 * time.Hour)) {
// TODO: add ValidUntil to custom res struct
if res.CSR != nil && len(res.CSR) > 0 {
// CSR stores the time when the renewal shall be tried again
nextTryUnix, err := strconv.ParseInt(string(res.CSR), 10, 64)
if err == nil && time.Now().Before(time.Unix(nextTryUnix, 0)) {
return tlsCertificate, true
}
}
go (func() {
res.CSR = nil // acme client doesn't like CSR to be set
tlsCertificate, err = obtainCert(acmeClient, []string{string(sni)}, res, "", mainDomainSuffix)
if err != nil {
log.Printf("Couldn't renew certificate for %s: %s", sni, err)
}
})()
}
}
return tlsCertificate, true
}
var obtainLocks = sync.Map{}
func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Resource, user string, mainDomainSuffix []byte) (tls.Certificate, error) {
name := strings.TrimPrefix(domains[0], "*")
if os.Getenv("DNS_PROVIDER") == "" && len(domains[0]) > 0 && domains[0][0] == '*' {
domains = domains[1:]
}
// lock to avoid simultaneous requests
_, working := obtainLocks.LoadOrStore(name, struct{}{})
if working {
for working {
time.Sleep(100 * time.Millisecond)
_, working = obtainLocks.Load(name)
}
cert, ok := retrieveCertFromDB([]byte(name), mainDomainSuffix)
if !ok {
return tls.Certificate{}, errors.New("certificate failed in synchronous request")
}
return cert, nil
}
defer obtainLocks.Delete(name)
if acmeClient == nil {
return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", string(mainDomainSuffix)), nil
}
// request actual cert
var res *certificate.Resource
var err error
if renew != nil && renew.CertURL != "" {
if os.Getenv("ACME_USE_RATE_LIMITS") != "false" {
acmeClientRequestLimit.Take()
}
log.Printf("Renewing certificate for %v", domains)
res, err = acmeClient.Certificate.Renew(*renew, true, false, "")
if err != nil {
log.Printf("Couldn't renew certificate for %v, trying to request a new one: %s", domains, err)
res = nil
}
}
if res == nil {
if user != "" {
if err := CheckUserLimit(user); err != nil {
return tls.Certificate{}, err
}
}
if os.Getenv("ACME_USE_RATE_LIMITS") != "false" {
acmeClientOrderLimit.Take()
acmeClientRequestLimit.Take()
}
log.Printf("Requesting new certificate for %v", domains)
res, err = acmeClient.Certificate.Obtain(certificate.ObtainRequest{
Domains: domains,
Bundle: true,
MustStaple: false,
})
}
if err != nil {
log.Printf("Couldn't obtain certificate for %v: %s", domains, err)
if renew != nil && renew.CertURL != "" {
tlsCertificate, err := tls.X509KeyPair(renew.Certificate, renew.PrivateKey)
if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) {
// avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at
renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10))
PogrebPut(KeyDatabase, []byte(name), renew)
return tlsCertificate, nil
}
}
return mockCert(domains[0], err.Error(), string(mainDomainSuffix)), err
}
log.Printf("Obtained certificate for %v", domains)
PogrebPut(KeyDatabase, []byte(name), res)
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
return tls.Certificate{}, err
}
return tlsCertificate, nil
}
func mockCert(domain, msg, mainDomainSuffix string) tls.Certificate {
key, err := certcrypto.GeneratePrivateKey(certcrypto.RSA2048)
if err != nil {
panic(err)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: domain,
Organization: []string{"Codeberg Pages Error Certificate (couldn't obtain ACME certificate)"},
OrganizationalUnit: []string{
"Will not try again for 6 hours to avoid hitting rate limits for your domain.",
"Check https://docs.codeberg.org/codeberg-pages/troubleshooting/ for troubleshooting tips, and feel " +
"free to create an issue at https://codeberg.org/Codeberg/pages-server if you can't solve it.\n",
"Error message: " + msg,
},
},
// certificates younger than 7 days are renewed, so this enforces the cert to not be renewed for a 6 hours
NotAfter: time.Now().Add(time.Hour*24*7 + time.Hour*6),
NotBefore: time.Now(),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
certBytes, err := x509.CreateCertificate(
rand.Reader,
&template,
&template,
&key.(*rsa.PrivateKey).PublicKey,
key,
)
if err != nil {
panic(err)
}
out := &bytes.Buffer{}
err = pem.Encode(out, &pem.Block{
Bytes: certBytes,
Type: "CERTIFICATE",
})
if err != nil {
panic(err)
}
outBytes := out.Bytes()
res := &certificate.Resource{
PrivateKey: certcrypto.PEMEncode(key),
Certificate: outBytes,
IssuerCertificate: outBytes,
Domain: domain,
}
databaseName := domain
if domain == "*"+mainDomainSuffix || domain == mainDomainSuffix[1:] {
databaseName = mainDomainSuffix
}
PogrebPut(KeyDatabase, []byte(databaseName), res)
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
panic(err)
}
return tlsCertificate
}
func SetupCertificates(mainDomainSuffix []byte) {
if KeyDatabaseErr != nil {
panic(KeyDatabaseErr)
}
if os.Getenv("ACME_ACCEPT_TERMS") != "true" || (os.Getenv("DNS_PROVIDER") == "" && os.Getenv("ACME_API") != "https://acme.mock.directory") {
panic(errors.New("you must set ACME_ACCEPT_TERMS and DNS_PROVIDER, unless ACME_API is set to https://acme.mock.directory"))
}
// getting main cert before ACME account so that we can panic here on database failure without hitting rate limits
mainCertBytes, err := KeyDatabase.Get(mainDomainSuffix)
if err != nil {
// key database is not working
panic(err)
}
if account, err := ioutil.ReadFile("acme-account.json"); err == nil {
err = json.Unmarshal(account, &myAcmeAccount)
if err != nil {
panic(err)
}
myAcmeAccount.Key, err = certcrypto.ParsePEMPrivateKey([]byte(myAcmeAccount.KeyPEM))
if err != nil {
panic(err)
}
myAcmeConfig = lego.NewConfig(&myAcmeAccount)
myAcmeConfig.CADirURL = EnvOr("ACME_API", "https://acme-v02.api.letsencrypt.org/directory")
myAcmeConfig.Certificate.KeyType = certcrypto.RSA2048
_, err := lego.NewClient(myAcmeConfig)
if err != nil {
log.Printf("[ERROR] Can't create ACME client, continuing with mock certs only: %s", err)
}
} else if os.IsNotExist(err) {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
myAcmeAccount = AcmeAccount{
Email: EnvOr("ACME_EMAIL", "noreply@example.email"),
Key: privateKey,
KeyPEM: string(certcrypto.PEMEncode(privateKey)),
}
myAcmeConfig = lego.NewConfig(&myAcmeAccount)
myAcmeConfig.CADirURL = EnvOr("ACME_API", "https://acme-v02.api.letsencrypt.org/directory")
myAcmeConfig.Certificate.KeyType = certcrypto.RSA2048
tempClient, err := lego.NewClient(myAcmeConfig)
if err != nil {
log.Printf("[ERROR] Can't create ACME client, continuing with mock certs only: %s", err)
} else {
// accept terms & log in to EAB
if os.Getenv("ACME_EAB_KID") == "" || os.Getenv("ACME_EAB_HMAC") == "" {
reg, err := tempClient.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: os.Getenv("ACME_ACCEPT_TERMS") == "true"})
if err != nil {
log.Printf("[ERROR] Can't register ACME account, continuing with mock certs only: %s", err)
} else {
myAcmeAccount.Registration = reg
}
} else {
reg, err := tempClient.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
TermsOfServiceAgreed: os.Getenv("ACME_ACCEPT_TERMS") == "true",
Kid: os.Getenv("ACME_EAB_KID"),
HmacEncoded: os.Getenv("ACME_EAB_HMAC"),
})
if err != nil {
log.Printf("[ERROR] Can't register ACME account, continuing with mock certs only: %s", err)
} else {
myAcmeAccount.Registration = reg
}
}
if myAcmeAccount.Registration != nil {
acmeAccountJson, err := json.Marshal(myAcmeAccount)
if err != nil {
log.Printf("[FAIL] Error during json.Marshal(myAcmeAccount), waiting for manual restart to avoid rate limits: %s", err)
select {}
}
err = ioutil.WriteFile("acme-account.json", acmeAccountJson, 0600)
if err != nil {
log.Printf("[FAIL] Error during ioutil.WriteFile(\"acme-account.json\"), waiting for manual restart to avoid rate limits: %s", err)
select {}
}
}
}
} else {
panic(err)
}
acmeClient, err = lego.NewClient(myAcmeConfig)
if err != nil {
log.Printf("[ERROR] Can't create ACME client, continuing with mock certs only: %s", err)
} else {
err = acmeClient.Challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{})
if err != nil {
log.Printf("[ERROR] Can't create TLS-ALPN-01 provider: %s", err)
}
if os.Getenv("ENABLE_HTTP_SERVER") == "true" {
err = acmeClient.Challenge.SetHTTP01Provider(AcmeHTTPChallengeProvider{})
if err != nil {
log.Printf("[ERROR] Can't create HTTP-01 provider: %s", err)
}
}
}
mainDomainAcmeClient, err = lego.NewClient(myAcmeConfig)
if err != nil {
log.Printf("[ERROR] Can't create ACME client, continuing with mock certs only: %s", err)
} else {
if os.Getenv("DNS_PROVIDER") == "" {
// using mock server, don't use wildcard certs
err := mainDomainAcmeClient.Challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{})
if err != nil {
log.Printf("[ERROR] Can't create TLS-ALPN-01 provider: %s", err)
}
} else {
provider, err := dns.NewDNSChallengeProviderByName(os.Getenv("DNS_PROVIDER"))
if err != nil {
log.Printf("[ERROR] Can't create DNS Challenge provider: %s", err)
}
err = mainDomainAcmeClient.Challenge.SetDNS01Provider(provider)
if err != nil {
log.Printf("[ERROR] Can't create DNS-01 provider: %s", err)
}
}
}
if mainCertBytes == nil {
_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, nil, "", mainDomainSuffix)
if err != nil {
log.Printf("[ERROR] Couldn't renew main domain certificate, continuing with mock certs only: %s", err)
}
}
go (func() {
for {
err := KeyDatabase.Sync()
if err != nil {
log.Printf("[ERROR] Syncing key database failed: %s", err)
}
time.Sleep(5 * time.Minute)
// TODO: graceful exit
}
})()
go (func() {
for {
// clean up expired certs
now := time.Now()
expiredCertCount := 0
keyDatabaseIterator := KeyDatabase.Items()
key, resBytes, err := keyDatabaseIterator.Next()
for err == nil {
if !bytes.Equal(key, mainDomainSuffix) {
resGob := bytes.NewBuffer(resBytes)
resDec := gob.NewDecoder(resGob)
res := &certificate.Resource{}
err = resDec.Decode(res)
if err != nil {
panic(err)
}
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
if err != nil || !tlsCertificates[0].NotAfter.After(now) {
err := KeyDatabase.Delete(key)
if err != nil {
log.Printf("[ERROR] Deleting expired certificate for %s failed: %s", string(key), err)
} else {
expiredCertCount++
}
}
}
key, resBytes, err = keyDatabaseIterator.Next()
}
log.Printf("[INFO] Removed %d expired certificates from the database", expiredCertCount)
// compact the database
result, err := KeyDatabase.Compact()
if err != nil {
log.Printf("[ERROR] Compacting key database failed: %s", err)
} else {
log.Printf("[INFO] Compacted key database (%+v)", result)
}
// update main cert
res := &certificate.Resource{}
if !PogrebGet(KeyDatabase, mainDomainSuffix, res) {
log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", "expected main domain cert to exist, but it's missing - seems like the database is corrupted")
} else {
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
// renew main certificate 30 days before it expires
if !tlsCertificates[0].NotAfter.After(time.Now().Add(-30 * 24 * time.Hour)) {
go (func() {
_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, res, "", mainDomainSuffix)
if err != nil {
log.Printf("[ERROR] Couldn't renew certificate for main domain: %s", err)
}
})()
}
}
time.Sleep(12 * time.Hour)
}
})()
}

113
server/domains.go Normal file
View file

@ -0,0 +1,113 @@
package server
import (
"github.com/OrlovEvgeny/go-mcache"
"github.com/valyala/fasthttp"
"net"
"strings"
"time"
)
// DnsLookupCacheTimeout specifies the timeout for the DNS lookup cache.
var DnsLookupCacheTimeout = 15 * time.Minute
// dnsLookupCache stores DNS lookups for custom domains
var dnsLookupCache = mcache.New()
// getTargetFromDNS searches for CNAME or TXT entries on the request domain ending with MainDomainSuffix.
// If everything is fine, it returns the target data.
func getTargetFromDNS(domain, mainDomainSuffix string) (targetOwner, targetRepo, targetBranch string) {
// Get CNAME or TXT
var cname string
var err error
if cachedName, ok := dnsLookupCache.Get(domain); ok {
cname = cachedName.(string)
} else {
cname, err = net.LookupCNAME(domain)
cname = strings.TrimSuffix(cname, ".")
if err != nil || !strings.HasSuffix(cname, mainDomainSuffix) {
cname = ""
// TODO: check if the A record matches!
names, err := net.LookupTXT(domain)
if err == nil {
for _, name := range names {
name = strings.TrimSuffix(name, ".")
if strings.HasSuffix(name, mainDomainSuffix) {
cname = name
break
}
}
}
}
_ = dnsLookupCache.Set(domain, cname, DnsLookupCacheTimeout)
}
if cname == "" {
return
}
cnameParts := strings.Split(strings.TrimSuffix(cname, mainDomainSuffix), ".")
targetOwner = cnameParts[len(cnameParts)-1]
if len(cnameParts) > 1 {
targetRepo = cnameParts[len(cnameParts)-2]
}
if len(cnameParts) > 2 {
targetBranch = cnameParts[len(cnameParts)-3]
}
if targetRepo == "" {
targetRepo = "pages"
}
if targetBranch == "" && targetRepo != "pages" {
targetBranch = "pages"
}
// if targetBranch is still empty, the caller must find the default branch
return
}
// CanonicalDomainCacheTimeout specifies the timeout for the canonical domain cache.
var CanonicalDomainCacheTimeout = 15 * time.Minute
// canonicalDomainCache stores canonical domains
var canonicalDomainCache = mcache.New()
// checkCanonicalDomain returns the canonical domain specified in the repo (using the file `.canonical-domain`).
func checkCanonicalDomain(targetOwner, targetRepo, targetBranch, actualDomain, mainDomainSuffix, giteaRoot, giteaApiToken string) (canonicalDomain string, valid bool) {
domains := []string{}
if cachedValue, ok := canonicalDomainCache.Get(targetOwner + "/" + targetRepo + "/" + targetBranch); ok {
domains = cachedValue.([]string)
for _, domain := range domains {
if domain == actualDomain {
valid = true
break
}
}
} else {
req := fasthttp.AcquireRequest()
req.SetRequestURI(giteaRoot + "/api/v1/repos/" + targetOwner + "/" + targetRepo + "/raw/" + targetBranch + "/.domains" + "?access_token=" + giteaApiToken)
res := fasthttp.AcquireResponse()
err := upstreamClient.Do(req, res)
if err == nil && res.StatusCode() == fasthttp.StatusOK {
for _, domain := range strings.Split(string(res.Body()), "\n") {
domain = strings.ToLower(domain)
domain = strings.TrimSpace(domain)
domain = strings.TrimPrefix(domain, "http://")
domain = strings.TrimPrefix(domain, "https://")
if len(domain) > 0 && !strings.HasPrefix(domain, "#") && !strings.ContainsAny(domain, "\t /") && strings.ContainsRune(domain, '.') {
domains = append(domains, domain)
}
if domain == actualDomain {
valid = true
}
}
}
domains = append(domains, targetOwner+mainDomainSuffix)
if domains[len(domains)-1] == actualDomain {
valid = true
}
if targetRepo != "" && targetRepo != "pages" {
domains[len(domains)-1] += "/" + targetRepo
}
_ = canonicalDomainCache.Set(targetOwner+"/"+targetRepo+"/"+targetBranch, domains, CanonicalDomainCacheTimeout)
}
canonicalDomain = domains[0]
return
}

553
server/handler.go Normal file
View file

@ -0,0 +1,553 @@
package server
import (
"bytes"
"fmt"
"io"
"mime"
"path"
"strconv"
"strings"
"time"
"github.com/OrlovEvgeny/go-mcache"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
"github.com/valyala/fastjson"
"codeberg.org/codeberg/pages/html"
)
// Handler handles a single HTTP request to the web server.
func Handler(mainDomainSuffix, rawDomain, giteaRoot []byte, rawInfoPage, giteaApiToken string, blacklistedPaths, allowedCorsDomains [][]byte) func(ctx *fasthttp.RequestCtx) {
return func(ctx *fasthttp.RequestCtx) {
log := log.With().Str("Handler", string(ctx.Request.Header.RequestURI())).Logger()
ctx.Response.Header.Set("Server", "Codeberg Pages")
// Force new default from specification (since November 2020) - see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy#strict-origin-when-cross-origin
ctx.Response.Header.Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Enable browser caching for up to 10 minutes
ctx.Response.Header.Set("Cache-Control", "public, max-age=600")
trimmedHost := TrimHostPort(ctx.Request.Host())
// Add HSTS for RawDomain and MainDomainSuffix
if hsts := GetHSTSHeader(trimmedHost, mainDomainSuffix, rawDomain); hsts != "" {
ctx.Response.Header.Set("Strict-Transport-Security", hsts)
}
// Block all methods not required for static pages
if !ctx.IsGet() && !ctx.IsHead() && !ctx.IsOptions() {
ctx.Response.Header.Set("Allow", "GET, HEAD, OPTIONS")
ctx.Error("Method not allowed", fasthttp.StatusMethodNotAllowed)
return
}
// Block blacklisted paths (like ACME challenges)
for _, blacklistedPath := range blacklistedPaths {
if bytes.HasPrefix(ctx.Path(), blacklistedPath) {
returnErrorPage(ctx, fasthttp.StatusForbidden)
return
}
}
// Allow CORS for specified domains
if ctx.IsOptions() {
allowCors := false
for _, allowedCorsDomain := range allowedCorsDomains {
if bytes.Equal(trimmedHost, allowedCorsDomain) {
allowCors = true
break
}
}
if allowCors {
ctx.Response.Header.Set("Access-Control-Allow-Origin", "*")
ctx.Response.Header.Set("Access-Control-Allow-Methods", "GET, HEAD")
}
ctx.Response.Header.Set("Allow", "GET, HEAD, OPTIONS")
ctx.Response.Header.SetStatusCode(fasthttp.StatusNoContent)
return
}
// Prepare request information to Gitea
var targetOwner, targetRepo, targetBranch, targetPath string
var targetOptions = &upstreamOptions{
ForbiddenMimeTypes: map[string]struct{}{},
TryIndexPages: true,
}
// tryBranch checks if a branch exists and populates the target variables. If canonicalLink is non-empty, it will
// also disallow search indexing and add a Link header to the canonical URL.
var tryBranch = func(repo string, branch string, path []string, canonicalLink string) bool {
if repo == "" {
return false
}
// Check if the branch exists, otherwise treat it as a file path
branchTimestampResult := getBranchTimestamp(targetOwner, repo, branch, string(giteaRoot), giteaApiToken)
if branchTimestampResult == nil {
// branch doesn't exist
return false
}
// Branch exists, use it
targetRepo = repo
targetPath = strings.Trim(strings.Join(path, "/"), "/")
targetBranch = branchTimestampResult.branch
targetOptions.BranchTimestamp = branchTimestampResult.timestamp
if canonicalLink != "" {
// Hide from search machines & add canonical link
ctx.Response.Header.Set("X-Robots-Tag", "noarchive, noindex")
ctx.Response.Header.Set("Link",
strings.NewReplacer("%b", targetBranch, "%p", targetPath).Replace(canonicalLink)+
"; rel=\"canonical\"",
)
}
return true
}
// tryUpstream forwards the target request to the Gitea API, and shows an error page on failure.
var tryUpstream = func() {
// check if a canonical domain exists on a request on MainDomain
if bytes.HasSuffix(trimmedHost, mainDomainSuffix) {
canonicalDomain, _ := checkCanonicalDomain(targetOwner, targetRepo, targetBranch, "", string(mainDomainSuffix), string(giteaRoot), giteaApiToken)
if !strings.HasSuffix(strings.SplitN(canonicalDomain, "/", 2)[0], string(mainDomainSuffix)) {
canonicalPath := string(ctx.RequestURI())
if targetRepo != "pages" {
canonicalPath = "/" + strings.SplitN(canonicalPath, "/", 3)[2]
}
ctx.Redirect("https://"+canonicalDomain+canonicalPath, fasthttp.StatusTemporaryRedirect)
return
}
}
// Try to request the file from the Gitea API
if !upstream(ctx, targetOwner, targetRepo, targetBranch, targetPath, string(giteaRoot), giteaApiToken, targetOptions) {
returnErrorPage(ctx, ctx.Response.StatusCode())
}
}
log.Debug().Msg("preparations")
if rawDomain != nil && bytes.Equal(trimmedHost, rawDomain) {
// Serve raw content from RawDomain
log.Debug().Msg("raw domain")
targetOptions.TryIndexPages = false
targetOptions.ForbiddenMimeTypes["text/html"] = struct{}{}
targetOptions.DefaultMimeType = "text/plain; charset=utf-8"
pathElements := strings.Split(string(bytes.Trim(ctx.Request.URI().Path(), "/")), "/")
if len(pathElements) < 2 {
// https://{RawDomain}/{owner}/{repo}[/@{branch}]/{path} is required
ctx.Redirect(rawInfoPage, fasthttp.StatusTemporaryRedirect)
return
}
targetOwner = pathElements[0]
targetRepo = pathElements[1]
// raw.codeberg.org/example/myrepo/@main/index.html
if len(pathElements) > 2 && strings.HasPrefix(pathElements[2], "@") {
log.Debug().Msg("raw domain preparations, now trying with specified branch")
if tryBranch(targetRepo, pathElements[2][1:], pathElements[3:],
string(giteaRoot)+"/"+targetOwner+"/"+targetRepo+"/src/branch/%b/%p",
) {
log.Debug().Msg("tryBranch, now trying upstream")
tryUpstream()
return
}
log.Debug().Msg("missing branch")
returnErrorPage(ctx, fasthttp.StatusFailedDependency)
return
} else {
log.Debug().Msg("raw domain preparations, now trying with default branch")
tryBranch(targetRepo, "", pathElements[2:],
string(giteaRoot)+"/"+targetOwner+"/"+targetRepo+"/src/branch/%b/%p",
)
log.Debug().Msg("tryBranch, now trying upstream")
tryUpstream()
return
}
} else if bytes.HasSuffix(trimmedHost, mainDomainSuffix) {
// Serve pages from subdomains of MainDomainSuffix
log.Debug().Msg("main domain suffix")
pathElements := strings.Split(string(bytes.Trim(ctx.Request.URI().Path(), "/")), "/")
targetOwner = string(bytes.TrimSuffix(trimmedHost, mainDomainSuffix))
targetRepo = pathElements[0]
targetPath = strings.Trim(strings.Join(pathElements[1:], "/"), "/")
if targetOwner == "www" {
// www.codeberg.page redirects to codeberg.page
ctx.Redirect("https://"+string(mainDomainSuffix[1:])+string(ctx.Path()), fasthttp.StatusPermanentRedirect)
return
}
// Check if the first directory is a repo with the second directory as a branch
// example.codeberg.page/myrepo/@main/index.html
if len(pathElements) > 1 && strings.HasPrefix(pathElements[1], "@") {
if targetRepo == "pages" {
// example.codeberg.org/pages/@... redirects to example.codeberg.org/@...
ctx.Redirect("/"+strings.Join(pathElements[1:], "/"), fasthttp.StatusTemporaryRedirect)
return
}
log.Debug().Msg("main domain preparations, now trying with specified repo & branch")
if tryBranch(pathElements[0], pathElements[1][1:], pathElements[2:],
"/"+pathElements[0]+"/%p",
) {
log.Debug().Msg("tryBranch, now trying upstream")
tryUpstream()
} else {
returnErrorPage(ctx, fasthttp.StatusFailedDependency)
}
return
}
// Check if the first directory is a branch for the "pages" repo
// example.codeberg.page/@main/index.html
if strings.HasPrefix(pathElements[0], "@") {
log.Debug().Msg("main domain preparations, now trying with specified branch")
if tryBranch("pages", pathElements[0][1:], pathElements[1:], "/%p") {
log.Debug().Msg("tryBranch, now trying upstream")
tryUpstream()
} else {
returnErrorPage(ctx, fasthttp.StatusFailedDependency)
}
return
}
// Check if the first directory is a repo with a "pages" branch
// example.codeberg.page/myrepo/index.html
// example.codeberg.page/pages/... is not allowed here.
log.Debug().Msg("main domain preparations, now trying with specified repo")
if pathElements[0] != "pages" && tryBranch(pathElements[0], "pages", pathElements[1:], "") {
log.Debug().Msg("tryBranch, now trying upstream")
tryUpstream()
return
}
// Try to use the "pages" repo on its default branch
// example.codeberg.page/index.html
log.Debug().Msg("main domain preparations, now trying with default repo/branch")
if tryBranch("pages", "", pathElements, "") {
log.Debug().Msg("tryBranch, now trying upstream")
tryUpstream()
return
}
// Couldn't find a valid repo/branch
returnErrorPage(ctx, fasthttp.StatusFailedDependency)
return
} else {
trimmedHostStr := string(trimmedHost)
// Serve pages from external domains
targetOwner, targetRepo, targetBranch = getTargetFromDNS(trimmedHostStr, string(mainDomainSuffix))
if targetOwner == "" {
returnErrorPage(ctx, fasthttp.StatusFailedDependency)
return
}
pathElements := strings.Split(string(bytes.Trim(ctx.Request.URI().Path(), "/")), "/")
canonicalLink := ""
if strings.HasPrefix(pathElements[0], "@") {
targetBranch = pathElements[0][1:]
pathElements = pathElements[1:]
canonicalLink = "/%p"
}
// Try to use the given repo on the given branch or the default branch
log.Debug().Msg("custom domain preparations, now trying with details from DNS")
if tryBranch(targetRepo, targetBranch, pathElements, canonicalLink) {
canonicalDomain, valid := checkCanonicalDomain(targetOwner, targetRepo, targetBranch, trimmedHostStr, string(mainDomainSuffix), string(giteaRoot), giteaApiToken)
if !valid {
returnErrorPage(ctx, fasthttp.StatusMisdirectedRequest)
return
} else if canonicalDomain != trimmedHostStr {
// only redirect if the target is also a codeberg page!
targetOwner, _, _ = getTargetFromDNS(strings.SplitN(canonicalDomain, "/", 2)[0], string(mainDomainSuffix))
if targetOwner != "" {
ctx.Redirect("https://"+canonicalDomain+string(ctx.RequestURI()), fasthttp.StatusTemporaryRedirect)
return
} else {
returnErrorPage(ctx, fasthttp.StatusFailedDependency)
return
}
}
log.Debug().Msg("tryBranch, now trying upstream")
tryUpstream()
return
} else {
returnErrorPage(ctx, fasthttp.StatusFailedDependency)
return
}
}
}
}
// returnErrorPage sets the response status code and writes NotFoundPage to the response body, with "%status" replaced
// with the provided status code.
func returnErrorPage(ctx *fasthttp.RequestCtx, code int) {
ctx.Response.SetStatusCode(code)
ctx.Response.Header.SetContentType("text/html; charset=utf-8")
message := fasthttp.StatusMessage(code)
if code == fasthttp.StatusMisdirectedRequest {
message += " - domain not specified in <code>.domains</code> file"
}
if code == fasthttp.StatusFailedDependency {
message += " - target repo/branch doesn't exist or is private"
}
// TODO: use template engine?
ctx.Response.SetBody(bytes.ReplaceAll(html.NotFoundPage, []byte("%status"), []byte(strconv.Itoa(code)+" "+message)))
}
// DefaultBranchCacheTimeout specifies the timeout for the default branch cache. It can be quite long.
var DefaultBranchCacheTimeout = 15 * time.Minute
// BranchExistanceCacheTimeout specifies the timeout for the branch timestamp & existance cache. It should be shorter
// than FileCacheTimeout, as that gets invalidated if the branch timestamp has changed. That way, repo changes will be
// picked up faster, while still allowing the content to be cached longer if nothing changes.
var BranchExistanceCacheTimeout = 5 * time.Minute
// branchTimestampCache stores branch timestamps for faster cache checking
var branchTimestampCache = mcache.New()
type branchTimestamp struct {
branch string
timestamp time.Time
}
// FileCacheTimeout specifies the timeout for the file content cache - you might want to make this quite long, depending
// on your available memory.
var FileCacheTimeout = 5 * time.Minute
// FileCacheSizeLimit limits the maximum file size that will be cached, and is set to 1 MB by default.
var FileCacheSizeLimit = 1024 * 1024
// fileResponseCache stores responses from the Gitea server
// TODO: make this an MRU cache with a size limit
var fileResponseCache = mcache.New()
type fileResponse struct {
exists bool
mimeType string
body []byte
}
// getBranchTimestamp finds the default branch (if branch is "") and returns the last modification time of the branch
// (or nil if the branch doesn't exist)
func getBranchTimestamp(owner, repo, branch, giteaRoot, giteaApiToken string) *branchTimestamp {
if result, ok := branchTimestampCache.Get(owner + "/" + repo + "/" + branch); ok {
if result == nil {
return nil
}
return result.(*branchTimestamp)
}
result := &branchTimestamp{}
result.branch = branch
if branch == "" {
// Get default branch
var body = make([]byte, 0)
// TODO: use header for API key?
status, body, err := fasthttp.GetTimeout(body, giteaRoot+"/api/v1/repos/"+owner+"/"+repo+"?access_token="+giteaApiToken, 5*time.Second)
if err != nil || status != 200 {
_ = branchTimestampCache.Set(owner+"/"+repo+"/"+branch, nil, DefaultBranchCacheTimeout)
return nil
}
result.branch = fastjson.GetString(body, "default_branch")
}
var body = make([]byte, 0)
status, body, err := fasthttp.GetTimeout(body, giteaRoot+"/api/v1/repos/"+owner+"/"+repo+"/branches/"+branch+"?access_token="+giteaApiToken, 5*time.Second)
if err != nil || status != 200 {
return nil
}
result.timestamp, _ = time.Parse(time.RFC3339, fastjson.GetString(body, "commit", "timestamp"))
_ = branchTimestampCache.Set(owner+"/"+repo+"/"+branch, result, BranchExistanceCacheTimeout)
return result
}
var upstreamClient = fasthttp.Client{
ReadTimeout: 10 * time.Second,
MaxConnDuration: 60 * time.Second,
MaxConnWaitTimeout: 1000 * time.Millisecond,
MaxConnsPerHost: 128 * 16, // TODO: adjust bottlenecks for best performance with Gitea!
}
// upstreamIndexPages lists pages that may be considered as index pages for directories.
var upstreamIndexPages = []string{
"index.html",
}
// upstream requests a file from the Gitea API at GiteaRoot and writes it to the request context.
func upstream(ctx *fasthttp.RequestCtx, targetOwner, targetRepo, targetBranch, targetPath, giteaRoot, giteaApiToken string, options *upstreamOptions) (final bool) {
log := log.With().Strs("upstream", []string{targetOwner, targetRepo, targetBranch, targetPath}).Logger()
if options.ForbiddenMimeTypes == nil {
options.ForbiddenMimeTypes = map[string]struct{}{}
}
// Check if the branch exists and when it was modified
if options.BranchTimestamp == (time.Time{}) {
branch := getBranchTimestamp(targetOwner, targetRepo, targetBranch, giteaRoot, giteaApiToken)
if branch == nil {
returnErrorPage(ctx, fasthttp.StatusFailedDependency)
return true
}
targetBranch = branch.branch
options.BranchTimestamp = branch.timestamp
}
if targetOwner == "" || targetRepo == "" || targetBranch == "" {
returnErrorPage(ctx, fasthttp.StatusBadRequest)
return true
}
// Check if the browser has a cached version
if ifModifiedSince, err := time.Parse(time.RFC1123, string(ctx.Request.Header.Peek("If-Modified-Since"))); err == nil {
if !ifModifiedSince.Before(options.BranchTimestamp) {
ctx.Response.SetStatusCode(fasthttp.StatusNotModified)
return true
}
}
log.Debug().Msg("preparations")
// Make a GET request to the upstream URL
uri := targetOwner + "/" + targetRepo + "/raw/" + targetBranch + "/" + targetPath
var req *fasthttp.Request
var res *fasthttp.Response
var cachedResponse fileResponse
var err error
if cachedValue, ok := fileResponseCache.Get(uri + "?timestamp=" + strconv.FormatInt(options.BranchTimestamp.Unix(), 10)); ok && len(cachedValue.(fileResponse).body) > 0 {
cachedResponse = cachedValue.(fileResponse)
} else {
req = fasthttp.AcquireRequest()
req.SetRequestURI(giteaRoot + "/api/v1/repos/" + uri + "?access_token=" + giteaApiToken)
res = fasthttp.AcquireResponse()
res.SetBodyStream(&strings.Reader{}, -1)
err = upstreamClient.Do(req, res)
}
log.Debug().Msg("acquisition")
// Handle errors
if (res == nil && !cachedResponse.exists) || (res != nil && res.StatusCode() == fasthttp.StatusNotFound) {
if options.TryIndexPages {
// copy the options struct & try if an index page exists
optionsForIndexPages := *options
optionsForIndexPages.TryIndexPages = false
optionsForIndexPages.AppendTrailingSlash = true
for _, indexPage := range upstreamIndexPages {
if upstream(ctx, targetOwner, targetRepo, targetBranch, strings.TrimSuffix(targetPath, "/")+"/"+indexPage, giteaRoot, giteaApiToken, &optionsForIndexPages) {
_ = fileResponseCache.Set(uri+"?timestamp="+strconv.FormatInt(options.BranchTimestamp.Unix(), 10), fileResponse{
exists: false,
}, FileCacheTimeout)
return true
}
}
// compatibility fix for GitHub Pages (/example → /example.html)
optionsForIndexPages.AppendTrailingSlash = false
optionsForIndexPages.RedirectIfExists = string(ctx.Request.URI().Path()) + ".html"
if upstream(ctx, targetOwner, targetRepo, targetBranch, targetPath+".html", giteaRoot, giteaApiToken, &optionsForIndexPages) {
_ = fileResponseCache.Set(uri+"?timestamp="+strconv.FormatInt(options.BranchTimestamp.Unix(), 10), fileResponse{
exists: false,
}, FileCacheTimeout)
return true
}
}
ctx.Response.SetStatusCode(fasthttp.StatusNotFound)
if res != nil {
// Update cache if the request is fresh
_ = fileResponseCache.Set(uri+"?timestamp="+strconv.FormatInt(options.BranchTimestamp.Unix(), 10), fileResponse{
exists: false,
}, FileCacheTimeout)
}
return false
}
if res != nil && (err != nil || res.StatusCode() != fasthttp.StatusOK) {
fmt.Printf("Couldn't fetch contents from \"%s\": %s (status code %d)\n", req.RequestURI(), err, res.StatusCode())
returnErrorPage(ctx, fasthttp.StatusInternalServerError)
return true
}
// Append trailing slash if missing (for index files), and redirect to fix filenames in general
// options.AppendTrailingSlash is only true when looking for index pages
if options.AppendTrailingSlash && !bytes.HasSuffix(ctx.Request.URI().Path(), []byte{'/'}) {
ctx.Redirect(string(ctx.Request.URI().Path())+"/", fasthttp.StatusTemporaryRedirect)
return true
}
if bytes.HasSuffix(ctx.Request.URI().Path(), []byte("/index.html")) {
ctx.Redirect(strings.TrimSuffix(string(ctx.Request.URI().Path()), "index.html"), fasthttp.StatusTemporaryRedirect)
return true
}
if options.RedirectIfExists != "" {
ctx.Redirect(options.RedirectIfExists, fasthttp.StatusTemporaryRedirect)
return true
}
log.Debug().Msg("error handling")
// Set the MIME type
mimeType := mime.TypeByExtension(path.Ext(targetPath))
mimeTypeSplit := strings.SplitN(mimeType, ";", 2)
if _, ok := options.ForbiddenMimeTypes[mimeTypeSplit[0]]; ok || mimeType == "" {
if options.DefaultMimeType != "" {
mimeType = options.DefaultMimeType
} else {
mimeType = "application/octet-stream"
}
}
ctx.Response.Header.SetContentType(mimeType)
// Everything's okay so far
ctx.Response.SetStatusCode(fasthttp.StatusOK)
ctx.Response.Header.SetLastModified(options.BranchTimestamp)
log.Debug().Msg("response preparations")
// Write the response body to the original request
var cacheBodyWriter bytes.Buffer
if res != nil {
if res.Header.ContentLength() > FileCacheSizeLimit {
err = res.BodyWriteTo(ctx.Response.BodyWriter())
} else {
// TODO: cache is half-empty if request is cancelled - does the ctx.Err() below do the trick?
err = res.BodyWriteTo(io.MultiWriter(ctx.Response.BodyWriter(), &cacheBodyWriter))
}
} else {
_, err = ctx.Write(cachedResponse.body)
}
if err != nil {
fmt.Printf("Couldn't write body for \"%s\": %s\n", req.RequestURI(), err)
returnErrorPage(ctx, fasthttp.StatusInternalServerError)
return true
}
log.Debug().Msg("response")
if res != nil && ctx.Err() == nil {
cachedResponse.exists = true
cachedResponse.mimeType = mimeType
cachedResponse.body = cacheBodyWriter.Bytes()
_ = fileResponseCache.Set(uri+"?timestamp="+strconv.FormatInt(options.BranchTimestamp.Unix(), 10), cachedResponse, FileCacheTimeout)
}
return true
}
// upstreamOptions provides various options for the upstream request.
type upstreamOptions struct {
DefaultMimeType string
ForbiddenMimeTypes map[string]struct{}
TryIndexPages bool
AppendTrailingSlash bool
RedirectIfExists string
BranchTimestamp time.Time
}

63
server/handler_test.go Normal file
View file

@ -0,0 +1,63 @@
package server
import (
"fmt"
"github.com/valyala/fasthttp"
"testing"
"time"
)
func TestHandlerPerformance(t *testing.T) {
testHandler := Handler(
[]byte("codeberg.page"),
[]byte("raw.codeberg.org"),
[]byte("https://codeberg.org"),
"https://docs.codeberg.org/pages/raw-content/",
"",
[][]byte{[]byte("/.well-known/acme-challenge/")},
[][]byte{[]byte("raw.codeberg.org"), []byte("fonts.codeberg.org"), []byte("design.codeberg.org")},
)
ctx := &fasthttp.RequestCtx{
Request: *fasthttp.AcquireRequest(),
Response: *fasthttp.AcquireResponse(),
}
ctx.Request.SetRequestURI("http://mondstern.codeberg.page/")
fmt.Printf("Start: %v\n", time.Now())
start := time.Now()
testHandler(ctx)
end := time.Now()
fmt.Printf("Done: %v\n", time.Now())
if ctx.Response.StatusCode() != 200 || len(ctx.Response.Body()) < 2048 {
t.Errorf("request failed with status code %d and body length %d", ctx.Response.StatusCode(), len(ctx.Response.Body()))
} else {
t.Logf("request took %d milliseconds", end.Sub(start).Milliseconds())
}
ctx.Response.Reset()
ctx.Response.ResetBody()
fmt.Printf("Start: %v\n", time.Now())
start = time.Now()
testHandler(ctx)
end = time.Now()
fmt.Printf("Done: %v\n", time.Now())
if ctx.Response.StatusCode() != 200 || len(ctx.Response.Body()) < 2048 {
t.Errorf("request failed with status code %d and body length %d", ctx.Response.StatusCode(), len(ctx.Response.Body()))
} else {
t.Logf("request took %d milliseconds", end.Sub(start).Milliseconds())
}
ctx.Response.Reset()
ctx.Response.ResetBody()
ctx.Request.SetRequestURI("http://example.momar.xyz/")
fmt.Printf("Start: %v\n", time.Now())
start = time.Now()
testHandler(ctx)
end = time.Now()
fmt.Printf("Done: %v\n", time.Now())
if ctx.Response.StatusCode() != 200 || len(ctx.Response.Body()) < 1 {
t.Errorf("request failed with status code %d and body length %d", ctx.Response.StatusCode(), len(ctx.Response.Body()))
} else {
t.Logf("request took %d milliseconds", end.Sub(start).Milliseconds())
}
}

67
server/helpers.go Normal file
View file

@ -0,0 +1,67 @@
package server
import (
"bytes"
"encoding/gob"
"os"
"github.com/akrylysov/pogreb"
)
// GetHSTSHeader returns a HSTS header with includeSubdomains & preload for MainDomainSuffix and RawDomain, or an empty
// string for custom domains.
func GetHSTSHeader(host, mainDomainSuffix, rawDomain []byte) string {
if bytes.HasSuffix(host, mainDomainSuffix) || bytes.Equal(host, rawDomain) {
return "max-age=63072000; includeSubdomains; preload"
} else {
return ""
}
}
func TrimHostPort(host []byte) []byte {
i := bytes.IndexByte(host, ':')
if i >= 0 {
return host[:i]
}
return host
}
func PogrebPut(db *pogreb.DB, name []byte, obj interface{}) {
var resGob bytes.Buffer
resEnc := gob.NewEncoder(&resGob)
err := resEnc.Encode(obj)
if err != nil {
panic(err)
}
err = db.Put(name, resGob.Bytes())
if err != nil {
panic(err)
}
}
func PogrebGet(db *pogreb.DB, name []byte, obj interface{}) bool {
resBytes, err := db.Get(name)
if err != nil {
panic(err)
}
if resBytes == nil {
return false
}
resGob := bytes.NewBuffer(resBytes)
resDec := gob.NewDecoder(resGob)
err = resDec.Decode(obj)
if err != nil {
panic(err)
}
return true
}
// EnvOr reads an environment variable and returns a default value if it's empty.
// TODO: to helpers.go or use CLI framework
func EnvOr(env string, or string) string {
if v := os.Getenv(env); v != "" {
return v
}
return or
}