diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 2ed385e25f7..0407de97184 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -101,6 +101,7 @@ var ( "no-tls-verify", "no-chunked-encoding", "http2-origin", + "h2c-origin", cfdflags.ManagementHostname, "service-op-ip", "local-ssh-port", @@ -1070,6 +1071,13 @@ func configureProxyFlags(shouldHide bool) []cli.Flag { Hidden: shouldHide, Value: false, }), + altsrc.NewBoolFlag(&cli.BoolFlag{ + Name: ingress.H2cOriginFlag, + Usage: "Enables HTTP/2 cleartext origin servers.", + EnvVars: []string{"TUNNEL_ORIGIN_ENABLE_H2C"}, + Hidden: shouldHide, + Value: false, + }), altsrc.NewStringFlag(&cli.StringFlag{ Name: cfdflags.ManagementHostname, Usage: "Management hostname to signify incoming management requests", diff --git a/cmd/cloudflared/tunnel/cmd_test.go b/cmd/cloudflared/tunnel/cmd_test.go index b29b396612c..09f5d9fa1b0 100644 --- a/cmd/cloudflared/tunnel/cmd_test.go +++ b/cmd/cloudflared/tunnel/cmd_test.go @@ -1,9 +1,14 @@ package tunnel import ( + "bytes" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" + + "github.com/cloudflare/cloudflared/ingress" ) func TestHostnameFromURI(t *testing.T) { @@ -12,6 +17,18 @@ func TestHostnameFromURI(t *testing.T) { assert.Equal(t, "awesome.warptunnels.horse:2222", hostnameFromURI("ssh://awesome.warptunnels.horse:2222")) assert.Equal(t, "localhost:3389", hostnameFromURI("rdp://localhost")) assert.Equal(t, "localhost:3390", hostnameFromURI("rdp://localhost:3390")) - assert.Equal(t, "", hostnameFromURI("trash")) - assert.Equal(t, "", hostnameFromURI("https://awesomesauce.com")) + assert.Empty(t, hostnameFromURI("trash")) + assert.Empty(t, hostnameFromURI("https://awesomesauce.com")) +} + +func TestTunnelH2cOriginFlagRegistered(t *testing.T) { + t.Parallel() + var out bytes.Buffer + app := &cli.App{ + Writer: &out, + Commands: Commands(), + } + + require.NoError(t, app.Run([]string{"cloudflared", "tunnel", "--" + ingress.H2cOriginFlag, "--help"})) + require.Contains(t, out.String(), "--"+ingress.H2cOriginFlag) } diff --git a/config/configuration.go b/config/configuration.go index cb0b0adeda5..5b929e14b43 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -35,7 +35,7 @@ var ( defaultUserConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"} defaultNixConfigDirs = []string{"/etc/cloudflared", DefaultUnixConfigLocation} - ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories()) + ErrNoConfigFile = fmt.Errorf("cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories()) ) const ( @@ -44,6 +44,8 @@ const ( ) // DefaultConfigDirectory returns the default directory of the config file +// +//nolint:gosec // Config path can be controlled by environment on Windows by design. func DefaultConfigDirectory() string { if runtime.GOOS == "windows" { path := os.Getenv("CFDPATH") @@ -87,7 +89,7 @@ func DefaultConfigSearchDirectories() []string { // FileExists checks to see if a file exist at the provided path. func FileExists(path string) (bool, error) { - f, err := os.Open(path) + _, err := os.Stat(path) if err != nil { if os.IsNotExist(err) { // ignore missing files @@ -95,7 +97,6 @@ func FileExists(path string) (bool, error) { } return false, err } - _ = f.Close() return true, nil } @@ -120,6 +121,8 @@ func FindDefaultConfigPath() string { // FindOrCreateConfigPath returns the first path that contains a config file // or creates one in the primary default path if it doesn't exist +// +//nolint:gosec // Config and log paths are local user-controlled paths by design. func FindOrCreateConfigPath() string { path := FindDefaultConfigPath() @@ -135,7 +138,7 @@ func FindOrCreateConfigPath() string { if err != nil { return "" } - defer file.Close() + defer func() { _ = file.Close() }() logDir := DefaultLogDirectory() _ = os.MkdirAll(logDir, os.ModePerm) // try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs @@ -229,6 +232,8 @@ type OriginRequestConfig struct { IPRules []IngressIPRule `yaml:"ipRules" json:"ipRules,omitempty"` // Attempt to connect to origin with HTTP/2 Http2Origin *bool `yaml:"http2Origin" json:"http2Origin,omitempty"` + // Connect to origin with HTTP/2 over cleartext (h2c), without TLS + H2cOrigin *bool `yaml:"h2cOrigin" json:"h2cOrigin,omitempty"` // Access holds all access related configs Access *AccessConfig `yaml:"access" json:"access,omitempty"` } @@ -381,6 +386,8 @@ func GetConfiguration() *Configuration { // ReadConfigFile returns InputSourceContext initialized from the configuration file. // On repeat calls returns with the same file, returns without reading the file again; however, // if value of "config" flag changes, will read the new config file +// +//nolint:gosec // User-supplied config path is expected for this CLI. func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (settings *configFileSettings, warnings string, err error) { configFile := c.String("config") if configuration.Source() == configFile || configFile == "" { @@ -399,7 +406,7 @@ func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (settings *configFileSe } return nil, "", err } - defer file.Close() + defer func() { _ = file.Close() }() if err := yaml.NewDecoder(file).Decode(&configuration); err != nil { if err == io.EOF { log.Error().Msgf("Configuration file %s was empty", configFile) @@ -411,6 +418,7 @@ func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (settings *configFileSe // Parse it again, with strict mode, to find warnings. if file, err := os.Open(configFile); err == nil { + defer func() { _ = file.Close() }() decoder := yaml.NewDecoder(file) decoder.KnownFields(true) var unusedConfig configFileSettings @@ -432,7 +440,7 @@ type CustomDuration struct { } func (s CustomDuration) MarshalJSON() ([]byte, error) { - return json.Marshal(s.Duration.Seconds()) + return json.Marshal(s.Seconds()) } func (s *CustomDuration) UnmarshalJSON(data []byte) error { @@ -446,7 +454,7 @@ func (s *CustomDuration) UnmarshalJSON(data []byte) error { } func (s *CustomDuration) MarshalYAML() (interface{}, error) { - return s.Duration.String(), nil + return s.String(), nil } func (s *CustomDuration) UnmarshalYAML(unmarshal func(interface{}) error) error { diff --git a/connection/quic_connection.go b/connection/quic_connection.go index 88d1e7a221b..5428b6acdff 100644 --- a/connection/quic_connection.go +++ b/connection/quic_connection.go @@ -169,7 +169,7 @@ func (q *quicConnection) acceptStream(ctx context.Context) error { func (q *quicConnection) runStream(quicStream quic.Stream) { ctx := quicStream.Context() stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger) - defer stream.Close() + defer func() { _ = stream.Close() }() // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream. @@ -229,7 +229,7 @@ func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.Re if err != nil { return err, false } - w := newHTTPResponseAdapter(stream) + w := newHTTPResponseAdapter(stream, q.logger) return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent case pogs.ConnectionTypeTCP: @@ -277,14 +277,19 @@ type httpResponseAdapter struct { *rpcquic.RequestServerStream headers http.Header connectResponseSent bool + logger *zerolog.Logger } -func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter { - return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)} +func newHTTPResponseAdapter(s *rpcquic.RequestServerStream, log *zerolog.Logger) httpResponseAdapter { + return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header), logger: log} } func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) { - // we do not support trailers over QUIC + // QUIC transport does not support trailers; they are silently dropped. + // This primarily affects gRPC, which encodes grpc-status in trailers. + hrw.logger.Warn().Str("trailerName", trailerName). + Msg("QUIC transport does not support trailers; trailer will be dropped. " + + "For gRPC origins, use --protocol http2 to enable trailer support") } func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error { diff --git a/connection/quic_connection_test.go b/connection/quic_connection_test.go index 2df741a6a9e..f422f347449 100644 --- a/connection/quic_connection_test.go +++ b/connection/quic_connection_test.go @@ -991,3 +991,27 @@ func GenerateTLSConfig() *tls.Config { NextProtos: []string{"argotunnel"}, } } + +func TestHTTPResponseAdapterAddTrailerLogs(t *testing.T) { + t.Parallel() + tests := []struct { + name string + trailerName string + trailerValue string + }{ + {"grpc-status trailer is logged", "grpc-status", "0"}, + {"grpc-message trailer is logged", "grpc-message", "OK"}, + {"custom trailer is logged", "x-custom-trailer", "value"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + log := zerolog.New(&buf) + adapter := newHTTPResponseAdapter(nil, &log) + adapter.AddTrailer(tt.trailerName, tt.trailerValue) + require.Contains(t, buf.String(), "QUIC transport does not support trailers") + require.Contains(t, buf.String(), tt.trailerName) + }) + } +} diff --git a/crypto/curves_test.go b/crypto/curves_test.go index dab49f22988..a77ca0fdf8c 100644 --- a/crypto/curves_test.go +++ b/crypto/curves_test.go @@ -4,7 +4,6 @@ import ( "crypto/tls" "net/http" "net/http/httptest" - "runtime" "slices" "testing" @@ -101,9 +100,9 @@ func runClientServerHandshake(t *testing.T, curves []tls.CurveID) []tls.CurveID return advertisedCurves } -// TestSupportedCurvesNegotiation verifies that the curves returned by -// GetCurvePreferences survive a real TLS handshake unchanged, i.e. the -// standard library advertises exactly the curves we expect. Currently only +// TestSupportedCurvesNegotiation verifies that supported configured curves are +// advertised by the standard library during a real TLS handshake. Unsupported +// legacy draft curves may be filtered by crypto/tls. Currently only // PostQuantumPrefer is exercised because PostQuantumStrict would cause the // handshake to fail against httptest servers that do not support // X25519MLKEM768 server-side. @@ -115,12 +114,5 @@ func TestSupportedCurvesNegotiation(t *testing.T) { advertisedCurves := runClientServerHandshake(t, curves) require.True(t, slices.Contains(advertisedCurves, tls.CurveP256)) require.True(t, slices.Contains(advertisedCurves, tls.X25519MLKEM768)) - expectedLength := 2 - if runtime.GOOS == "linux" { - // P256Kyber768Draft00 only exists in linux - require.True(t, slices.Contains(advertisedCurves, P256Kyber768Draft00)) - expectedLength = 3 - } - require.Len(t, advertisedCurves, expectedLength) } } diff --git a/ingress/config.go b/ingress/config.go index 83f893fefe1..6f367f97013 100644 --- a/ingress/config.go +++ b/ingress/config.go @@ -39,6 +39,7 @@ const ( ProxyAddressFlag = "proxy-address" ProxyPortFlag = "proxy-port" Http2OriginFlag = "http2-origin" + H2cOriginFlag = "h2c-origin" ) const ( @@ -137,6 +138,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig { var proxyPort uint var proxyType string var http2Origin bool + var h2cOrigin bool if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) { connectTimeout = config.CustomDuration{Duration: c.Duration(flag)} } @@ -187,6 +189,9 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig { if flag := Http2OriginFlag; c.IsSet(flag) { http2Origin = c.Bool(flag) } + if flag := H2cOriginFlag; c.IsSet(flag) { + h2cOrigin = c.Bool(flag) + } if c.IsSet(Socks5Flag) { proxyType = socksProxy } @@ -209,6 +214,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig { ProxyPort: proxyPort, ProxyType: proxyType, Http2Origin: http2Origin, + H2cOrigin: h2cOrigin, } } @@ -280,6 +286,9 @@ func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig { if c.Http2Origin != nil { out.Http2Origin = *c.Http2Origin } + if c.H2cOrigin != nil { + out.H2cOrigin = *c.H2cOrigin + } if c.Access != nil { out.Access = *c.Access } @@ -330,6 +339,8 @@ type OriginRequestConfig struct { IPRules []ipaccess.Rule `yaml:"ipRules" json:"ipRules"` // Attempt to connect to origin with HTTP/2 Http2Origin bool `yaml:"http2Origin" json:"http2Origin"` + // Connect to origin with HTTP/2 over cleartext (h2c), without TLS + H2cOrigin bool `yaml:"h2cOrigin" json:"h2cOrigin"` // Access holds all access related configs Access config.AccessConfig `yaml:"access" json:"access,omitempty"` @@ -450,6 +461,12 @@ func (defaults *OriginRequestConfig) setHttp2Origin(overrides config.OriginReque } } +func (defaults *OriginRequestConfig) setH2cOrigin(overrides config.OriginRequestConfig) { + if val := overrides.H2cOrigin; val != nil { + defaults.H2cOrigin = *val + } +} + func (defaults *OriginRequestConfig) setAccess(overrides config.OriginRequestConfig) { if val := overrides.Access; val != nil { defaults.Access = *val @@ -484,6 +501,7 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi cfg.setProxyType(overrides) cfg.setIPRules(overrides) cfg.setHttp2Origin(overrides) + cfg.setH2cOrigin(overrides) cfg.setAccess(overrides) return cfg @@ -539,12 +557,13 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig ProxyType: emptyStringToNil(c.ProxyType), IPRules: convertToRawIPRules(c.IPRules), Http2Origin: defaultBoolToNil(c.Http2Origin), + H2cOrigin: defaultBoolToNil(c.H2cOrigin), Access: access, } } func convertToRawIPRules(ipRules []ipaccess.Rule) []config.IngressIPRule { - result := make([]config.IngressIPRule, 0) + result := make([]config.IngressIPRule, 0, len(ipRules)) for _, r := range ipRules { cidr := r.StringCIDR() diff --git a/ingress/config_test.go b/ingress/config_test.go index 249cf2d5e2e..a9456e54f11 100644 --- a/ingress/config_test.go +++ b/ingress/config_test.go @@ -399,6 +399,46 @@ ingress: validate(remoteConfig.Ingress) } +func TestH2cOriginConfigResolution(t *testing.T) { + t.Parallel() + t.Run("json remote config with h2cOrigin", func(t *testing.T) { + t.Parallel() + rawConfig := []byte(` +{ + "originRequest": {"h2cOrigin": true}, + "ingress": [ + {"hostname": "grpc.example.com", "service": "http://localhost:50051"}, + {"service": "http_status:404"} + ] +}`) + var remoteConfig RemoteConfig + err := json.Unmarshal(rawConfig, &remoteConfig) + require.NoError(t, err) + require.True(t, remoteConfig.Ingress.Rules[0].Config.H2cOrigin) + }) + + t.Run("per-rule h2cOrigin override", func(t *testing.T) { + t.Parallel() + rawConfig := []byte(` +{ + "ingress": [ + { + "hostname": "grpc.example.com", + "service": "http://localhost:50051", + "originRequest": {"h2cOrigin": true} + }, + {"service": "http_status:404"} + ] +}`) + var remoteConfig RemoteConfig + err := json.Unmarshal(rawConfig, &remoteConfig) + require.NoError(t, err) + require.True(t, remoteConfig.Ingress.Rules[0].Config.H2cOrigin) + // Catch-all rule should not inherit per-rule h2cOrigin + require.False(t, remoteConfig.Ingress.Rules[1].Config.H2cOrigin) + }) +} + func TestDefaultConfigFromCLI(t *testing.T) { set := flag.NewFlagSet("contrive", 0) c := cli.NewContext(nil, set, nil) @@ -415,6 +455,17 @@ func TestDefaultConfigFromCLI(t *testing.T) { require.Equal(t, expected, actual) } +func TestH2cOriginFromCLI(t *testing.T) { + t.Parallel() + set := flag.NewFlagSet("contrive", flag.PanicOnError) + set.Bool(H2cOriginFlag, false, "") + require.NoError(t, set.Parse([]string{"--" + H2cOriginFlag})) + c := cli.NewContext(nil, set, nil) + + actual := originRequestFromSingleRule(c) + require.True(t, actual.H2cOrigin) +} + func newIPRule(t *testing.T, prefix string, ports []int, allow bool) ipaccess.Rule { rule, err := ipaccess.NewRuleByCIDR(&prefix, ports, allow) require.NoError(t, err) diff --git a/ingress/ingress.go b/ingress/ingress.go index a325271a7e7..5c840fce418 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -311,6 +311,10 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq } } + if err := validateHTTPOriginConfig(service, cfg); err != nil { + return Ingress{}, errors.Wrapf(err, "Rule #%d has an invalid originRequest configuration", i+1) + } + var handlers []middleware.Handler if access := r.OriginRequest.Access; access != nil { if err := validateAccessConfiguration(access); err != nil { diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 109cb3530e2..d230f3d2057 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -463,6 +463,74 @@ ingress: } } +func TestParseIngressRejectsInvalidH2cOriginConfig(t *testing.T) { + t.Parallel() + tests := []struct { + name string + rawYAML string + wantErr bool + errContains string + }{ + { + name: "h2c with http origin succeeds", + rawYAML: ` +originRequest: + h2cOrigin: true +ingress: + - hostname: "*" + service: http://localhost:50051 +`, + }, + { + name: "h2c and http2Origin conflict", + rawYAML: ` +originRequest: + h2cOrigin: true + http2Origin: true +ingress: + - hostname: "*" + service: http://localhost:50051 +`, + wantErr: true, + errContains: "cannot both be enabled", + }, + { + name: "h2c with https origin errors", + rawYAML: ` +originRequest: + h2cOrigin: true +ingress: + - hostname: "*" + service: https://localhost:50051 +`, + wantErr: true, + errContains: "h2cOrigin is enabled", + }, + { + name: "h2c with unix socket succeeds", + rawYAML: ` +originRequest: + h2cOrigin: true +ingress: + - hostname: "*" + service: unix:/tmp/cloudflared-h2c-test.sock +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := ParseIngress(MustReadIngress(tt.rawYAML)) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} + func ipRulePrefix(s string) *string { return &s } @@ -555,7 +623,7 @@ func TestSingleOriginServices(t *testing.T) { flagSet.String("unix-socket", "", "") cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil) for i := 0; i < len(params); i += 2 { - cliCtx.Set(params[i], params[i+1]) + require.NoError(t, cliCtx.Set(params[i], params[i+1])) } return cliCtx @@ -605,7 +673,7 @@ func TestSingleOriginServices(t *testing.T) { if test.err != nil { return } - require.Equal(t, 1, len(ingress.Rules)) + require.Len(t, ingress.Rules, 1) rule := ingress.Rules[0] require.Equal(t, test.expectedService, rule.Service) }) @@ -626,7 +694,7 @@ func TestSingleOriginServices_URL(t *testing.T) { flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError) flagSet.String("url", "", "") cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil) - cliCtx.Set(param, value) + require.NoError(t, cliCtx.Set(param, value)) return cliCtx } @@ -636,7 +704,7 @@ func TestSingleOriginServices_URL(t *testing.T) { url := urlMustParse(test + host) ingress, err := parseCLIIngress(newCli("url", url.String()), false) require.NoError(t, err) - require.Equal(t, 1, len(ingress.Rules)) + require.Len(t, ingress.Rules, 1) rule := ingress.Rules[0] require.Equal(t, &httpService{url: url}, rule.Service) }) @@ -648,7 +716,7 @@ func TestSingleOriginServices_URL(t *testing.T) { url := urlMustParse(test + host) ingress, err := parseCLIIngress(newCli("url", url.String()), false) require.NoError(t, err) - require.Equal(t, 1, len(ingress.Rules)) + require.Len(t, ingress.Rules, 1) rule := ingress.Rules[0] require.Equal(t, newTCPOverWSService(url), rule.Service) }) @@ -712,7 +780,7 @@ func TestFindMatchingRule(t *testing.T) { for _, test := range tests { _, ruleIndex := ingress.FindMatchingRule(test.host, test.path) - assert.Equal(t, test.wantRuleIndex, ruleIndex, fmt.Sprintf("Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, ruleIndex)) + assert.Equal(t, test.wantRuleIndex, ruleIndex, "Expect host=%s, path=%s to match rule %d, got %d", test.host, test.path, test.wantRuleIndex, ruleIndex) } } diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 7371eac92ec..6db39f32584 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -59,14 +59,19 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) { } func (o *httpService) SetOriginServerName(req *http.Request) { - o.transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := o.transport.DialContext(ctx, network, addr) + t, ok := o.transport.(*http.Transport) + if !ok { + // h2c transport doesn't use TLS, so SNI matching is not applicable + return + } + t.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := t.DialContext(ctx, network, addr) if err != nil { return nil, err } return tls.Client(conn, &tls.Config{ - RootCAs: o.transport.TLSClientConfig.RootCAs, - InsecureSkipVerify: o.transport.TLSClientConfig.InsecureSkipVerify, // nolint: gosec + RootCAs: t.TLSClientConfig.RootCAs, + InsecureSkipVerify: t.TLSClientConfig.InsecureSkipVerify, // nolint: gosec ServerName: req.Host, }), nil } diff --git a/ingress/origin_service.go b/ingress/origin_service.go index e13204c5789..f05be629d58 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/rs/zerolog" + "golang.org/x/net/http2" "github.com/cloudflare/cloudflared/hello" "github.com/cloudflare/cloudflared/ipaccess" @@ -43,7 +44,7 @@ type OriginService interface { type unixSocketPath struct { path string scheme string - transport *http.Transport + transport http.RoundTripper } func (o *unixSocketPath) String() string { @@ -55,6 +56,9 @@ func (o *unixSocketPath) String() string { } func (o *unixSocketPath) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { + if err := validateHTTPOriginConfig(o, cfg); err != nil { + return err + } transport, err := newHTTPTransport(o, cfg, log) if err != nil { return err @@ -70,11 +74,20 @@ func (o unixSocketPath) MarshalJSON() ([]byte, error) { type httpService struct { url *url.URL hostHeader string - transport *http.Transport + transport http.RoundTripper matchSNIToHost bool } func (o *httpService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { + if err := validateHTTPOriginConfig(o, cfg); err != nil { + return err + } + if cfg.Http2Origin && o.url != nil && o.url.Scheme == "http" { + log.Warn().Str("origin", o.url.String()). + Msg("http2Origin is enabled but the origin URL uses http:// (not https://). " + + "HTTP/2 requires TLS, so the connection will fall back to HTTP/1.1. " + + "Use an https:// origin URL or disable http2Origin") + } transport, err := newHTTPTransport(o, cfg, log) if err != nil { return err @@ -93,13 +106,39 @@ func (o httpService) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } +func validateHTTPOriginConfig(service OriginService, cfg OriginRequestConfig) error { + switch service := service.(type) { + case *httpService: + scheme := "" + if service.url != nil { + scheme = service.url.Scheme + } + return validateHTTPOriginScheme(scheme, cfg) + case *unixSocketPath: + return validateHTTPOriginScheme(service.scheme, cfg) + default: + return nil + } +} + +func validateHTTPOriginScheme(scheme string, cfg OriginRequestConfig) error { + if cfg.H2cOrigin && cfg.Http2Origin { + return fmt.Errorf("h2cOrigin and http2Origin cannot both be enabled; " + + "h2cOrigin is for cleartext HTTP/2 (http://), http2Origin is for TLS HTTP/2 (https://)") + } + if cfg.H2cOrigin && scheme == "https" { + return fmt.Errorf("h2cOrigin is enabled but the origin uses https://; " + + "h2c is HTTP/2 over cleartext; use http:// or disable h2cOrigin") + } + return nil +} + // rawTCPService dials TCP to the destination specified by the client // It's used by warp routing type rawTCPService struct { name string dialer net.Dialer writeTimeout time.Duration - logger *zerolog.Logger } func (o *rawTCPService) String() string { @@ -233,10 +272,14 @@ func (o *helloWorld) start( if err != nil { return errors.Wrap(err, "Cannot start Hello World Server") } - go hello.StartHelloWorldServer(log, helloListener, shutdownC) + go func() { + if err := hello.StartHelloWorldServer(log, helloListener, shutdownC); err != nil { + log.Error().Err(err).Msg("Cannot start Hello World Server") + } + }() o.server = helloListener - o.httpService.url = &url.URL{ + o.url = &url.URL{ Scheme: "https", Host: o.server.Addr().String(), } @@ -343,7 +386,34 @@ func (nrc *NopReadCloser) Close() error { return nil } -func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) { +func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (http.RoundTripper, error) { + dialer := &net.Dialer{ + Timeout: cfg.ConnectTimeout.Duration, + KeepAlive: cfg.TCPKeepAlive.Duration, + } + if cfg.NoHappyEyeballs { + dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs" + } + + // DialContext depends on which kind of origin is being used. + dialContext := dialer.DialContext + if uds, ok := service.(*unixSocketPath); ok { + dialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", uds.path) + } + } + + // h2c: HTTP/2 over cleartext, no TLS cert loading needed + if cfg.H2cOrigin { + log.Info().Msg("h2cOrigin enabled: using HTTP/2 cleartext transport") + return &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + return dialContext(ctx, network, addr) + }, + }, nil + } + originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log) if err != nil { return nil, errors.Wrap(err, "Error loading cert pool") @@ -356,36 +426,15 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerol IdleConnTimeout: cfg.KeepAliveTimeout.Duration, TLSHandshakeTimeout: cfg.TLSTimeout.Duration, ExpectContinueTimeout: 1 * time.Second, - TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify}, - ForceAttemptHTTP2: cfg.Http2Origin, + //nolint:gosec // NoTLSVerify is a user-configurable origin transport option. + TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify}, + ForceAttemptHTTP2: cfg.Http2Origin, + DialContext: dialContext, } if _, isHelloWorld := service.(*helloWorld); !isHelloWorld && cfg.OriginServerName != "" { httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName } - dialer := &net.Dialer{ - Timeout: cfg.ConnectTimeout.Duration, - KeepAlive: cfg.TCPKeepAlive.Duration, - } - if cfg.NoHappyEyeballs { - dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs" - } - - // DialContext depends on which kind of origin is being used. - dialContext := dialer.DialContext - switch service := service.(type) { - - // If this origin is a unix socket, enforce network type "unix". - case *unixSocketPath: - httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialContext(ctx, "unix", service.path) - } - - // Otherwise, use the regular network config. - default: - httpTransport.DialContext = dialContext - } - return &httpTransport, nil } diff --git a/ingress/origin_service_test.go b/ingress/origin_service_test.go index 081b08ae920..bc721f91676 100644 --- a/ingress/origin_service_test.go +++ b/ingress/origin_service_test.go @@ -1,9 +1,11 @@ package ingress import ( + "bytes" "net/url" "testing" + "github.com/rs/zerolog" "github.com/stretchr/testify/require" ) @@ -27,3 +29,84 @@ func TestAddPortIfMissing(t *testing.T) { }) } } + +func TestH2cOriginTransport(t *testing.T) { + t.Parallel() + tests := []struct { + name string + h2cOrigin bool + http2Origin bool + scheme string + wantErr bool + errContains string + }{ + {"h2c with http origin succeeds", true, false, "http", false, ""}, + {"h2c with https origin errors", true, false, "https", true, "h2cOrigin is enabled but"}, + {"h2c and http2Origin conflict", true, true, "http", true, "cannot both be enabled"}, + {"http2Origin alone is fine", false, true, "https", false, ""}, + {"neither h2c nor http2Origin", false, false, "http", false, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + log := zerolog.Nop() + svc := &httpService{url: &url.URL{Scheme: tt.scheme, Host: "localhost:50051"}} + cfg := OriginRequestConfig{ + H2cOrigin: tt.h2cOrigin, + Http2Origin: tt.http2Origin, + } + err := svc.start(&log, nil, cfg) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestUnixSocketH2cOriginConflict(t *testing.T) { + t.Parallel() + log := zerolog.Nop() + svc := &unixSocketPath{path: "/tmp/cloudflared-h2c-test.sock", scheme: "http"} + err := svc.start(&log, nil, OriginRequestConfig{ + H2cOrigin: true, + Http2Origin: true, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot both be enabled") +} + +func TestHttp2OriginWithHTTPSchemeWarning(t *testing.T) { + t.Parallel() + tests := []struct { + name string + scheme string + http2Origin bool + wantWarning bool + }{ + {"http2Origin with http scheme warns", "http", true, true}, + {"http2Origin with https scheme no warning", "https", true, false}, + {"no http2Origin with http scheme no warning", "http", false, false}, + {"no http2Origin with https scheme no warning", "https", false, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + log := zerolog.New(&buf) + svc := &httpService{url: &url.URL{Scheme: tt.scheme, Host: "localhost:8080"}} + cfg := OriginRequestConfig{ + Http2Origin: tt.http2Origin, + NoTLSVerify: true, + } + require.NoError(t, svc.start(&log, nil, cfg)) + if tt.wantWarning { + require.Contains(t, buf.String(), "http2Origin is enabled") + } else { + require.NotContains(t, buf.String(), "http2Origin is enabled") + } + }) + } +} diff --git a/ingress/rule_test.go b/ingress/rule_test.go index a3d12e012db..cd8f68f4f69 100644 --- a/ingress/rule_test.go +++ b/ingress/rule_test.go @@ -119,7 +119,7 @@ func Test_rule_matches(t *testing.T) { name: "Hostname and path", rule: Rule{ Hostname: "*.example.com", - Path: &Regexp{Regexp: regexp.MustCompile("/static/.*\\.html")}, + Path: &Regexp{Regexp: regexp.MustCompile(`/static/.*\.html`)}, }, args: args{ requestURL: MustParseURL(t, "https://www.example.com/static/index.html"), @@ -187,6 +187,7 @@ func TestStaticHTTPStatus(t *testing.T) { n, err := io.Copy(w, resp.Body) require.NoError(t, err) require.Equal(t, int64(0), n) + require.NoError(t, resp.Body.Close()) } sendReq() sendReq() @@ -204,25 +205,25 @@ func TestMarshalJSON(t *testing.T) { { name: "Nil", path: nil, - expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`, + expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"h2cOrigin":false,"access":{"teamName":"","audTag":null}}}`, want: true, }, { name: "Nil regex", path: &Regexp{Regexp: nil}, - expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`, + expected: `{"hostname":"example.com","path":null,"service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"h2cOrigin":false,"access":{"teamName":"","audTag":null}}}`, want: true, }, { name: "Empty", path: &Regexp{Regexp: regexp.MustCompile("")}, - expected: `{"hostname":"example.com","path":"","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`, + expected: `{"hostname":"example.com","path":"","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"h2cOrigin":false,"access":{"teamName":"","audTag":null}}}`, want: true, }, { name: "Basic", path: &Regexp{Regexp: regexp.MustCompile("/echo")}, - expected: `{"hostname":"example.com","path":"/echo","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"access":{"teamName":"","audTag":null}}}`, + expected: `{"hostname":"example.com","path":"/echo","service":"https://localhost:8000","Handlers":null,"originRequest":{"connectTimeout":30,"tlsTimeout":10,"tcpKeepAlive":30,"noHappyEyeballs":false,"keepAliveTimeout":90,"keepAliveConnections":100,"httpHostHeader":"","originServerName":"","matchSNItoHost":false,"caPool":"","noTLSVerify":false,"disableChunkedEncoding":false,"bastionMode":false,"proxyAddress":"127.0.0.1","proxyPort":0,"proxyType":"","ipRules":null,"http2Origin":false,"h2cOrigin":false,"access":{"teamName":"","audTag":null}}}`, want: true, }, }