Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cmd/cloudflared/tunnel/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ var (
"no-tls-verify",
"no-chunked-encoding",
"http2-origin",
"h2c-origin",
cfdflags.ManagementHostname,
"service-op-ip",
"local-ssh-port",
Expand Down Expand Up @@ -1070,6 +1071,13 @@ func configureProxyFlags(shouldHide bool) []cli.Flag {
Hidden: shouldHide,
Value: false,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: ingress.H2cOriginFlag,
Usage: "Enables HTTP/2 cleartext origin servers.",
EnvVars: []string{"TUNNEL_ORIGIN_ENABLE_H2C"},
Hidden: shouldHide,
Value: false,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: cfdflags.ManagementHostname,
Usage: "Management hostname to signify incoming management requests",
Expand Down
21 changes: 19 additions & 2 deletions cmd/cloudflared/tunnel/cmd_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package tunnel

import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"

"github.com/cloudflare/cloudflared/ingress"
)

func TestHostnameFromURI(t *testing.T) {
Expand All @@ -12,6 +17,18 @@ func TestHostnameFromURI(t *testing.T) {
assert.Equal(t, "awesome.warptunnels.horse:2222", hostnameFromURI("ssh://awesome.warptunnels.horse:2222"))
assert.Equal(t, "localhost:3389", hostnameFromURI("rdp://localhost"))
assert.Equal(t, "localhost:3390", hostnameFromURI("rdp://localhost:3390"))
assert.Equal(t, "", hostnameFromURI("trash"))
assert.Equal(t, "", hostnameFromURI("https://awesomesauce.com"))
assert.Empty(t, hostnameFromURI("trash"))
assert.Empty(t, hostnameFromURI("https://awesomesauce.com"))
}

func TestTunnelH2cOriginFlagRegistered(t *testing.T) {
t.Parallel()
var out bytes.Buffer
app := &cli.App{
Writer: &out,
Commands: Commands(),
}

require.NoError(t, app.Run([]string{"cloudflared", "tunnel", "--" + ingress.H2cOriginFlag, "--help"}))
require.Contains(t, out.String(), "--"+ingress.H2cOriginFlag)
}
22 changes: 15 additions & 7 deletions config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var (
defaultUserConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
defaultNixConfigDirs = []string{"/etc/cloudflared", DefaultUnixConfigLocation}

ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories())
ErrNoConfigFile = fmt.Errorf("cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories())
)

const (
Expand All @@ -44,6 +44,8 @@ const (
)

// DefaultConfigDirectory returns the default directory of the config file
//
//nolint:gosec // Config path can be controlled by environment on Windows by design.
func DefaultConfigDirectory() string {
if runtime.GOOS == "windows" {
path := os.Getenv("CFDPATH")
Expand Down Expand Up @@ -87,15 +89,14 @@ func DefaultConfigSearchDirectories() []string {

// FileExists checks to see if a file exist at the provided path.
func FileExists(path string) (bool, error) {
f, err := os.Open(path)
_, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
// ignore missing files
return false, nil
}
return false, err
}
_ = f.Close()
return true, nil
}

Expand All @@ -120,6 +121,8 @@ func FindDefaultConfigPath() string {

// FindOrCreateConfigPath returns the first path that contains a config file
// or creates one in the primary default path if it doesn't exist
//
//nolint:gosec // Config and log paths are local user-controlled paths by design.
func FindOrCreateConfigPath() string {
path := FindDefaultConfigPath()

Expand All @@ -135,7 +138,7 @@ func FindOrCreateConfigPath() string {
if err != nil {
return ""
}
defer file.Close()
defer func() { _ = file.Close() }()

logDir := DefaultLogDirectory()
_ = os.MkdirAll(logDir, os.ModePerm) // try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs
Expand Down Expand Up @@ -229,6 +232,8 @@ type OriginRequestConfig struct {
IPRules []IngressIPRule `yaml:"ipRules" json:"ipRules,omitempty"`
// Attempt to connect to origin with HTTP/2
Http2Origin *bool `yaml:"http2Origin" json:"http2Origin,omitempty"`
// Connect to origin with HTTP/2 over cleartext (h2c), without TLS
H2cOrigin *bool `yaml:"h2cOrigin" json:"h2cOrigin,omitempty"`
// Access holds all access related configs
Access *AccessConfig `yaml:"access" json:"access,omitempty"`
}
Expand Down Expand Up @@ -381,6 +386,8 @@ func GetConfiguration() *Configuration {
// ReadConfigFile returns InputSourceContext initialized from the configuration file.
// On repeat calls returns with the same file, returns without reading the file again; however,
// if value of "config" flag changes, will read the new config file
//
//nolint:gosec // User-supplied config path is expected for this CLI.
func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (settings *configFileSettings, warnings string, err error) {
configFile := c.String("config")
if configuration.Source() == configFile || configFile == "" {
Expand All @@ -399,7 +406,7 @@ func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (settings *configFileSe
}
return nil, "", err
}
defer file.Close()
defer func() { _ = file.Close() }()
if err := yaml.NewDecoder(file).Decode(&configuration); err != nil {
if err == io.EOF {
log.Error().Msgf("Configuration file %s was empty", configFile)
Expand All @@ -411,6 +418,7 @@ func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (settings *configFileSe

// Parse it again, with strict mode, to find warnings.
if file, err := os.Open(configFile); err == nil {
defer func() { _ = file.Close() }()
decoder := yaml.NewDecoder(file)
decoder.KnownFields(true)
var unusedConfig configFileSettings
Expand All @@ -432,7 +440,7 @@ type CustomDuration struct {
}

func (s CustomDuration) MarshalJSON() ([]byte, error) {
return json.Marshal(s.Duration.Seconds())
return json.Marshal(s.Seconds())
}

func (s *CustomDuration) UnmarshalJSON(data []byte) error {
Expand All @@ -446,7 +454,7 @@ func (s *CustomDuration) UnmarshalJSON(data []byte) error {
}

func (s *CustomDuration) MarshalYAML() (interface{}, error) {
return s.Duration.String(), nil
return s.String(), nil
}

func (s *CustomDuration) UnmarshalYAML(unmarshal func(interface{}) error) error {
Expand Down
15 changes: 10 additions & 5 deletions connection/quic_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (q *quicConnection) acceptStream(ctx context.Context) error {
func (q *quicConnection) runStream(quicStream quic.Stream) {
ctx := quicStream.Context()
stream := cfdquic.NewSafeStreamCloser(quicStream, q.streamWriteTimeout, q.logger)
defer stream.Close()
defer func() { _ = stream.Close() }()

// we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that
// code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream.
Expand Down Expand Up @@ -229,7 +229,7 @@ func (q *quicConnection) dispatchRequest(ctx context.Context, stream *rpcquic.Re
if err != nil {
return err, false
}
w := newHTTPResponseAdapter(stream)
w := newHTTPResponseAdapter(stream, q.logger)
return originProxy.ProxyHTTP(&w, tracedReq, request.Type == pogs.ConnectionTypeWebsocket), w.connectResponseSent

case pogs.ConnectionTypeTCP:
Expand Down Expand Up @@ -277,14 +277,19 @@ type httpResponseAdapter struct {
*rpcquic.RequestServerStream
headers http.Header
connectResponseSent bool
logger *zerolog.Logger
}

func newHTTPResponseAdapter(s *rpcquic.RequestServerStream) httpResponseAdapter {
return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header)}
func newHTTPResponseAdapter(s *rpcquic.RequestServerStream, log *zerolog.Logger) httpResponseAdapter {
return httpResponseAdapter{RequestServerStream: s, headers: make(http.Header), logger: log}
}

func (hrw *httpResponseAdapter) AddTrailer(trailerName, trailerValue string) {
// we do not support trailers over QUIC
// QUIC transport does not support trailers; they are silently dropped.
// This primarily affects gRPC, which encodes grpc-status in trailers.
hrw.logger.Warn().Str("trailerName", trailerName).
Msg("QUIC transport does not support trailers; trailer will be dropped. " +
"For gRPC origins, use --protocol http2 to enable trailer support")
}

func (hrw *httpResponseAdapter) WriteRespHeaders(status int, header http.Header) error {
Expand Down
24 changes: 24 additions & 0 deletions connection/quic_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -991,3 +991,27 @@ func GenerateTLSConfig() *tls.Config {
NextProtos: []string{"argotunnel"},
}
}

func TestHTTPResponseAdapterAddTrailerLogs(t *testing.T) {
t.Parallel()
tests := []struct {
name string
trailerName string
trailerValue string
}{
{"grpc-status trailer is logged", "grpc-status", "0"},
{"grpc-message trailer is logged", "grpc-message", "OK"},
{"custom trailer is logged", "x-custom-trailer", "value"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var buf bytes.Buffer
log := zerolog.New(&buf)
adapter := newHTTPResponseAdapter(nil, &log)
adapter.AddTrailer(tt.trailerName, tt.trailerValue)
require.Contains(t, buf.String(), "QUIC transport does not support trailers")
require.Contains(t, buf.String(), tt.trailerName)
})
}
}
14 changes: 3 additions & 11 deletions crypto/curves_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"crypto/tls"
"net/http"
"net/http/httptest"
"runtime"
"slices"
"testing"

Expand Down Expand Up @@ -101,9 +100,9 @@ func runClientServerHandshake(t *testing.T, curves []tls.CurveID) []tls.CurveID
return advertisedCurves
}

// TestSupportedCurvesNegotiation verifies that the curves returned by
// GetCurvePreferences survive a real TLS handshake unchanged, i.e. the
// standard library advertises exactly the curves we expect. Currently only
// TestSupportedCurvesNegotiation verifies that supported configured curves are
// advertised by the standard library during a real TLS handshake. Unsupported
// legacy draft curves may be filtered by crypto/tls. Currently only
// PostQuantumPrefer is exercised because PostQuantumStrict would cause the
// handshake to fail against httptest servers that do not support
// X25519MLKEM768 server-side.
Expand All @@ -115,12 +114,5 @@ func TestSupportedCurvesNegotiation(t *testing.T) {
advertisedCurves := runClientServerHandshake(t, curves)
require.True(t, slices.Contains(advertisedCurves, tls.CurveP256))
require.True(t, slices.Contains(advertisedCurves, tls.X25519MLKEM768))
expectedLength := 2
if runtime.GOOS == "linux" {
// P256Kyber768Draft00 only exists in linux
require.True(t, slices.Contains(advertisedCurves, P256Kyber768Draft00))
expectedLength = 3
}
require.Len(t, advertisedCurves, expectedLength)
}
}
21 changes: 20 additions & 1 deletion ingress/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
ProxyAddressFlag = "proxy-address"
ProxyPortFlag = "proxy-port"
Http2OriginFlag = "http2-origin"
H2cOriginFlag = "h2c-origin"
)

const (
Expand Down Expand Up @@ -137,6 +138,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
var proxyPort uint
var proxyType string
var http2Origin bool
var h2cOrigin bool
if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) {
connectTimeout = config.CustomDuration{Duration: c.Duration(flag)}
}
Expand Down Expand Up @@ -187,6 +189,9 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
if flag := Http2OriginFlag; c.IsSet(flag) {
http2Origin = c.Bool(flag)
}
if flag := H2cOriginFlag; c.IsSet(flag) {
h2cOrigin = c.Bool(flag)
}
if c.IsSet(Socks5Flag) {
proxyType = socksProxy
}
Expand All @@ -209,6 +214,7 @@ func originRequestFromSingleRule(c *cli.Context) OriginRequestConfig {
ProxyPort: proxyPort,
ProxyType: proxyType,
Http2Origin: http2Origin,
H2cOrigin: h2cOrigin,
}
}

Expand Down Expand Up @@ -280,6 +286,9 @@ func originRequestFromConfig(c config.OriginRequestConfig) OriginRequestConfig {
if c.Http2Origin != nil {
out.Http2Origin = *c.Http2Origin
}
if c.H2cOrigin != nil {
out.H2cOrigin = *c.H2cOrigin
}
if c.Access != nil {
out.Access = *c.Access
}
Expand Down Expand Up @@ -330,6 +339,8 @@ type OriginRequestConfig struct {
IPRules []ipaccess.Rule `yaml:"ipRules" json:"ipRules"`
// Attempt to connect to origin with HTTP/2
Http2Origin bool `yaml:"http2Origin" json:"http2Origin"`
// Connect to origin with HTTP/2 over cleartext (h2c), without TLS
H2cOrigin bool `yaml:"h2cOrigin" json:"h2cOrigin"`

// Access holds all access related configs
Access config.AccessConfig `yaml:"access" json:"access,omitempty"`
Expand Down Expand Up @@ -450,6 +461,12 @@ func (defaults *OriginRequestConfig) setHttp2Origin(overrides config.OriginReque
}
}

func (defaults *OriginRequestConfig) setH2cOrigin(overrides config.OriginRequestConfig) {
if val := overrides.H2cOrigin; val != nil {
defaults.H2cOrigin = *val
}
}

func (defaults *OriginRequestConfig) setAccess(overrides config.OriginRequestConfig) {
if val := overrides.Access; val != nil {
defaults.Access = *val
Expand Down Expand Up @@ -484,6 +501,7 @@ func setConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfi
cfg.setProxyType(overrides)
cfg.setIPRules(overrides)
cfg.setHttp2Origin(overrides)
cfg.setH2cOrigin(overrides)
cfg.setAccess(overrides)

return cfg
Expand Down Expand Up @@ -539,12 +557,13 @@ func ConvertToRawOriginConfig(c OriginRequestConfig) config.OriginRequestConfig
ProxyType: emptyStringToNil(c.ProxyType),
IPRules: convertToRawIPRules(c.IPRules),
Http2Origin: defaultBoolToNil(c.Http2Origin),
H2cOrigin: defaultBoolToNil(c.H2cOrigin),
Access: access,
}
}

func convertToRawIPRules(ipRules []ipaccess.Rule) []config.IngressIPRule {
result := make([]config.IngressIPRule, 0)
result := make([]config.IngressIPRule, 0, len(ipRules))
for _, r := range ipRules {
cidr := r.StringCIDR()

Expand Down
Loading
Loading