diff --git a/README.md b/README.md index e8b5018..b97d05f 100644 --- a/README.md +++ b/README.md @@ -136,8 +136,8 @@ spec: - serve - --cache-hostname=nix-cache.yournetwork.local # TODO: Replace with your own hostname - --cache-data-path=/storage - - --upstream-cache=cache.nixos.org - - --upstream-cache=nix-community.cachix.org + - --upstream-cache=https://cache.nixos.org + - --upstream-cache=https://nix-community.cachix.org - --upstream-public-key=cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY= - --upstream-public-key=nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs= ports: @@ -189,7 +189,7 @@ ncps can be configured using the following flags: - `--cache-lru-schedule`: The cron spec for cleaning the store to keep it under `--cache-max-size`. Refer to https://pkg.go.dev/github.com/robfig/cron/v3#hdr-Usage for documentation (Environment variable: `$CACHE_LRU_SCHEDULE`) - `--cache-lru-schedule-timezone`: The name of the timezone to use for the cron schedule (default: "Local"). (Environment variable: `$CACHE_LRU_SCHEDULE_TZ`) - `--server-addr`: The address and port the server listens on (default: ":8501"). (Environment variable: `$SERVER_ADDR`) -- `--upstream-cache`: The **hostname** of an upstream binary cache (e.g., `cache.nixos.org`). **Do not include the scheme (https://).** This flag can be used multiple times to specify multiple upstream caches, for example: `--upstream-cache cache.nixos.org --upstream-cache nix-community.cachix.org`. (Environment variable: `$UPSTREAM_CACHES`) +- `--upstream-cache`: The URL of an upstream binary cache (e.g., `https://cache.nixos.org`). This flag can be used multiple times to specify multiple upstream caches. (Environment variable: `$UPSTREAM_CACHES`) - `--upstream-public-key`: The public key of an upstream cache in the format `host:public-key`. This flag is used to verify the signatures of store paths downloaded from upstream caches. This flag can be used multiple times, once for each upstream cache. (Environment variable: `$UPSTREAM_PUBLIC_KEYS`) ## Nix Configuration diff --git a/cmd/serve.go b/cmd/serve.go index 3a1a61b..4487894 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "regexp" "time" @@ -86,7 +87,7 @@ func serveCommand(logger log15.Logger) *cli.Command { }, &cli.StringSliceFlag{ Name: "upstream-cache", - Usage: "Set to host for each upstream cache", + Usage: "Set to URL (with scheme) for each upstream cache", Sources: cli.EnvVars("UPSTREAM_CACHES"), Required: true, }, @@ -136,10 +137,15 @@ func getUpstreamCaches(_ context.Context, logger log15.Logger, cmd *cli.Command) ucs := make([]upstream.Cache, 0, len(ucSlice)) - for _, host := range ucSlice { + for _, us := range ucSlice { var pubKeys []string - rx := regexp.MustCompile(fmt.Sprintf(`^%s-[0-9]+:[A-Za-z0-9+/=]+$`, regexp.QuoteMeta(host))) + u, err := url.Parse(us) + if err != nil { + return nil, fmt.Errorf("error parsing --upstream-cache=%q: %w", us, err) + } + + rx := regexp.MustCompile(fmt.Sprintf(`^%s-[0-9]+:[A-Za-z0-9+/=]+$`, regexp.QuoteMeta(u.Host))) for _, pubKey := range cmd.StringSlice("upstream-public-key") { if rx.MatchString(pubKey) { @@ -147,7 +153,7 @@ func getUpstreamCaches(_ context.Context, logger log15.Logger, cmd *cli.Command) } } - uc, err := upstream.New(logger, host, pubKeys) + uc, err := upstream.New(logger, u, pubKeys) if err != nil { return nil, fmt.Errorf("error creating a new upstream cache: %w", err) } diff --git a/pkg/cache/cache_internal_test.go b/pkg/cache/cache_internal_test.go index 8ab1527..76eaab5 100644 --- a/pkg/cache/cache_internal_test.go +++ b/pkg/cache/cache_internal_test.go @@ -5,7 +5,6 @@ import ( "database/sql" "math/rand/v2" "net/http/httptest" - "net/url" "os" "path/filepath" "testing" @@ -50,10 +49,7 @@ func TestAddUpstreamCaches(t *testing.T) { for _, idx := range randomOrder { ts := testServers[idx] - u, err := url.Parse(ts.URL) - require.NoError(t, err) - - uc, err := upstream.New(logger, u.Host, nil) + uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), nil) require.NoError(t, err) ucs = append(ucs, uc) @@ -100,10 +96,7 @@ func TestAddUpstreamCaches(t *testing.T) { for _, idx := range randomOrder { ts := testServers[idx] - u, err := url.Parse(ts.URL) - require.NoError(t, err) - - uc, err := upstream.New(logger, u.Host, nil) + uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), nil) require.NoError(t, err) ucs = append(ucs, uc) @@ -147,10 +140,7 @@ func TestRunLRU(t *testing.T) { ts := testdata.HTTPTestServer(t, 40) defer ts.Close() - tu, err := url.Parse(ts.URL) - require.NoError(t, err) - - uc, err := upstream.New(logger, tu.Host, testdata.PublicKeys()) + uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), nil) require.NoError(t, err) c.AddUpstreamCaches(uc) @@ -175,12 +165,12 @@ func TestRunLRU(t *testing.T) { var sizePulled int64 - for _, nar := range allEntries { + for i, nar := range allEntries { _, err := c.GetNarInfo(context.Background(), nar.NarInfoHash) - require.NoError(t, err) + require.NoErrorf(t, err, "unable to get narinfo for idx %d", i) size, _, err := c.GetNar(context.Background(), nar.NarHash, "xz") - require.NoError(t, err) + require.NoError(t, err, "unable to get nar for idx %d", i) sizePulled += size } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index 772a6ba..0b12d21 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "io" - "net/url" "os" "path/filepath" "strings" @@ -171,14 +170,11 @@ func TestGetNarInfo(t *testing.T) { ts := testdata.HTTPTestServer(t, 40) defer ts.Close() - tu, err := url.Parse(ts.URL) - require.NoError(t, err) - dir, err := os.MkdirTemp("", "cache-path-") require.NoError(t, err) defer os.RemoveAll(dir) // clean up - uc, err := upstream.New(logger, tu.Host, testdata.PublicKeys()) + uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), testdata.PublicKeys()) require.NoError(t, err) dbFile := filepath.Join(dir, "var", "ncps", "db", "db.sqlite") @@ -660,14 +656,11 @@ func TestGetNar(t *testing.T) { ts := testdata.HTTPTestServer(t, 40) defer ts.Close() - tu, err := url.Parse(ts.URL) - require.NoError(t, err) - dir, err := os.MkdirTemp("", "cache-path-") require.NoError(t, err) defer os.RemoveAll(dir) // clean up - uc, err := upstream.New(logger, tu.Host, testdata.PublicKeys()) + uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), testdata.PublicKeys()) require.NoError(t, err) dbFile := filepath.Join(dir, "var", "ncps", "db", "db.sqlite") diff --git a/pkg/cache/upstream/cache.go b/pkg/cache/upstream/cache.go index d6d69b5..f8ed513 100644 --- a/pkg/cache/upstream/cache.go +++ b/pkg/cache/upstream/cache.go @@ -6,8 +6,8 @@ import ( "fmt" "io" "net/http" + "net/url" "strconv" - "strings" "time" "github.com/inconshreveable/log15/v3" @@ -19,17 +19,14 @@ import ( ) var ( - // ErrHostnameRequired is returned if the given hostName to New is not given. - ErrHostnameRequired = errors.New("hostName is required") + // ErrURLRequired is returned if the given URL to New is not given. + ErrURLRequired = errors.New("the URL is required") - // ErrHostnameMustNotContainScheme is returned if the given hostName to New contained a scheme. - ErrHostnameMustNotContainScheme = errors.New("hostName must not contain scheme") + // ErrURLMustContainScheme is returned if the given URL to New did not contain a scheme. + ErrURLMustContainScheme = errors.New("the URL must contain scheme") - // ErrHostnameNotValid is returned if the given hostName to New is not valid. - ErrHostnameNotValid = errors.New("hostName is not valid") - - // ErrHostnameMustNotContainPath is returned if the given hostName to New contained a path. - ErrHostnameMustNotContainPath = errors.New("hostName must not contain a path") + // ErrInvalidURL is returned if the given hostName to New is not valid. + ErrInvalidURL = errors.New("the URL is not valid") // ErrNotFound is returned if the nar or narinfo were not found. ErrNotFound = errors.New("not found") @@ -43,16 +40,16 @@ var ( // Cache represents the upstream cache service. type Cache struct { - hostName string + url *url.URL logger log15.Logger priority uint64 publicKeys []signature.PublicKey } -func New(logger log15.Logger, hostName string, pubKeys []string) (Cache, error) { - c := Cache{logger: logger, hostName: hostName} +func New(logger log15.Logger, u *url.URL, pubKeys []string) (Cache, error) { + c := Cache{logger: logger, url: u} - if err := c.validateHostname(hostName); err != nil { + if err := c.validateURL(u); err != nil { return c, err } @@ -67,7 +64,7 @@ func New(logger log15.Logger, hostName string, pubKeys []string) (Cache, error) priority, err := c.parsePriority() if err != nil { - return c, fmt.Errorf("error parsing the priority for %q: %w", hostName, err) + return c, fmt.Errorf("error parsing the priority for %q: %w", u, err) } c.priority = priority @@ -76,11 +73,11 @@ func New(logger log15.Logger, hostName string, pubKeys []string) (Cache, error) } // GetHostname returns the hostname. -func (c Cache) GetHostname() string { return c.hostName } +func (c Cache) GetHostname() string { return c.url.Hostname() } // GetNarInfo returns a parsed NarInfo from the cache server. func (c Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo, error) { - r, err := http.NewRequestWithContext(ctx, "GET", c.getHostnameWithScheme()+helper.NarInfoURLPath(hash), nil) + r, err := http.NewRequestWithContext(ctx, "GET", c.url.JoinPath(helper.NarInfoURLPath(hash)).String(), nil) if err != nil { return nil, fmt.Errorf("error creating a new request: %w", err) } @@ -128,7 +125,7 @@ func (c Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo, e func (c Cache) GetNar(ctx context.Context, hash, compression string) (int64, io.ReadCloser, error) { log := c.logger.New("hash", hash, "compression", compression) - r, err := http.NewRequestWithContext(ctx, "GET", c.getHostnameWithScheme()+helper.NarURLPath(hash, compression), nil) + r, err := http.NewRequestWithContext(ctx, "GET", c.url.JoinPath(helper.NarURLPath(hash, compression)).String(), nil) if err != nil { return 0, nil, fmt.Errorf("error creating a new request: %w", err) } @@ -168,15 +165,6 @@ func (c Cache) GetNar(ctx context.Context, hash, compression string) (int64, io. // GetPriority returns the priority of this upstream cache. func (c Cache) GetPriority() uint64 { return c.priority } -func (c Cache) getHostnameWithScheme() string { - scheme := "https" - if strings.HasPrefix(c.hostName, "127.0.0.1") { - scheme = "http" - } - - return scheme + "://" + c.hostName -} - func (c Cache) parsePriority() (uint64, error) { // TODO: Should probably pass context around and have things like logger in the context ctx := context.Background() @@ -184,7 +172,7 @@ func (c Cache) parsePriority() (uint64, error) { ctx, cancelFn := context.WithTimeout(ctx, 3*time.Second) defer cancelFn() - r, err := http.NewRequestWithContext(ctx, "GET", c.getHostnameWithScheme()+"/nix-cache-info", nil) + r, err := http.NewRequestWithContext(ctx, "GET", c.url.JoinPath("/nix-cache-info").String(), nil) if err != nil { return 0, fmt.Errorf("error creating a new request: %w", err) } @@ -210,23 +198,17 @@ func (c Cache) parsePriority() (uint64, error) { return nci.Priority, nil } -func (c Cache) validateHostname(hostName string) error { - if hostName == "" { - c.logger.Error("given hostname is empty", "hostName", hostName) - - return ErrHostnameRequired - } - - if strings.Contains(hostName, "://") { - c.logger.Error("hostname should not contain a scheme", "hostName", hostName) +func (c Cache) validateURL(u *url.URL) error { + if u == nil { + c.logger.Error("given url is nil", "url", u) - return ErrHostnameMustNotContainScheme + return ErrURLRequired } - if strings.Contains(hostName, "/") { - c.logger.Error("hostname should not contain a path", "hostName", hostName) + if u.Scheme == "" { + c.logger.Error("hostname should not contain a scheme", "url", u) - return ErrHostnameMustNotContainPath + return ErrURLMustContainScheme } return nil diff --git a/pkg/cache/upstream/cache_test.go b/pkg/cache/upstream/cache_test.go index 3dc7779..c1ff862 100644 --- a/pkg/cache/upstream/cache_test.go +++ b/pkg/cache/upstream/cache_test.go @@ -3,7 +3,6 @@ package upstream_test import ( "context" "io" - "net/url" "strings" "testing" @@ -13,6 +12,7 @@ import ( "github.com/kalbasit/ncps/pkg/cache/upstream" "github.com/kalbasit/ncps/testdata" + "github.com/kalbasit/ncps/testhelper" ) //nolint:gochecknoglobals @@ -30,22 +30,26 @@ func TestNew(t *testing.T) { t.Parallel() t.Run("hostname must not be empty", func(t *testing.T) { - _, err := upstream.New(logger, "", nil) - assert.ErrorIs(t, err, upstream.ErrHostnameRequired) + _, err := upstream.New(logger, nil, nil) + assert.ErrorIs(t, err, upstream.ErrURLRequired) }) t.Run("hostname must not contain scheme", func(t *testing.T) { - _, err := upstream.New(logger, "https://cache.nixos.org", nil) - assert.ErrorIs(t, err, upstream.ErrHostnameMustNotContainScheme) + _, err := upstream.New(logger, testhelper.MustParseURL(t, "cache.nixos.org"), nil) + assert.ErrorIs(t, err, upstream.ErrURLMustContainScheme) }) - t.Run("hostname must not contain a path", func(t *testing.T) { - _, err := upstream.New(logger, "cache.nixos.org/path/to", nil) - assert.ErrorIs(t, err, upstream.ErrHostnameMustNotContainPath) + t.Run("valid url with no path must not return no error", func(t *testing.T) { + _, err := upstream.New(logger, + testhelper.MustParseURL(t, "https://cache.nixos.org"), nil) + + assert.NoError(t, err) }) - t.Run("valid hostName must return no error", func(t *testing.T) { - _, err := upstream.New(logger, "cache.nixos.org", nil) + t.Run("valid url with only / must not return no error", func(t *testing.T) { + _, err := upstream.New(logger, + testhelper.MustParseURL(t, "https://cache.nixos.org/"), nil) + assert.NoError(t, err) }) }) @@ -54,14 +58,14 @@ func TestNew(t *testing.T) { t.Parallel() t.Run("invalid public keys", func(t *testing.T) { - _, err := upstream.New(logger, "cache.nixos.org", []string{"invalid"}) + _, err := upstream.New(logger, testhelper.MustParseURL(t, "https://cache.nixos.org"), []string{"invalid"}) assert.True(t, strings.HasPrefix(err.Error(), "error parsing the public key: public key is corrupt:")) }) t.Run("valid public keys", func(t *testing.T) { _, err := upstream.New( logger, - "cache.nixos.org", + testhelper.MustParseURL(t, "https://cache.nixos.org"), testdata.PublicKeys(), ) assert.NoError(t, err) @@ -73,7 +77,7 @@ func TestNew(t *testing.T) { c, err := upstream.New( logger, - "cache.nixos.org", + testhelper.MustParseURL(t, "https://cache.nixos.org"), testdata.PublicKeys(), ) require.NoError(t, err) @@ -90,21 +94,23 @@ func TestGetNarInfo(t *testing.T) { t.Parallel() var ( - c upstream.Cache - + c upstream.Cache err error ) + ts := testdata.HTTPTestServer(t, 40) + defer ts.Close() + if withKeys { c, err = upstream.New( logger, - "cache.nixos.org", + testhelper.MustParseURL(t, ts.URL), testdata.PublicKeys(), ) } else { c, err = upstream.New( logger, - "cache.nixos.org", + testhelper.MustParseURL(t, ts.URL), nil, ) } @@ -112,15 +118,11 @@ func TestGetNarInfo(t *testing.T) { require.NoError(t, err) t.Run("hash not found", func(t *testing.T) { - t.Parallel() - _, err := c.GetNarInfo(context.Background(), "abc123") assert.ErrorIs(t, err, upstream.ErrNotFound) }) t.Run("hash is found", func(t *testing.T) { - t.Parallel() - ni, err := c.GetNarInfo(context.Background(), testdata.Nar1.NarInfoHash) require.NoError(t, err) @@ -128,46 +130,16 @@ func TestGetNarInfo(t *testing.T) { }) t.Run("check has failed", func(t *testing.T) { - t.Parallel() - hash := "broken-" + testdata.Nar1.NarInfoHash - ts := testdata.HTTPTestServer(t, 40) - defer ts.Close() - - tu, err := url.Parse(ts.URL) - require.NoError(t, err) - - c, err := upstream.New( - logger, - tu.Host, - testdata.PublicKeys(), - ) - require.NoError(t, err) - _, err = c.GetNarInfo(context.Background(), hash) assert.ErrorContains(t, err, "error while checking the narInfo: invalid Reference[0]: notfound-path") }) for _, entry := range testdata.Entries { t.Run("check does not fail", func(t *testing.T) { - t.Parallel() - hash := entry.NarInfoHash - ts := testdata.HTTPTestServer(t, 40) - defer ts.Close() - - tu, err := url.Parse(ts.URL) - require.NoError(t, err) - - c, err := upstream.New( - logger, - tu.Host, - testdata.PublicKeys(), - ) - require.NoError(t, err) - _, err = c.GetNarInfo(context.Background(), hash) assert.NoError(t, err) }) @@ -185,7 +157,7 @@ func TestGetNarInfo(t *testing.T) { func TestGetNar(t *testing.T) { c, err := upstream.New( logger, - "cache.nixos.org", + testhelper.MustParseURL(t, "https://cache.nixos.org"), testdata.PublicKeys(), ) require.NoError(t, err) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index c9ac967..c15ce8b 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -5,7 +5,6 @@ import ( "io" "net/http" "net/http/httptest" - "net/url" "os" "path/filepath" "strings" @@ -38,10 +37,7 @@ func TestServeHTTP(t *testing.T) { hts := testdata.HTTPTestServer(t, 40) defer hts.Close() - htu, err := url.Parse(hts.URL) - require.NoError(t, err) - - uc, err := upstream.New(logger, htu.Host, testdata.PublicKeys()) + uc, err := upstream.New(logger, testhelper.MustParseURL(t, hts.URL), testdata.PublicKeys()) require.NoError(t, err) t.Run("DELETE requests", func(t *testing.T) { diff --git a/testhelper/url.go b/testhelper/url.go new file mode 100644 index 0000000..7c5c7c7 --- /dev/null +++ b/testhelper/url.go @@ -0,0 +1,18 @@ +package testhelper + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +// MustParseURL parses the url (string) and returns or fails the test. +func MustParseURL(t *testing.T, us string) *url.URL { + t.Helper() + + u, err := url.Parse(us) + require.NoError(t, err) + + return u +}