.domains
file"
- case http.StatusFailedDependency:
- message += " - target repo/branch doesn't exist or is private"
- }
-
- return message
-}
-
-// TODO: use template engine
-func errorBody(statusCode int) string {
- return strings.ReplaceAll(NotFoundPage,
- "%status%",
- strconv.Itoa(statusCode)+" "+errorMessage(statusCode))
-}
diff --git a/html/error.html b/html/error.html
deleted file mode 100644
index f1975f7..0000000
--- a/html/error.html
+++ /dev/null
@@ -1,38 +0,0 @@
-
-
-
-
-
-
+func ReturnErrorPage(ctx *context.Context, msg string, statusCode int) {
+ ctx.RespWriter.Header().Set("Content-Type", "text/html; charset=utf-8")
+ ctx.RespWriter.WriteHeader(statusCode)
+
+ templateContext := TemplateContext{
+ StatusCode: statusCode,
+ StatusText: http.StatusText(statusCode),
+ Message: sanitizer.Sanitize(msg),
+ }
+
+ err := errorTemplate.Execute(ctx.RespWriter, templateContext)
+ if err != nil {
+ log.Err(err).Str("message", msg).Int("status", statusCode).Msg("could not write response")
+ }
+}
+
+func createBlueMondayPolicy() *bluemonday.Policy {
+ p := bluemonday.NewPolicy()
+
+ p.AllowElements("code")
+
+ return p
+}
+
+func loadCustomTemplateOrDefault() string {
+ contents, err := os.ReadFile("custom/error.html")
+ if err != nil {
+ if !os.IsNotExist(err) {
+ wd, wdErr := os.Getwd()
+ if wdErr != nil {
+ log.Err(err).Msg("could not load custom error page 'custom/error.html'")
+ } else {
+ log.Err(err).Msgf("could not load custom error page '%v'", path.Join(wd, "custom/error.html"))
+ }
+ }
+ return errorPage
+ }
+ return string(contents)
+}
diff --git a/html/html_test.go b/html/html_test.go
new file mode 100644
index 0000000..b395bb2
--- /dev/null
+++ b/html/html_test.go
@@ -0,0 +1,54 @@
+package html
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestSanitizerSimpleString(t *testing.T) {
+ str := "simple text message without any html elements"
+
+ assert.Equal(t, str, sanitizer.Sanitize(str))
+}
+
+func TestSanitizerStringWithCodeTag(t *testing.T) {
+ str := "simple text message with html
tag"
+
+ assert.Equal(t, str, sanitizer.Sanitize(str))
+}
+
+func TestSanitizerStringWithCodeTagWithAttribute(t *testing.T) {
+ str := "simple text message with html
tag"
+ expected := "simple text message with html
tag"
+
+ assert.Equal(t, expected, sanitizer.Sanitize(str))
+}
+
+func TestSanitizerStringWithATag(t *testing.T) {
+ str := "simple text message with a link to another page"
+ expected := "simple text message with a link to another page"
+
+ assert.Equal(t, expected, sanitizer.Sanitize(str))
+}
+
+func TestSanitizerStringWithATagAndHref(t *testing.T) {
+ str := "simple text message with a link to another page"
+ expected := "simple text message with a link to another page"
+
+ assert.Equal(t, expected, sanitizer.Sanitize(str))
+}
+
+func TestSanitizerStringWithImgTag(t *testing.T) {
+ str := "simple text message with a
"
+ expected := "simple text message with a "
+
+ assert.Equal(t, expected, sanitizer.Sanitize(str))
+}
+
+func TestSanitizerStringWithImgTagAndOnerrorAttribute(t *testing.T) {
+ str := "simple text message with a
"
+ expected := "simple text message with a "
+
+ assert.Equal(t, expected, sanitizer.Sanitize(str))
+}
diff --git a/html/templates/error.html b/html/templates/error.html
new file mode 100644
index 0000000..ccaa682
--- /dev/null
+++ b/html/templates/error.html
@@ -0,0 +1,58 @@
+
+
+
+
+
+ {{.StatusText}}
+
+
+
+
+
+
+
+
+ {{.StatusText}} (Error {{.StatusCode}})!
+
+
Sorry, but this page couldn't be served:
+ "{{.Message}}"
+
+ The page you tried to reach is hosted on Codeberg Pages, which might currently be experiencing technical
+ difficulties. If that is the case, it could take a little while until this page is available again.
+
+
+ Otherwise, this page might also be unavailable due to a configuration error. If you are the owner of this
+ website, please make sure to check the
+ troubleshooting section in the Docs!
+
+
+
+
+ Static pages made easy -
+ Codeberg Pages
+
+
+
diff --git a/integration/get_test.go b/integration/get_test.go
index 81d8488..cfb7188 100644
--- a/integration/get_test.go
+++ b/integration/get_test.go
@@ -20,7 +20,9 @@ func TestGetRedirect(t *testing.T) {
log.Println("=== TestGetRedirect ===")
// test custom domain redirect
resp, err := getTestHTTPSClient().Get("https://calciumdibromid.localhost.mock.directory:4430")
- assert.NoError(t, err)
+ if !assert.NoError(t, err) {
+ t.FailNow()
+ }
if !assert.EqualValues(t, http.StatusTemporaryRedirect, resp.StatusCode) {
t.FailNow()
}
@@ -31,7 +33,7 @@ func TestGetRedirect(t *testing.T) {
func TestGetContent(t *testing.T) {
log.Println("=== TestGetContent ===")
// test get image
- resp, err := getTestHTTPSClient().Get("https://magiclike.localhost.mock.directory:4430/images/827679288a.jpg")
+ resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/images/827679288a.jpg")
assert.NoError(t, err)
if !assert.EqualValues(t, http.StatusOK, resp.StatusCode) {
t.FailNow()
@@ -42,7 +44,7 @@ func TestGetContent(t *testing.T) {
assert.Len(t, resp.Header.Get("ETag"), 42)
// specify branch
- resp, err = getTestHTTPSClient().Get("https://momar.localhost.mock.directory:4430/pag/@master/")
+ resp, err = getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/pag/@master/")
assert.NoError(t, err)
if !assert.NotNil(t, resp) {
t.FailNow()
@@ -53,7 +55,7 @@ func TestGetContent(t *testing.T) {
assert.Len(t, resp.Header.Get("ETag"), 44)
// access branch name contains '/'
- resp, err = getTestHTTPSClient().Get("https://blumia.localhost.mock.directory:4430/pages-server-integration-tests/@docs~main/")
+ resp, err = getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/blumia/@docs~main/")
assert.NoError(t, err)
if !assert.EqualValues(t, http.StatusOK, resp.StatusCode) {
t.FailNow()
@@ -62,7 +64,7 @@ func TestGetContent(t *testing.T) {
assert.True(t, getSize(resp.Body) > 100)
assert.Len(t, resp.Header.Get("ETag"), 44)
- // TODO: test get of non cachable content (content size > fileCacheSizeLimit)
+ // TODO: test get of non cacheable content (content size > fileCacheSizeLimit)
}
func TestCustomDomain(t *testing.T) {
@@ -78,10 +80,67 @@ func TestCustomDomain(t *testing.T) {
assert.EqualValues(t, 106, getSize(resp.Body))
}
+func TestCustomDomainRedirects(t *testing.T) {
+ log.Println("=== TestCustomDomainRedirects ===")
+ // test redirect from default pages domain to custom domain
+ resp, err := getTestHTTPSClient().Get("https://6543.localhost.mock.directory:4430/test_pages-server_custom-mock-domain/@main/README.md")
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusTemporaryRedirect, resp.StatusCode)
+ assert.EqualValues(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
+ // TODO: custom port is not evaluated (witch does hurt tests & dev env only)
+ // assert.EqualValues(t, "https://mock-pages.codeberg-test.org:4430/@main/README.md", resp.Header.Get("Location"))
+ assert.EqualValues(t, "https://mock-pages.codeberg-test.org/@main/README.md", resp.Header.Get("Location"))
+ assert.EqualValues(t, `https:/codeberg.org/6543/test_pages-server_custom-mock-domain/src/branch/main/README.md; rel="canonical"; rel="canonical"`, resp.Header.Get("Link"))
+
+ // test redirect from an custom domain to the primary custom domain (www.example.com -> example.com)
+ // regression test to https://codeberg.org/Codeberg/pages-server/issues/153
+ resp, err = getTestHTTPSClient().Get("https://mock-pages-redirect.codeberg-test.org:4430/README.md")
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusTemporaryRedirect, resp.StatusCode)
+ assert.EqualValues(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
+ // TODO: custom port is not evaluated (witch does hurt tests & dev env only)
+ // assert.EqualValues(t, "https://mock-pages.codeberg-test.org:4430/README.md", resp.Header.Get("Location"))
+ assert.EqualValues(t, "https://mock-pages.codeberg-test.org/README.md", resp.Header.Get("Location"))
+}
+
+func TestRawCustomDomain(t *testing.T) {
+ log.Println("=== TestRawCustomDomain ===")
+ // test raw domain response for custom domain branch
+ resp, err := getTestHTTPSClient().Get("https://raw.localhost.mock.directory:4430/cb_pages_tests/raw-test/example") // need cb_pages_tests fork
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusOK, resp.StatusCode)
+ assert.EqualValues(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
+ assert.EqualValues(t, "76", resp.Header.Get("Content-Length"))
+ assert.EqualValues(t, 76, getSize(resp.Body))
+}
+
+func TestRawIndex(t *testing.T) {
+ log.Println("=== TestRawIndex ===")
+ // test raw domain response for index.html
+ resp, err := getTestHTTPSClient().Get("https://raw.localhost.mock.directory:4430/cb_pages_tests/raw-test/@branch-test/index.html") // need cb_pages_tests fork
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusOK, resp.StatusCode)
+ assert.EqualValues(t, "text/plain; charset=utf-8", resp.Header.Get("Content-Type"))
+ assert.EqualValues(t, "597", resp.Header.Get("Content-Length"))
+ assert.EqualValues(t, 597, getSize(resp.Body))
+}
+
func TestGetNotFound(t *testing.T) {
log.Println("=== TestGetNotFound ===")
// test custom not found pages
- resp, err := getTestHTTPSClient().Get("https://crystal.localhost.mock.directory:4430/pages-404-demo/blah")
+ resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/pages-404-demo/blah")
assert.NoError(t, err)
if !assert.NotNil(t, resp) {
t.FailNow()
@@ -92,10 +151,51 @@ func TestGetNotFound(t *testing.T) {
assert.EqualValues(t, 37, getSize(resp.Body))
}
+func TestRedirect(t *testing.T) {
+ log.Println("=== TestRedirect ===")
+ // test redirects
+ resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/some_redirects/redirect")
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusMovedPermanently, resp.StatusCode)
+ assert.EqualValues(t, "https://example.com/", resp.Header.Get("Location"))
+}
+
+func TestSPARedirect(t *testing.T) {
+ log.Println("=== TestSPARedirect ===")
+ // test SPA redirects
+ url := "https://cb_pages_tests.localhost.mock.directory:4430/some_redirects/app/aqdjw"
+ resp, err := getTestHTTPSClient().Get(url)
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusOK, resp.StatusCode)
+ assert.EqualValues(t, url, resp.Request.URL.String())
+ assert.EqualValues(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
+ assert.EqualValues(t, "258", resp.Header.Get("Content-Length"))
+ assert.EqualValues(t, 258, getSize(resp.Body))
+}
+
+func TestSplatRedirect(t *testing.T) {
+ log.Println("=== TestSplatRedirect ===")
+ // test splat redirects
+ resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/some_redirects/articles/qfopefe")
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusMovedPermanently, resp.StatusCode)
+ assert.EqualValues(t, "/posts/qfopefe", resp.Header.Get("Location"))
+}
+
func TestFollowSymlink(t *testing.T) {
log.Printf("=== TestFollowSymlink ===\n")
- resp, err := getTestHTTPSClient().Get("https://6543.localhost.mock.directory:4430/tests_for_pages-server/@main/link")
+ // file symlink
+ resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/tests_for_pages-server/@main/link")
assert.NoError(t, err)
if !assert.NotNil(t, resp) {
t.FailNow()
@@ -106,12 +206,22 @@ func TestFollowSymlink(t *testing.T) {
body := getBytes(resp.Body)
assert.EqualValues(t, 4, len(body))
assert.EqualValues(t, "abc\n", string(body))
+
+ // relative file links (../index.html file in this case)
+ resp, err = getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/tests_for_pages-server/@main/dir_aim/some/")
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusOK, resp.StatusCode)
+ assert.EqualValues(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
+ assert.EqualValues(t, "an index\n", string(getBytes(resp.Body)))
}
func TestLFSSupport(t *testing.T) {
log.Printf("=== TestLFSSupport ===\n")
- resp, err := getTestHTTPSClient().Get("https://6543.localhost.mock.directory:4430/tests_for_pages-server/@main/lfs.txt")
+ resp, err := getTestHTTPSClient().Get("https://cb_pages_tests.localhost.mock.directory:4430/tests_for_pages-server/@main/lfs.txt")
assert.NoError(t, err)
if !assert.NotNil(t, resp) {
t.FailNow()
@@ -134,6 +244,18 @@ func TestGetOptions(t *testing.T) {
assert.EqualValues(t, "GET, HEAD, OPTIONS", resp.Header.Get("Allow"))
}
+func TestHttpRedirect(t *testing.T) {
+ log.Println("=== TestHttpRedirect ===")
+ resp, err := getTestHTTPSClient().Get("http://mock-pages.codeberg-test.org:8880/README.md")
+ assert.NoError(t, err)
+ if !assert.NotNil(t, resp) {
+ t.FailNow()
+ }
+ assert.EqualValues(t, http.StatusMovedPermanently, resp.StatusCode)
+ assert.EqualValues(t, "text/html; charset=utf-8", resp.Header.Get("Content-Type"))
+ assert.EqualValues(t, "https://mock-pages.codeberg-test.org:4430/README.md", resp.Header.Get("Location"))
+}
+
func getTestHTTPSClient() *http.Client {
cookieJar, _ := cookiejar.New(nil)
return &http.Client{
diff --git a/integration/main_test.go b/integration/main_test.go
index 406b33a..86fd9d3 100644
--- a/integration/main_test.go
+++ b/integration/main_test.go
@@ -10,9 +10,10 @@ import (
"testing"
"time"
- "codeberg.org/codeberg/pages/cmd"
-
"github.com/urfave/cli/v2"
+
+ cmd "codeberg.org/codeberg/pages/cli"
+ "codeberg.org/codeberg/pages/server"
)
func TestMain(m *testing.M) {
@@ -23,7 +24,7 @@ func TestMain(m *testing.M) {
}
defer func() {
serverCancel()
- log.Println("=== TestMain: Server STOPED ===")
+ log.Println("=== TestMain: Server STOPPED ===")
}()
time.Sleep(10 * time.Second)
@@ -32,19 +33,25 @@ func TestMain(m *testing.M) {
}
func startServer(ctx context.Context) error {
- args := []string{
- "--verbose",
- "--acme-accept-terms", "true",
- }
+ args := []string{"integration"}
setEnvIfNotSet("ACME_API", "https://acme.mock.directory")
setEnvIfNotSet("PAGES_DOMAIN", "localhost.mock.directory")
setEnvIfNotSet("RAW_DOMAIN", "raw.localhost.mock.directory")
+ setEnvIfNotSet("PAGES_BRANCHES", "pages,main,master")
setEnvIfNotSet("PORT", "4430")
+ setEnvIfNotSet("HTTP_PORT", "8880")
+ setEnvIfNotSet("ENABLE_HTTP_SERVER", "true")
+ setEnvIfNotSet("DB_TYPE", "sqlite3")
+ setEnvIfNotSet("GITEA_ROOT", "https://codeberg.org")
+ setEnvIfNotSet("LOG_LEVEL", "trace")
+ setEnvIfNotSet("ENABLE_LFS_SUPPORT", "true")
+ setEnvIfNotSet("ENABLE_SYMLINK_SUPPORT", "true")
+ setEnvIfNotSet("ACME_ACCOUNT_CONFIG", "integration/acme-account.json")
app := cli.NewApp()
app.Name = "pages-server"
- app.Action = cmd.Serve
- app.Flags = cmd.ServeFlags
+ app.Action = server.Serve
+ app.Flags = cmd.ServerFlags
go func() {
if err := app.RunContext(ctx, args); err != nil {
diff --git a/main.go b/main.go
index 2836b86..87e21f3 100644
--- a/main.go
+++ b/main.go
@@ -1,31 +1,21 @@
package main
import (
- "fmt"
"os"
_ "github.com/joho/godotenv/autoload"
- "github.com/urfave/cli/v2"
+ "github.com/rs/zerolog/log"
- "codeberg.org/codeberg/pages/cmd"
+ "codeberg.org/codeberg/pages/cli"
+ "codeberg.org/codeberg/pages/server"
)
-// can be changed with -X on compile
-var version = "dev"
-
func main() {
- app := cli.NewApp()
- app.Name = "pages-server"
- app.Version = version
- app.Usage = "pages server"
- app.Action = cmd.Serve
- app.Flags = cmd.ServeFlags
- app.Commands = []*cli.Command{
- cmd.Certs,
- }
+ app := cli.CreatePagesApp()
+ app.Action = server.Serve
if err := app.Run(os.Args); err != nil {
- _, _ = fmt.Fprintln(os.Stderr, err)
+ log.Error().Err(err).Msg("A fatal error occurred")
os.Exit(1)
}
}
diff --git a/renovate.json b/renovate.json
new file mode 100644
index 0000000..9dd1cd7
--- /dev/null
+++ b/renovate.json
@@ -0,0 +1,27 @@
+{
+ "$schema": "https://docs.renovatebot.com/renovate-schema.json",
+ "extends": [
+ "config:recommended",
+ ":maintainLockFilesWeekly",
+ ":enablePreCommit",
+ "schedule:automergeDaily",
+ "schedule:weekends"
+ ],
+ "automergeType": "branch",
+ "automergeMajor": false,
+ "automerge": true,
+ "prConcurrentLimit": 5,
+ "labels": ["dependencies"],
+ "packageRules": [
+ {
+ "matchManagers": ["gomod", "dockerfile"]
+ },
+ {
+ "groupName": "golang deps non-major",
+ "matchManagers": ["gomod"],
+ "matchUpdateTypes": ["minor", "patch"],
+ "extends": ["schedule:daily"]
+ }
+ ],
+ "postUpdateOptions": ["gomodTidy", "gomodUpdateImportPaths"]
+}
diff --git a/server/acme/client.go b/server/acme/client.go
new file mode 100644
index 0000000..d5c83d0
--- /dev/null
+++ b/server/acme/client.go
@@ -0,0 +1,26 @@
+package acme
+
+import (
+ "errors"
+ "fmt"
+
+ "codeberg.org/codeberg/pages/config"
+ "codeberg.org/codeberg/pages/server/cache"
+ "codeberg.org/codeberg/pages/server/certificates"
+)
+
+var ErrAcmeMissConfig = errors.New("ACME client has wrong config")
+
+func CreateAcmeClient(cfg config.ACMEConfig, enableHTTPServer bool, challengeCache cache.ICache) (*certificates.AcmeClient, error) {
+ // check config
+ if (!cfg.AcceptTerms || (cfg.DNSProvider == "" && !cfg.NoDNS01)) && cfg.APIEndpoint != "https://acme.mock.directory" {
+ return nil, fmt.Errorf("%w: you must set $ACME_ACCEPT_TERMS and $DNS_PROVIDER or $NO_DNS_01, unless $ACME_API is set to https://acme.mock.directory", ErrAcmeMissConfig)
+ }
+ if cfg.EAB_HMAC != "" && cfg.EAB_KID == "" {
+ return nil, fmt.Errorf("%w: ACME_EAB_HMAC also needs ACME_EAB_KID to be set", ErrAcmeMissConfig)
+ } else if cfg.EAB_HMAC == "" && cfg.EAB_KID != "" {
+ return nil, fmt.Errorf("%w: ACME_EAB_KID also needs ACME_EAB_HMAC to be set", ErrAcmeMissConfig)
+ }
+
+ return certificates.NewAcmeClient(cfg, enableHTTPServer, challengeCache)
+}
diff --git a/server/cache/interface.go b/server/cache/interface.go
index 2952b29..b3412cc 100644
--- a/server/cache/interface.go
+++ b/server/cache/interface.go
@@ -2,7 +2,8 @@ package cache
import "time"
-type SetGetKey interface {
+// ICache is an interface that defines how the pages server interacts with the cache.
+type ICache interface {
Set(key string, value interface{}, ttl time.Duration) error
Get(key string) (interface{}, bool)
Remove(key string)
diff --git a/server/cache/setup.go b/server/cache/memory.go
similarity index 69%
rename from server/cache/setup.go
rename to server/cache/memory.go
index a5928b0..093696f 100644
--- a/server/cache/setup.go
+++ b/server/cache/memory.go
@@ -2,6 +2,6 @@ package cache
import "github.com/OrlovEvgeny/go-mcache"
-func NewKeyValueCache() SetGetKey {
+func NewInMemoryCache() ICache {
return mcache.New()
}
diff --git a/server/certificates/acme_client.go b/server/certificates/acme_client.go
new file mode 100644
index 0000000..f42fd8f
--- /dev/null
+++ b/server/certificates/acme_client.go
@@ -0,0 +1,93 @@
+package certificates
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/go-acme/lego/v4/lego"
+ "github.com/go-acme/lego/v4/providers/dns"
+ "github.com/reugn/equalizer"
+ "github.com/rs/zerolog/log"
+
+ "codeberg.org/codeberg/pages/config"
+ "codeberg.org/codeberg/pages/server/cache"
+)
+
+type AcmeClient struct {
+ legoClient *lego.Client
+ dnsChallengerLegoClient *lego.Client
+
+ obtainLocks sync.Map
+
+ acmeUseRateLimits bool
+
+ // limiter
+ acmeClientOrderLimit *equalizer.TokenBucket
+ acmeClientRequestLimit *equalizer.TokenBucket
+ acmeClientFailLimit *equalizer.TokenBucket
+ acmeClientCertificateLimitPerUser map[string]*equalizer.TokenBucket
+}
+
+func NewAcmeClient(cfg config.ACMEConfig, enableHTTPServer bool, challengeCache cache.ICache) (*AcmeClient, error) {
+ acmeConfig, err := setupAcmeConfig(cfg)
+ if err != nil {
+ return nil, err
+ }
+
+ acmeClient, err := lego.NewClient(acmeConfig)
+ if err != nil {
+ log.Fatal().Err(err).Msg("Can't create ACME client, continuing with mock certs only")
+ } else {
+ err = acmeClient.Challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{challengeCache})
+ if err != nil {
+ log.Error().Err(err).Msg("Can't create TLS-ALPN-01 provider")
+ }
+ if enableHTTPServer {
+ err = acmeClient.Challenge.SetHTTP01Provider(AcmeHTTPChallengeProvider{challengeCache})
+ if err != nil {
+ log.Error().Err(err).Msg("Can't create HTTP-01 provider")
+ }
+ }
+ }
+
+ mainDomainAcmeClient, err := lego.NewClient(acmeConfig)
+ if err != nil {
+ log.Error().Err(err).Msg("Can't create ACME client, continuing with mock certs only")
+ } else {
+ if cfg.DNSProvider == "" {
+ // using mock wildcard certs
+ mainDomainAcmeClient = nil
+ } else {
+ // use DNS-Challenge https://go-acme.github.io/lego/dns/
+ provider, err := dns.NewDNSChallengeProviderByName(cfg.DNSProvider)
+ if err != nil {
+ return nil, fmt.Errorf("can not create DNS Challenge provider: %w", err)
+ }
+ if err := mainDomainAcmeClient.Challenge.SetDNS01Provider(provider); err != nil {
+ return nil, fmt.Errorf("can not create DNS-01 provider: %w", err)
+ }
+ }
+ }
+
+ return &AcmeClient{
+ legoClient: acmeClient,
+ dnsChallengerLegoClient: mainDomainAcmeClient,
+
+ acmeUseRateLimits: cfg.UseRateLimits,
+
+ obtainLocks: sync.Map{},
+
+ // limiter
+
+ // rate limit is 300 / 3 hours, we want 200 / 2 hours but to refill more often, so that's 25 new domains every 15 minutes
+ // TODO: when this is used a lot, we probably have to think of a somewhat better solution?
+ acmeClientOrderLimit: equalizer.NewTokenBucket(25, 15*time.Minute),
+ // rate limit is 20 / second, we want 5 / second (especially as one cert takes at least two requests)
+ acmeClientRequestLimit: equalizer.NewTokenBucket(5, 1*time.Second),
+ // rate limit is 5 / hour https://letsencrypt.org/docs/failed-validation-limit/
+ acmeClientFailLimit: equalizer.NewTokenBucket(5, 1*time.Hour),
+ // checkUserLimit() use this to rate also per user
+ acmeClientCertificateLimitPerUser: map[string]*equalizer.TokenBucket{},
+ }, nil
+}
diff --git a/server/certificates/acme_config.go b/server/certificates/acme_config.go
new file mode 100644
index 0000000..2b5151d
--- /dev/null
+++ b/server/certificates/acme_config.go
@@ -0,0 +1,110 @@
+package certificates
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "encoding/json"
+ "fmt"
+ "os"
+
+ "codeberg.org/codeberg/pages/config"
+ "github.com/go-acme/lego/v4/certcrypto"
+ "github.com/go-acme/lego/v4/lego"
+ "github.com/go-acme/lego/v4/registration"
+ "github.com/rs/zerolog/log"
+)
+
+const challengePath = "/.well-known/acme-challenge/"
+
+func setupAcmeConfig(cfg config.ACMEConfig) (*lego.Config, error) {
+ var myAcmeAccount AcmeAccount
+ var myAcmeConfig *lego.Config
+
+ if cfg.AccountConfigFile == "" {
+ return nil, fmt.Errorf("invalid acme config file: '%s'", cfg.AccountConfigFile)
+ }
+
+ if account, err := os.ReadFile(cfg.AccountConfigFile); err == nil {
+ log.Info().Msgf("found existing acme account config file '%s'", cfg.AccountConfigFile)
+ if err := json.Unmarshal(account, &myAcmeAccount); err != nil {
+ return nil, err
+ }
+
+ myAcmeAccount.Key, err = certcrypto.ParsePEMPrivateKey([]byte(myAcmeAccount.KeyPEM))
+ if err != nil {
+ return nil, err
+ }
+
+ myAcmeConfig = lego.NewConfig(&myAcmeAccount)
+ myAcmeConfig.CADirURL = cfg.APIEndpoint
+ myAcmeConfig.Certificate.KeyType = certcrypto.RSA2048
+
+ // Validate Config
+ _, err := lego.NewClient(myAcmeConfig)
+ if err != nil {
+ log.Info().Err(err).Msg("config validation failed, you might just delete the config file and let it recreate")
+ return nil, fmt.Errorf("acme config validation failed: %w", err)
+ }
+
+ return myAcmeConfig, nil
+ } else if !os.IsNotExist(err) {
+ return nil, err
+ }
+
+ log.Info().Msgf("no existing acme account config found, try to create a new one")
+
+ privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ return nil, err
+ }
+ myAcmeAccount = AcmeAccount{
+ Email: cfg.Email,
+ Key: privateKey,
+ KeyPEM: string(certcrypto.PEMEncode(privateKey)),
+ }
+ myAcmeConfig = lego.NewConfig(&myAcmeAccount)
+ myAcmeConfig.CADirURL = cfg.APIEndpoint
+ myAcmeConfig.Certificate.KeyType = certcrypto.RSA2048
+ tempClient, err := lego.NewClient(myAcmeConfig)
+ if err != nil {
+ log.Error().Err(err).Msg("Can't create ACME client, continuing with mock certs only")
+ } else {
+ // accept terms & log in to EAB
+ if cfg.EAB_KID == "" || cfg.EAB_HMAC == "" {
+ reg, err := tempClient.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: cfg.AcceptTerms})
+ if err != nil {
+ log.Error().Err(err).Msg("Can't register ACME account, continuing with mock certs only")
+ } else {
+ myAcmeAccount.Registration = reg
+ }
+ } else {
+ reg, err := tempClient.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
+ TermsOfServiceAgreed: cfg.AcceptTerms,
+ Kid: cfg.EAB_KID,
+ HmacEncoded: cfg.EAB_HMAC,
+ })
+ if err != nil {
+ log.Error().Err(err).Msg("Can't register ACME account, continuing with mock certs only")
+ } else {
+ myAcmeAccount.Registration = reg
+ }
+ }
+
+ if myAcmeAccount.Registration != nil {
+ acmeAccountJSON, err := json.Marshal(myAcmeAccount)
+ if err != nil {
+ log.Error().Err(err).Msg("json.Marshalfailed, waiting for manual restart to avoid rate limits")
+ select {}
+ }
+ log.Info().Msgf("new acme account created. write to config file '%s'", cfg.AccountConfigFile)
+ err = os.WriteFile(cfg.AccountConfigFile, acmeAccountJSON, 0o600)
+ if err != nil {
+ log.Error().Err(err).Msg("os.WriteFile failed, waiting for manual restart to avoid rate limits")
+ select {}
+ }
+ }
+ }
+
+ return myAcmeConfig, nil
+}
diff --git a/server/certificates/cached_challengers.go b/server/certificates/cached_challengers.go
new file mode 100644
index 0000000..39439fb
--- /dev/null
+++ b/server/certificates/cached_challengers.go
@@ -0,0 +1,83 @@
+package certificates
+
+import (
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/go-acme/lego/v4/challenge"
+ "github.com/rs/zerolog/log"
+
+ "codeberg.org/codeberg/pages/server/cache"
+ "codeberg.org/codeberg/pages/server/context"
+)
+
+type AcmeTLSChallengeProvider struct {
+ challengeCache cache.ICache
+}
+
+// make sure AcmeTLSChallengeProvider match Provider interface
+var _ challenge.Provider = AcmeTLSChallengeProvider{}
+
+func (a AcmeTLSChallengeProvider) Present(domain, _, keyAuth string) error {
+ return a.challengeCache.Set(domain, keyAuth, 1*time.Hour)
+}
+
+func (a AcmeTLSChallengeProvider) CleanUp(domain, _, _ string) error {
+ a.challengeCache.Remove(domain)
+ return nil
+}
+
+type AcmeHTTPChallengeProvider struct {
+ challengeCache cache.ICache
+}
+
+// make sure AcmeHTTPChallengeProvider match Provider interface
+var _ challenge.Provider = AcmeHTTPChallengeProvider{}
+
+func (a AcmeHTTPChallengeProvider) Present(domain, token, keyAuth string) error {
+ return a.challengeCache.Set(domain+"/"+token, keyAuth, 1*time.Hour)
+}
+
+func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
+ a.challengeCache.Remove(domain + "/" + token)
+ return nil
+}
+
+func SetupHTTPACMEChallengeServer(challengeCache cache.ICache, sslPort uint) http.HandlerFunc {
+ // handle custom-ssl-ports to be added on https redirects
+ portPart := ""
+ if sslPort != 443 {
+ portPart = fmt.Sprintf(":%d", sslPort)
+ }
+
+ return func(w http.ResponseWriter, req *http.Request) {
+ ctx := context.New(w, req)
+ domain := ctx.TrimHostPort()
+
+ // it's an acme request
+ if strings.HasPrefix(ctx.Path(), challengePath) {
+ challenge, ok := challengeCache.Get(domain + "/" + strings.TrimPrefix(ctx.Path(), challengePath))
+ if !ok || challenge == nil {
+ log.Info().Msgf("HTTP-ACME challenge for '%s' failed: token not found", domain)
+ ctx.String("no challenge for this token", http.StatusNotFound)
+ }
+ log.Info().Msgf("HTTP-ACME challenge for '%s' succeeded", domain)
+ ctx.String(challenge.(string))
+ return
+ }
+
+ // it's a normal http request that needs to be redirected
+ u, err := url.Parse(fmt.Sprintf("https://%s%s%s", domain, portPart, ctx.Path()))
+ if err != nil {
+ log.Error().Err(err).Msg("could not craft http to https redirect")
+ ctx.String("", http.StatusInternalServerError)
+ }
+
+ newURL := u.String()
+ log.Debug().Msgf("redirect http to https: %s", newURL)
+ ctx.Redirect(newURL, http.StatusMovedPermanently)
+ }
+}
diff --git a/server/certificates/certificates.go b/server/certificates/certificates.go
index 8af4be5..aeb619f 100644
--- a/server/certificates/certificates.go
+++ b/server/certificates/certificates.go
@@ -1,67 +1,72 @@
package certificates
import (
- "bytes"
"context"
- "crypto/ecdsa"
- "crypto/elliptic"
- "crypto/rand"
"crypto/tls"
"crypto/x509"
- "encoding/gob"
- "encoding/json"
"errors"
"fmt"
- "os"
"strconv"
"strings"
- "sync"
"time"
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
- "github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/challenge/tlsalpn01"
"github.com/go-acme/lego/v4/lego"
- "github.com/go-acme/lego/v4/providers/dns"
- "github.com/go-acme/lego/v4/registration"
+ "github.com/hashicorp/golang-lru/v2/expirable"
"github.com/reugn/equalizer"
+ "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"codeberg.org/codeberg/pages/server/cache"
+ psContext "codeberg.org/codeberg/pages/server/context"
"codeberg.org/codeberg/pages/server/database"
dnsutils "codeberg.org/codeberg/pages/server/dns"
"codeberg.org/codeberg/pages/server/gitea"
"codeberg.org/codeberg/pages/server/upstream"
)
+var ErrUserRateLimitExceeded = errors.New("rate limit exceeded: 10 certificates per user per 24 hours")
+
// TLSConfig returns the configuration for generating, serving and cleaning up Let's Encrypt certificates.
func TLSConfig(mainDomainSuffix string,
giteaClient *gitea.Client,
- dnsProvider string,
- acmeUseRateLimits bool,
- keyCache, challengeCache, dnsLookupCache, canonicalDomainCache cache.SetGetKey,
+ acmeClient *AcmeClient,
+ firstDefaultBranch string,
+ challengeCache, canonicalDomainCache cache.ICache,
certDB database.CertDB,
+ noDNS01 bool,
+ rawDomain string,
) *tls.Config {
+ // every cert is at most 24h in the cache and 7 days before expiry the cert is renewed
+ keyCache := expirable.NewLRU[string, *tls.Certificate](32, nil, 24*time.Hour)
+
return &tls.Config{
// check DNS name & get certificate from Let's Encrypt
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
- sni := strings.ToLower(strings.TrimSpace(info.ServerName))
- if len(sni) < 1 {
- return nil, errors.New("missing sni")
+ ctx := psContext.New(nil, nil)
+ log := log.With().Str("ReqId", ctx.ReqId).Logger()
+
+ domain := strings.ToLower(strings.TrimSpace(info.ServerName))
+ log.Debug().Str("domain", domain).Msg("start: get tls certificate")
+ if len(domain) < 1 {
+ return nil, errors.New("missing domain info via SNI (RFC 4366, Section 3.1)")
}
+ // https request init is actually a acme challenge
if info.SupportedProtos != nil {
for _, proto := range info.SupportedProtos {
if proto != tlsalpn01.ACMETLS1Protocol {
continue
}
+ log.Info().Msgf("Detect ACME-TLS1 challenge for '%s'", domain)
- challenge, ok := challengeCache.Get(sni)
+ challenge, ok := challengeCache.Get(domain)
if !ok {
return nil, errors.New("no challenge for this domain")
}
- cert, err := tlsalpn01.ChallengeCert(sni, challenge.(string))
+ cert, err := tlsalpn01.ChallengeCert(domain, challenge.(string))
if err != nil {
return nil, err
}
@@ -70,54 +75,77 @@ func TLSConfig(mainDomainSuffix string,
}
targetOwner := ""
- if strings.HasSuffix(sni, mainDomainSuffix) || strings.EqualFold(sni, mainDomainSuffix[1:]) {
- // deliver default certificate for the main domain (*.codeberg.page)
- sni = mainDomainSuffix
+ mayObtainCert := true
+
+ if strings.HasSuffix(domain, mainDomainSuffix) || strings.EqualFold(domain, mainDomainSuffix[1:]) {
+ if noDNS01 {
+ // Limit the domains allowed to request a certificate to pages-server domains
+ // and domains for an existing user of org
+ if !strings.EqualFold(domain, mainDomainSuffix[1:]) && !strings.EqualFold(domain, rawDomain) {
+ targetOwner := strings.TrimSuffix(domain, mainDomainSuffix)
+ owner_exist, err := giteaClient.GiteaCheckIfOwnerExists(targetOwner)
+ mayObtainCert = owner_exist
+ if err != nil {
+ log.Error().Err(err).Msgf("Failed to check '%s' existence on the forge: %s", targetOwner, err)
+ mayObtainCert = false
+ }
+ }
+ } else {
+ // deliver default certificate for the main domain (*.codeberg.page)
+ domain = mainDomainSuffix
+ }
} else {
var targetRepo, targetBranch string
- targetOwner, targetRepo, targetBranch = dnsutils.GetTargetFromDNS(sni, mainDomainSuffix, dnsLookupCache)
+ targetOwner, targetRepo, targetBranch = dnsutils.GetTargetFromDNS(domain, mainDomainSuffix, firstDefaultBranch)
if targetOwner == "" {
// DNS not set up, return main certificate to redirect to the docs
- sni = mainDomainSuffix
+ domain = mainDomainSuffix
} else {
targetOpt := &upstream.Options{
TargetOwner: targetOwner,
TargetRepo: targetRepo,
TargetBranch: targetBranch,
}
- _, valid := targetOpt.CheckCanonicalDomain(giteaClient, sni, mainDomainSuffix, canonicalDomainCache)
+ _, valid := targetOpt.CheckCanonicalDomain(ctx, giteaClient, domain, mainDomainSuffix, canonicalDomainCache)
if !valid {
- sni = mainDomainSuffix
+ // We shouldn't obtain a certificate when we cannot check if the
+ // repository has specified this domain in the `.domains` file.
+ mayObtainCert = false
}
}
}
- if tlsCertificate, ok := keyCache.Get(sni); ok {
+ if tlsCertificate, ok := keyCache.Get(domain); ok {
// we can use an existing certificate object
- return tlsCertificate.(*tls.Certificate), nil
+ return tlsCertificate, nil
}
- var tlsCertificate tls.Certificate
+ var tlsCertificate *tls.Certificate
var err error
- var ok bool
- if tlsCertificate, ok = retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider, acmeUseRateLimits, certDB); !ok {
- // request a new certificate
- if strings.EqualFold(sni, mainDomainSuffix) {
+ if tlsCertificate, err = acmeClient.retrieveCertFromDB(log, domain, mainDomainSuffix, false, certDB); err != nil {
+ if !errors.Is(err, database.ErrNotFound) {
+ return nil, err
+ }
+ // we could not find a cert in db, request a new certificate
+
+ // first check if we are allowed to obtain a cert for this domain
+ if strings.EqualFold(domain, mainDomainSuffix) {
return nil, errors.New("won't request certificate for main domain, something really bad has happened")
}
+ if !mayObtainCert {
+ return nil, fmt.Errorf("won't request certificate for %q", domain)
+ }
- tlsCertificate, err = obtainCert(acmeClient, []string{sni}, nil, targetOwner, dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
+ tlsCertificate, err = acmeClient.obtainCert(log, acmeClient.legoClient, []string{domain}, nil, targetOwner, false, mainDomainSuffix, certDB)
if err != nil {
return nil, err
}
}
- if err := keyCache.Set(sni, &tlsCertificate, 15*time.Minute); err != nil {
- return nil, err
- }
- return &tlsCertificate, nil
+ keyCache.Add(domain, tlsCertificate)
+
+ return tlsCertificate, nil
},
- PreferServerCipherSuites: true,
NextProtos: []string{
"h2",
"http/1.1",
@@ -138,159 +166,115 @@ func TLSConfig(mainDomainSuffix string,
}
}
-func checkUserLimit(user string) error {
- userLimit, ok := acmeClientCertificateLimitPerUser[user]
+func (c *AcmeClient) checkUserLimit(user string) error {
+ userLimit, ok := c.acmeClientCertificateLimitPerUser[user]
if !ok {
- // Each Codeberg user can only add 10 new domains per day.
+ // Each user can only add 10 new domains per day.
userLimit = equalizer.NewTokenBucket(10, time.Hour*24)
- acmeClientCertificateLimitPerUser[user] = userLimit
+ c.acmeClientCertificateLimitPerUser[user] = userLimit
}
if !userLimit.Ask() {
- return errors.New("rate limit exceeded: 10 certificates per user per 24 hours")
+ return fmt.Errorf("user '%s' error: %w", user, ErrUserRateLimitExceeded)
}
return nil
}
-var (
- acmeClient, mainDomainAcmeClient *lego.Client
- acmeClientCertificateLimitPerUser = map[string]*equalizer.TokenBucket{}
-)
-
-// rate limit is 300 / 3 hours, we want 200 / 2 hours but to refill more often, so that's 25 new domains every 15 minutes
-// TODO: when this is used a lot, we probably have to think of a somewhat better solution?
-var acmeClientOrderLimit = equalizer.NewTokenBucket(25, 15*time.Minute)
-
-// rate limit is 20 / second, we want 5 / second (especially as one cert takes at least two requests)
-var acmeClientRequestLimit = equalizer.NewTokenBucket(5, 1*time.Second)
-
-type AcmeTLSChallengeProvider struct {
- challengeCache cache.SetGetKey
-}
-
-// make sure AcmeTLSChallengeProvider match Provider interface
-var _ challenge.Provider = AcmeTLSChallengeProvider{}
-
-func (a AcmeTLSChallengeProvider) Present(domain, _, keyAuth string) error {
- return a.challengeCache.Set(domain, keyAuth, 1*time.Hour)
-}
-
-func (a AcmeTLSChallengeProvider) CleanUp(domain, _, _ string) error {
- a.challengeCache.Remove(domain)
- return nil
-}
-
-type AcmeHTTPChallengeProvider struct {
- challengeCache cache.SetGetKey
-}
-
-// make sure AcmeHTTPChallengeProvider match Provider interface
-var _ challenge.Provider = AcmeHTTPChallengeProvider{}
-
-func (a AcmeHTTPChallengeProvider) Present(domain, token, keyAuth string) error {
- return a.challengeCache.Set(domain+"/"+token, keyAuth, 1*time.Hour)
-}
-
-func (a AcmeHTTPChallengeProvider) CleanUp(domain, token, _ string) error {
- a.challengeCache.Remove(domain + "/" + token)
- return nil
-}
-
-func retrieveCertFromDB(sni, mainDomainSuffix, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) (tls.Certificate, bool) {
+func (c *AcmeClient) retrieveCertFromDB(log zerolog.Logger, sni, mainDomainSuffix string, useDnsProvider bool, certDB database.CertDB) (*tls.Certificate, error) {
// parse certificate from database
res, err := certDB.Get(sni)
if err != nil {
- panic(err) // TODO: no panic
- }
- if res == nil {
- return tls.Certificate{}, false
+ return nil, err
+ } else if res == nil {
+ return nil, database.ErrNotFound
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
- panic(err)
+ return nil, err
}
// TODO: document & put into own function
if !strings.EqualFold(sni, mainDomainSuffix) {
- tlsCertificate.Leaf, err = x509.ParseCertificate(tlsCertificate.Certificate[0])
+ tlsCertificate.Leaf, err = leaf(&tlsCertificate)
if err != nil {
- panic(err)
+ return nil, err
}
-
// renew certificates 7 days before they expire
if tlsCertificate.Leaf.NotAfter.Before(time.Now().Add(7 * 24 * time.Hour)) {
- // TODO: add ValidUntil to custom res struct
- if res.CSR != nil && len(res.CSR) > 0 {
+ // TODO: use ValidTill of custom cert struct
+ if len(res.CSR) > 0 {
// CSR stores the time when the renewal shall be tried again
nextTryUnix, err := strconv.ParseInt(string(res.CSR), 10, 64)
if err == nil && time.Now().Before(time.Unix(nextTryUnix, 0)) {
- return tlsCertificate, true
+ return &tlsCertificate, nil
}
}
+ // TODO: make a queue ?
go (func() {
res.CSR = nil // acme client doesn't like CSR to be set
- tlsCertificate, err = obtainCert(acmeClient, []string{sni}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
- if err != nil {
+ if _, err := c.obtainCert(log, c.legoClient, []string{sni}, res, "", useDnsProvider, mainDomainSuffix, certDB); err != nil {
log.Error().Msgf("Couldn't renew certificate for %s: %v", sni, err)
}
})()
}
}
- return tlsCertificate, true
+ return &tlsCertificate, nil
}
-var obtainLocks = sync.Map{}
-
-func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Resource, user, dnsProvider, mainDomainSuffix string, acmeUseRateLimits bool, keyDatabase database.CertDB) (tls.Certificate, error) {
+func (c *AcmeClient) obtainCert(log zerolog.Logger, acmeClient *lego.Client, domains []string, renew *certificate.Resource, user string, useDnsProvider bool, mainDomainSuffix string, keyDatabase database.CertDB) (*tls.Certificate, error) {
name := strings.TrimPrefix(domains[0], "*")
- if dnsProvider == "" && len(domains[0]) > 0 && domains[0][0] == '*' {
- domains = domains[1:]
- }
// lock to avoid simultaneous requests
- _, working := obtainLocks.LoadOrStore(name, struct{}{})
+ _, working := c.obtainLocks.LoadOrStore(name, struct{}{})
if working {
for working {
time.Sleep(100 * time.Millisecond)
- _, working = obtainLocks.Load(name)
+ _, working = c.obtainLocks.Load(name)
}
- cert, ok := retrieveCertFromDB(name, mainDomainSuffix, dnsProvider, acmeUseRateLimits, keyDatabase)
- if !ok {
- return tls.Certificate{}, errors.New("certificate failed in synchronous request")
+ cert, err := c.retrieveCertFromDB(log, name, mainDomainSuffix, useDnsProvider, keyDatabase)
+ if err != nil {
+ return nil, fmt.Errorf("certificate failed in synchronous request: %w", err)
}
return cert, nil
}
- defer obtainLocks.Delete(name)
+ defer c.obtainLocks.Delete(name)
if acmeClient == nil {
- return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", mainDomainSuffix, keyDatabase), nil
+ if useDnsProvider {
+ return mockCert(domains[0], "DNS ACME client is not defined", mainDomainSuffix, keyDatabase)
+ } else {
+ return mockCert(domains[0], "ACME client uninitialized. This is a server error, please report!", mainDomainSuffix, keyDatabase)
+ }
}
// request actual cert
var res *certificate.Resource
var err error
if renew != nil && renew.CertURL != "" {
- if acmeUseRateLimits {
- acmeClientRequestLimit.Take()
+ if c.acmeUseRateLimits {
+ c.acmeClientRequestLimit.Take()
}
log.Debug().Msgf("Renewing certificate for: %v", domains)
res, err = acmeClient.Certificate.Renew(*renew, true, false, "")
if err != nil {
log.Error().Err(err).Msgf("Couldn't renew certificate for %v, trying to request a new one", domains)
+ if c.acmeUseRateLimits {
+ c.acmeClientFailLimit.Take()
+ }
res = nil
}
}
if res == nil {
if user != "" {
- if err := checkUserLimit(user); err != nil {
- return tls.Certificate{}, err
+ if err := c.checkUserLimit(user); err != nil {
+ return nil, err
}
}
- if acmeUseRateLimits {
- acmeClientOrderLimit.Take()
- acmeClientRequestLimit.Take()
+ if c.acmeUseRateLimits {
+ c.acmeClientOrderLimit.Take()
+ c.acmeClientRequestLimit.Take()
}
log.Debug().Msgf("Re-requesting new certificate for %v", domains)
res, err = acmeClient.Certificate.Obtain(certificate.ObtainRequest{
@@ -298,163 +282,59 @@ func obtainCert(acmeClient *lego.Client, domains []string, renew *certificate.Re
Bundle: true,
MustStaple: false,
})
+ if c.acmeUseRateLimits && err != nil {
+ c.acmeClientFailLimit.Take()
+ }
}
if err != nil {
log.Error().Err(err).Msgf("Couldn't obtain again a certificate or %v", domains)
if renew != nil && renew.CertURL != "" {
tlsCertificate, err := tls.X509KeyPair(renew.Certificate, renew.PrivateKey)
- if err == nil && tlsCertificate.Leaf.NotAfter.After(time.Now()) {
+ if err != nil {
+ mockC, err2 := mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase)
+ if err2 != nil {
+ return nil, errors.Join(err, err2)
+ }
+ return mockC, err
+ }
+ leaf, err := leaf(&tlsCertificate)
+ if err == nil && leaf.NotAfter.After(time.Now()) {
+ tlsCertificate.Leaf = leaf
// avoid sending a mock cert instead of a still valid cert, instead abuse CSR field to store time to try again at
renew.CSR = []byte(strconv.FormatInt(time.Now().Add(6*time.Hour).Unix(), 10))
if err := keyDatabase.Put(name, renew); err != nil {
- return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase), err
+ mockC, err2 := mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase)
+ if err2 != nil {
+ return nil, errors.Join(err, err2)
+ }
+ return mockC, err
}
- return tlsCertificate, nil
+ return &tlsCertificate, nil
}
}
- return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase), err
+ return mockCert(domains[0], err.Error(), mainDomainSuffix, keyDatabase)
}
log.Debug().Msgf("Obtained certificate for %v", domains)
if err := keyDatabase.Put(name, res); err != nil {
- return tls.Certificate{}, err
+ return nil, err
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
- return tls.Certificate{}, err
- }
- return tlsCertificate, nil
-}
-
-func SetupAcmeConfig(acmeAPI, acmeMail, acmeEabHmac, acmeEabKID string, acmeAcceptTerms bool) (*lego.Config, error) {
- const configFile = "acme-account.json"
- var myAcmeAccount AcmeAccount
- var myAcmeConfig *lego.Config
-
- if account, err := os.ReadFile(configFile); err == nil {
- if err := json.Unmarshal(account, &myAcmeAccount); err != nil {
- return nil, err
- }
- myAcmeAccount.Key, err = certcrypto.ParsePEMPrivateKey([]byte(myAcmeAccount.KeyPEM))
- if err != nil {
- return nil, err
- }
- myAcmeConfig = lego.NewConfig(&myAcmeAccount)
- myAcmeConfig.CADirURL = acmeAPI
- myAcmeConfig.Certificate.KeyType = certcrypto.RSA2048
-
- // Validate Config
- _, err := lego.NewClient(myAcmeConfig)
- if err != nil {
- // TODO: should we fail hard instead?
- log.Error().Err(err).Msg("Can't create ACME client, continuing with mock certs only")
- }
- return myAcmeConfig, nil
- } else if !os.IsNotExist(err) {
return nil, err
}
-
- privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
- if err != nil {
- return nil, err
- }
- myAcmeAccount = AcmeAccount{
- Email: acmeMail,
- Key: privateKey,
- KeyPEM: string(certcrypto.PEMEncode(privateKey)),
- }
- myAcmeConfig = lego.NewConfig(&myAcmeAccount)
- myAcmeConfig.CADirURL = acmeAPI
- myAcmeConfig.Certificate.KeyType = certcrypto.RSA2048
- tempClient, err := lego.NewClient(myAcmeConfig)
- if err != nil {
- log.Error().Err(err).Msg("Can't create ACME client, continuing with mock certs only")
- } else {
- // accept terms & log in to EAB
- if acmeEabKID == "" || acmeEabHmac == "" {
- reg, err := tempClient.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: acmeAcceptTerms})
- if err != nil {
- log.Error().Err(err).Msg("Can't register ACME account, continuing with mock certs only")
- } else {
- myAcmeAccount.Registration = reg
- }
- } else {
- reg, err := tempClient.Registration.RegisterWithExternalAccountBinding(registration.RegisterEABOptions{
- TermsOfServiceAgreed: acmeAcceptTerms,
- Kid: acmeEabKID,
- HmacEncoded: acmeEabHmac,
- })
- if err != nil {
- log.Error().Err(err).Msg("Can't register ACME account, continuing with mock certs only")
- } else {
- myAcmeAccount.Registration = reg
- }
- }
-
- if myAcmeAccount.Registration != nil {
- acmeAccountJSON, err := json.Marshal(myAcmeAccount)
- if err != nil {
- log.Error().Err(err).Msg("json.Marshalfailed, waiting for manual restart to avoid rate limits")
- select {}
- }
- err = os.WriteFile(configFile, acmeAccountJSON, 0o600)
- if err != nil {
- log.Error().Err(err).Msg("os.WriteFile failed, waiting for manual restart to avoid rate limits")
- select {}
- }
- }
- }
-
- return myAcmeConfig, nil
+ return &tlsCertificate, nil
}
-func SetupCertificates(mainDomainSuffix, dnsProvider string, acmeConfig *lego.Config, acmeUseRateLimits, enableHTTPServer bool, challengeCache cache.SetGetKey, certDB database.CertDB) error {
+func SetupMainDomainCertificates(log zerolog.Logger, mainDomainSuffix string, acmeClient *AcmeClient, certDB database.CertDB) error {
// getting main cert before ACME account so that we can fail here without hitting rate limits
mainCertBytes, err := certDB.Get(mainDomainSuffix)
- if err != nil {
- return fmt.Errorf("cert database is not working")
- }
-
- acmeClient, err = lego.NewClient(acmeConfig)
- if err != nil {
- log.Fatal().Err(err).Msg("Can't create ACME client, continuing with mock certs only")
- } else {
- err = acmeClient.Challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{challengeCache})
- if err != nil {
- log.Error().Err(err).Msg("Can't create TLS-ALPN-01 provider")
- }
- if enableHTTPServer {
- err = acmeClient.Challenge.SetHTTP01Provider(AcmeHTTPChallengeProvider{challengeCache})
- if err != nil {
- log.Error().Err(err).Msg("Can't create HTTP-01 provider")
- }
- }
- }
-
- mainDomainAcmeClient, err = lego.NewClient(acmeConfig)
- if err != nil {
- log.Error().Err(err).Msg("Can't create ACME client, continuing with mock certs only")
- } else {
- if dnsProvider == "" {
- // using mock server, don't use wildcard certs
- err := mainDomainAcmeClient.Challenge.SetTLSALPN01Provider(AcmeTLSChallengeProvider{challengeCache})
- if err != nil {
- log.Error().Err(err).Msg("Can't create TLS-ALPN-01 provider")
- }
- } else {
- provider, err := dns.NewDNSChallengeProviderByName(dnsProvider)
- if err != nil {
- log.Error().Err(err).Msg("Can't create DNS Challenge provider")
- }
- err = mainDomainAcmeClient.Challenge.SetDNS01Provider(provider)
- if err != nil {
- log.Error().Err(err).Msg("Can't create DNS-01 provider")
- }
- }
+ if err != nil && !errors.Is(err, database.ErrNotFound) {
+ return fmt.Errorf("cert database is not working: %w", err)
}
if mainCertBytes == nil {
- _, err = obtainCert(mainDomainAcmeClient, []string{"*" + mainDomainSuffix, mainDomainSuffix[1:]}, nil, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
+ _, err = acmeClient.obtainCert(log, acmeClient.dnsChallengerLegoClient, []string{"*" + mainDomainSuffix, mainDomainSuffix[1:]}, nil, "", true, mainDomainSuffix, certDB)
if err != nil {
log.Error().Err(err).Msg("Couldn't renew main domain certificate, continuing with mock certs only")
}
@@ -463,43 +343,29 @@ func SetupCertificates(mainDomainSuffix, dnsProvider string, acmeConfig *lego.Co
return nil
}
-func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffix, dnsProvider string, acmeUseRateLimits bool, certDB database.CertDB) {
+func MaintainCertDB(log zerolog.Logger, ctx context.Context, interval time.Duration, acmeClient *AcmeClient, mainDomainSuffix string, certDB database.CertDB) {
for {
- // clean up expired certs
- now := time.Now()
+ // delete expired certs that will be invalid until next clean up
+ threshold := time.Now().Add(interval)
expiredCertCount := 0
- keyDatabaseIterator := certDB.Items()
- key, resBytes, err := keyDatabaseIterator.Next()
- for err == nil {
- if !strings.EqualFold(string(key), mainDomainSuffix) {
- resGob := bytes.NewBuffer(resBytes)
- resDec := gob.NewDecoder(resGob)
- res := &certificate.Resource{}
- err = resDec.Decode(res)
- if err != nil {
- panic(err)
- }
- tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
- if err != nil || tlsCertificates[0].NotAfter.Before(now) {
- err := certDB.Delete(string(key))
- if err != nil {
- log.Error().Err(err).Msgf("Deleting expired certificate for %q failed", string(key))
- } else {
- expiredCertCount++
+ certs, err := certDB.Items(0, 0)
+ if err != nil {
+ log.Error().Err(err).Msg("could not get certs from list")
+ } else {
+ for _, cert := range certs {
+ if !strings.EqualFold(cert.Domain, strings.TrimPrefix(mainDomainSuffix, ".")) {
+ if time.Unix(cert.ValidTill, 0).Before(threshold) {
+ err := certDB.Delete(cert.Domain)
+ if err != nil {
+ log.Error().Err(err).Msgf("Deleting expired certificate for %q failed", cert.Domain)
+ } else {
+ expiredCertCount++
+ }
}
}
}
- key, resBytes, err = keyDatabaseIterator.Next()
- }
- log.Debug().Msgf("Removed %d expired certificates from the database", expiredCertCount)
-
- // compact the database
- msg, err := certDB.Compact()
- if err != nil {
- log.Error().Err(err).Msg("Compacting key database failed")
- } else {
- log.Debug().Msgf("Compacted key database: %s", msg)
+ log.Debug().Msgf("Removed %d expired certificates from the database", expiredCertCount)
}
// update main cert
@@ -510,11 +376,12 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
log.Error().Msgf("Couldn't renew certificate for main domain %q expected main domain cert to exist, but it's missing - seems like the database is corrupted", mainDomainSuffix)
} else {
tlsCertificates, err := certcrypto.ParsePEMBundle(res.Certificate)
-
- // renew main certificate 30 days before it expires
- if tlsCertificates[0].NotAfter.Before(time.Now().Add(30 * 24 * time.Hour)) {
+ if err != nil {
+ log.Error().Err(fmt.Errorf("could not parse cert for mainDomainSuffix: %w", err))
+ } else if tlsCertificates[0].NotAfter.Before(time.Now().Add(30 * 24 * time.Hour)) {
+ // renew main certificate 30 days before it expires
go (func() {
- _, err = obtainCert(mainDomainAcmeClient, []string{"*" + mainDomainSuffix, mainDomainSuffix[1:]}, res, "", dnsProvider, mainDomainSuffix, acmeUseRateLimits, certDB)
+ _, err = acmeClient.obtainCert(log, acmeClient.dnsChallengerLegoClient, []string{"*" + mainDomainSuffix, mainDomainSuffix[1:]}, res, "", true, mainDomainSuffix, certDB)
if err != nil {
log.Error().Err(err).Msg("Couldn't renew certificate for main domain")
}
@@ -529,3 +396,21 @@ func MaintainCertDB(ctx context.Context, interval time.Duration, mainDomainSuffi
}
}
}
+
+// leaf returns the parsed leaf certificate, either from c.Leaf or by parsing
+// the corresponding c.Certificate[0].
+// After successfully parsing the cert c.Leaf gets set to the parsed cert.
+func leaf(c *tls.Certificate) (*x509.Certificate, error) {
+ if c.Leaf != nil {
+ return c.Leaf, nil
+ }
+
+ leaf, err := x509.ParseCertificate(c.Certificate[0])
+ if err != nil {
+ return nil, fmt.Errorf("tlsCert - failed to parse leaf: %w", err)
+ }
+
+ c.Leaf = leaf
+
+ return leaf, err
+}
diff --git a/server/certificates/mock.go b/server/certificates/mock.go
index 0e87e6e..a28d0f4 100644
--- a/server/certificates/mock.go
+++ b/server/certificates/mock.go
@@ -13,14 +13,15 @@ import (
"github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
+ "github.com/rs/zerolog/log"
"codeberg.org/codeberg/pages/server/database"
)
-func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB) tls.Certificate {
+func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB) (*tls.Certificate, error) {
key, err := certcrypto.GeneratePrivateKey(certcrypto.RSA2048)
if err != nil {
- panic(err)
+ return nil, err
}
template := x509.Certificate{
@@ -52,7 +53,7 @@ func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB)
key,
)
if err != nil {
- panic(err)
+ return nil, err
}
out := &bytes.Buffer{}
@@ -61,7 +62,7 @@ func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB)
Type: "CERTIFICATE",
})
if err != nil {
- panic(err)
+ return nil, err
}
outBytes := out.Bytes()
res := &certificate.Resource{
@@ -75,12 +76,12 @@ func mockCert(domain, msg, mainDomainSuffix string, keyDatabase database.CertDB)
databaseName = mainDomainSuffix
}
if err := keyDatabase.Put(databaseName, res); err != nil {
- panic(err)
+ log.Error().Err(err)
}
tlsCertificate, err := tls.X509KeyPair(res.Certificate, res.PrivateKey)
if err != nil {
- panic(err)
+ return nil, err
}
- return tlsCertificate
+ return &tlsCertificate, nil
}
diff --git a/server/certificates/mock_test.go b/server/certificates/mock_test.go
index 1cbd1f6..644e8a9 100644
--- a/server/certificates/mock_test.go
+++ b/server/certificates/mock_test.go
@@ -3,14 +3,18 @@ package certificates
import (
"testing"
- "codeberg.org/codeberg/pages/server/database"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+
+ "codeberg.org/codeberg/pages/server/database"
)
func TestMockCert(t *testing.T) {
- db, err := database.NewTmpDB()
+ db := database.NewMockCertDB(t)
+ db.Mock.On("Put", mock.Anything, mock.Anything).Return(nil)
+
+ cert, err := mockCert("example.com", "some error msg", "codeberg.page", db)
assert.NoError(t, err)
- cert := mockCert("example.com", "some error msg", "codeberg.page", db)
if assert.NotEmpty(t, cert) {
assert.NotEmpty(t, cert.Certificate)
}
diff --git a/server/context/context.go b/server/context/context.go
index 481fee2..e695ab7 100644
--- a/server/context/context.go
+++ b/server/context/context.go
@@ -5,19 +5,29 @@ import (
"net/http"
"codeberg.org/codeberg/pages/server/utils"
+ "github.com/hashicorp/go-uuid"
+ "github.com/rs/zerolog/log"
)
type Context struct {
RespWriter http.ResponseWriter
Req *http.Request
StatusCode int
+ ReqId string
}
func New(w http.ResponseWriter, r *http.Request) *Context {
+ req_uuid, err := uuid.GenerateUUID()
+ if err != nil {
+ log.Error().Err(err).Msg("Failed to generate request id, assigning error value")
+ req_uuid = "ERROR"
+ }
+
return &Context{
RespWriter: w,
Req: r,
StatusCode: http.StatusOK,
+ ReqId: req_uuid,
}
}
@@ -48,11 +58,9 @@ func (c *Context) Redirect(uri string, statusCode int) {
http.Redirect(c.RespWriter, c.Req, uri, statusCode)
}
-// Path returns requested path.
-//
-// The returned bytes are valid until your request handler returns.
+// Path returns the cleaned requested path.
func (c *Context) Path() string {
- return c.Req.URL.Path
+ return utils.CleanPath(c.Req.URL.Path)
}
func (c *Context) Host() string {
diff --git a/server/database/interface.go b/server/database/interface.go
index 3ba3efc..7fdbae7 100644
--- a/server/database/interface.go
+++ b/server/database/interface.go
@@ -1,15 +1,78 @@
package database
import (
- "github.com/akrylysov/pogreb"
+ "fmt"
+
+ "github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/certificate"
+ "github.com/rs/zerolog/log"
)
+//go:generate go install github.com/vektra/mockery/v2@latest
+//go:generate mockery --name CertDB --output . --filename mock.go --inpackage --case underscore
+
type CertDB interface {
Close() error
Put(name string, cert *certificate.Resource) error
Get(name string) (*certificate.Resource, error)
Delete(key string) error
- Compact() (string, error)
- Items() *pogreb.ItemIterator
+ Items(page, pageSize int) ([]*Cert, error)
+}
+
+type Cert struct {
+ Domain string `xorm:"pk NOT NULL UNIQUE 'domain'"`
+ Created int64 `xorm:"created NOT NULL DEFAULT 0 'created'"`
+ Updated int64 `xorm:"updated NOT NULL DEFAULT 0 'updated'"`
+ ValidTill int64 `xorm:" NOT NULL DEFAULT 0 'valid_till'"`
+ // certificate.Resource
+ CertURL string `xorm:"'cert_url'"`
+ CertStableURL string `xorm:"'cert_stable_url'"`
+ PrivateKey []byte `xorm:"'private_key'"`
+ Certificate []byte `xorm:"'certificate'"`
+ IssuerCertificate []byte `xorm:"'issuer_certificate'"`
+}
+
+func (c Cert) Raw() *certificate.Resource {
+ return &certificate.Resource{
+ Domain: c.Domain,
+ CertURL: c.CertURL,
+ CertStableURL: c.CertStableURL,
+ PrivateKey: c.PrivateKey,
+ Certificate: c.Certificate,
+ IssuerCertificate: c.IssuerCertificate,
+ }
+}
+
+func toCert(name string, c *certificate.Resource) (*Cert, error) {
+ tlsCertificates, err := certcrypto.ParsePEMBundle(c.Certificate)
+ if err != nil {
+ return nil, err
+ }
+ if len(tlsCertificates) == 0 || tlsCertificates[0] == nil {
+ err := fmt.Errorf("parsed cert resource has no cert")
+ log.Error().Err(err).Str("domain", c.Domain).Msgf("cert: %v", c)
+ return nil, err
+ }
+ validTill := tlsCertificates[0].NotAfter.Unix()
+
+ // handle wildcard certs
+ if name[:1] == "." {
+ name = "*" + name
+ }
+ if name != c.Domain {
+ err := fmt.Errorf("domain key '%s' and cert domain '%s' not equal", name, c.Domain)
+ log.Error().Err(err).Msg("toCert conversion did discover mismatch")
+ // TODO: fail hard: return nil, err
+ }
+
+ return &Cert{
+ Domain: c.Domain,
+ ValidTill: validTill,
+
+ CertURL: c.CertURL,
+ CertStableURL: c.CertStableURL,
+ PrivateKey: c.PrivateKey,
+ Certificate: c.Certificate,
+ IssuerCertificate: c.IssuerCertificate,
+ }, nil
}
diff --git a/server/database/mock.go b/server/database/mock.go
index dfe2316..e7e2c38 100644
--- a/server/database/mock.go
+++ b/server/database/mock.go
@@ -1,55 +1,122 @@
+// Code generated by mockery v2.20.0. DO NOT EDIT.
+
package database
import (
- "fmt"
- "time"
-
- "github.com/OrlovEvgeny/go-mcache"
- "github.com/akrylysov/pogreb"
- "github.com/go-acme/lego/v4/certificate"
+ certificate "github.com/go-acme/lego/v4/certificate"
+ mock "github.com/stretchr/testify/mock"
)
-var _ CertDB = tmpDB{}
-
-type tmpDB struct {
- intern *mcache.CacheDriver
- ttl time.Duration
+// MockCertDB is an autogenerated mock type for the CertDB type
+type MockCertDB struct {
+ mock.Mock
}
-func (p tmpDB) Close() error {
- _ = p.intern.Close()
- return nil
-}
+// Close provides a mock function with given fields:
+func (_m *MockCertDB) Close() error {
+ ret := _m.Called()
-func (p tmpDB) Put(name string, cert *certificate.Resource) error {
- return p.intern.Set(name, cert, p.ttl)
-}
-
-func (p tmpDB) Get(name string) (*certificate.Resource, error) {
- cert, has := p.intern.Get(name)
- if !has {
- return nil, fmt.Errorf("cert for %q not found", name)
+ var r0 error
+ if rf, ok := ret.Get(0).(func() error); ok {
+ r0 = rf()
+ } else {
+ r0 = ret.Error(0)
}
- return cert.(*certificate.Resource), nil
+
+ return r0
}
-func (p tmpDB) Delete(key string) error {
- p.intern.Remove(key)
- return nil
+// Delete provides a mock function with given fields: key
+func (_m *MockCertDB) Delete(key string) error {
+ ret := _m.Called(key)
+
+ var r0 error
+ if rf, ok := ret.Get(0).(func(string) error); ok {
+ r0 = rf(key)
+ } else {
+ r0 = ret.Error(0)
+ }
+
+ return r0
}
-func (p tmpDB) Compact() (string, error) {
- p.intern.Truncate()
- return "Truncate done", nil
+// Get provides a mock function with given fields: name
+func (_m *MockCertDB) Get(name string) (*certificate.Resource, error) {
+ ret := _m.Called(name)
+
+ var r0 *certificate.Resource
+ var r1 error
+ if rf, ok := ret.Get(0).(func(string) (*certificate.Resource, error)); ok {
+ return rf(name)
+ }
+ if rf, ok := ret.Get(0).(func(string) *certificate.Resource); ok {
+ r0 = rf(name)
+ } else {
+ if ret.Get(0) != nil {
+ r0 = ret.Get(0).(*certificate.Resource)
+ }
+ }
+
+ if rf, ok := ret.Get(1).(func(string) error); ok {
+ r1 = rf(name)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
}
-func (p tmpDB) Items() *pogreb.ItemIterator {
- panic("ItemIterator not implemented for tmpDB")
+// Items provides a mock function with given fields: page, pageSize
+func (_m *MockCertDB) Items(page int, pageSize int) ([]*Cert, error) {
+ ret := _m.Called(page, pageSize)
+
+ var r0 []*Cert
+ var r1 error
+ if rf, ok := ret.Get(0).(func(int, int) ([]*Cert, error)); ok {
+ return rf(page, pageSize)
+ }
+ if rf, ok := ret.Get(0).(func(int, int) []*Cert); ok {
+ r0 = rf(page, pageSize)
+ } else {
+ if ret.Get(0) != nil {
+ r0 = ret.Get(0).([]*Cert)
+ }
+ }
+
+ if rf, ok := ret.Get(1).(func(int, int) error); ok {
+ r1 = rf(page, pageSize)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
}
-func NewTmpDB() (CertDB, error) {
- return &tmpDB{
- intern: mcache.New(),
- ttl: time.Minute,
- }, nil
+// Put provides a mock function with given fields: name, cert
+func (_m *MockCertDB) Put(name string, cert *certificate.Resource) error {
+ ret := _m.Called(name, cert)
+
+ var r0 error
+ if rf, ok := ret.Get(0).(func(string, *certificate.Resource) error); ok {
+ r0 = rf(name, cert)
+ } else {
+ r0 = ret.Error(0)
+ }
+
+ return r0
+}
+
+type mockConstructorTestingTNewMockCertDB interface {
+ mock.TestingT
+ Cleanup(func())
+}
+
+// NewMockCertDB creates a new instance of MockCertDB. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
+func NewMockCertDB(t mockConstructorTestingTNewMockCertDB) *MockCertDB {
+ mock := &MockCertDB{}
+ mock.Mock.Test(t)
+
+ t.Cleanup(func() { mock.AssertExpectations(t) })
+
+ return mock
}
diff --git a/server/database/setup.go b/server/database/setup.go
deleted file mode 100644
index 097c63e..0000000
--- a/server/database/setup.go
+++ /dev/null
@@ -1,109 +0,0 @@
-package database
-
-import (
- "bytes"
- "context"
- "encoding/gob"
- "fmt"
- "time"
-
- "github.com/akrylysov/pogreb"
- "github.com/akrylysov/pogreb/fs"
- "github.com/go-acme/lego/v4/certificate"
- "github.com/rs/zerolog/log"
-)
-
-var _ CertDB = aDB{}
-
-type aDB struct {
- ctx context.Context
- cancel context.CancelFunc
- intern *pogreb.DB
- syncInterval time.Duration
-}
-
-func (p aDB) Close() error {
- p.cancel()
- return p.intern.Sync()
-}
-
-func (p aDB) Put(name string, cert *certificate.Resource) error {
- var resGob bytes.Buffer
- if err := gob.NewEncoder(&resGob).Encode(cert); err != nil {
- return err
- }
- return p.intern.Put([]byte(name), resGob.Bytes())
-}
-
-func (p aDB) Get(name string) (*certificate.Resource, error) {
- cert := &certificate.Resource{}
- resBytes, err := p.intern.Get([]byte(name))
- if err != nil {
- return nil, err
- }
- if resBytes == nil {
- return nil, nil
- }
- if err := gob.NewDecoder(bytes.NewBuffer(resBytes)).Decode(cert); err != nil {
- return nil, err
- }
- return cert, nil
-}
-
-func (p aDB) Delete(key string) error {
- return p.intern.Delete([]byte(key))
-}
-
-func (p aDB) Compact() (string, error) {
- result, err := p.intern.Compact()
- if err != nil {
- return "", err
- }
- return fmt.Sprintf("%+v", result), nil
-}
-
-func (p aDB) Items() *pogreb.ItemIterator {
- return p.intern.Items()
-}
-
-var _ CertDB = &aDB{}
-
-func (p aDB) sync() {
- for {
- err := p.intern.Sync()
- if err != nil {
- log.Error().Err(err).Msg("Syncing cert database failed")
- }
- select {
- case <-p.ctx.Done():
- return
- case <-time.After(p.syncInterval):
- }
- }
-}
-
-func New(path string) (CertDB, error) {
- if path == "" {
- return nil, fmt.Errorf("path not set")
- }
- db, err := pogreb.Open(path, &pogreb.Options{
- BackgroundSyncInterval: 30 * time.Second,
- BackgroundCompactionInterval: 6 * time.Hour,
- FileSystem: fs.OSMMap,
- })
- if err != nil {
- return nil, err
- }
-
- ctx, cancel := context.WithCancel(context.Background())
- result := &aDB{
- ctx: ctx,
- cancel: cancel,
- intern: db,
- syncInterval: 5 * time.Minute,
- }
-
- go result.sync()
-
- return result, nil
-}
diff --git a/server/database/xorm.go b/server/database/xorm.go
new file mode 100644
index 0000000..63fa39e
--- /dev/null
+++ b/server/database/xorm.go
@@ -0,0 +1,138 @@
+package database
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/rs/zerolog/log"
+
+ "github.com/go-acme/lego/v4/certificate"
+ "xorm.io/xorm"
+
+ // register sql driver
+ _ "github.com/go-sql-driver/mysql"
+ _ "github.com/lib/pq"
+ _ "github.com/mattn/go-sqlite3"
+)
+
+var _ CertDB = xDB{}
+
+var ErrNotFound = errors.New("entry not found")
+
+type xDB struct {
+ engine *xorm.Engine
+}
+
+func NewXormDB(dbType, dbConn string) (CertDB, error) {
+ if !supportedDriver(dbType) {
+ return nil, fmt.Errorf("not supported db type '%s'", dbType)
+ }
+ if dbConn == "" {
+ return nil, fmt.Errorf("no db connection provided")
+ }
+
+ e, err := xorm.NewEngine(dbType, dbConn)
+ if err != nil {
+ return nil, err
+ }
+
+ if err := e.Sync2(new(Cert)); err != nil {
+ return nil, fmt.Errorf("could not sync db model :%w", err)
+ }
+
+ return &xDB{
+ engine: e,
+ }, nil
+}
+
+func (x xDB) Close() error {
+ return x.engine.Close()
+}
+
+func (x xDB) Put(domain string, cert *certificate.Resource) error {
+ log.Trace().Str("domain", cert.Domain).Msg("inserting cert to db")
+
+ c, err := toCert(domain, cert)
+ if err != nil {
+ return err
+ }
+
+ sess := x.engine.NewSession()
+ if err := sess.Begin(); err != nil {
+ return err
+ }
+ defer sess.Close()
+
+ if exist, _ := sess.ID(c.Domain).Exist(new(Cert)); exist {
+ if _, err := sess.ID(c.Domain).Update(c); err != nil {
+ return err
+ }
+ } else {
+ if _, err = sess.Insert(c); err != nil {
+ return err
+ }
+ }
+
+ return sess.Commit()
+}
+
+func (x xDB) Get(domain string) (*certificate.Resource, error) {
+ // handle wildcard certs
+ if domain[:1] == "." {
+ domain = "*" + domain
+ }
+
+ cert := new(Cert)
+ log.Trace().Str("domain", domain).Msg("get cert from db")
+ if found, err := x.engine.ID(domain).Get(cert); err != nil {
+ return nil, err
+ } else if !found {
+ return nil, fmt.Errorf("%w: name='%s'", ErrNotFound, domain)
+ }
+ return cert.Raw(), nil
+}
+
+func (x xDB) Delete(domain string) error {
+ // handle wildcard certs
+ if domain[:1] == "." {
+ domain = "*" + domain
+ }
+
+ log.Trace().Str("domain", domain).Msg("delete cert from db")
+ _, err := x.engine.ID(domain).Delete(new(Cert))
+ return err
+}
+
+// Items return al certs from db, if pageSize is 0 it does not use limit
+func (x xDB) Items(page, pageSize int) ([]*Cert, error) {
+ // paginated return
+ if pageSize > 0 {
+ certs := make([]*Cert, 0, pageSize)
+ if page >= 0 {
+ page = 1
+ }
+ err := x.engine.Limit(pageSize, (page-1)*pageSize).Find(&certs)
+ return certs, err
+ }
+
+ // return all
+ certs := make([]*Cert, 0, 64)
+ err := x.engine.Find(&certs)
+ return certs, err
+}
+
+// Supported database drivers
+const (
+ DriverSqlite = "sqlite3"
+ DriverMysql = "mysql"
+ DriverPostgres = "postgres"
+)
+
+func supportedDriver(driver string) bool {
+ switch driver {
+ case DriverMysql, DriverPostgres, DriverSqlite:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/server/database/xorm_test.go b/server/database/xorm_test.go
new file mode 100644
index 0000000..50d8a7f
--- /dev/null
+++ b/server/database/xorm_test.go
@@ -0,0 +1,92 @@
+package database
+
+import (
+ "errors"
+ "testing"
+
+ "github.com/go-acme/lego/v4/certificate"
+ "github.com/stretchr/testify/assert"
+ "xorm.io/xorm"
+)
+
+func newTestDB(t *testing.T) *xDB {
+ e, err := xorm.NewEngine("sqlite3", ":memory:")
+ assert.NoError(t, err)
+ assert.NoError(t, e.Sync2(new(Cert)))
+ return &xDB{engine: e}
+}
+
+func TestSanitizeWildcardCerts(t *testing.T) {
+ certDB := newTestDB(t)
+
+ _, err := certDB.Get(".not.found")
+ assert.True(t, errors.Is(err, ErrNotFound))
+
+ // TODO: cert key and domain mismatch are don not fail hard jet
+ // https://codeberg.org/Codeberg/pages-server/src/commit/d8595cee882e53d7f44f1ddc4ef8a1f7b8f31d8d/server/database/interface.go#L64
+ //
+ // assert.Error(t, certDB.Put(".wildcard.de", &certificate.Resource{
+ // Domain: "*.localhost.mock.directory",
+ // Certificate: localhost_mock_directory_certificate,
+ // }))
+
+ // insert new wildcard cert
+ assert.NoError(t, certDB.Put(".wildcard.de", &certificate.Resource{
+ Domain: "*.wildcard.de",
+ Certificate: localhost_mock_directory_certificate,
+ }))
+
+ // update existing cert
+ assert.NoError(t, certDB.Put(".wildcard.de", &certificate.Resource{
+ Domain: "*.wildcard.de",
+ Certificate: localhost_mock_directory_certificate,
+ }))
+
+ c1, err := certDB.Get(".wildcard.de")
+ assert.NoError(t, err)
+ c2, err := certDB.Get("*.wildcard.de")
+ assert.NoError(t, err)
+ assert.EqualValues(t, c1, c2)
+}
+
+var localhost_mock_directory_certificate = []byte(`-----BEGIN CERTIFICATE-----
+MIIDczCCAlugAwIBAgIIJyBaXHmLk6gwDQYJKoZIhvcNAQELBQAwKDEmMCQGA1UE
+AxMdUGViYmxlIEludGVybWVkaWF0ZSBDQSA0OWE0ZmIwHhcNMjMwMjEwMDEwOTA2
+WhcNMjgwMjEwMDEwOTA2WjAjMSEwHwYDVQQDExhsb2NhbGhvc3QubW9jay5kaXJl
+Y3RvcnkwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDIU/CjzS7t62Gj
+neEMqvP7sn99ULT7AEUzEfWL05fWG2z714qcUg1hXkZLgdVDgmsCpplyddip7+2t
+ZH/9rLPLMqJphzvOL4CF6jDLbeifETtKyjnt9vUZFnnNWcP3tu8lo8iYSl08qsUI
+Pp/hiEriAQzCDjTbR5m9xUPNPYqxzcS4ALzmmCX9Qfc4CuuhMkdv2G4TT7rylWrA
+SCSRPnGjeA7pCByfNrO/uXbxmzl3sMO3k5sqgMkx1QIHEN412V8+vtx88mt2sM6k
+xjzGZWWKXlRq+oufIKX9KPplhsCjMH6E3VNAzgOPYDqXagtUcGmLWghURltO8Mt2
+zwM6OgjjAgMBAAGjgaUwgaIwDgYDVR0PAQH/BAQDAgWgMB0GA1UdJQQWMBQGCCsG
+AQUFBwMBBggrBgEFBQcDAjAMBgNVHRMBAf8EAjAAMB0GA1UdDgQWBBSMQvlJ1755
+sarf8i1KNqj7s5o/aDAfBgNVHSMEGDAWgBTcZcxJMhWdP7MecHCCpNkFURC/YzAj
+BgNVHREEHDAaghhsb2NhbGhvc3QubW9jay5kaXJlY3RvcnkwDQYJKoZIhvcNAQEL
+BQADggEBACcd7TT28OWwzQN2PcH0aG38JX5Wp2iOS/unDCfWjNAztXHW7nBDMxza
+VtyebkJfccexpuVuOsjOX+bww0vtEYIvKX3/GbkhogksBrNkE0sJZtMnZWMR33wa
+YxAy/kJBTmLi02r8fX9ZhwjldStHKBav4USuP7DXZjrgX7LFQhR4LIDrPaYqQRZ8
+ltC3mM9LDQ9rQyIFP5cSBMO3RUAm4I8JyLoOdb/9G2uxjHr7r6eG1g8DmLYSKBsQ
+mWGQDOYgR3cGltDe2yMxM++yHY+b1uhxGOWMrDA1+1k7yI19LL8Ifi2FMovDfu/X
+JxYk1NNNtdctwaYJFenmGQvDaIq1KgE=
+-----END CERTIFICATE-----
+-----BEGIN CERTIFICATE-----
+MIIDUDCCAjigAwIBAgIIKBJ7IIA6W1swDQYJKoZIhvcNAQELBQAwIDEeMBwGA1UE
+AxMVUGViYmxlIFJvb3QgQ0EgNTdmZjE2MCAXDTIzMDIwOTA1MzMxMloYDzIwNTMw
+MjA5MDUzMzEyWjAoMSYwJAYDVQQDEx1QZWJibGUgSW50ZXJtZWRpYXRlIENBIDQ5
+YTRmYjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANOvlqRx8SXQFWo2
+gFCiXxls53eENcyr8+meFyjgnS853eEvplaPxoa2MREKd+ZYxM8EMMfj2XGvR3UI
+aqR5QyLQ9ihuRqvQo4fG91usBHgH+vDbGPdMX8gDmm9HgnmtOVhSKJU+M2jfE1SW
+UuWB9xOa3LMreTXbTNfZEMoXf+GcWZMbx5WPgEga3DvfmV+RsfNvB55eD7YAyZgF
+ZnQ3Dskmnxxlkz0EGgd7rqhFHHNB9jARlL22gITADwoWZidlr3ciM9DISymRKQ0c
+mRN15fQjNWdtuREgJlpXecbYQMGhdTOmFrqdHkveD1o63rGSC4z+s/APV6xIbcRp
+aNpO7L8CAwEAAaOBgzCBgDAOBgNVHQ8BAf8EBAMCAoQwHQYDVR0lBBYwFAYIKwYB
+BQUHAwEGCCsGAQUFBwMCMA8GA1UdEwEB/wQFMAMBAf8wHQYDVR0OBBYEFNxlzEky
+FZ0/sx5wcIKk2QVREL9jMB8GA1UdIwQYMBaAFOqfkm9rebIz4z0SDIKW5edLg5JM
+MA0GCSqGSIb3DQEBCwUAA4IBAQBRG9AHEnyj2fKzVDDbQaKHjAF5jh0gwyHoIeRK
+FkP9mQNSWxhvPWI0tK/E49LopzmVuzSbDd5kZsaii73rAs6f6Rf9W5veo3AFSEad
+stM+Zv0f2vWB38nuvkoCRLXMX+QUeuL65rKxdEpyArBju4L3/PqAZRgMLcrH+ak8
+nvw5RdAq+Km/ZWyJgGikK6cfMmh91YALCDFnoWUWrCjkBaBFKrG59ONV9f0IQX07
+aNfFXFCF5l466xw9dHjw5iaFib10cpY3iq4kyPYIMs6uaewkCtxWKKjiozM4g4w3
+HqwyUyZ52WUJOJ/6G9DJLDtN3fgGR+IAp8BhYd5CqOscnt3h
+-----END CERTIFICATE-----`)
diff --git a/server/dns/dns.go b/server/dns/dns.go
index 818e29a..e29e42c 100644
--- a/server/dns/dns.go
+++ b/server/dns/dns.go
@@ -5,20 +5,26 @@ import (
"strings"
"time"
- "codeberg.org/codeberg/pages/server/cache"
+ "github.com/hashicorp/golang-lru/v2/expirable"
)
-// lookupCacheTimeout specifies the timeout for the DNS lookup cache.
-var lookupCacheTimeout = 15 * time.Minute
+const (
+ lookupCacheValidity = 30 * time.Second
+ defaultPagesRepo = "pages"
+)
+
+// TODO(#316): refactor to not use global variables
+var lookupCache *expirable.LRU[string, string] = expirable.NewLRU[string, string](4096, nil, lookupCacheValidity)
// GetTargetFromDNS searches for CNAME or TXT entries on the request domain ending with MainDomainSuffix.
// If everything is fine, it returns the target data.
-func GetTargetFromDNS(domain, mainDomainSuffix string, dnsLookupCache cache.SetGetKey) (targetOwner, targetRepo, targetBranch string) {
+func GetTargetFromDNS(domain, mainDomainSuffix, firstDefaultBranch string) (targetOwner, targetRepo, targetBranch string) {
// Get CNAME or TXT
var cname string
var err error
- if cachedName, ok := dnsLookupCache.Get(domain); ok {
- cname = cachedName.(string)
+
+ if entry, ok := lookupCache.Get(domain); ok {
+ cname = entry
} else {
cname, err = net.LookupCNAME(domain)
cname = strings.TrimSuffix(cname, ".")
@@ -28,7 +34,7 @@ func GetTargetFromDNS(domain, mainDomainSuffix string, dnsLookupCache cache.SetG
names, err := net.LookupTXT(domain)
if err == nil {
for _, name := range names {
- name = strings.TrimSuffix(name, ".")
+ name = strings.TrimSuffix(strings.TrimSpace(name), ".")
if strings.HasSuffix(name, mainDomainSuffix) {
cname = name
break
@@ -36,7 +42,7 @@ func GetTargetFromDNS(domain, mainDomainSuffix string, dnsLookupCache cache.SetG
}
}
}
- _ = dnsLookupCache.Set(domain, cname, lookupCacheTimeout)
+ _ = lookupCache.Add(domain, cname)
}
if cname == "" {
return
@@ -50,10 +56,10 @@ func GetTargetFromDNS(domain, mainDomainSuffix string, dnsLookupCache cache.SetG
targetBranch = cnameParts[len(cnameParts)-3]
}
if targetRepo == "" {
- targetRepo = "pages"
+ targetRepo = defaultPagesRepo
}
- if targetBranch == "" && targetRepo != "pages" {
- targetBranch = "pages"
+ if targetBranch == "" && targetRepo != defaultPagesRepo {
+ targetBranch = firstDefaultBranch
}
// if targetBranch is still empty, the caller must find the default branch
return
diff --git a/server/gitea/cache.go b/server/gitea/cache.go
index 85cbcde..03f40a9 100644
--- a/server/gitea/cache.go
+++ b/server/gitea/cache.go
@@ -2,14 +2,17 @@ package gitea
import (
"bytes"
+ "encoding/json"
"fmt"
"io"
"net/http"
"time"
+ "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"codeberg.org/codeberg/pages/server/cache"
+ "codeberg.org/codeberg/pages/server/context"
)
const (
@@ -26,23 +29,27 @@ const (
// TODO: move as option into cache interface
fileCacheTimeout = 5 * time.Minute
+ // ownerExistenceCacheTimeout specifies the timeout for the existence of a repo/org
+ ownerExistenceCacheTimeout = 5 * time.Minute
+
// fileCacheSizeLimit limits the maximum file size that will be cached, and is set to 1 MB by default.
fileCacheSizeLimit = int64(1000 * 1000)
)
type FileResponse struct {
- Exists bool
- IsSymlink bool
- ETag string
- MimeType string
- Body []byte
+ Exists bool `json:"exists"`
+ IsSymlink bool `json:"isSymlink"`
+ ETag string `json:"eTag"`
+ MimeType string `json:"mimeType"` // uncompressed MIME type
+ RawMime string `json:"rawMime"` // raw MIME type (if compressed, type of compression)
+ Body []byte `json:"-"` // saved separately
}
func (f FileResponse) IsEmpty() bool {
- return len(f.Body) != 0
+ return len(f.Body) == 0
}
-func (f FileResponse) createHttpResponse(cacheKey string) (header http.Header, statusCode int) {
+func (f FileResponse) createHttpResponse(cacheKey string, decompress bool) (header http.Header, statusCode int) {
header = make(http.Header)
if f.Exists {
@@ -55,7 +62,13 @@ func (f FileResponse) createHttpResponse(cacheKey string) (header http.Header, s
header.Set(giteaObjectTypeHeader, objTypeSymlink)
}
header.Set(ETagHeader, f.ETag)
- header.Set(ContentTypeHeader, f.MimeType)
+
+ if decompress {
+ header.Set(ContentTypeHeader, f.MimeType)
+ } else {
+ header.Set(ContentTypeHeader, f.RawMime)
+ }
+
header.Set(ContentLengthHeader, fmt.Sprintf("%d", len(f.Body)))
header.Set(PagesCacheIndicatorHeader, "true")
@@ -64,42 +77,67 @@ func (f FileResponse) createHttpResponse(cacheKey string) (header http.Header, s
}
type BranchTimestamp struct {
- Branch string
- Timestamp time.Time
- notFound bool
+ NotFound bool `json:"notFound"`
+ Branch string `json:"branch,omitempty"`
+ Timestamp time.Time `json:"timestamp,omitempty"`
}
type writeCacheReader struct {
originalReader io.ReadCloser
buffer *bytes.Buffer
- rileResponse *FileResponse
+ fileResponse *FileResponse
cacheKey string
- cache cache.SetGetKey
+ cache cache.ICache
hasError bool
+ doNotCache bool
+ complete bool
+ log zerolog.Logger
}
func (t *writeCacheReader) Read(p []byte) (n int, err error) {
+ t.log.Trace().Msgf("[cache] read %q", t.cacheKey)
n, err = t.originalReader.Read(p)
- if err != nil {
- log.Trace().Err(err).Msgf("[cache] original reader for %q has returned an error", t.cacheKey)
+ if err == io.EOF {
+ t.complete = true
+ }
+ if err != nil && err != io.EOF {
+ t.log.Trace().Err(err).Msgf("[cache] original reader for %q has returned an error", t.cacheKey)
t.hasError = true
} else if n > 0 {
- _, _ = t.buffer.Write(p[:n])
+ if t.buffer.Len()+n > int(fileCacheSizeLimit) {
+ t.doNotCache = true
+ t.buffer.Reset()
+ } else {
+ _, _ = t.buffer.Write(p[:n])
+ }
}
return
}
func (t *writeCacheReader) Close() error {
- if !t.hasError {
- fc := *t.rileResponse
- fc.Body = t.buffer.Bytes()
- _ = t.cache.Set(t.cacheKey, fc, fileCacheTimeout)
+ doWrite := !t.hasError && !t.doNotCache && t.complete
+ fc := *t.fileResponse
+ fc.Body = t.buffer.Bytes()
+ if doWrite {
+ jsonToCache, err := json.Marshal(fc)
+ if err != nil {
+ t.log.Trace().Err(err).Msgf("[cache] marshaling json for %q has returned an error", t.cacheKey+"|Metadata")
+ }
+ err = t.cache.Set(t.cacheKey+"|Metadata", jsonToCache, fileCacheTimeout)
+ if err != nil {
+ t.log.Trace().Err(err).Msgf("[cache] writer for %q has returned an error", t.cacheKey+"|Metadata")
+ }
+ err = t.cache.Set(t.cacheKey+"|Body", fc.Body, fileCacheTimeout)
+ if err != nil {
+ t.log.Trace().Err(err).Msgf("[cache] writer for %q has returned an error", t.cacheKey+"|Body")
+ }
}
- log.Trace().Msgf("cacheReader for %q saved=%t closed", t.cacheKey, !t.hasError)
+ t.log.Trace().Msgf("cacheReader for %q saved=%t closed", t.cacheKey, doWrite)
return t.originalReader.Close()
}
-func (f FileResponse) CreateCacheReader(r io.ReadCloser, cache cache.SetGetKey, cacheKey string) io.ReadCloser {
+func (f FileResponse) CreateCacheReader(ctx *context.Context, r io.ReadCloser, cache cache.ICache, cacheKey string) io.ReadCloser {
+ log := log.With().Str("ReqId", ctx.ReqId).Logger()
if r == nil || cache == nil || cacheKey == "" {
log.Error().Msg("could not create CacheReader")
return nil
@@ -108,8 +146,9 @@ func (f FileResponse) CreateCacheReader(r io.ReadCloser, cache cache.SetGetKey,
return &writeCacheReader{
originalReader: r,
buffer: bytes.NewBuffer(make([]byte, 0)),
- rileResponse: &f,
+ fileResponse: &f,
cache: cache,
cacheKey: cacheKey,
+ log: log,
}
}
diff --git a/server/gitea/client.go b/server/gitea/client.go
index 51647ba..5633bf2 100644
--- a/server/gitea/client.go
+++ b/server/gitea/client.go
@@ -2,6 +2,7 @@ package gitea
import (
"bytes"
+ "encoding/json"
"errors"
"fmt"
"io"
@@ -16,16 +17,20 @@ import (
"code.gitea.io/sdk/gitea"
"github.com/rs/zerolog/log"
+ "codeberg.org/codeberg/pages/config"
"codeberg.org/codeberg/pages/server/cache"
+ "codeberg.org/codeberg/pages/server/context"
+ "codeberg.org/codeberg/pages/server/version"
)
var ErrorNotFound = errors.New("not found")
const (
- // cache key prefixe
+ // cache key prefixes
branchTimestampCacheKeyPrefix = "branchTime"
defaultBranchCacheKeyPrefix = "defaultBranch"
rawContentCacheKeyPrefix = "rawContent"
+ ownerExistenceKeyPrefix = "ownerExist"
// pages server
PagesCacheIndicatorHeader = "X-Pages-Cache"
@@ -36,14 +41,16 @@ const (
objTypeSymlink = "symlink"
// std
- ETagHeader = "ETag"
- ContentTypeHeader = "Content-Type"
- ContentLengthHeader = "Content-Length"
+ ETagHeader = "ETag"
+ ContentTypeHeader = "Content-Type"
+ ContentLengthHeader = "Content-Length"
+ ContentEncodingHeader = "Content-Encoding"
)
type Client struct {
sdkClient *gitea.Client
- responseCache cache.SetGetKey
+ sdkFileClient *gitea.Client
+ responseCache cache.ICache
giteaRoot string
@@ -54,37 +61,50 @@ type Client struct {
defaultMimeType string
}
-func NewClient(giteaRoot, giteaAPIToken string, respCache cache.SetGetKey, followSymlinks, supportLFS bool) (*Client, error) {
- rootURL, err := url.Parse(giteaRoot)
+func NewClient(cfg config.ForgeConfig, respCache cache.ICache) (*Client, error) {
+ // url.Parse returns valid on almost anything...
+ rootURL, err := url.ParseRequestURI(cfg.Root)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("invalid forgejo/gitea root url: %w", err)
}
- giteaRoot = strings.Trim(rootURL.String(), "/")
+ giteaRoot := strings.TrimSuffix(rootURL.String(), "/")
- stdClient := http.Client{Timeout: 10 * time.Second}
-
- // TODO: pass down
- var (
- forbiddenMimeTypes map[string]bool
- defaultMimeType string
- )
-
- if forbiddenMimeTypes == nil {
- forbiddenMimeTypes = make(map[string]bool)
+ forbiddenMimeTypes := make(map[string]bool, len(cfg.ForbiddenMimeTypes))
+ for _, mimeType := range cfg.ForbiddenMimeTypes {
+ forbiddenMimeTypes[mimeType] = true
}
+
+ defaultMimeType := cfg.DefaultMimeType
if defaultMimeType == "" {
defaultMimeType = "application/octet-stream"
}
- sdk, err := gitea.NewClient(giteaRoot, gitea.SetHTTPClient(&stdClient), gitea.SetToken(giteaAPIToken))
+ sdkClient, err := gitea.NewClient(
+ giteaRoot,
+ gitea.SetHTTPClient(&http.Client{Timeout: 10 * time.Second}),
+ gitea.SetToken(cfg.Token),
+ gitea.SetUserAgent("pages-server/"+version.Version),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ sdkFileClient, err := gitea.NewClient(
+ giteaRoot,
+ gitea.SetHTTPClient(&http.Client{Timeout: 1 * time.Hour}),
+ gitea.SetToken(cfg.Token),
+ gitea.SetUserAgent("pages-server/"+version.Version),
+ )
+
return &Client{
- sdkClient: sdk,
+ sdkClient: sdkClient,
+ sdkFileClient: sdkFileClient,
responseCache: respCache,
giteaRoot: giteaRoot,
- followSymlinks: followSymlinks,
- supportLFS: supportLFS,
+ followSymlinks: cfg.FollowSymlinks,
+ supportLFS: cfg.LFSEnabled,
forbiddenMimeTypes: forbiddenMimeTypes,
defaultMimeType: defaultMimeType,
@@ -95,8 +115,8 @@ func (client *Client) ContentWebLink(targetOwner, targetRepo, branch, resource s
return path.Join(client.giteaRoot, targetOwner, targetRepo, "src/branch", branch, resource)
}
-func (client *Client) GiteaRawContent(targetOwner, targetRepo, ref, resource string) ([]byte, error) {
- reader, _, _, err := client.ServeRawContent(targetOwner, targetRepo, ref, resource)
+func (client *Client) GiteaRawContent(ctx *context.Context, targetOwner, targetRepo, ref, resource string) ([]byte, error) {
+ reader, _, _, err := client.ServeRawContent(ctx, targetOwner, targetRepo, ref, resource, false)
if err != nil {
return nil, err
}
@@ -104,31 +124,47 @@ func (client *Client) GiteaRawContent(targetOwner, targetRepo, ref, resource str
return io.ReadAll(reader)
}
-func (client *Client) ServeRawContent(targetOwner, targetRepo, ref, resource string) (io.ReadCloser, http.Header, int, error) {
+func (client *Client) ServeRawContent(ctx *context.Context, targetOwner, targetRepo, ref, resource string, decompress bool) (io.ReadCloser, http.Header, int, error) {
cacheKey := fmt.Sprintf("%s/%s/%s|%s|%s", rawContentCacheKeyPrefix, targetOwner, targetRepo, ref, resource)
- log := log.With().Str("cache_key", cacheKey).Logger()
-
+ log := log.With().Str("ReqId", ctx.ReqId).Str("cache_key", cacheKey).Logger()
+ log.Trace().Msg("try file in cache")
// handle if cache entry exist
- if cache, ok := client.responseCache.Get(cacheKey); ok {
- cache := cache.(FileResponse)
- cachedHeader, cachedStatusCode := cache.createHttpResponse(cacheKey)
- // TODO: check against some timestamp missmatch?!?
+ if cacheMetadata, ok := client.responseCache.Get(cacheKey + "|Metadata"); ok {
+ var cache FileResponse
+ err := json.Unmarshal(cacheMetadata.([]byte), &cache)
+ if err != nil {
+ log.Error().Err(err).Msgf("[cache] failed to unmarshal metadata for: %s", cacheKey)
+ return nil, nil, http.StatusNotFound, err
+ }
+
+ if !cache.Exists {
+ return nil, nil, http.StatusNotFound, ErrorNotFound
+ }
+
+ body, ok := client.responseCache.Get(cacheKey + "|Body")
+ if !ok {
+ log.Error().Msgf("[cache] failed to get body for: %s", cacheKey)
+ return nil, nil, http.StatusNotFound, ErrorNotFound
+ }
+ cache.Body = body.([]byte)
+
+ cachedHeader, cachedStatusCode := cache.createHttpResponse(cacheKey, decompress)
if cache.Exists {
if cache.IsSymlink {
linkDest := string(cache.Body)
log.Debug().Msgf("[cache] follow symlink from %q to %q", resource, linkDest)
- return client.ServeRawContent(targetOwner, targetRepo, ref, linkDest)
+ return client.ServeRawContent(ctx, targetOwner, targetRepo, ref, linkDest, decompress)
} else {
- log.Debug().Msg("[cache] return bytes")
+ log.Debug().Msgf("[cache] return %d bytes", len(cache.Body))
return io.NopCloser(bytes.NewReader(cache.Body)), cachedHeader, cachedStatusCode, nil
}
} else {
- return nil, cachedHeader, cachedStatusCode, ErrorNotFound
+ return nil, nil, http.StatusNotFound, ErrorNotFound
}
}
-
+ log.Trace().Msg("file not in cache")
// not in cache, open reader via gitea api
- reader, resp, err := client.sdkClient.GetFileReader(targetOwner, targetRepo, ref, resource, client.supportLFS)
+ reader, resp, err := client.sdkFileClient.GetFileReader(targetOwner, targetRepo, ref, resource, client.supportLFS)
if resp != nil {
switch resp.StatusCode {
case http.StatusOK:
@@ -145,42 +181,58 @@ func (client *Client) ServeRawContent(targetOwner, targetRepo, ref, resource str
}
linkDest := strings.TrimSpace(string(linkDestBytes))
+ // handle relative links
+ // we first remove the link from the path, and make a relative join (resolve parent paths like "/../" too)
+ linkDest = path.Join(path.Dir(resource), linkDest)
+
// we store symlink not content to reduce duplicates in cache
- if err := client.responseCache.Set(cacheKey, FileResponse{
+ fileResponse := FileResponse{
Exists: true,
IsSymlink: true,
Body: []byte(linkDest),
ETag: resp.Header.Get(ETagHeader),
- }, fileCacheTimeout); err != nil {
+ }
+ log.Trace().Msgf("file response has %d bytes", len(fileResponse.Body))
+ jsonToCache, err := json.Marshal(fileResponse)
+ if err != nil {
+ log.Error().Err(err).Msgf("[cache] marshaling json metadata for %q has returned an error", cacheKey)
+ }
+ if err := client.responseCache.Set(cacheKey+"|Metadata", jsonToCache, fileCacheTimeout); err != nil {
+ log.Error().Err(err).Msg("[cache] error on cache write")
+ }
+ if err := client.responseCache.Set(cacheKey+"|Body", fileResponse.Body, fileCacheTimeout); err != nil {
log.Error().Err(err).Msg("[cache] error on cache write")
}
log.Debug().Msgf("follow symlink from %q to %q", resource, linkDest)
- return client.ServeRawContent(targetOwner, targetRepo, ref, linkDest)
+ return client.ServeRawContent(ctx, targetOwner, targetRepo, ref, linkDest, decompress)
}
}
// now we are sure it's content so set the MIME type
- mimeType := client.getMimeTypeByExtension(resource)
+ mimeType, rawType := client.getMimeTypeByExtension(resource)
resp.Response.Header.Set(ContentTypeHeader, mimeType)
-
- if !shouldRespBeSavedToCache(resp.Response) {
- return reader, resp.Response.Header, resp.StatusCode, err
+ if decompress {
+ resp.Response.Header.Set(ContentTypeHeader, mimeType)
+ } else {
+ resp.Response.Header.Set(ContentTypeHeader, rawType)
}
- // now we write to cache and respond at the sime time
+ // now we write to cache and respond at the same time
fileResp := FileResponse{
Exists: true,
ETag: resp.Header.Get(ETagHeader),
MimeType: mimeType,
+ RawMime: rawType,
}
- return fileResp.CreateCacheReader(reader, client.responseCache, cacheKey), resp.Response.Header, resp.StatusCode, nil
+ return fileResp.CreateCacheReader(ctx, reader, client.responseCache, cacheKey), resp.Response.Header, resp.StatusCode, nil
case http.StatusNotFound:
- if err := client.responseCache.Set(cacheKey, FileResponse{
- Exists: false,
- ETag: resp.Header.Get(ETagHeader),
- }, fileCacheTimeout); err != nil {
+ jsonToCache, err := json.Marshal(FileResponse{ETag: resp.Header.Get(ETagHeader)})
+ if err != nil {
+ log.Error().Err(err).Msgf("[cache] marshaling json metadata for %q has returned an error", cacheKey)
+ }
+ if err := client.responseCache.Set(cacheKey+"|Metadata", jsonToCache, fileCacheTimeout); err != nil {
log.Error().Err(err).Msg("[cache] error on cache write")
}
@@ -195,21 +247,36 @@ func (client *Client) ServeRawContent(targetOwner, targetRepo, ref, resource str
func (client *Client) GiteaGetRepoBranchTimestamp(repoOwner, repoName, branchName string) (*BranchTimestamp, error) {
cacheKey := fmt.Sprintf("%s/%s/%s/%s", branchTimestampCacheKeyPrefix, repoOwner, repoName, branchName)
- if stamp, ok := client.responseCache.Get(cacheKey); ok && stamp != nil {
- branchTimeStamp := stamp.(*BranchTimestamp)
- if branchTimeStamp.notFound {
- log.Trace().Msgf("[cache] use branch %q not found", branchName)
+ if stampRaw, ok := client.responseCache.Get(cacheKey); ok {
+ var stamp BranchTimestamp
+ err := json.Unmarshal(stampRaw.([]byte), &stamp)
+ if err != nil {
+ log.Error().Err(err).Bytes("stamp", stampRaw.([]byte)).Msgf("[cache] failed to unmarshal timestamp for: %s", cacheKey)
return &BranchTimestamp{}, ErrorNotFound
}
- log.Trace().Msgf("[cache] use branch %q exist", branchName)
- return branchTimeStamp, nil
+
+ if stamp.NotFound {
+ log.Trace().Msgf("[cache] branch %q does not exist", branchName)
+
+ return &BranchTimestamp{}, ErrorNotFound
+ } else {
+ log.Trace().Msgf("[cache] use branch %q exist", branchName)
+ // This comes from the refactoring of the caching library.
+ // The branch as reported by the API was stored in the cache, and I'm not sure if there are
+ // situations where it differs from the name in the request, hence this is left here.
+ return &stamp, nil
+ }
}
branch, resp, err := client.sdkClient.GetRepoBranch(repoOwner, repoName, branchName)
if err != nil {
if resp != nil && resp.StatusCode == http.StatusNotFound {
log.Trace().Msgf("[cache] set cache branch %q not found", branchName)
- if err := client.responseCache.Set(cacheKey, &BranchTimestamp{Branch: branchName, notFound: true}, branchExistenceCacheTimeout); err != nil {
+ jsonToCache, err := json.Marshal(BranchTimestamp{NotFound: true})
+ if err != nil {
+ log.Error().Err(err).Msgf("[cache] marshaling empty timestamp for '%s' has returned an error", cacheKey)
+ }
+ if err := client.responseCache.Set(cacheKey, jsonToCache, branchExistenceCacheTimeout); err != nil {
log.Error().Err(err).Msg("[cache] error on cache write")
}
return &BranchTimestamp{}, ErrorNotFound
@@ -226,7 +293,11 @@ func (client *Client) GiteaGetRepoBranchTimestamp(repoOwner, repoName, branchNam
}
log.Trace().Msgf("set cache branch [%s] exist", branchName)
- if err := client.responseCache.Set(cacheKey, stamp, branchExistenceCacheTimeout); err != nil {
+ jsonToCache, err := json.Marshal(stamp)
+ if err != nil {
+ log.Error().Err(err).Msgf("[cache] marshaling timestamp for %q has returned an error", cacheKey)
+ }
+ if err := client.responseCache.Set(cacheKey, jsonToCache, branchExistenceCacheTimeout); err != nil {
log.Error().Err(err).Msg("[cache] error on cache write")
}
return stamp, nil
@@ -235,8 +306,8 @@ func (client *Client) GiteaGetRepoBranchTimestamp(repoOwner, repoName, branchNam
func (client *Client) GiteaGetRepoDefaultBranch(repoOwner, repoName string) (string, error) {
cacheKey := fmt.Sprintf("%s/%s/%s", defaultBranchCacheKeyPrefix, repoOwner, repoName)
- if branch, ok := client.responseCache.Get(cacheKey); ok && branch != nil {
- return branch.(string), nil
+ if branch, ok := client.responseCache.Get(cacheKey); ok {
+ return string(branch.([]byte)), nil
}
repo, resp, err := client.sdkClient.GetRepo(repoOwner, repoName)
@@ -248,37 +319,68 @@ func (client *Client) GiteaGetRepoDefaultBranch(repoOwner, repoName string) (str
}
branch := repo.DefaultBranch
- if err := client.responseCache.Set(cacheKey, branch, defaultBranchCacheTimeout); err != nil {
+ if err := client.responseCache.Set(cacheKey, []byte(branch), defaultBranchCacheTimeout); err != nil {
log.Error().Err(err).Msg("[cache] error on cache write")
}
return branch, nil
}
-func (client *Client) getMimeTypeByExtension(resource string) string {
- mimeType := mime.TypeByExtension(path.Ext(resource))
+func (client *Client) GiteaCheckIfOwnerExists(owner string) (bool, error) {
+ cacheKey := fmt.Sprintf("%s/%s", ownerExistenceKeyPrefix, owner)
+
+ if existRaw, ok := client.responseCache.Get(cacheKey); ok && existRaw != nil {
+ exist, err := strconv.ParseBool(existRaw.(string))
+ return exist, err
+ }
+
+ _, resp, err := client.sdkClient.GetUserInfo(owner)
+ if resp.StatusCode == http.StatusOK && err == nil {
+ if err := client.responseCache.Set(cacheKey, []byte("true"), ownerExistenceCacheTimeout); err != nil {
+ log.Error().Err(err).Msg("[cache] error on cache write")
+ }
+ return true, nil
+ } else if resp.StatusCode != http.StatusNotFound {
+ return false, err
+ }
+
+ _, resp, err = client.sdkClient.GetOrg(owner)
+ if resp.StatusCode == http.StatusOK && err == nil {
+ if err := client.responseCache.Set(cacheKey, []byte("true"), ownerExistenceCacheTimeout); err != nil {
+ log.Error().Err(err).Msg("[cache] error on cache write")
+ }
+ return true, nil
+ } else if resp.StatusCode != http.StatusNotFound {
+ return false, err
+ }
+ if err := client.responseCache.Set(cacheKey, []byte("false"), ownerExistenceCacheTimeout); err != nil {
+ log.Error().Err(err).Msg("[cache] error on cache write")
+ }
+ return false, nil
+}
+
+func (client *Client) extToMime(ext string) string {
+ mimeType := mime.TypeByExtension(path.Ext(ext))
mimeTypeSplit := strings.SplitN(mimeType, ";", 2)
if client.forbiddenMimeTypes[mimeTypeSplit[0]] || mimeType == "" {
mimeType = client.defaultMimeType
}
- log.Trace().Msgf("probe mime of %q is %q", resource, mimeType)
+ log.Trace().Msgf("probe mime of extension '%q' is '%q'", ext, mimeType)
+
return mimeType
}
-func shouldRespBeSavedToCache(resp *http.Response) bool {
- if resp == nil {
- return false
+func (client *Client) getMimeTypeByExtension(resource string) (mimeType, rawType string) {
+ rawExt := path.Ext(resource)
+ innerExt := rawExt
+ switch rawExt {
+ case ".gz", ".br", ".zst":
+ innerExt = path.Ext(resource[:len(resource)-len(rawExt)])
}
-
- contentLengthRaw := resp.Header.Get(ContentLengthHeader)
- if contentLengthRaw == "" {
- return false
+ rawType = client.extToMime(rawExt)
+ mimeType = rawType
+ if innerExt != rawExt {
+ mimeType = client.extToMime(innerExt)
}
-
- contentLeng, err := strconv.ParseInt(contentLengthRaw, 10, 64)
- if err != nil {
- log.Error().Err(err).Msg("could not parse content length")
- }
-
- // if content to big or could not be determined we not cache it
- return contentLeng > 0 && contentLeng < fileCacheSizeLimit
+ log.Trace().Msgf("probe mime of %q is (%q / raw %q)", resource, mimeType, rawType)
+ return mimeType, rawType
}
diff --git a/server/handler/handler.go b/server/handler/handler.go
index 78301e9..437697a 100644
--- a/server/handler/handler.go
+++ b/server/handler/handler.go
@@ -6,32 +6,31 @@ import (
"github.com/rs/zerolog/log"
+ "codeberg.org/codeberg/pages/config"
"codeberg.org/codeberg/pages/html"
"codeberg.org/codeberg/pages/server/cache"
"codeberg.org/codeberg/pages/server/context"
"codeberg.org/codeberg/pages/server/gitea"
- "codeberg.org/codeberg/pages/server/version"
)
const (
headerAccessControlAllowOrigin = "Access-Control-Allow-Origin"
headerAccessControlAllowMethods = "Access-Control-Allow-Methods"
defaultPagesRepo = "pages"
- defaultPagesBranch = "pages"
)
// Handler handles a single HTTP request to the web server.
-func Handler(mainDomainSuffix, rawDomain string,
+func Handler(
+ cfg config.ServerConfig,
giteaClient *gitea.Client,
- rawInfoPage string,
- blacklistedPaths, allowedCorsDomains []string,
- dnsLookupCache, canonicalDomainCache cache.SetGetKey,
+ canonicalDomainCache, redirectsCache cache.ICache,
) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
- log := log.With().Strs("Handler", []string{req.Host, req.RequestURI}).Logger()
ctx := context.New(w, req)
+ log := log.With().Str("ReqId", ctx.ReqId).Strs("Handler", []string{req.Host, req.RequestURI}).Logger()
+ log.Debug().Msg("\n----------------------------------------------------------")
- ctx.RespWriter.Header().Set("Server", "CodebergPages/"+version.Version)
+ ctx.RespWriter.Header().Set("Server", "pages-server")
// Force new default from specification (since November 2020) - see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy#strict-origin-when-cross-origin
ctx.RespWriter.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
@@ -41,8 +40,8 @@ func Handler(mainDomainSuffix, rawDomain string,
trimmedHost := ctx.TrimHostPort()
- // Add HSTS for RawDomain and MainDomainSuffix
- if hsts := getHSTSHeader(trimmedHost, mainDomainSuffix, rawDomain); hsts != "" {
+ // Add HSTS for RawDomain and MainDomain
+ if hsts := getHSTSHeader(trimmedHost, cfg.MainDomain, cfg.RawDomain); hsts != "" {
ctx.RespWriter.Header().Set("Strict-Transport-Security", hsts)
}
@@ -64,16 +63,16 @@ func Handler(mainDomainSuffix, rawDomain string,
}
// Block blacklisted paths (like ACME challenges)
- for _, blacklistedPath := range blacklistedPaths {
+ for _, blacklistedPath := range cfg.BlacklistedPaths {
if strings.HasPrefix(ctx.Path(), blacklistedPath) {
- html.ReturnErrorPage(ctx, "requested blacklisted path", http.StatusForbidden)
+ html.ReturnErrorPage(ctx, "requested path is blacklisted", http.StatusForbidden)
return
}
}
// Allow CORS for specified domains
allowCors := false
- for _, allowedCorsDomain := range allowedCorsDomains {
+ for _, allowedCorsDomain := range cfg.AllowedCorsDomains {
if strings.EqualFold(trimmedHost, allowedCorsDomain) {
allowCors = true
break
@@ -87,27 +86,29 @@ func Handler(mainDomainSuffix, rawDomain string,
// Prepare request information to Gitea
pathElements := strings.Split(strings.Trim(ctx.Path(), "/"), "/")
- if rawDomain != "" && strings.EqualFold(trimmedHost, rawDomain) {
- log.Debug().Msg("raw domain request detecded")
+ if cfg.RawDomain != "" && strings.EqualFold(trimmedHost, cfg.RawDomain) {
+ log.Debug().Msg("raw domain request detected")
handleRaw(log, ctx, giteaClient,
- mainDomainSuffix, rawInfoPage,
+ cfg.MainDomain,
trimmedHost,
pathElements,
- canonicalDomainCache)
- } else if strings.HasSuffix(trimmedHost, mainDomainSuffix) {
- log.Debug().Msg("subdomain request detecded")
+ canonicalDomainCache, redirectsCache)
+ } else if strings.HasSuffix(trimmedHost, cfg.MainDomain) {
+ log.Debug().Msg("subdomain request detected")
handleSubDomain(log, ctx, giteaClient,
- mainDomainSuffix,
+ cfg.MainDomain,
+ cfg.PagesBranches,
trimmedHost,
pathElements,
- canonicalDomainCache)
+ canonicalDomainCache, redirectsCache)
} else {
- log.Debug().Msg("custom domain request detecded")
+ log.Debug().Msg("custom domain request detected")
handleCustomDomain(log, ctx, giteaClient,
- mainDomainSuffix,
+ cfg.MainDomain,
trimmedHost,
pathElements,
- dnsLookupCache, canonicalDomainCache)
+ cfg.PagesBranches[0],
+ canonicalDomainCache, redirectsCache)
}
}
}
diff --git a/server/handler/handler_custom_domain.go b/server/handler/handler_custom_domain.go
index 2f98085..8a5f9d7 100644
--- a/server/handler/handler_custom_domain.go
+++ b/server/handler/handler_custom_domain.go
@@ -18,10 +18,11 @@ func handleCustomDomain(log zerolog.Logger, ctx *context.Context, giteaClient *g
mainDomainSuffix string,
trimmedHost string,
pathElements []string,
- dnsLookupCache, canonicalDomainCache cache.SetGetKey,
+ firstDefaultBranch string,
+ canonicalDomainCache, redirectsCache cache.ICache,
) {
// Serve pages from custom domains
- targetOwner, targetRepo, targetBranch := dns.GetTargetFromDNS(trimmedHost, mainDomainSuffix, dnsLookupCache)
+ targetOwner, targetRepo, targetBranch := dns.GetTargetFromDNS(trimmedHost, mainDomainSuffix, firstDefaultBranch)
if targetOwner == "" {
html.ReturnErrorPage(ctx,
"could not obtain repo owner from custom domain",
@@ -46,15 +47,15 @@ func handleCustomDomain(log zerolog.Logger, ctx *context.Context, giteaClient *g
TargetBranch: targetBranch,
TargetPath: path.Join(pathParts...),
}, canonicalLink); works {
- canonicalDomain, valid := targetOpt.CheckCanonicalDomain(giteaClient, trimmedHost, mainDomainSuffix, canonicalDomainCache)
+ canonicalDomain, valid := targetOpt.CheckCanonicalDomain(ctx, giteaClient, trimmedHost, mainDomainSuffix, canonicalDomainCache)
if !valid {
html.ReturnErrorPage(ctx, "domain not specified in .domains
file", http.StatusMisdirectedRequest)
return
} else if canonicalDomain != trimmedHost {
// only redirect if the target is also a codeberg page!
- targetOwner, _, _ = dns.GetTargetFromDNS(strings.SplitN(canonicalDomain, "/", 2)[0], mainDomainSuffix, dnsLookupCache)
+ targetOwner, _, _ = dns.GetTargetFromDNS(strings.SplitN(canonicalDomain, "/", 2)[0], mainDomainSuffix, firstDefaultBranch)
if targetOwner != "" {
- ctx.Redirect("https://"+canonicalDomain+targetOpt.TargetPath, http.StatusTemporaryRedirect)
+ ctx.Redirect("https://"+canonicalDomain+"/"+targetOpt.TargetPath, http.StatusTemporaryRedirect)
return
}
@@ -62,8 +63,8 @@ func handleCustomDomain(log zerolog.Logger, ctx *context.Context, giteaClient *g
return
}
- log.Debug().Msg("tryBranch, now trying upstream 7")
- tryUpstream(ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache)
+ log.Debug().Str("url", trimmedHost).Msg("tryBranch, now trying upstream")
+ tryUpstream(log, ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache, redirectsCache)
return
}
diff --git a/server/handler/handler_raw_domain.go b/server/handler/handler_raw_domain.go
index 5e974da..bbbf7da 100644
--- a/server/handler/handler_raw_domain.go
+++ b/server/handler/handler_raw_domain.go
@@ -16,17 +16,21 @@ import (
)
func handleRaw(log zerolog.Logger, ctx *context.Context, giteaClient *gitea.Client,
- mainDomainSuffix, rawInfoPage string,
+ mainDomainSuffix string,
trimmedHost string,
pathElements []string,
- canonicalDomainCache cache.SetGetKey,
+ canonicalDomainCache, redirectsCache cache.ICache,
) {
// Serve raw content from RawDomain
log.Debug().Msg("raw domain")
if len(pathElements) < 2 {
- // https://{RawDomain}/{owner}/{repo}[/@{branch}]/{path} is required
- ctx.Redirect(rawInfoPage, http.StatusTemporaryRedirect)
+ html.ReturnErrorPage(
+ ctx,
+ "a url in the form of https://{domain}/{owner}/{repo}[/@{branch}]/{path}
is required",
+ http.StatusBadRequest,
+ )
+
return
}
@@ -41,7 +45,7 @@ func handleRaw(log zerolog.Logger, ctx *context.Context, giteaClient *gitea.Clie
TargetPath: path.Join(pathElements[3:]...),
}, true); works {
log.Trace().Msg("tryUpstream: serve raw domain with specified branch")
- tryUpstream(ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache)
+ tryUpstream(log, ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache, redirectsCache)
return
}
log.Debug().Msg("missing branch info")
@@ -58,10 +62,10 @@ func handleRaw(log zerolog.Logger, ctx *context.Context, giteaClient *gitea.Clie
TargetPath: path.Join(pathElements[2:]...),
}, true); works {
log.Trace().Msg("tryUpstream: serve raw domain with default branch")
- tryUpstream(ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache)
+ tryUpstream(log, ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache, redirectsCache)
} else {
html.ReturnErrorPage(ctx,
- fmt.Sprintf("raw domain could not find repo '%s/%s' or repo is empty", targetOpt.TargetOwner, targetOpt.TargetRepo),
+ fmt.Sprintf("raw domain could not find repo %s/%s
or repo is empty", targetOpt.TargetOwner, targetOpt.TargetRepo),
http.StatusNotFound)
}
}
diff --git a/server/handler/handler_sub_domain.go b/server/handler/handler_sub_domain.go
index 2a75e9f..e335019 100644
--- a/server/handler/handler_sub_domain.go
+++ b/server/handler/handler_sub_domain.go
@@ -7,6 +7,7 @@ import (
"strings"
"github.com/rs/zerolog"
+ "golang.org/x/exp/slices"
"codeberg.org/codeberg/pages/html"
"codeberg.org/codeberg/pages/server/cache"
@@ -17,9 +18,10 @@ import (
func handleSubDomain(log zerolog.Logger, ctx *context.Context, giteaClient *gitea.Client,
mainDomainSuffix string,
+ defaultPagesBranches []string,
trimmedHost string,
pathElements []string,
- canonicalDomainCache cache.SetGetKey,
+ canonicalDomainCache, redirectsCache cache.ICache,
) {
// Serve pages from subdomains of MainDomainSuffix
log.Debug().Msg("main domain suffix")
@@ -51,11 +53,13 @@ func handleSubDomain(log zerolog.Logger, ctx *context.Context, giteaClient *gite
TargetPath: path.Join(pathElements[2:]...),
}, true); works {
log.Trace().Msg("tryUpstream: serve with specified repo and branch")
- tryUpstream(ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache)
+ tryUpstream(log, ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache, redirectsCache)
} else {
- html.ReturnErrorPage(ctx,
- fmt.Sprintf("explizite set branch %q do not exist at '%s/%s'", targetOpt.TargetBranch, targetOpt.TargetOwner, targetOpt.TargetRepo),
- http.StatusFailedDependency)
+ html.ReturnErrorPage(
+ ctx,
+ formatSetBranchNotFoundMessage(pathElements[1][1:], targetOwner, pathElements[0]),
+ http.StatusFailedDependency,
+ )
}
return
}
@@ -63,38 +67,66 @@ func handleSubDomain(log zerolog.Logger, ctx *context.Context, giteaClient *gite
// Check if the first directory is a branch for the defaultPagesRepo
// example.codeberg.page/@main/index.html
if strings.HasPrefix(pathElements[0], "@") {
+ targetBranch := pathElements[0][1:]
+
+ // if the default pages branch can be determined exactly, it does not need to be set
+ if len(defaultPagesBranches) == 1 && slices.Contains(defaultPagesBranches, targetBranch) {
+ // example.codeberg.org/@pages/... redirects to example.codeberg.org/...
+ ctx.Redirect("/"+strings.Join(pathElements[1:], "/"), http.StatusTemporaryRedirect)
+ return
+ }
+
log.Debug().Msg("main domain preparations, now trying with specified branch")
if targetOpt, works := tryBranch(log, ctx, giteaClient, &upstream.Options{
TryIndexPages: true,
TargetOwner: targetOwner,
TargetRepo: defaultPagesRepo,
- TargetBranch: pathElements[0][1:],
+ TargetBranch: targetBranch,
TargetPath: path.Join(pathElements[1:]...),
}, true); works {
log.Trace().Msg("tryUpstream: serve default pages repo with specified branch")
- tryUpstream(ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache)
+ tryUpstream(log, ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache, redirectsCache)
} else {
- html.ReturnErrorPage(ctx,
- fmt.Sprintf("explizite set branch %q do not exist at '%s/%s'", targetOpt.TargetBranch, targetOpt.TargetOwner, targetOpt.TargetRepo),
- http.StatusFailedDependency)
+ html.ReturnErrorPage(
+ ctx,
+ formatSetBranchNotFoundMessage(targetBranch, targetOwner, defaultPagesRepo),
+ http.StatusFailedDependency,
+ )
}
return
}
- // Check if the first directory is a repo with a defaultPagesRepo branch
- // example.codeberg.page/myrepo/index.html
- // example.codeberg.page/pages/... is not allowed here.
- log.Debug().Msg("main domain preparations, now trying with specified repo")
- if pathElements[0] != defaultPagesRepo {
+ for _, defaultPagesBranch := range defaultPagesBranches {
+ // Check if the first directory is a repo with a default pages branch
+ // example.codeberg.page/myrepo/index.html
+ // example.codeberg.page/{PAGES_BRANCHE}/... is not allowed here.
+ log.Debug().Msg("main domain preparations, now trying with specified repo")
+ if pathElements[0] != defaultPagesBranch {
+ if targetOpt, works := tryBranch(log, ctx, giteaClient, &upstream.Options{
+ TryIndexPages: true,
+ TargetOwner: targetOwner,
+ TargetRepo: pathElements[0],
+ TargetBranch: defaultPagesBranch,
+ TargetPath: path.Join(pathElements[1:]...),
+ }, false); works {
+ log.Debug().Msg("tryBranch, now trying upstream 5")
+ tryUpstream(log, ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache, redirectsCache)
+ return
+ }
+ }
+
+ // Try to use the defaultPagesRepo on an default pages branch
+ // example.codeberg.page/index.html
+ log.Debug().Msg("main domain preparations, now trying with default repo")
if targetOpt, works := tryBranch(log, ctx, giteaClient, &upstream.Options{
TryIndexPages: true,
TargetOwner: targetOwner,
- TargetRepo: pathElements[0],
+ TargetRepo: defaultPagesRepo,
TargetBranch: defaultPagesBranch,
- TargetPath: path.Join(pathElements[1:]...),
+ TargetPath: path.Join(pathElements...),
}, false); works {
- log.Debug().Msg("tryBranch, now trying upstream 5")
- tryUpstream(ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache)
+ log.Debug().Msg("tryBranch, now trying upstream 6")
+ tryUpstream(log, ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache, redirectsCache)
return
}
}
@@ -109,12 +141,16 @@ func handleSubDomain(log zerolog.Logger, ctx *context.Context, giteaClient *gite
TargetPath: path.Join(pathElements...),
}, false); works {
log.Debug().Msg("tryBranch, now trying upstream 6")
- tryUpstream(ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache)
+ tryUpstream(log, ctx, giteaClient, mainDomainSuffix, trimmedHost, targetOpt, canonicalDomainCache, redirectsCache)
return
}
// Couldn't find a valid repo/branch
html.ReturnErrorPage(ctx,
- fmt.Sprintf("could not find a valid repository[%s]", targetRepo),
+ fmt.Sprintf("could not find a valid repository or branch for repository: %s
", targetRepo),
http.StatusNotFound)
}
+
+func formatSetBranchNotFoundMessage(branch, owner, repo string) string {
+ return fmt.Sprintf("explicitly set branch %q
does not exist at %s/%s
", branch, owner, repo)
+}
diff --git a/server/handler/handler_test.go b/server/handler/handler_test.go
index 626564a..765b3b1 100644
--- a/server/handler/handler_test.go
+++ b/server/handler/handler_test.go
@@ -1,30 +1,39 @@
package handler
import (
+ "net/http"
"net/http/httptest"
"testing"
"time"
+ "codeberg.org/codeberg/pages/config"
"codeberg.org/codeberg/pages/server/cache"
"codeberg.org/codeberg/pages/server/gitea"
"github.com/rs/zerolog/log"
)
func TestHandlerPerformance(t *testing.T) {
- giteaClient, _ := gitea.NewClient("https://codeberg.org", "", cache.NewKeyValueCache(), false, false)
- testHandler := Handler(
- "codeberg.page", "raw.codeberg.org",
- giteaClient,
- "https://docs.codeberg.org/pages/raw-content/",
- []string{"/.well-known/acme-challenge/"},
- []string{"raw.codeberg.org", "fonts.codeberg.org", "design.codeberg.org"},
- cache.NewKeyValueCache(),
- cache.NewKeyValueCache(),
- )
+ cfg := config.ForgeConfig{
+ Root: "https://codeberg.org",
+ Token: "",
+ LFSEnabled: false,
+ FollowSymlinks: false,
+ }
+ giteaClient, _ := gitea.NewClient(cfg, cache.NewInMemoryCache())
+ serverCfg := config.ServerConfig{
+ MainDomain: "codeberg.page",
+ RawDomain: "raw.codeberg.page",
+ BlacklistedPaths: []string{
+ "/.well-known/acme-challenge/",
+ },
+ AllowedCorsDomains: []string{"raw.codeberg.org", "fonts.codeberg.org", "design.codeberg.org"},
+ PagesBranches: []string{"pages"},
+ }
+ testHandler := Handler(serverCfg, giteaClient, cache.NewInMemoryCache(), cache.NewInMemoryCache())
testCase := func(uri string, status int) {
t.Run(uri, func(t *testing.T) {
- req := httptest.NewRequest("GET", uri, nil)
+ req := httptest.NewRequest("GET", uri, http.NoBody)
w := httptest.NewRecorder()
log.Printf("Start: %v\n", time.Now())
diff --git a/server/handler/try.go b/server/handler/try.go
index 5a09b91..e5fc49b 100644
--- a/server/handler/try.go
+++ b/server/handler/try.go
@@ -1,6 +1,7 @@
package handler
import (
+ "fmt"
"net/http"
"strings"
@@ -14,14 +15,15 @@ import (
)
// tryUpstream forwards the target request to the Gitea API, and shows an error page on failure.
-func tryUpstream(ctx *context.Context, giteaClient *gitea.Client,
+func tryUpstream(log zerolog.Logger, ctx *context.Context, giteaClient *gitea.Client,
mainDomainSuffix, trimmedHost string,
options *upstream.Options,
- canonicalDomainCache cache.SetGetKey,
+ canonicalDomainCache cache.ICache,
+ redirectsCache cache.ICache,
) {
// check if a canonical domain exists on a request on MainDomain
- if strings.HasSuffix(trimmedHost, mainDomainSuffix) {
- canonicalDomain, _ := options.CheckCanonicalDomain(giteaClient, "", mainDomainSuffix, canonicalDomainCache)
+ if strings.HasSuffix(trimmedHost, mainDomainSuffix) && !options.ServeRaw {
+ canonicalDomain, _ := options.CheckCanonicalDomain(ctx, giteaClient, "", mainDomainSuffix, canonicalDomainCache)
if !strings.HasSuffix(strings.SplitN(canonicalDomain, "/", 2)[0], mainDomainSuffix) {
canonicalPath := ctx.Req.RequestURI
if options.TargetRepo != defaultPagesRepo {
@@ -30,7 +32,12 @@ func tryUpstream(ctx *context.Context, giteaClient *gitea.Client,
canonicalPath = "/" + path[2]
}
}
- ctx.Redirect("https://"+canonicalDomain+canonicalPath, http.StatusTemporaryRedirect)
+
+ redirect_to := "https://" + canonicalDomain + canonicalPath
+
+ log.Debug().Str("to", redirect_to).Msg("redirecting")
+
+ ctx.Redirect(redirect_to, http.StatusTemporaryRedirect)
return
}
}
@@ -39,8 +46,9 @@ func tryUpstream(ctx *context.Context, giteaClient *gitea.Client,
options.Host = trimmedHost
// Try to request the file from the Gitea API
- if !options.Upstream(ctx, giteaClient) {
- html.ReturnErrorPage(ctx, "", ctx.StatusCode)
+ log.Debug().Msg("requesting from upstream")
+ if !options.Upstream(ctx, giteaClient, redirectsCache) {
+ html.ReturnErrorPage(ctx, fmt.Sprintf("Forge returned %d %s", ctx.StatusCode, http.StatusText(ctx.StatusCode)), ctx.StatusCode)
}
}
diff --git a/server/profiling.go b/server/profiling.go
new file mode 100644
index 0000000..7d20926
--- /dev/null
+++ b/server/profiling.go
@@ -0,0 +1,21 @@
+package server
+
+import (
+ "net/http"
+ _ "net/http/pprof"
+
+ "github.com/rs/zerolog/log"
+)
+
+func StartProfilingServer(listeningAddress string) {
+ server := &http.Server{
+ Addr: listeningAddress,
+ Handler: http.DefaultServeMux,
+ }
+
+ log.Info().Msgf("Starting debug server on %s", listeningAddress)
+
+ go func() {
+ log.Fatal().Err(server.ListenAndServe()).Msg("Failed to start debug server")
+ }()
+}
diff --git a/server/setup.go b/server/setup.go
deleted file mode 100644
index 282e692..0000000
--- a/server/setup.go
+++ /dev/null
@@ -1,27 +0,0 @@
-package server
-
-import (
- "net/http"
- "strings"
-
- "codeberg.org/codeberg/pages/server/cache"
- "codeberg.org/codeberg/pages/server/context"
- "codeberg.org/codeberg/pages/server/utils"
-)
-
-func SetupHTTPACMEChallengeServer(challengeCache cache.SetGetKey) http.HandlerFunc {
- challengePath := "/.well-known/acme-challenge/"
-
- return func(w http.ResponseWriter, req *http.Request) {
- ctx := context.New(w, req)
- if strings.HasPrefix(ctx.Path(), challengePath) {
- challenge, ok := challengeCache.Get(utils.TrimHostPort(ctx.Host()) + "/" + strings.TrimPrefix(ctx.Path(), challengePath))
- if !ok || challenge == nil {
- ctx.String("no challenge for this token", http.StatusNotFound)
- }
- ctx.String(challenge.(string))
- } else {
- ctx.Redirect("https://"+ctx.Host()+ctx.Path(), http.StatusMovedPermanently)
- }
- }
-}
diff --git a/server/startup.go b/server/startup.go
new file mode 100644
index 0000000..4ae26c1
--- /dev/null
+++ b/server/startup.go
@@ -0,0 +1,145 @@
+package server
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "net"
+ "net/http"
+ "os"
+ "strings"
+ "time"
+
+ "github.com/pires/go-proxyproto"
+ "github.com/rs/zerolog"
+ "github.com/rs/zerolog/log"
+ "github.com/urfave/cli/v2"
+
+ cmd "codeberg.org/codeberg/pages/cli"
+ "codeberg.org/codeberg/pages/config"
+ "codeberg.org/codeberg/pages/server/acme"
+ "codeberg.org/codeberg/pages/server/cache"
+ "codeberg.org/codeberg/pages/server/certificates"
+ "codeberg.org/codeberg/pages/server/gitea"
+ "codeberg.org/codeberg/pages/server/handler"
+)
+
+// Serve sets up and starts the web server.
+func Serve(ctx *cli.Context) error {
+ // initialize logger with Trace, overridden later with actual level
+ log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).With().Timestamp().Caller().Logger().Level(zerolog.TraceLevel)
+
+ cfg, err := config.ReadConfig(ctx)
+ if err != nil {
+ log.Error().Err(err).Msg("could not read config")
+ }
+
+ config.MergeConfig(ctx, cfg)
+
+ // Initialize the logger.
+ logLevel, err := zerolog.ParseLevel(cfg.LogLevel)
+ if err != nil {
+ return err
+ }
+ fmt.Printf("Setting log level to: %s\n", logLevel)
+ log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).With().Timestamp().Caller().Logger().Level(logLevel)
+
+ listeningSSLAddress := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
+ listeningHTTPAddress := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.HttpPort)
+
+ if cfg.Server.RawDomain != "" {
+ cfg.Server.AllowedCorsDomains = append(cfg.Server.AllowedCorsDomains, cfg.Server.RawDomain)
+ }
+
+ // Make sure MainDomain has a leading dot
+ if !strings.HasPrefix(cfg.Server.MainDomain, ".") {
+ // TODO make this better
+ cfg.Server.MainDomain = "." + cfg.Server.MainDomain
+ }
+
+ if len(cfg.Server.PagesBranches) == 0 {
+ return fmt.Errorf("no default branches set (PAGES_BRANCHES)")
+ }
+
+ // Init ssl cert database
+ certDB, closeFn, err := cmd.OpenCertDB(ctx)
+ if err != nil {
+ return err
+ }
+ defer closeFn()
+
+ challengeCache := cache.NewInMemoryCache()
+ // canonicalDomainCache stores canonical domains
+ canonicalDomainCache := cache.NewInMemoryCache()
+ // redirectsCache stores redirects in _redirects files
+ redirectsCache := cache.NewInMemoryCache()
+ // clientResponseCache stores responses from the Gitea server
+ clientResponseCache := cache.NewInMemoryCache()
+
+ giteaClient, err := gitea.NewClient(cfg.Forge, clientResponseCache)
+ if err != nil {
+ return fmt.Errorf("could not create new gitea client: %v", err)
+ }
+
+ acmeClient, err := acme.CreateAcmeClient(cfg.ACME, cfg.Server.HttpServerEnabled, challengeCache)
+ if err != nil {
+ return err
+ }
+
+ if err := certificates.SetupMainDomainCertificates(log.Logger, cfg.Server.MainDomain, acmeClient, certDB); err != nil {
+ return err
+ }
+
+ // Create listener for SSL connections
+ log.Info().Msgf("Create TCP listener for SSL on %s", listeningSSLAddress)
+ listener, err := net.Listen("tcp", listeningSSLAddress)
+ if err != nil {
+ return fmt.Errorf("couldn't create listener: %v", err)
+ }
+
+ if cfg.Server.UseProxyProtocol {
+ listener = &proxyproto.Listener{Listener: listener}
+ }
+ // Setup listener for SSL connections
+ listener = tls.NewListener(listener, certificates.TLSConfig(
+ cfg.Server.MainDomain,
+ giteaClient,
+ acmeClient,
+ cfg.Server.PagesBranches[0],
+ challengeCache, canonicalDomainCache,
+ certDB,
+ cfg.ACME.NoDNS01,
+ cfg.Server.RawDomain,
+ ))
+
+ interval := 12 * time.Hour
+ certMaintainCtx, cancelCertMaintain := context.WithCancel(context.Background())
+ defer cancelCertMaintain()
+ go certificates.MaintainCertDB(log.Logger, certMaintainCtx, interval, acmeClient, cfg.Server.MainDomain, certDB)
+
+ if cfg.Server.HttpServerEnabled {
+ // Create handler for http->https redirect and http acme challenges
+ httpHandler := certificates.SetupHTTPACMEChallengeServer(challengeCache, uint(cfg.Server.Port))
+
+ // Create listener for http and start listening
+ go func() {
+ log.Info().Msgf("Start HTTP server listening on %s", listeningHTTPAddress)
+ err := http.ListenAndServe(listeningHTTPAddress, httpHandler)
+ if err != nil {
+ log.Error().Err(err).Msg("Couldn't start HTTP server")
+ }
+ }()
+ }
+
+ if ctx.IsSet("enable-profiling") {
+ StartProfilingServer(ctx.String("profiling-address"))
+ }
+
+ // Create ssl handler based on settings
+ sslHandler := handler.Handler(cfg.Server, giteaClient, canonicalDomainCache, redirectsCache)
+
+ // Start the ssl listener
+ log.Info().Msgf("Start SSL server using TCP listener on %s", listener.Addr())
+
+ return http.Serve(listener, sslHandler)
+}
diff --git a/server/upstream/domains.go b/server/upstream/domains.go
index 0e29673..f68a02b 100644
--- a/server/upstream/domains.go
+++ b/server/upstream/domains.go
@@ -1,12 +1,14 @@
package upstream
import (
+ "errors"
"strings"
"time"
"github.com/rs/zerolog/log"
"codeberg.org/codeberg/pages/server/cache"
+ "codeberg.org/codeberg/pages/server/context"
"codeberg.org/codeberg/pages/server/gitea"
)
@@ -16,45 +18,54 @@ var canonicalDomainCacheTimeout = 15 * time.Minute
const canonicalDomainConfig = ".domains"
// CheckCanonicalDomain returns the canonical domain specified in the repo (using the `.domains` file).
-func (o *Options) CheckCanonicalDomain(giteaClient *gitea.Client, actualDomain, mainDomainSuffix string, canonicalDomainCache cache.SetGetKey) (string, bool) {
- var (
- domains []string
- valid bool
- )
+func (o *Options) CheckCanonicalDomain(ctx *context.Context, giteaClient *gitea.Client, actualDomain, mainDomainSuffix string, canonicalDomainCache cache.ICache) (domain string, valid bool) {
+ // Check if this request is cached.
if cachedValue, ok := canonicalDomainCache.Get(o.TargetOwner + "/" + o.TargetRepo + "/" + o.TargetBranch); ok {
- domains = cachedValue.([]string)
+ domains := cachedValue.([]string)
for _, domain := range domains {
if domain == actualDomain {
valid = true
break
}
}
- } else {
- body, err := giteaClient.GiteaRawContent(o.TargetOwner, o.TargetRepo, o.TargetBranch, canonicalDomainConfig)
- if err == nil {
- for _, domain := range strings.Split(string(body), "\n") {
- domain = strings.ToLower(domain)
- domain = strings.TrimSpace(domain)
- domain = strings.TrimPrefix(domain, "http://")
- domain = strings.TrimPrefix(domain, "https://")
- if len(domain) > 0 && !strings.HasPrefix(domain, "#") && !strings.ContainsAny(domain, "\t /") && strings.ContainsRune(domain, '.') {
- domains = append(domains, domain)
- }
- if domain == actualDomain {
- valid = true
- }
- }
- } else {
- log.Info().Err(err).Msgf("could not read %s of %s/%s", canonicalDomainConfig, o.TargetOwner, o.TargetRepo)
+ return domains[0], valid
+ }
+
+ body, err := giteaClient.GiteaRawContent(ctx, o.TargetOwner, o.TargetRepo, o.TargetBranch, canonicalDomainConfig)
+ if err != nil && !errors.Is(err, gitea.ErrorNotFound) {
+ log.Error().Err(err).Msgf("could not read %s of %s/%s", canonicalDomainConfig, o.TargetOwner, o.TargetRepo)
+ }
+
+ var domains []string
+ for _, domain := range strings.Split(string(body), "\n") {
+ domain = strings.ToLower(domain)
+ domain = strings.TrimSpace(domain)
+ domain = strings.TrimPrefix(domain, "http://")
+ domain = strings.TrimPrefix(domain, "https://")
+ if domain != "" && !strings.HasPrefix(domain, "#") && !strings.ContainsAny(domain, "\t /") && strings.ContainsRune(domain, '.') {
+ domains = append(domains, domain)
}
- domains = append(domains, o.TargetOwner+mainDomainSuffix)
- if domains[len(domains)-1] == actualDomain {
+ if domain == actualDomain {
valid = true
}
- if o.TargetRepo != "" && o.TargetRepo != "pages" {
- domains[len(domains)-1] += "/" + o.TargetRepo
- }
- _ = canonicalDomainCache.Set(o.TargetOwner+"/"+o.TargetRepo+"/"+o.TargetBranch, domains, canonicalDomainCacheTimeout)
}
+
+ // Add [owner].[pages-domain] as valid domain.
+ domains = append(domains, o.TargetOwner+mainDomainSuffix)
+ if domains[len(domains)-1] == actualDomain {
+ valid = true
+ }
+
+ // If the target repository isn't called pages, add `/[repository]` to the
+ // previous valid domain.
+ if o.TargetRepo != "" && o.TargetRepo != "pages" {
+ domains[len(domains)-1] += "/" + o.TargetRepo
+ }
+
+ // Add result to cache.
+ _ = canonicalDomainCache.Set(o.TargetOwner+"/"+o.TargetRepo+"/"+o.TargetBranch, domains, canonicalDomainCacheTimeout)
+
+ // Return the first domain from the list and return if any of the domains
+ // matched the requested domain.
return domains[0], valid
}
diff --git a/server/upstream/header.go b/server/upstream/header.go
index 9575a3f..3a218a1 100644
--- a/server/upstream/header.go
+++ b/server/upstream/header.go
@@ -24,5 +24,8 @@ func (o *Options) setHeader(ctx *context.Context, header http.Header) {
} else {
ctx.RespWriter.Header().Set(gitea.ContentTypeHeader, mime)
}
- ctx.RespWriter.Header().Set(headerLastModified, o.BranchTimestamp.In(time.UTC).Format(time.RFC1123))
+ if encoding := header.Get(gitea.ContentEncodingHeader); encoding != "" && encoding != "identity" {
+ ctx.RespWriter.Header().Set(gitea.ContentEncodingHeader, encoding)
+ }
+ ctx.RespWriter.Header().Set(headerLastModified, o.BranchTimestamp.In(time.UTC).Format(http.TimeFormat))
}
diff --git a/server/upstream/helper.go b/server/upstream/helper.go
index a84d4f0..ac0ab3f 100644
--- a/server/upstream/helper.go
+++ b/server/upstream/helper.go
@@ -17,17 +17,17 @@ func (o *Options) GetBranchTimestamp(giteaClient *gitea.Client) (bool, error) {
// Get default branch
defaultBranch, err := giteaClient.GiteaGetRepoDefaultBranch(o.TargetOwner, o.TargetRepo)
if err != nil {
- log.Err(err).Msg("Could't fetch default branch from repository")
+ log.Err(err).Msg("Couldn't fetch default branch from repository")
return false, err
}
- log.Debug().Msgf("Succesfully fetched default branch %q from Gitea", defaultBranch)
+ log.Debug().Msgf("Successfully fetched default branch %q from Gitea", defaultBranch)
o.TargetBranch = defaultBranch
}
timestamp, err := giteaClient.GiteaGetRepoBranchTimestamp(o.TargetOwner, o.TargetRepo, o.TargetBranch)
if err != nil {
if !errors.Is(err, gitea.ErrorNotFound) {
- log.Error().Err(err).Msg("Could not get latest commit's timestamp from branch")
+ log.Error().Err(err).Msg("Could not get latest commit timestamp from branch")
}
return false, err
}
@@ -36,7 +36,7 @@ func (o *Options) GetBranchTimestamp(giteaClient *gitea.Client) (bool, error) {
return false, fmt.Errorf("empty response")
}
- log.Debug().Msgf("Succesfully fetched latest commit's timestamp from branch: %#v", timestamp)
+ log.Debug().Msgf("Successfully fetched latest commit timestamp from branch: %#v", timestamp)
o.BranchTimestamp = timestamp.Timestamp
o.TargetBranch = timestamp.Branch
return true, nil
diff --git a/server/upstream/redirects.go b/server/upstream/redirects.go
new file mode 100644
index 0000000..b0762d5
--- /dev/null
+++ b/server/upstream/redirects.go
@@ -0,0 +1,108 @@
+package upstream
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ "codeberg.org/codeberg/pages/server/cache"
+ "codeberg.org/codeberg/pages/server/context"
+ "codeberg.org/codeberg/pages/server/gitea"
+ "github.com/rs/zerolog/log"
+)
+
+type Redirect struct {
+ From string
+ To string
+ StatusCode int
+}
+
+// rewriteURL returns the destination URL and true if r matches reqURL.
+func (r *Redirect) rewriteURL(reqURL string) (dstURL string, ok bool) {
+ // check if from url matches request url
+ if strings.TrimSuffix(r.From, "/") == strings.TrimSuffix(reqURL, "/") {
+ return r.To, true
+ }
+ // handle wildcard redirects
+ if strings.HasSuffix(r.From, "/*") {
+ trimmedFromURL := strings.TrimSuffix(r.From, "/*")
+ if reqURL == trimmedFromURL || strings.HasPrefix(reqURL, trimmedFromURL+"/") {
+ if strings.Contains(r.To, ":splat") {
+ matched := strings.TrimPrefix(reqURL, trimmedFromURL)
+ matched = strings.TrimPrefix(matched, "/")
+ return strings.ReplaceAll(r.To, ":splat", matched), true
+ }
+ return r.To, true
+ }
+ }
+ return "", false
+}
+
+// redirectsCacheTimeout specifies the timeout for the redirects cache.
+var redirectsCacheTimeout = 10 * time.Minute
+
+const redirectsConfig = "_redirects"
+
+// getRedirects returns redirects specified in the _redirects file.
+func (o *Options) getRedirects(ctx *context.Context, giteaClient *gitea.Client, redirectsCache cache.ICache) []Redirect {
+ var redirects []Redirect
+ cacheKey := o.TargetOwner + "/" + o.TargetRepo + "/" + o.TargetBranch
+
+ // Check for cached redirects
+ if cachedValue, ok := redirectsCache.Get(cacheKey); ok {
+ redirects = cachedValue.([]Redirect)
+ } else {
+ // Get _redirects file and parse
+ body, err := giteaClient.GiteaRawContent(ctx, o.TargetOwner, o.TargetRepo, o.TargetBranch, redirectsConfig)
+ if err == nil {
+ for _, line := range strings.Split(string(body), "\n") {
+ redirectArr := strings.Fields(line)
+
+ // Ignore comments and invalid lines
+ if strings.HasPrefix(line, "#") || len(redirectArr) < 2 {
+ continue
+ }
+
+ // Get redirect status code
+ statusCode := 301
+ if len(redirectArr) == 3 {
+ statusCode, err = strconv.Atoi(redirectArr[2])
+ if err != nil {
+ log.Info().Err(err).Msgf("could not read %s of %s/%s", redirectsConfig, o.TargetOwner, o.TargetRepo)
+ }
+ }
+
+ redirects = append(redirects, Redirect{
+ From: redirectArr[0],
+ To: redirectArr[1],
+ StatusCode: statusCode,
+ })
+ }
+ }
+ _ = redirectsCache.Set(cacheKey, redirects, redirectsCacheTimeout)
+ }
+ return redirects
+}
+
+func (o *Options) matchRedirects(ctx *context.Context, giteaClient *gitea.Client, redirects []Redirect, redirectsCache cache.ICache) (final bool) {
+ reqURL := ctx.Req.RequestURI
+ // remove repo and branch from request url
+ reqURL = strings.TrimPrefix(reqURL, "/"+o.TargetRepo)
+ reqURL = strings.TrimPrefix(reqURL, "/@"+o.TargetBranch)
+
+ for _, redirect := range redirects {
+ if dstURL, ok := redirect.rewriteURL(reqURL); ok {
+ if o.TargetPath == dstURL { // recursion base case, rewrite directly when paths are the same
+ return true
+ } else if redirect.StatusCode == 200 { // do rewrite if status code is 200
+ o.TargetPath = dstURL
+ o.Upstream(ctx, giteaClient, redirectsCache)
+ } else {
+ ctx.Redirect(dstURL, redirect.StatusCode)
+ }
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/server/upstream/redirects_test.go b/server/upstream/redirects_test.go
new file mode 100644
index 0000000..6118a70
--- /dev/null
+++ b/server/upstream/redirects_test.go
@@ -0,0 +1,36 @@
+package upstream
+
+import (
+ "testing"
+)
+
+func TestRedirect_rewriteURL(t *testing.T) {
+ for _, tc := range []struct {
+ redirect Redirect
+ reqURL string
+ wantDstURL string
+ wantOk bool
+ }{
+ {Redirect{"/", "/dst", 200}, "/", "/dst", true},
+ {Redirect{"/", "/dst", 200}, "/foo", "", false},
+ {Redirect{"/src", "/dst", 200}, "/src", "/dst", true},
+ {Redirect{"/src", "/dst", 200}, "/foo", "", false},
+ {Redirect{"/src", "/dst", 200}, "/src/foo", "", false},
+ {Redirect{"/*", "/dst", 200}, "/", "/dst", true},
+ {Redirect{"/*", "/dst", 200}, "/src", "/dst", true},
+ {Redirect{"/src/*", "/dst/:splat", 200}, "/src", "/dst/", true},
+ {Redirect{"/src/*", "/dst/:splat", 200}, "/src/", "/dst/", true},
+ {Redirect{"/src/*", "/dst/:splat", 200}, "/src/foo", "/dst/foo", true},
+ {Redirect{"/src/*", "/dst/:splat", 200}, "/src/foo/bar", "/dst/foo/bar", true},
+ {Redirect{"/src/*", "/dst/:splatsuffix", 200}, "/src/foo", "/dst/foosuffix", true},
+ {Redirect{"/src/*", "/dst:splat", 200}, "/src/foo", "/dstfoo", true},
+ {Redirect{"/src/*", "/dst", 200}, "/srcfoo", "", false},
+ // This is the example from FEATURES.md:
+ {Redirect{"/articles/*", "/posts/:splat", 302}, "/articles/2022/10/12/post-1/", "/posts/2022/10/12/post-1/", true},
+ } {
+ if dstURL, ok := tc.redirect.rewriteURL(tc.reqURL); dstURL != tc.wantDstURL || ok != tc.wantOk {
+ t.Errorf("%#v.rewriteURL(%q) = %q, %v; want %q, %v",
+ tc.redirect, tc.reqURL, dstURL, ok, tc.wantDstURL, tc.wantOk)
+ }
+ }
+}
diff --git a/server/upstream/upstream.go b/server/upstream/upstream.go
index 7c3c848..9aac271 100644
--- a/server/upstream/upstream.go
+++ b/server/upstream/upstream.go
@@ -1,16 +1,20 @@
package upstream
import (
+ "cmp"
"errors"
"fmt"
"io"
"net/http"
+ "slices"
+ "strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"codeberg.org/codeberg/pages/html"
+ "codeberg.org/codeberg/pages/server/cache"
"codeberg.org/codeberg/pages/server/context"
"codeberg.org/codeberg/pages/server/gitea"
)
@@ -18,6 +22,8 @@ import (
const (
headerLastModified = "Last-Modified"
headerIfModifiedSince = "If-Modified-Since"
+ headerAcceptEncoding = "Accept-Encoding"
+ headerContentEncoding = "Content-Encoding"
rawMime = "text/plain; charset=utf-8"
)
@@ -51,12 +57,80 @@ type Options struct {
ServeRaw bool
}
+// allowed encodings
+var allowedEncodings = map[string]string{
+ "gzip": ".gz",
+ "br": ".br",
+ "zstd": ".zst",
+ "identity": "",
+}
+
+// parses Accept-Encoding header into a list of acceptable encodings
+func AcceptEncodings(header string) []string {
+ log.Trace().Msgf("got accept-encoding: %s", header)
+ encodings := []string{}
+ globQuality := 0.0
+ qualities := make(map[string]float64)
+
+ for _, encoding := range strings.Split(header, ",") {
+ name, quality_str, has_quality := strings.Cut(encoding, ";q=")
+ quality := 1.0
+
+ if has_quality {
+ var err error
+ quality, err = strconv.ParseFloat(quality_str, 64)
+ if err != nil || quality < 0 {
+ continue
+ }
+ }
+
+ name = strings.TrimSpace(name)
+
+ if name == "*" {
+ globQuality = quality
+ } else {
+ _, allowed := allowedEncodings[name]
+ if allowed {
+ qualities[name] = quality
+ if quality > 0 {
+ encodings = append(encodings, name)
+ }
+ }
+ }
+ }
+
+ if globQuality > 0 {
+ for encoding := range allowedEncodings {
+ _, exists := qualities[encoding]
+ if !exists {
+ encodings = append(encodings, encoding)
+ qualities[encoding] = globQuality
+ }
+ }
+ } else {
+ _, exists := qualities["identity"]
+ if !exists {
+ encodings = append(encodings, "identity")
+ qualities["identity"] = -1
+ }
+ }
+
+ slices.SortStableFunc(encodings, func(x, y string) int {
+ // sort in reverse order; big quality comes first
+ return cmp.Compare(qualities[y], qualities[x])
+ })
+ log.Trace().Msgf("decided encoding order: %v", encodings)
+ return encodings
+}
+
// Upstream requests a file from the Gitea API at GiteaRoot and writes it to the request context.
-func (o *Options) Upstream(ctx *context.Context, giteaClient *gitea.Client) (final bool) {
- log := log.With().Strs("upstream", []string{o.TargetOwner, o.TargetRepo, o.TargetBranch, o.TargetPath}).Logger()
+func (o *Options) Upstream(ctx *context.Context, giteaClient *gitea.Client, redirectsCache cache.ICache) bool {
+ log := log.With().Str("ReqId", ctx.ReqId).Strs("upstream", []string{o.TargetOwner, o.TargetRepo, o.TargetBranch, o.TargetPath}).Logger()
+
+ log.Debug().Msg("Start")
if o.TargetOwner == "" || o.TargetRepo == "" {
- html.ReturnErrorPage(ctx, "either repo owner or name info is missing", http.StatusBadRequest)
+ html.ReturnErrorPage(ctx, "forge client: either repo owner or name info is missing", http.StatusBadRequest)
return true
}
@@ -66,7 +140,7 @@ func (o *Options) Upstream(ctx *context.Context, giteaClient *gitea.Client) (fin
// handle 404
if err != nil && errors.Is(err, gitea.ErrorNotFound) || !branchExist {
html.ReturnErrorPage(ctx,
- fmt.Sprintf("branch %q for '%s/%s' not found", o.TargetBranch, o.TargetOwner, o.TargetRepo),
+ fmt.Sprintf("branch %q
for %s/%s
not found", o.TargetBranch, o.TargetOwner, o.TargetRepo),
http.StatusNotFound)
return true
}
@@ -74,7 +148,7 @@ func (o *Options) Upstream(ctx *context.Context, giteaClient *gitea.Client) (fin
// handle unexpected errors
if err != nil {
html.ReturnErrorPage(ctx,
- fmt.Sprintf("could not get timestamp of branch %q: %v", o.TargetBranch, err),
+ fmt.Sprintf("could not get timestamp of branch %q
: '%v'", o.TargetBranch, err),
http.StatusFailedDependency)
return true
}
@@ -94,48 +168,90 @@ func (o *Options) Upstream(ctx *context.Context, giteaClient *gitea.Client) (fin
log.Debug().Msg("Preparing")
- reader, header, statusCode, err := giteaClient.ServeRawContent(o.TargetOwner, o.TargetRepo, o.TargetBranch, o.TargetPath)
- if reader != nil {
- defer reader.Close()
+ var reader io.ReadCloser
+ var header http.Header
+ var statusCode int
+ var err error
+
+ // pick first non-404 response for encoding, *only* if not root
+ if o.TargetPath == "" || strings.HasSuffix(o.TargetPath, "/") {
+ err = gitea.ErrorNotFound
+ } else {
+ for _, encoding := range AcceptEncodings(ctx.Req.Header.Get(headerAcceptEncoding)) {
+ log.Trace().Msgf("try %s encoding", encoding)
+
+ // add extension for encoding
+ path := o.TargetPath + allowedEncodings[encoding]
+ reader, header, statusCode, err = giteaClient.ServeRawContent(ctx, o.TargetOwner, o.TargetRepo, o.TargetBranch, path, true)
+ if statusCode == http.StatusNotFound {
+ continue
+ }
+ if err != nil {
+ break
+ }
+ log.Debug().Msgf("using %s encoding", encoding)
+ if encoding != "identity" {
+ header.Set(headerContentEncoding, encoding)
+ }
+ break
+ }
+ if reader != nil {
+ defer reader.Close()
+ }
}
log.Debug().Msg("Aquisting")
// Handle not found error
if err != nil && errors.Is(err, gitea.ErrorNotFound) {
+ log.Debug().Msg("Handling not found error")
+ // Get and match redirects
+ redirects := o.getRedirects(ctx, giteaClient, redirectsCache)
+ if o.matchRedirects(ctx, giteaClient, redirects, redirectsCache) {
+ log.Trace().Msg("redirect")
+ return true
+ }
+
if o.TryIndexPages {
+ log.Trace().Msg("try index page")
// copy the o struct & try if an index page exists
optionsForIndexPages := *o
optionsForIndexPages.TryIndexPages = false
optionsForIndexPages.appendTrailingSlash = true
for _, indexPage := range upstreamIndexPages {
optionsForIndexPages.TargetPath = strings.TrimSuffix(o.TargetPath, "/") + "/" + indexPage
- if optionsForIndexPages.Upstream(ctx, giteaClient) {
+ if optionsForIndexPages.Upstream(ctx, giteaClient, redirectsCache) {
return true
}
}
+ log.Trace().Msg("try html file with path name")
// compatibility fix for GitHub Pages (/example → /example.html)
optionsForIndexPages.appendTrailingSlash = false
optionsForIndexPages.redirectIfExists = strings.TrimSuffix(ctx.Path(), "/") + ".html"
optionsForIndexPages.TargetPath = o.TargetPath + ".html"
- if optionsForIndexPages.Upstream(ctx, giteaClient) {
+ if optionsForIndexPages.Upstream(ctx, giteaClient, redirectsCache) {
return true
}
}
+ log.Debug().Msg("not found")
+
ctx.StatusCode = http.StatusNotFound
if o.TryIndexPages {
+ log.Trace().Msg("try not found page")
// copy the o struct & try if a not found page exists
optionsForNotFoundPages := *o
optionsForNotFoundPages.TryIndexPages = false
optionsForNotFoundPages.appendTrailingSlash = false
for _, notFoundPage := range upstreamNotFoundPages {
optionsForNotFoundPages.TargetPath = "/" + notFoundPage
- if optionsForNotFoundPages.Upstream(ctx, giteaClient) {
+ if optionsForNotFoundPages.Upstream(ctx, giteaClient, redirectsCache) {
return true
}
}
+ log.Trace().Msg("not found page missing")
}
+
return false
}
@@ -145,16 +261,16 @@ func (o *Options) Upstream(ctx *context.Context, giteaClient *gitea.Client) (fin
var msg string
if err != nil {
- msg = "gitea client returned unexpected error"
+ msg = "forge client: returned unexpected error"
log.Error().Err(err).Msg(msg)
- msg = fmt.Sprintf("%s: %v", msg, err)
+ msg = fmt.Sprintf("%s: '%v'", msg, err)
}
if reader == nil {
- msg = "gitea client returned no reader"
+ msg = "forge client: returned no reader"
log.Error().Msg(msg)
}
if statusCode != http.StatusOK {
- msg = fmt.Sprintf("Couldn't fetch contents (status code %d)", statusCode)
+ msg = fmt.Sprintf("forge client: couldn't fetch contents: %d - %s
", statusCode, http.StatusText(statusCode))
log.Error().Msg(msg)
}
@@ -165,10 +281,12 @@ func (o *Options) Upstream(ctx *context.Context, giteaClient *gitea.Client) (fin
// Append trailing slash if missing (for index files), and redirect to fix filenames in general
// o.appendTrailingSlash is only true when looking for index pages
if o.appendTrailingSlash && !strings.HasSuffix(ctx.Path(), "/") {
+ log.Trace().Msg("append trailing slash and redirect")
ctx.Redirect(ctx.Path()+"/", http.StatusTemporaryRedirect)
return true
}
- if strings.HasSuffix(ctx.Path(), "/index.html") {
+ if strings.HasSuffix(ctx.Path(), "/index.html") && !o.ServeRaw {
+ log.Trace().Msg("remove index.html from path and redirect")
ctx.Redirect(strings.TrimSuffix(ctx.Path(), "index.html"), http.StatusTemporaryRedirect)
return true
}
diff --git a/server/utils/utils.go b/server/utils/utils.go
index 30f948d..91ed359 100644
--- a/server/utils/utils.go
+++ b/server/utils/utils.go
@@ -1,6 +1,8 @@
package utils
import (
+ "net/url"
+ "path"
"strings"
)
@@ -11,3 +13,15 @@ func TrimHostPort(host string) string {
}
return host
}
+
+func CleanPath(uriPath string) string {
+ unescapedPath, _ := url.PathUnescape(uriPath)
+ cleanedPath := path.Join("/", unescapedPath)
+
+ // If the path refers to a directory, add a trailing slash.
+ if !strings.HasSuffix(cleanedPath, "/") && (strings.HasSuffix(unescapedPath, "/") || strings.HasSuffix(unescapedPath, "/.") || strings.HasSuffix(unescapedPath, "/..")) {
+ cleanedPath += "/"
+ }
+
+ return cleanedPath
+}
diff --git a/server/utils/utils_test.go b/server/utils/utils_test.go
index 2532392..b8fcea9 100644
--- a/server/utils/utils_test.go
+++ b/server/utils/utils_test.go
@@ -11,3 +11,59 @@ func TestTrimHostPort(t *testing.T) {
assert.EqualValues(t, "", TrimHostPort(":"))
assert.EqualValues(t, "example.com", TrimHostPort("example.com:80"))
}
+
+// TestCleanPath is mostly copied from fasthttp, to keep the behaviour we had before migrating away from it.
+// Source (MIT licensed): https://github.com/valyala/fasthttp/blob/v1.48.0/uri_test.go#L154
+// Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors
+func TestCleanPath(t *testing.T) {
+ // double slash
+ testURIPathNormalize(t, "/aa//bb", "/aa/bb")
+
+ // triple slash
+ testURIPathNormalize(t, "/x///y/", "/x/y/")
+
+ // multi slashes
+ testURIPathNormalize(t, "/abc//de///fg////", "/abc/de/fg/")
+
+ // encoded slashes
+ testURIPathNormalize(t, "/xxxx%2fyyy%2f%2F%2F", "/xxxx/yyy/")
+
+ // dotdot
+ testURIPathNormalize(t, "/aaa/..", "/")
+
+ // dotdot with trailing slash
+ testURIPathNormalize(t, "/xxx/yyy/../", "/xxx/")
+
+ // multi dotdots
+ testURIPathNormalize(t, "/aaa/bbb/ccc/../../ddd", "/aaa/ddd")
+
+ // dotdots separated by other data
+ testURIPathNormalize(t, "/a/b/../c/d/../e/..", "/a/c/")
+
+ // too many dotdots
+ testURIPathNormalize(t, "/aaa/../../../../xxx", "/xxx")
+ testURIPathNormalize(t, "/../../../../../..", "/")
+ testURIPathNormalize(t, "/../../../../../../", "/")
+
+ // encoded dotdots
+ testURIPathNormalize(t, "/aaa%2Fbbb%2F%2E.%2Fxxx", "/aaa/xxx")
+
+ // double slash with dotdots
+ testURIPathNormalize(t, "/aaa////..//b", "/b")
+
+ // fake dotdot
+ testURIPathNormalize(t, "/aaa/..bbb/ccc/..", "/aaa/..bbb/")
+
+ // single dot
+ testURIPathNormalize(t, "/a/./b/././c/./d.html", "/a/b/c/d.html")
+ testURIPathNormalize(t, "./foo/", "/foo/")
+ testURIPathNormalize(t, "./../.././../../aaa/bbb/../../../././../", "/")
+ testURIPathNormalize(t, "./a/./.././../b/./foo.html", "/b/foo.html")
+}
+
+func testURIPathNormalize(t *testing.T, requestURI, expectedPath string) {
+ cleanedPath := CleanPath(requestURI)
+ if cleanedPath != expectedPath {
+ t.Fatalf("Unexpected path %q. Expected %q. requestURI=%q", cleanedPath, expectedPath, requestURI)
+ }
+}