Merge branch 'main' into issue115

This commit is contained in:
6543 2023-01-04 05:39:56 +00:00
commit e3b10685f0
19 changed files with 172 additions and 85 deletions

20
.golangci.yml Normal file
View file

@ -0,0 +1,20 @@
linters-settings:
gocritic:
enabled-tags:
- diagnostic
- experimental
- opinionated
- performance
- style
disabled-checks:
- importShadow
- ifElseChain
- hugeParam
linters:
enable:
- unconvert
- gocritic
run:
timeout: 5m

View file

@ -102,7 +102,7 @@ pipeline:
registry: codeberg.org registry: codeberg.org
dockerfile: Dockerfile dockerfile: Dockerfile
repo: codeberg.org/codeberg/pages-server repo: codeberg.org/codeberg/pages-server
tag: [ latest, "${CI_COMMIT_TAG}" ] tags: [ latest, "${CI_COMMIT_TAG}" ]
username: username:
from_secret: bot_user from_secret: bot_user
password: password:

View file

@ -38,10 +38,10 @@ tool-gofumpt:
fi fi
test: test:
go test -race codeberg.org/codeberg/pages/server/... go test -race codeberg.org/codeberg/pages/server/... codeberg.org/codeberg/pages/html/
test-run TEST: test-run TEST:
go test -race -run "^{{TEST}}$" codeberg.org/codeberg/pages/server/... go test -race -run "^{{TEST}}$" codeberg.org/codeberg/pages/server/... codeberg.org/codeberg/pages/html/
integration: integration:
go test -race -tags integration codeberg.org/codeberg/pages/integration/... go test -race -tags integration codeberg.org/codeberg/pages/integration/...

View file

@ -66,7 +66,7 @@ func Serve(ctx *cli.Context) error {
} }
allowedCorsDomains := AllowedCorsDomains allowedCorsDomains := AllowedCorsDomains
if len(rawDomain) != 0 { if rawDomain != "" {
allowedCorsDomains = append(allowedCorsDomains, rawDomain) allowedCorsDomains = append(allowedCorsDomains, rawDomain)
} }

View file

