Skip to content

Commit

Permalink
pkg/cache/upstream: Take in a URL not a Hostname (#79)
Browse files Browse the repository at this point in the history
Setting up an upstream cache with a hostname instead of a URL was a shortsighted decision that led only to problems and workarounds; Take in a URL instead.
  • Loading branch information
kalbasit authored Dec 12, 2024
1 parent 03c5dfa commit 9045e70
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 130 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ spec:
- serve
- --cache-hostname=nix-cache.yournetwork.local # TODO: Replace with your own hostname
- --cache-data-path=/storage
- --upstream-cache=cache.nixos.org
- --upstream-cache=nix-community.cachix.org
- --upstream-cache=https://cache.nixos.org
- --upstream-cache=https://nix-community.cachix.org
- --upstream-public-key=cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY=
- --upstream-public-key=nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=
ports:
Expand Down Expand Up @@ -189,7 +189,7 @@ ncps can be configured using the following flags:
- `--cache-lru-schedule`: The cron spec for cleaning the store to keep it under `--cache-max-size`. Refer to https://pkg.go.dev/github.com/robfig/cron/v3#hdr-Usage for documentation (Environment variable: `$CACHE_LRU_SCHEDULE`)
- `--cache-lru-schedule-timezone`: The name of the timezone to use for the cron schedule (default: "Local"). (Environment variable: `$CACHE_LRU_SCHEDULE_TZ`)
- `--server-addr`: The address and port the server listens on (default: ":8501"). (Environment variable: `$SERVER_ADDR`)
- `--upstream-cache`: The **hostname** of an upstream binary cache (e.g., `cache.nixos.org`). **Do not include the scheme (https://).** This flag can be used multiple times to specify multiple upstream caches, for example: `--upstream-cache cache.nixos.org --upstream-cache nix-community.cachix.org`. (Environment variable: `$UPSTREAM_CACHES`)
- `--upstream-cache`: The URL of an upstream binary cache (e.g., `https://cache.nixos.org`). This flag can be used multiple times to specify multiple upstream caches. (Environment variable: `$UPSTREAM_CACHES`)
- `--upstream-public-key`: The public key of an upstream cache in the format `host:public-key`. This flag is used to verify the signatures of store paths downloaded from upstream caches. This flag can be used multiple times, once for each upstream cache. (Environment variable: `$UPSTREAM_PUBLIC_KEYS`)

## Nix Configuration
Expand Down
14 changes: 10 additions & 4 deletions cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"time"

Expand Down Expand Up @@ -86,7 +87,7 @@ func serveCommand(logger log15.Logger) *cli.Command {
},
&cli.StringSliceFlag{
Name: "upstream-cache",
Usage: "Set to host for each upstream cache",
Usage: "Set to URL (with scheme) for each upstream cache",
Sources: cli.EnvVars("UPSTREAM_CACHES"),
Required: true,
},
Expand Down Expand Up @@ -136,18 +137,23 @@ func getUpstreamCaches(_ context.Context, logger log15.Logger, cmd *cli.Command)

ucs := make([]upstream.Cache, 0, len(ucSlice))

for _, host := range ucSlice {
for _, us := range ucSlice {
var pubKeys []string

rx := regexp.MustCompile(fmt.Sprintf(`^%s-[0-9]+:[A-Za-z0-9+/=]+$`, regexp.QuoteMeta(host)))
u, err := url.Parse(us)
if err != nil {
return nil, fmt.Errorf("error parsing --upstream-cache=%q: %w", us, err)
}

rx := regexp.MustCompile(fmt.Sprintf(`^%s-[0-9]+:[A-Za-z0-9+/=]+$`, regexp.QuoteMeta(u.Host)))

for _, pubKey := range cmd.StringSlice("upstream-public-key") {
if rx.MatchString(pubKey) {
pubKeys = append(pubKeys, pubKey)
}
}

uc, err := upstream.New(logger, host, pubKeys)
uc, err := upstream.New(logger, u, pubKeys)
if err != nil {
return nil, fmt.Errorf("error creating a new upstream cache: %w", err)
}
Expand Down
22 changes: 6 additions & 16 deletions pkg/cache/cache_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"math/rand/v2"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -50,10 +49,7 @@ func TestAddUpstreamCaches(t *testing.T) {
for _, idx := range randomOrder {
ts := testServers[idx]

u, err := url.Parse(ts.URL)
require.NoError(t, err)

uc, err := upstream.New(logger, u.Host, nil)
uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), nil)
require.NoError(t, err)

ucs = append(ucs, uc)
Expand Down Expand Up @@ -100,10 +96,7 @@ func TestAddUpstreamCaches(t *testing.T) {
for _, idx := range randomOrder {
ts := testServers[idx]

u, err := url.Parse(ts.URL)
require.NoError(t, err)

uc, err := upstream.New(logger, u.Host, nil)
uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), nil)
require.NoError(t, err)

ucs = append(ucs, uc)
Expand Down Expand Up @@ -147,10 +140,7 @@ func TestRunLRU(t *testing.T) {
ts := testdata.HTTPTestServer(t, 40)
defer ts.Close()

tu, err := url.Parse(ts.URL)
require.NoError(t, err)

uc, err := upstream.New(logger, tu.Host, testdata.PublicKeys())
uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), nil)
require.NoError(t, err)

c.AddUpstreamCaches(uc)
Expand All @@ -175,12 +165,12 @@ func TestRunLRU(t *testing.T) {

var sizePulled int64

for _, nar := range allEntries {
for i, nar := range allEntries {
_, err := c.GetNarInfo(context.Background(), nar.NarInfoHash)
require.NoError(t, err)
require.NoErrorf(t, err, "unable to get narinfo for idx %d", i)

size, _, err := c.GetNar(context.Background(), nar.NarHash, "xz")
require.NoError(t, err)
require.NoError(t, err, "unable to get nar for idx %d", i)

sizePulled += size
}
Expand Down
11 changes: 2 additions & 9 deletions pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"io"
"net/url"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -171,14 +170,11 @@ func TestGetNarInfo(t *testing.T) {
ts := testdata.HTTPTestServer(t, 40)
defer ts.Close()

tu, err := url.Parse(ts.URL)
require.NoError(t, err)

dir, err := os.MkdirTemp("", "cache-path-")
require.NoError(t, err)
defer os.RemoveAll(dir) // clean up

uc, err := upstream.New(logger, tu.Host, testdata.PublicKeys())
uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), testdata.PublicKeys())
require.NoError(t, err)

dbFile := filepath.Join(dir, "var", "ncps", "db", "db.sqlite")
Expand Down Expand Up @@ -660,14 +656,11 @@ func TestGetNar(t *testing.T) {
ts := testdata.HTTPTestServer(t, 40)
defer ts.Close()

tu, err := url.Parse(ts.URL)
require.NoError(t, err)

dir, err := os.MkdirTemp("", "cache-path-")
require.NoError(t, err)
defer os.RemoveAll(dir) // clean up

uc, err := upstream.New(logger, tu.Host, testdata.PublicKeys())
uc, err := upstream.New(logger, testhelper.MustParseURL(t, ts.URL), testdata.PublicKeys())
require.NoError(t, err)

dbFile := filepath.Join(dir, "var", "ncps", "db", "db.sqlite")
Expand Down
64 changes: 23 additions & 41 deletions pkg/cache/upstream/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/inconshreveable/log15/v3"
Expand All @@ -19,17 +19,14 @@ import (
)

var (
// ErrHostnameRequired is returned if the given hostName to New is not given.
ErrHostnameRequired = errors.New("hostName is required")
// ErrURLRequired is returned if the given URL to New is not given.
ErrURLRequired = errors.New("the URL is required")

// ErrHostnameMustNotContainScheme is returned if the given hostName to New contained a scheme.
ErrHostnameMustNotContainScheme = errors.New("hostName must not contain scheme")
// ErrURLMustContainScheme is returned if the given URL to New did not contain a scheme.
ErrURLMustContainScheme = errors.New("the URL must contain scheme")

// ErrHostnameNotValid is returned if the given hostName to New is not valid.
ErrHostnameNotValid = errors.New("hostName is not valid")

// ErrHostnameMustNotContainPath is returned if the given hostName to New contained a path.
ErrHostnameMustNotContainPath = errors.New("hostName must not contain a path")
// ErrInvalidURL is returned if the given hostName to New is not valid.
ErrInvalidURL = errors.New("the URL is not valid")

// ErrNotFound is returned if the nar or narinfo were not found.
ErrNotFound = errors.New("not found")
Expand All @@ -43,16 +40,16 @@ var (

// Cache represents the upstream cache service.
type Cache struct {
hostName string
url *url.URL
logger log15.Logger
priority uint64
publicKeys []signature.PublicKey
}

func New(logger log15.Logger, hostName string, pubKeys []string) (Cache, error) {
c := Cache{logger: logger, hostName: hostName}
func New(logger log15.Logger, u *url.URL, pubKeys []string) (Cache, error) {
c := Cache{logger: logger, url: u}

if err := c.validateHostname(hostName); err != nil {
if err := c.validateURL(u); err != nil {
return c, err
}

Expand All @@ -67,7 +64,7 @@ func New(logger log15.Logger, hostName string, pubKeys []string) (Cache, error)

priority, err := c.parsePriority()
if err != nil {
return c, fmt.Errorf("error parsing the priority for %q: %w", hostName, err)
return c, fmt.Errorf("error parsing the priority for %q: %w", u, err)
}

c.priority = priority
Expand All @@ -76,11 +73,11 @@ func New(logger log15.Logger, hostName string, pubKeys []string) (Cache, error)
}

// GetHostname returns the hostname.
func (c Cache) GetHostname() string { return c.hostName }
func (c Cache) GetHostname() string { return c.url.Hostname() }

// GetNarInfo returns a parsed NarInfo from the cache server.
func (c Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo, error) {
r, err := http.NewRequestWithContext(ctx, "GET", c.getHostnameWithScheme()+helper.NarInfoURLPath(hash), nil)
r, err := http.NewRequestWithContext(ctx, "GET", c.url.JoinPath(helper.NarInfoURLPath(hash)).String(), nil)
if err != nil {
return nil, fmt.Errorf("error creating a new request: %w", err)
}
Expand Down Expand Up @@ -128,7 +125,7 @@ func (c Cache) GetNarInfo(ctx context.Context, hash string) (*narinfo.NarInfo, e
func (c Cache) GetNar(ctx context.Context, hash, compression string) (int64, io.ReadCloser, error) {
log := c.logger.New("hash", hash, "compression", compression)

r, err := http.NewRequestWithContext(ctx, "GET", c.getHostnameWithScheme()+helper.NarURLPath(hash, compression), nil)
r, err := http.NewRequestWithContext(ctx, "GET", c.url.JoinPath(helper.NarURLPath(hash, compression)).String(), nil)
if err != nil {
return 0, nil, fmt.Errorf("error creating a new request: %w", err)
}
Expand Down Expand Up @@ -168,23 +165,14 @@ func (c Cache) GetNar(ctx context.Context, hash, compression string) (int64, io.
// GetPriority returns the priority of this upstream cache.
func (c Cache) GetPriority() uint64 { return c.priority }

func (c Cache) getHostnameWithScheme() string {
scheme := "https"
if strings.HasPrefix(c.hostName, "127.0.0.1") {
scheme = "http"
}

return scheme + "://" + c.hostName
}

func (c Cache) parsePriority() (uint64, error) {
// TODO: Should probably pass context around and have things like logger in the context
ctx := context.Background()

ctx, cancelFn := context.WithTimeout(ctx, 3*time.Second)
defer cancelFn()

r, err := http.NewRequestWithContext(ctx, "GET", c.getHostnameWithScheme()+"/nix-cache-info", nil)
r, err := http.NewRequestWithContext(ctx, "GET", c.url.JoinPath("/nix-cache-info").String(), nil)
if err != nil {
return 0, fmt.Errorf("error creating a new request: %w", err)
}
Expand All @@ -210,23 +198,17 @@ func (c Cache) parsePriority() (uint64, error) {
return nci.Priority, nil
}

func (c Cache) validateHostname(hostName string) error {
if hostName == "" {
c.logger.Error("given hostname is empty", "hostName", hostName)

return ErrHostnameRequired
}

if strings.Contains(hostName, "://") {
c.logger.Error("hostname should not contain a scheme", "hostName", hostName)
func (c Cache) validateURL(u *url.URL) error {
if u == nil {
c.logger.Error("given url is nil", "url", u)

return ErrHostnameMustNotContainScheme
return ErrURLRequired
}

if strings.Contains(hostName, "/") {
c.logger.Error("hostname should not contain a path", "hostName", hostName)
if u.Scheme == "" {
c.logger.Error("hostname should not contain a scheme", "url", u)

return ErrHostnameMustNotContainPath
return ErrURLMustContainScheme
}

return nil
Expand Down
Loading

0 comments on commit 9045e70

Please sign in to comment.