@ -1,6 +1,7 @@
package html package html
import ( import (
"html/template"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -14,16 +15,27 @@ func ReturnErrorPage(ctx *context.Context, msg string, statusCode int) {
ctx.RespWriter.Header().Set("Content-Type", "text/html; charset=utf-8") ctx.RespWriter.Header().Set("Content-Type", "text/html; charset=utf-8")
ctx.RespWriter.WriteHeader(statusCode) ctx.RespWriter.WriteHeader(statusCode)
if msg == "" { msg = generateResponse(msg, statusCode)
msg = errorBody(statusCode)
} else {
// TODO: use template engine
msg = strings.ReplaceAll(strings.ReplaceAll(ErrorPage, "%message%", msg), "%status%", http.StatusText(statusCode))
}
_, _ = ctx.RespWriter.Write([]byte(msg)) _, _ = ctx.RespWriter.Write([]byte(msg))
} }
// TODO: use template engine
func generateResponse(msg string, statusCode int) string {
if msg == "" {
msg = strings.ReplaceAll(NotFoundPage,
"%status%",
strconv.Itoa(statusCode)+" "+errorMessage(statusCode))
} else {
msg = strings.ReplaceAll(
strings.ReplaceAll(ErrorPage, "%message%", template.HTMLEscapeString(msg)),
"%status%",
http.StatusText(statusCode))
}
return msg
}
func errorMessage(statusCode int) string { func errorMessage(statusCode int) string {
message := http.StatusText(statusCode) message := http.StatusText(statusCode)
@ -36,10 +48,3 @@ func errorMessage(statusCode int) string {
return message return message
} }
// TODO: use template engine
func errorBody(statusCode int) string {
return strings.ReplaceAll(NotFoundPage,
"%status%",
strconv.Itoa(statusCode)+" "+errorMessage(statusCode))
}

38
html/error_test.go Normal file
View file

@ -0,0 +1,38 @@
package html
import (
"net/http"
"strings"
"testing"
)
func TestValidMessage(t *testing.T) {
testString := "requested blacklisted path"
statusCode := http.StatusForbidden
expected := strings.ReplaceAll(
strings.ReplaceAll(ErrorPage, "%message%", testString),
"%status%",
http.StatusText(statusCode))
actual := generateResponse(testString, statusCode)
if expected != actual {
t.Errorf("generated response did not match: expected: '%s', got: '%s'", expected, actual)
}
}
func TestMessageWithHtml(t *testing.T) {
testString := `abc<img src=1 onerror=alert("xss");`
escapedString := "abc&lt;img src=1 onerror=alert(&#34;xss&#34;);"
statusCode := http.StatusNotFound
expected := strings.ReplaceAll(
strings.ReplaceAll(ErrorPage, "%message%", escapedString),
"%status%",
http.StatusText(statusCode))
actual := generateResponse(testString, statusCode)
if expected != actual {
t.Errorf("generated response did not match: expected: '%s', got: '%s'", expected, actual)
}
}

View file

@ -31,7 +31,7 @@ func TestGetRedirect(t *testing.T) {
func TestGetContent(t *testing.T) { func TestGetContent(t *testing.T) {
log.Println("=== TestGetContent ===") log.Println("=== TestGetContent ===")
// test get image // test get image
resp, err := getTestHTTPSClient().Get("https://magiclike.localhost.mock.directory:4430/images/827679288a.jpg") resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/images/827679288a.jpg")
assert.NoError(t, err) assert.NoError(t, err)
if !assert.EqualValues(t, http.StatusOK, resp.StatusCode) { if !assert.EqualValues(t, http.StatusOK, resp.StatusCode) {
t.FailNow() t.FailNow()
@ -42,7 +42,7 @@ func TestGetContent(t *testing.T) {
assert.Len(t, resp.Header.Get("ETag"), 42) assert.Len(t, resp.Header.Get("ETag"), 42)
// specify branch // specify branch
resp, err = getTestHTTPSClient().Get("https://momar.localhost.mock.directory:4430/pag/@master/") resp, err = getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/pag/@master/")
assert.NoError(t, err) assert.NoError(t, err)
if !assert.NotNil(t, resp) { if !assert.NotNil(t, resp) {
t.FailNow() t.FailNow()
@ -53,7 +53,7 @@ func TestGetContent(t *testing.T) {
assert.Len(t, resp.Header.Get("ETag"), 44) assert.Len(t, resp.Header.Get("ETag"), 44)
// access branch name contains '/' // access branch name contains '/'
resp, err = getTestHTTPSClient().Get("https://blumia.localhost.mock.directory:4430/pages-server-integration-tests/@docs~main/") resp, err = getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/blumia/@docs~main/")
assert.NoError(t, err) assert.NoError(t, err)
if !assert.EqualValues(t, http.StatusOK, resp.StatusCode) { if !assert.EqualValues(t, http.StatusOK, resp.StatusCode) {
t.FailNow() t.FailNow()
@ -81,7 +81,7 @@ func TestCustomDomain(t *testing.T) {
func TestGetNotFound(t *testing.T) { func TestGetNotFound(t *testing.T) {
log.Println("=== TestGetNotFound ===") log.Println("=== TestGetNotFound ===")
// test custom not found pages // test custom not found pages
resp, err := getTestHTTPSClient().Get("https://crystal.localhost.mock.directory:4430/pages-404-demo/blah") resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/pages-404-demo/blah")
assert.NoError(t, err) assert.NoError(t, err)
if !assert.NotNil(t, resp) { if !assert.NotNil(t, resp) {
t.FailNow() t.FailNow()
@ -95,7 +95,7 @@ func TestGetNotFound(t *testing.T) {
func TestFollowSymlink(t *testing.T) { func TestFollowSymlink(t *testing.T) {
log.Printf("=== TestFollowSymlink ===\n") log.Printf("=== TestFollowSymlink ===\n")
resp, err := getTestHTTPSClient().Get("https://6543.localhost.mock.directory:4430/tests_for_pages-server/@main/link") resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/tests_for_pages-server/@main/link")
assert.NoError(t, err) assert.NoError(t, err)
if !assert.NotNil(t, resp) { if !assert.NotNil(t, resp) {
t.FailNow() t.FailNow()
@ -111,7 +111,7 @@ func TestFollowSymlink(t *testing.T) {
func TestLFSSupport(t *testing.T) { func TestLFSSupport(t *testing.T) {
log.Printf("=== TestLFSSupport ===\n") log.Printf("=== TestLFSSupport ===\n")
resp, err := getTestHTTPSClient().Get("https://6543.localhost.mock.directory:4430/tests_for_pages-server/@main/lfs.txt") resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/tests_for_pages-server/@main/lfs.txt")
assert.NoError(t, err) assert.NoError(t, err)
if !assert.NotNil(t, resp) { if !assert.NotNil(t, resp) {
t.FailNow() t.FailNow()
@ -124,7 +124,7 @@ func TestLFSSupport(t *testing.T) {
func TestGetOptions(t *testing.T) { func TestGetOptions(t *testing.T) {
log.Println("=== TestGetOptions ===") log.Println("=== TestGetOptions ===")
req, _ := http.NewRequest(http.MethodOptions, "https://mock-pages.codeberg-test.org:4430/README.md", nil) req, _ := http.NewRequest(http.MethodOptions, "https://mock-pages.codeberg-test.org:4430/README.md", http.NoBody)
resp, err := getTestHTTPSClient().Do(req) resp, err := getTestHTTPSClient().Do(req)
assert.NoError(t, err) assert.NoError(t, err)
if !assert.NotNil(t, resp) { if !assert.NotNil(t, resp) {

View file

@ -28,7 +28,7 @@ func TestMain(m *testing.M) {
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
os.Exit(m.Run()) m.Run()
} }
func startServer(ctx context.Context) error { func startServer(ctx context.Context) error {

View file

@ -54,7 +54,10 @@ func TLSConfig(mainDomainSuffix string,
if info.SupportedProtos != nil { if info.SupportedProtos != nil {
for _, proto := range info.SupportedProtos { for _, proto := range info.SupportedProtos {
if proto == tlsalpn01.ACMETLS1Protocol { if proto != tlsalpn01.ACMETLS1Protocol {
continue
}
challenge, ok := challengeCache.Get(sni) challenge, ok := challengeCache.Get(sni)
if !ok { if !ok {
return nil, errors.New("no challenge for this domain") return nil, errors.New("no challenge for this domain")
@ -66,7 +69,6 @@ func TLSConfig(mainDomainSuffix string,
return cert, nil return cert, nil
} }
} }
}
targetOwner := "" targetOwner := ""
if strings.HasSuffix(sni, mainDomainSuffix) || strings.EqualFold(sni, mainDomainSuffix[1:]) { if strings.HasSuffix(sni, mainDomainSuffix) || strings.EqualFold(sni, mainDomainSuffix[1:]) {
@ -162,6 +164,9 @@ var acmeClientOrderLimit = equalizer.NewTokenBucket(25, 15*time.Minute)
// rate limit is 20 / second, we want 5 / second (especially as one cert takes at least two requests) // rate limit is 20 / second, we want 5 / second (especially as one cert takes at least two requests)
var acmeClientRequestLimit = equalizer.NewTokenBucket(5, 1*time.Second) var acmeClientRequestLimit = equalizer.NewTokenBucket(5, 1*time.Second)
// rate limit is 5 / hour https://letsencrypt.org/docs/failed-validation-limit/
var acmeClientFailLimit = equalizer.NewTokenBucket(5, 1*time.Hour)
type AcmeTLSChallengeProvider struct { type AcmeTLSChallengeProvider struct {
challengeCache cache.SetGetKey challengeCache cache.SetGetKey
} }
@ -196,7 +201,7 @@ func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
func retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) (tls.Certificate, bool) { func retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) (tls.Certificate, bool) {
// parse certificate from database // parse certificate from database
res, err := certDB.Get(string(sni)) res, err := certDB.Get(sni)
if err != nil { if err != nil {
panic(err) // TODO: no panic panic(err) // TODO: no panic
} }
@ -217,7 +222,7 @@ func retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider string, acmeUseRateLi
} }
// renew certificates 7 days before they expire // renew certificates 7 days before they expire
if !tlsCertificate.Leaf.NotAfter.After(time.Now().Add(7 * 24 * time.Hour)) { if tlsCertificate.Leaf.NotAfter.Before(time.Now().Add(7 * 24 * time.Hour)) {
// TODO: add ValidUntil to custom res struct // TODO: add ValidUntil to custom res struct
if res.CSR != nil && len(res.CSR) > 0 { if res.CSR != nil && len(res.CSR) > 0 {
// CSR stores the time when the renewal shall be tried again // CSR stores the time when the renewal shall be tried again
@ -228,9 +233,9 @@ func retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider string, acmeUseRateLi
} }
go (func() { go (func() {
res.CSR = nil // acme client doesn't like CSR to be set res.CSR = nil // acme client doesn't like CSR to be set
tlsCertificate, err = obtainCert(acmeClient, []string{string(sni)}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB) tlsCertificate, err = obtainCert(acmeClient, []string{sni}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
if err != nil { if err != nil {
log.Error().Msgf("Couldn't renew certificate for %s: %v", string(sni), err) log.Error().Msgf("Couldn't renew certificate for %s: %v", sni, err)
} }
})() })()
} }
@ -263,7 +268,7 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
defer obtainLocks.Delete(name) defer obtainLocks.Delete(name)
if acmeClient == nil { if acmeClient == nil {
return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", string(mainDomainSuffix), keyDatabase), nil return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", mainDomainSuffix, keyDatabase), nil
} }
// request actual cert // request actual cert
@ -277,6 +282,9 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
res, err = acmeClient.Certificate.Renew(*renew, true, false, "") res, err = acmeClient.Certificate.Renew(*renew, true, false, "")
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Couldn't renew certificate for %v, trying to request a new one", domains) log.Error().Err(err).Msgf("Couldn't renew certificate for %v, trying to request a new one", domains)
if acmeUseRateLimits {
acmeClientFailLimit.Take()
}
res = nil res = nil
} }
} }
@ -297,21 +305,28 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
Bundle: true, Bundle: true,
MustStaple: false, MustStaple: false,
}) })
if acmeUseRateLimits && err != nil {
acmeClientFailLimit.Take()
}
} }
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Couldn't obtain again a certificate or %v", domains) log.Error().Err(err).Msgf("Couldn't obtain again a certificate or %v", domains)
if renew != nil && renew.CertURL != "" { if renew != nil && renew.CertURL != "" {
tlsCertificate, err := tls.X509KeyPair(renew.Certificate, renew.PrivateKey) tlsCertificate, err := tls.X509KeyPair(renew.Certificate, renew.PrivateKey)
if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) { if err != nil {
return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase), err
}
leaf, err := leaf(&tlsCertificate)
if err == nil && leaf.NotAfter.After(time.Now()) {
// avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at // avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at
renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10)) renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10))
if err := keyDatabase.Put(name, renew); err != nil { if err := keyDatabase.Put(name, renew); err != nil {
return mockCert(domains[0], err.Error(), string(mainDomainSuffix), keyDatabase), err return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase), err
} }
return tlsCertificate, nil return tlsCertificate, nil
} }
} }
return mockCert(domains[0], err.Error(), string(mainDomainSuffix), keyDatabase), err return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase), err
} }
log.Debug().Msgf("Obtained certificate for %v", domains) log.Debug().Msgf("Obtained certificate for %v", domains)
@ -409,7 +424,7 @@ func SetupAcmeConfig(acmeAPI, acmeMail, acmeEabHmac, acmeEabKID string, acmeAcce
func SetupCertificates(mainDomainSuffix, dnsProvider string, acmeConfig *lego.Config, acmeUseRateLimits, enableHTTPServer bool, challengeCache cache.SetGetKey, certDB database.CertDB) error { func SetupCertificates(mainDomainSuffix, dnsProvider string, acmeConfig *lego.Config, acmeUseRateLimits, enableHTTPServer bool, challengeCache cache.SetGetKey, certDB database.CertDB) error {
// getting main cert before ACME account so that we can fail here without hitting rate limits // getting main cert before ACME account so that we can fail here without hitting rate limits
mainCertBytes, err := certDB.Get(string(mainDomainSuffix)) mainCertBytes, err := certDB.Get(mainDomainSuffix)
if err != nil { if err != nil {
return fmt.Errorf("cert database is not working") return fmt.Errorf("cert database is not working")
} }
@ -453,7 +468,7 @@ func SetupCertificates(mainDomainSuffix, dnsProvider string, acmeConfig *lego.Co
} }
if mainCertBytes == nil { if mainCertBytes == nil {
_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, nil, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB) _, err = obtainCert(mainDomainAcmeClient, []string{"*" + mainDomainSuffix, mainDomainSuffix[1:]}, nil, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Couldn't renew main domain certificate, continuing with mock certs only") log.Error().Err(err).Msg("Couldn't renew main domain certificate, continuing with mock certs only")
} }
@ -480,7 +495,7 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
} }
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
if err != nil || !tlsCertificates[0].NotAfter.After(now) { if err != nil || tlsCertificates[0].NotAfter.Before(now) {
err := certDB.Delete(string(key)) err := certDB.Delete(string(key))
if err != nil { if err != nil {
log.Error().Err(err).Msgf("Deleting expired certificate for %q failed", string(key)) log.Error().Err(err).Msgf("Deleting expired certificate for %q failed", string(key))
@ -502,18 +517,18 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
} }
// update main cert // update main cert
res, err := certDB.Get(string(mainDomainSuffix)) res, err := certDB.Get(mainDomainSuffix)
if err != nil { if err != nil {
log.Error().Msgf("Couldn't get cert for domain %q", mainDomainSuffix) log.Error().Msgf("Couldn't get cert for domain %q", mainDomainSuffix)
} else if res == nil { } else if res == nil {
log.Error().Msgf("Couldn't renew certificate for main domain %q expected main domain cert to exist, but it's missing - seems like the database is corrupted", string(mainDomainSuffix)) log.Error().Msgf("Couldn't renew certificate for main domain %q expected main domain cert to exist, but it's missing - seems like the database is corrupted", mainDomainSuffix)
} else { } else {
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate) tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
// renew main certificate 30 days before it expires // renew main certificate 30 days before it expires
if !tlsCertificates[0].NotAfter.After(time.Now().Add(30 * 24 * time.Hour)) { if tlsCertificates[0].NotAfter.Before(time.Now().Add(30 * 24 * time.Hour)) {
go (func() { go (func() {
_, err = obtainCert(mainDomainAcmeClient, []string{"*" + string(mainDomainSuffix), string(mainDomainSuffix[1:])}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB) _, err = obtainCert(mainDomainAcmeClient, []string{"*" + mainDomainSuffix, mainDomainSuffix[1:]}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Couldn't renew certificate for main domain") log.Error().Err(err).Msg("Couldn't renew certificate for main domain")
} }
@ -528,3 +543,12 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
} }
} }
} }
// leaf returns the parsed leaf certificate, either from c.leaf or by parsing
// the corresponding c.Certificate[0].
func leaf(c *tls.Certificate) (*x509.Certificate, error) {
if c.Leaf != nil {
return c.Leaf, nil
}
return x509.ParseCertificate(c.Certificate[0])
}

View file

@ -44,7 +44,7 @@ func (p aDB) Get(name string) (*certificate.Resource, error) {
if resBytes == nil { if resBytes == nil {
return nil, nil return nil, nil
} }
if err = gob.NewDecoder(bytes.NewBuffer(resBytes)).Decode(cert); err != nil { if err := gob.NewDecoder(bytes.NewBuffer(resBytes)).Decode(cert); err != nil {
return nil, err return nil, err
} }
return cert, nil return cert, nil

View file

@ -42,9 +42,8 @@ func (f FileResponse) IsEmpty() bool {
return len(f.Body) != 0 return len(f.Body) != 0
} }
func (f FileResponse) createHttpResponse(cacheKey string) (http.Header, int) { func (f FileResponse) createHttpResponse(cacheKey string) (header http.Header, statusCode int) {
header := make(http.Header) header = make(http.Header)
var statusCode int
if f.Exists { if f.Exists {
statusCode = http.StatusOK statusCode = http.StatusOK

View file

@ -28,7 +28,7 @@ func Handler(mainDomainSuffix, rawDomain string,
dnsLookupCache, canonicalDomainCache cache.SetGetKey, dnsLookupCache, canonicalDomainCache cache.SetGetKey,
) http.HandlerFunc { ) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) {
log := log.With().Strs("Handler", []string{string(req.Host), req.RequestURI}).Logger() log := log.With().Strs("Handler", []string{req.Host, req.RequestURI}).Logger()
ctx := context.New(w, req) ctx := context.New(w, req)
ctx.RespWriter.Header().Set("Server", "CodebergPages/"+version.Version) ctx.RespWriter.Header().Set("Server", "CodebergPages/"+version.Version)

View file

@ -55,7 +55,7 @@ func handleCustomDomain(log zerolog.Logger, ctx *context.Context, giteaClient *g
// only redirect if the target is also a codeberg page! // only redirect if the target is also a codeberg page!
targetOwner, _, _ = dns.GetTargetFromDNS(strings.SplitN(canonicalDomain, "/", 2)[0], mainDomainSuffix, firstDefaultBranch, dnsLookupCache) targetOwner, _, _ = dns.GetTargetFromDNS(strings.SplitN(canonicalDomain, "/", 2)[0], mainDomainSuffix, firstDefaultBranch, dnsLookupCache)
if targetOwner != "" { if targetOwner != "" {
ctx.Redirect("https://"+canonicalDomain+string(targetOpt.TargetPath), http.StatusTemporaryRedirect) ctx.Redirect("https://"+canonicalDomain+targetOpt.TargetPath, http.StatusTemporaryRedirect)
return return
} }

View file

@ -30,7 +30,7 @@ func handleSubDomain(log zerolog.Logger, ctx *context.Context, giteaClient *gite
if targetOwner == "www" { if targetOwner == "www" {
// www.codeberg.page redirects to codeberg.page // TODO: rm hardcoded - use cname? // www.codeberg.page redirects to codeberg.page // TODO: rm hardcoded - use cname?
ctx.Redirect("https://"+string(mainDomainSuffix[1:])+string(ctx.Path()), http.StatusPermanentRedirect) ctx.Redirect("https://"+mainDomainSuffix[1:]+ctx.Path(), http.StatusPermanentRedirect)
return return
} }
@ -131,6 +131,6 @@ func handleSubDomain(log zerolog.Logger, ctx *context.Context, giteaClient *gite
// Couldn't find a valid repo/branch // Couldn't find a valid repo/branch
html.ReturnErrorPage(ctx, html.ReturnErrorPage(ctx,
fmt.Sprintf("couldn't find a valid repo[%s]", targetRepo), fmt.Sprintf("could not find a valid repository[%s]", targetRepo),
http.StatusFailedDependency) http.StatusNotFound)
} }

View file

@ -24,6 +24,7 @@ func TestHandlerPerformance(t *testing.T) {
) )
testCase := func(uri string, status int) { testCase := func(uri string, status int) {
t.Run(uri, func(t *testing.T) {
req := httptest.NewRequest("GET", uri, nil) req := httptest.NewRequest("GET", uri, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -40,10 +41,10 @@ func TestHandlerPerformance(t *testing.T) {
} else { } else {
t.Logf("request took %d milliseconds", end.Sub(start).Milliseconds()) t.Logf("request took %d milliseconds", end.Sub(start).Milliseconds())
} }
})
} }
testCase("https://mondstern.codeberg.page/", 424) // TODO: expect 200 testCase("https://mondstern.codeberg.page/", 404) // TODO: expect 200
testCase("https://mondstern.codeberg.page/", 424) // TODO: expect 200 testCase("https://codeberg.page/", 404) // TODO: expect 200
testCase("https://example.momar.xyz/", 424) // TODO: expect 200 testCase("https://example.momar.xyz/", 424)
testCase("https://codeberg.page/", 424) // TODO: expect 200
} }

View file

@ -21,8 +21,8 @@ func tryUpstream(ctx *context.Context, giteaClient *gitea.Client,
) { ) {
// check if a canonical domain exists on a request on MainDomain // check if a canonical domain exists on a request on MainDomain
if strings.HasSuffix(trimmedHost, mainDomainSuffix) { if strings.HasSuffix(trimmedHost, mainDomainSuffix) {
canonicalDomain, _ := options.CheckCanonicalDomain(giteaClient, "", string(mainDomainSuffix), canonicalDomainCache) canonicalDomain, _ := options.CheckCanonicalDomain(giteaClient, "", mainDomainSuffix, canonicalDomainCache)
if !strings.HasSuffix(strings.SplitN(canonicalDomain, "/", 2)[0], string(mainDomainSuffix)) { if !strings.HasSuffix(strings.SplitN(canonicalDomain, "/", 2)[0], mainDomainSuffix) {
canonicalPath := ctx.Req.RequestURI canonicalPath := ctx.Req.RequestURI
if options.TargetRepo != defaultPagesRepo { if options.TargetRepo != defaultPagesRepo {
path := strings.SplitN(canonicalPath, "/", 3) path := strings.SplitN(canonicalPath, "/", 3)
@ -35,8 +35,8 @@ func tryUpstream(ctx *context.Context, giteaClient *gitea.Client,
} }
} }
// add host for debugging // Add host for debugging.
options.Host = string(trimmedHost) options.Host = trimmedHost
// Try to request the file from the Gitea API // Try to request the file from the Gitea API
if !options.Upstream(ctx, giteaClient) { if !options.Upstream(ctx, giteaClient) {

View file

@ -15,13 +15,13 @@ func SetupHTTPACMEChallengeServer(challengeCache cache.SetGetKey) http.HandlerFu
return func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) {
ctx := context.New(w, req) ctx := context.New(w, req)
if strings.HasPrefix(ctx.Path(), challengePath) { if strings.HasPrefix(ctx.Path(), challengePath) {
challenge, ok := challengeCache.Get(utils.TrimHostPort(ctx.Host()) + "/" + string(strings.TrimPrefix(ctx.Path(), challengePath))) challenge, ok := challengeCache.Get(utils.TrimHostPort(ctx.Host()) + "/" + strings.TrimPrefix(ctx.Path(), challengePath))
if !ok || challenge == nil { if !ok || challenge == nil {
ctx.String("no challenge for this token", http.StatusNotFound) ctx.String("no challenge for this token", http.StatusNotFound)
} }
ctx.String(challenge.(string)) ctx.String(challenge.(string))
} else { } else {
ctx.Redirect("https://"+string(ctx.Host())+string(ctx.Path()), http.StatusMovedPermanently) ctx.Redirect("https://"+ctx.Host()+ctx.Path(), http.StatusMovedPermanently)
} }
} }
} }

View file

@ -13,7 +13,7 @@ import (
func (o *Options) GetBranchTimestamp(giteaClient *gitea.Client) (bool, error) { func (o *Options) GetBranchTimestamp(giteaClient *gitea.Client) (bool, error) {
log := log.With().Strs("BranchInfo", []string{o.TargetOwner, o.TargetRepo, o.TargetBranch}).Logger() log := log.With().Strs("BranchInfo", []string{o.TargetOwner, o.TargetRepo, o.TargetBranch}).Logger()
if len(o.TargetBranch) == 0 { if o.TargetBranch == "" {
// Get default branch // Get default branch
defaultBranch, err := giteaClient.GiteaGetRepoDefaultBranch(o.TargetOwner, o.TargetRepo) defaultBranch, err := giteaClient.GiteaGetRepoDefaultBranch(o.TargetOwner, o.TargetRepo)
if err != nil { if err != nil {

View file

@ -82,8 +82,8 @@ func (o *Options) Upstream(ctx *context.Context, giteaClient *gitea.Client) (fin
// Check if the browser has a cached version // Check if the browser has a cached version
if ctx.Response() != nil { if ctx.Response() != nil {
if ifModifiedSince, err := time.Parse(time.RFC1123, string(ctx.Response().Header.Get(headerIfModifiedSince))); err == nil { if ifModifiedSince, err := time.Parse(time.RFC1123, ctx.Response().Header.Get(headerIfModifiedSince)); err == nil {
if !ifModifiedSince.Before(o.BranchTimestamp) { if ifModifiedSince.After(o.BranchTimestamp) {
ctx.RespWriter.WriteHeader(http.StatusNotModified) ctx.RespWriter.WriteHeader(http.StatusNotModified)
log.Trace().Msg("check response against last modified: valid") log.Trace().Msg("check response against last modified: valid")
return true return true