diff --git a/.github/workflows/acceptance-test.yaml b/.github/workflows/acceptance-test.yaml index c8fe3b5d..1fa142f8 100644 --- a/.github/workflows/acceptance-test.yaml +++ b/.github/workflows/acceptance-test.yaml @@ -45,6 +45,10 @@ jobs: id: set-git-refname run: echo "git_refname=$(echo "${{ github.ref }}" | sed -r 's@refs/(heads|pull|tags)/@@g')" >> $GITHUB_OUTPUT + - name: Build TCP metadata exchange filter + run: | + make tcp-metadata-exchange-filter + - name: Deploy to kind cluster (istio-operator) run: | test/deploy-kind.sh diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2664f6a7..692225ae 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -44,6 +44,10 @@ jobs: restore-keys: | build-deps-v2 + - name: Build TCP metadata exchange filter + run: | + make tcp-metadata-exchange-filter + - name: Run unit tests run: make test diff --git a/Cargo.toml b/Cargo.toml index a68e0353..e8e974a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,3 +2,13 @@ members = [ "./experimental/wasm-tcp-metadata", ] + +[profile.release] +# do not include debug symbols +debug = false +# link-time optimalization +lto = 'thin' # this works much better for wasm3 than 'true' +# optimize for binary size for wasm s is better than z +opt-level = "s" +# do not unwind the stack when panicking +#panic = "abort" diff --git a/Makefile b/Makefile index c0a442ea..48ff1202 100644 --- a/Makefile +++ b/Makefile @@ -52,3 +52,9 @@ lint-fix-all: ${REPO_ROOT}/bin/golangci-lint ## lint --fix the whole repo .PHONY: mod-download-all mod-download-all: ## go mod download all go modules ./scripts/for_all_go_modules.sh -- go mod download all + +.PHONY: tcp-metadata-exchange-filter +tcp-metadata-exchange-filter: ## build the tcp-metadata-exchange-filter + rustup target add wasm32-unknown-unknown + cargo build --target wasm32-unknown-unknown --release + cp target/wasm32-unknown-unknown/release/wasm_tcp_metadata.wasm pkg/istio/filters/tcp-metadata-exchange-filter.wasm diff --git a/examples/grpc/client/main.go b/examples/grpc/client/main.go index 2077d8dc..8f63e55b 100644 --- a/examples/grpc/client/main.go +++ b/examples/grpc/client/main.go @@ -20,6 +20,7 @@ import ( "flag" "log" "os" + "sync" "time" "google.golang.org/grpc" @@ -27,6 +28,8 @@ import ( "github.com/cisco-open/nasp/examples/grpc/pb" "github.com/cisco-open/nasp/pkg/istio" + "github.com/cisco-open/nasp/pkg/network" + "github.com/cisco-open/nasp/pkg/util" ) var heimdallURL string @@ -38,6 +41,8 @@ func init() { } func main() { + logger := klog.Background() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -51,36 +56,50 @@ func main() { iih, err := istio.NewIstioIntegrationHandler(&istioHandlerConfig, klog.TODO()) if err != nil { - log.Fatal(err) + panic(err) } grpcDialOptions, err := iih.GetGRPCDialOptions() if err != nil { - log.Fatal(err) + panic(err) } - client, err := grpc.Dial( + client, err := grpc.DialContext( + ctx, "localhost:8082", grpcDialOptions..., ) if err != nil { - log.Fatal(err) + panic(err) } - iih.Run(ctx) + if err := iih.Run(ctx); err != nil { + panic(err) + } func() { defer cancel() defer client.Close() + wg := sync.WaitGroup{} for i := 0; i < 10; i++ { - reply, err := pb.NewGreeterClient(client).SayHello(ctx, &pb.HelloRequest{Name: "world"}) - if err != nil { - log.Fatal(err) - } - - log.Println(reply.Message) + wg.Add(1) + go func() { + defer wg.Done() + ctx = network.NewConnectionStateHolderToContext(ctx) + reply, err := pb.NewGreeterClient(client).SayHello(ctx, &pb.HelloRequest{Name: "world"}) + if err != nil { + log.Fatal(err) + } + + if s, ok := network.ConnectionStateFromContext(ctx); ok { + util.PrintConnectionState(s, logger) + } + + logger.Info(reply.Message) + }() } + wg.Wait() }() time.Sleep(time.Millisecond * 100) diff --git a/examples/grpc/main.go b/examples/grpc/main.go index ff958eb0..3d832ba4 100644 --- a/examples/grpc/main.go +++ b/examples/grpc/main.go @@ -21,13 +21,14 @@ import ( "log" "os" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" "k8s.io/klog/v2" "github.com/cisco-open/nasp/examples/grpc/pb" "github.com/cisco-open/nasp/pkg/istio" - - "google.golang.org/grpc" - "google.golang.org/grpc/reflection" + "github.com/cisco-open/nasp/pkg/network" + "github.com/cisco-open/nasp/pkg/util" ) var heimdallURL string @@ -43,7 +44,12 @@ type greeterServer struct { } func (gs *greeterServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { + if s, ok := network.ConnectionStateFromContext(ctx); ok { + util.PrintConnectionState(s, klog.Background()) + } + log.Printf("Received: %v", in.GetName()) + return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil } @@ -64,7 +70,9 @@ func main() { panic(err) } - iih.Run(ctx) + if err := iih.Run(ctx); err != nil { + panic(err) + } grpcServer := grpc.NewServer() reflection.Register(grpcServer) @@ -72,7 +80,7 @@ func main() { //////// standard HTTP library version with TLS - err = iih.ListenAndServe(context.Background(), ":8082", grpcServer) + err = iih.ListenAndServe(ctx, ":8082", grpcServer) if err != nil { log.Fatalf("failed to serve: %v", err) } diff --git a/examples/http/main.go b/examples/http/main.go index 761b6728..45e6b292 100644 --- a/examples/http/main.go +++ b/examples/http/main.go @@ -18,7 +18,6 @@ import ( "context" "errors" "flag" - "fmt" "io/ioutil" "net/http" "os" @@ -32,6 +31,7 @@ import ( "github.com/cisco-open/nasp/pkg/istio" "github.com/cisco-open/nasp/pkg/network" + "github.com/cisco-open/nasp/pkg/util" ) var mode string @@ -75,32 +75,16 @@ func sendHTTPRequest(url string, transport http.RoundTripper, logger logr.Logger if dumpClientResponse { buff, _ := ioutil.ReadAll(response.Body) - fmt.Printf("%s\n", string(buff)) + os.Stdout.Write(buff) } - if conn, ok := network.WrappedConnectionFromContext(response.Request.Context()); ok { - printConnectionInfo(conn, logger) + if state, ok := network.ConnectionStateFromContext(response.Request.Context()); ok { + util.PrintConnectionState(state, logger) } return nil } -func printConnectionInfo(connection network.Connection, logger logr.Logger) { - localAddr := connection.LocalAddr().String() - remoteAddr := connection.RemoteAddr().String() - var localSpiffeID, remoteSpiffeID string - - if cert := connection.GetLocalCertificate(); cert != nil { - localSpiffeID = cert.GetFirstURI() - } - - if cert := connection.GetPeerCertificate(); cert != nil { - remoteSpiffeID = cert.GetFirstURI() - } - - logger.Info("connection info", "localAddr", localAddr, "localSpiffeID", localSpiffeID, "remoteAddr", remoteAddr, "remoteSpiffeID", remoteSpiffeID, "ttfb", connection.GetTimeToFirstByte().Format(time.RFC3339Nano)) -} - func main() { logger := klog.TODO() ctx, cancel := context.WithCancel(context.Background()) @@ -128,9 +112,11 @@ func main() { iih.Run(ctx) - // make idle timeout minimal to test least request increment/decrement + // use client side connection pooling t := http.DefaultTransport.(*http.Transport) - t.IdleConnTimeout = time.Nanosecond * 1 + t.MaxIdleConns = 50 + t.MaxConnsPerHost = 50 + t.MaxIdleConnsPerHost = 50 transport, err := iih.GetHTTPTransport(t) if err != nil { @@ -153,10 +139,10 @@ func main() { i++ } - time.Sleep(sleepBeforeClientExit) - wg.Wait() + time.Sleep(sleepBeforeClientExit) + if len(clientErrors) > 0 { os.Exit(2) } @@ -174,6 +160,11 @@ func main() { logger.Error(err, "could not send http request") } } + + if state, ok := network.ConnectionStateFromContext(c.Request.Context()); ok { + util.PrintConnectionState(state, logger) + } + c.Data(http.StatusOK, "text/html", []byte("Hello world!")) }) err = iih.ListenAndServe(context.Background(), ":8080", r.Handler()) diff --git a/examples/tcp/main.go b/examples/tcp/main.go index 5f83feb0..f9515509 100644 --- a/examples/tcp/main.go +++ b/examples/tcp/main.go @@ -27,6 +27,8 @@ import ( "k8s.io/klog/v2" "github.com/cisco-open/nasp/pkg/istio" + "github.com/cisco-open/nasp/pkg/network" + "github.com/cisco-open/nasp/pkg/util" ) var mode string @@ -96,11 +98,16 @@ func server() { panic(err) } go func(conn net.Conn) { - defer conn.Close() + defer func() { + if s, ok := network.ConnectionStateFromNetConn(conn); ok { + util.PrintConnectionState(s, klog.Background()) + } + conn.Close() + }() reader := bufio.NewReader(conn) for { // read client request data - bytes, err := reader.ReadBytes(byte('!')) + bytes, err := reader.ReadBytes(byte('\n')) if err != nil { if err != io.EOF { fmt.Println("failed to read data, err:", err) @@ -138,7 +145,12 @@ func client() { panic(err) } - defer conn.Close() + defer func() { + if s, ok := network.ConnectionStateFromNetConn(conn); ok { + util.PrintConnectionState(s, klog.Background()) + } + conn.Close() + }() for i := 0; i < 5; i++ { if err := sendReceive(conn, fmt.Sprintf("name %d", i)); err != nil { diff --git a/experimental/spring/spring-boot-nasp-example/pom.xml b/experimental/spring/spring-boot-nasp-example/pom.xml index 0036137b..8a0264dc 100644 --- a/experimental/spring/spring-boot-nasp-example/pom.xml +++ b/experimental/spring/spring-boot-nasp-example/pom.xml @@ -35,7 +35,7 @@ com.google.cloud.tools jib-maven-plugin - 3.3.1 + 3.3.2 diff --git a/experimental/wasm-tcp-metadata/Cargo.toml b/experimental/wasm-tcp-metadata/Cargo.toml index 7818f8e5..a6159d33 100644 --- a/experimental/wasm-tcp-metadata/Cargo.toml +++ b/experimental/wasm-tcp-metadata/Cargo.toml @@ -8,16 +8,6 @@ edition = "2021" [lib] crate-type = ["cdylib"] -[profile.release] -# do not include debug symbols -debug = false -# link-time optimalization -lto = true -# optimize for binary size for wasm s is better than z -opt-level = "s" -# do not unwind the stack when panicking -panic = "abort" - [dependencies] proxy-wasm = "0.2.1" log = "0.4" diff --git a/go.mod b/go.mod index 8251973b..7868b9a2 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/cncf/xds/go v0.0.0-20220520190051-1e77728a1eaa github.com/envoyproxy/go-control-plane v0.10.3-0.20220719090109-b024c36d9935 github.com/gin-contrib/gzip v0.0.6 - github.com/gin-gonic/gin v1.8.1 + github.com/gin-gonic/gin v1.9.0 github.com/go-logr/logr v1.2.3 github.com/gohobby/deepcopy v1.0.1 github.com/golang/protobuf v1.5.2 @@ -56,8 +56,10 @@ require ( github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20220418222510-f25a4f6275ed // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 // indirect + github.com/bytedance/sonic v1.8.0 // indirect github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.0-20210816181553-5444fa50b93d // indirect github.com/docker/docker v20.10.24+incompatible // indirect @@ -69,10 +71,10 @@ require ( github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/jsonreference v0.19.6 // indirect github.com/go-openapi/swag v0.21.1 // indirect - github.com/go-playground/locales v0.14.0 // indirect - github.com/go-playground/universal-translator v0.18.0 // indirect - github.com/go-playground/validator/v10 v10.10.0 // indirect - github.com/goccy/go-json v0.9.7 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.11.2 // indirect + github.com/goccy/go-json v0.10.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect @@ -90,6 +92,7 @@ require ( github.com/inconshreveable/mousetrap v1.0.1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect github.com/lestrrat-go/blackmagic v1.0.0 // indirect @@ -98,7 +101,7 @@ require ( github.com/lestrrat-go/jwx v1.2.25 // indirect github.com/lestrrat-go/option v1.0.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/miekg/dns v1.1.50 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect @@ -109,7 +112,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/natefinch/lumberjack v2.0.0+incompatible // indirect github.com/nxadm/tail v1.4.8 // indirect - github.com/pelletier/go-toml/v2 v2.0.1 // indirect + github.com/pelletier/go-toml/v2 v2.0.6 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.2.0 // indirect @@ -120,13 +123,15 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect github.com/tetratelabs/wazero v1.0.0-pre.8 // indirect - github.com/ugorji/go/codec v1.2.7 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.9 // indirect github.com/wasmerio/wasmer-go v1.0.4 // indirect go.opencensus.io v0.23.0 // indirect go.opentelemetry.io/proto/otlp v0.18.0 // indirect go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.21.0 // indirect - golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/crypto v0.5.0 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/oauth2 v0.0.0-20220622183110-fd043fe589d2 // indirect golang.org/x/sys v0.5.0 // indirect diff --git a/go.sum b/go.sum index b4d346a8..55e85659 100644 --- a/go.sum +++ b/go.sum @@ -124,6 +124,9 @@ github.com/blend/go-sdk v1.20220411.3 h1:GFV4/FQX5UzXLPwWV03gP811pj7B8J2sbuq+GJQ github.com/blend/go-sdk v1.20220411.3/go.mod h1:7lnH8fTi6U4i1fArEXRyOIY2E1X4MALg09qsQqY1+ak= github.com/bytecodealliance/wasmtime-go/v3 v3.0.2 h1:3uZCA/BLTIu+DqCfguByNMJa2HVHpXvjfy0Dy7g6fuA= github.com/bytecodealliance/wasmtime-go/v3 v3.0.2/go.mod h1:RnUjnIXxEJcL6BgCvNyzCCRzZcxCgsZCi+RNlvYor5Q= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.8.0 h1:ea0Xadu+sHlu7x5O3gKhRpQ1IKiMrSiHttPF0ybECuA= +github.com/bytedance/sonic v1.8.0/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/cenkalti/backoff/v4 v4.1.3 h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8UtC4= github.com/cenkalti/backoff/v4 v4.1.3/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -136,6 +139,9 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chai2010/gettext-go v0.0.0-20160711120539-c6fed771bfd5 h1:7aWHqerlJ41y6FOsEUvknqgXnGmJyJSbjhAWq5pO4F8= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -215,8 +221,9 @@ github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d github.com/gin-contrib/gzip v0.0.6/go.mod h1:QOJlmV2xmayAjkNS2Y8NQsMneuRShOU/kjovCXNuzzk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= -github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= +github.com/gin-gonic/gin v1.9.0 h1:OjyFBKICoexlu99ctXNR2gg+c5pKrKMuyjgARg9qeY8= +github.com/gin-gonic/gin v1.9.0/go.mod h1:W1Me9+hsUSyj3CePGrd1/QrKJMSJ1Tu/0hFEH89961k= github.com/go-errors/errors v1.0.1 h1:LUHzmkK3GUKUrL/1gfBUxAHzcev3apQlezX/+O7ma6w= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -249,18 +256,22 @@ github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh github.com/go-openapi/swag v0.19.14/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-openapi/swag v0.21.1 h1:wm0rhTb5z7qpJRHBdPOMuY4QjVUMbF6/kwoYeRAOrKU= github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= -github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= -github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= -github.com/go-playground/validator/v10 v10.10.0 h1:I7mrTYv78z8k8VXa/qJlOlEXn/nBh+BF8dHX5nt/dr0= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= +github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU= +github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/goccy/go-json v0.9.7 h1:IcB+Aqpx/iMHu5Yooh7jEzJk1JZ7Pjtmys2ukPr7EeM= github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/goccy/go-json v0.10.0 h1:mXKd9Qw4NuzShiRlOXKews24ufknHO7gx30lsDyokKA= +github.com/goccy/go-json v0.10.0/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= @@ -423,6 +434,8 @@ github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8 github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= @@ -457,8 +470,9 @@ github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= @@ -514,8 +528,9 @@ github.com/openshift/api v0.0.0-20200713203337-b2494ecb17dd h1:MV2FH/cm1wqoVCIL9 github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= -github.com/pelletier/go-toml/v2 v2.0.1 h1:8e3L2cCQzLFi2CR4g7vGFuFxX7Jl1kKX8gW+iV0GUKU= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= +github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= +github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek= github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= @@ -603,9 +618,12 @@ github.com/tetratelabs/wabin v0.0.0-20220927005300-3b0fbf39a46a h1:P0R3+CTAT7daT github.com/tetratelabs/wazero v1.0.0-pre.8 h1:Ir82PWj79WCppH+9ny73eGY2qv+oCnE3VwMY92cBSyI= github.com/tetratelabs/wazero v1.0.0-pre.8/go.mod h1:u8wrFmpdrykiFK0DFPiFm5a4+0RzsdmXYVtijBKqUVo= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6M= -github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= +github.com/ugorji/go/codec v1.2.9 h1:rmenucSohSTiyL09Y+l2OCk+FrMxGMzho2+tjr5ticU= +github.com/ugorji/go/codec v1.2.9/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/wasmerio/wasmer-go v1.0.4 h1:MnqHoOGfiQ8MMq2RF6wyCeebKOe84G88h5yv+vmxJgs= github.com/wasmerio/wasmer-go v1.0.4/go.mod h1:0gzVdSfg6pysA6QVp6iVRPTagC6Wq9pOE8J86WKb2Fk= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= @@ -670,6 +688,8 @@ go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -682,8 +702,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220427172511-eb4f295cb31f/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e h1:T8NU3HyQ8ClP4SEE+KbFlg6n0NhuTsN4MyznaarGsZM= -golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -900,6 +920,7 @@ golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220610221304-9f5ed59c137d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= @@ -1262,6 +1283,7 @@ k8s.io/utils v0.0.0-20210802155522-efc7438f0176/go.mod h1:jPW/WVKK9YHAvNhRxK0md/ k8s.io/utils v0.0.0-20220210201930-3a6ce19ff2f9 h1:HNSDgDCrr/6Ly3WEGKZftiE7IY19Vz2GdbOCyI4qqhc= k8s.io/utils v0.0.0-20220210201930-3a6ce19ff2f9/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.0.30/go.mod h1:fEO7lRTdivWO2qYVCVG7dEADOMo/MLDCVr8So2g88Uw= diff --git a/pkg/istio/discovery/connection_wrappers.go b/pkg/istio/discovery/connection_wrappers.go index 62a9b105..5cf185ed 100644 --- a/pkg/istio/discovery/connection_wrappers.go +++ b/pkg/istio/discovery/connection_wrappers.go @@ -17,6 +17,7 @@ package discovery import ( "context" "net" + "sync" "github.com/cisco-open/nasp/pkg/network" ) @@ -61,11 +62,15 @@ type ConnectionCloseWrapper = network.ConnectionCloseWrapper type discoveryClientCloser struct { discoveryClient DiscoveryClient parentCloseWrapper ConnectionCloseWrapper + + once sync.Once } func (d *xdsDiscoveryClient) NewConnectionCloseWrapper() ConnectionCloseWrapper { return &discoveryClientCloser{ discoveryClient: d, + + once: sync.Once{}, } } @@ -82,14 +87,16 @@ func (c *discoveryClientCloser) AfterClose(conn net.Conn) error { } func (c *discoveryClientCloser) BeforeClose(conn net.Conn) error { - address := conn.RemoteAddr().String() - if res, ok := conn.(interface { - GetOriginalAddress() string - }); ok { - address = res.GetOriginalAddress() - } - - c.discoveryClient.DecrementActiveRequestsCount(address) + c.once.Do(func() { + address := conn.RemoteAddr().String() + if res, ok := conn.(interface { + GetOriginalAddress() string + }); ok { + address = res.GetOriginalAddress() + } + + c.discoveryClient.DecrementActiveRequestsCount(address) + }) if c.parentCloseWrapper != nil { return c.parentCloseWrapper.BeforeClose(conn) diff --git a/pkg/istio/istio.go b/pkg/istio/istio.go index ca0449bc..fac8b08d 100644 --- a/pkg/istio/istio.go +++ b/pkg/istio/istio.go @@ -43,13 +43,13 @@ import ( itcp "github.com/cisco-open/nasp/pkg/istio/tcp" k8slabels "github.com/cisco-open/nasp/pkg/k8s/labels" "github.com/cisco-open/nasp/pkg/network" + "github.com/cisco-open/nasp/pkg/network/listener" "github.com/cisco-open/nasp/pkg/proxywasm" "github.com/cisco-open/nasp/pkg/proxywasm/api" pwgrpc "github.com/cisco-open/nasp/pkg/proxywasm/grpc" pwhttp "github.com/cisco-open/nasp/pkg/proxywasm/http" "github.com/cisco-open/nasp/pkg/proxywasm/middleware" "github.com/cisco-open/nasp/pkg/proxywasm/tcp" - unifiedtls "github.com/cisco-open/nasp/pkg/tls" ) var DefaultIstioIntegrationHandlerConfig = IstioIntegrationHandlerConfig{ @@ -294,8 +294,9 @@ func (h *istioIntegrationHandler) GetHTTPTransport(transport http.RoundTripper) return nil, errors.Wrap(err, "could not get stream handler") } - tp := NewIstioHTTPRequestTransport(transport, h.caClient, h.discoveryClient, h.logger.WithName("http-transport")) - httpTransport := pwhttp.NewHTTPTransport(tp, streamHandler, h.logger) + logger := h.logger.WithName("http-transport") + tp := NewIstioHTTPRequestTransport(transport, h.caClient, h.discoveryClient, logger) + httpTransport := pwhttp.NewHTTPTransport(tp, streamHandler, logger) httpTransport.AddMiddleware(middleware.NewEnvoyHTTPHandlerMiddleware()) httpTransport.AddMiddleware(NewIstioHTTPHandlerMiddleware()) @@ -375,12 +376,12 @@ func (h *istioIntegrationHandler) ListenAndServe(ctx context.Context, listenAddr if lp.UseTLS() { if lp.Permissive() { - ul.SetTLSMode(unifiedtls.TLSModePermissive) + ul.SetTLSMode(listener.TLSModePermissive) } else { - ul.SetTLSMode(unifiedtls.TLSModeStrict) + ul.SetTLSMode(listener.TLSModeStrict) } } else { - ul.SetTLSMode(unifiedtls.TLSModeDisabled) + ul.SetTLSMode(listener.TLSModeDisabled) } if lp.IsClientCertificateRequired() { ul.SetTLSClientAuthMode(tls.RequireAnyClientCert) @@ -551,7 +552,11 @@ func (h *istioIntegrationHandler) GetTCPListener(l net.Listener) (net.Listener, }, } - l = unifiedtls.NewUnifiedListener(network.NewWrappedListener(l), network.WrapTLSConfig(tlsConfig), unifiedtls.TLSModeStrict) + l = listener.NewUnifiedListener(l, network.WrapTLSConfig(tlsConfig), listener.TLSModePermissive, + listener.UnifiedListenerWithTLSConnectionCreator(network.CreateTLSServerConn), + listener.UnifiedListenerWithConnectionWrapper(func(c net.Conn) net.Conn { + return network.WrapConnection(c) + })) return tcp.WrapListener(l, streamHandler), nil } diff --git a/pkg/istio/tcp/dialer.go b/pkg/istio/tcp/dialer.go index 040fa776..674a6737 100644 --- a/pkg/istio/tcp/dialer.go +++ b/pkg/istio/tcp/dialer.go @@ -49,15 +49,9 @@ func NewTCPDialer(streamHandler api.StreamHandler, tlsConfig *tls.Config, discov } func (d *tcpDialer) DialContext(ctx context.Context, _net string, address string) (net.Conn, error) { - ctx = network.NewConnectionToContext(ctx) - tlsConfig := d.tlsConfig.Clone() - prop, err := d.discoveryClient.GetTCPClientPropertiesByHost(context.Background(), address) - if err != nil { - return nil, err - } - + prop, _ := d.discoveryClient.GetTCPClientPropertiesByHost(context.Background(), address) if prop != nil { if !prop.UseTLS() { tlsConfig = nil @@ -74,7 +68,7 @@ func (d *tcpDialer) DialContext(ctx context.Context, _net string, address string useTLS := tlsConfig != nil opts := []network.DialerOption{ - network.DialerWithWrappedConnectionOptions(network.WrappedConnectionWithCloserWrapper(d.discoveryClient.NewConnectionCloseWrapper())), + network.DialerWithConnectionOptions(network.ConnectionWithCloserWrapper(d.discoveryClient.NewConnectionCloseWrapper())), network.DialerWithDialerWrapper(d.discoveryClient.NewDialWrapper()), } @@ -117,7 +111,7 @@ func (d *tcpDialer) DialContext(ctx context.Context, _net string, address string d.connectionPoolRegistry.AddPool(address, p) } - cp, err = d.connectionPoolRegistry.GetPool(address) + cp, err := d.connectionPoolRegistry.GetPool(address) if err != nil { return nil, err } diff --git a/pkg/istio/transport.go b/pkg/istio/transport.go index 5daea205..c749febf 100644 --- a/pkg/istio/transport.go +++ b/pkg/istio/transport.go @@ -15,7 +15,6 @@ package istio import ( - "context" "crypto/tls" "crypto/x509" "net/http" @@ -48,7 +47,7 @@ func NewIstioHTTPRequestTransport(transport http.RoundTripper, caClient ca.Clien func (t *istioHTTPRequestTransport) RoundTrip(req *http.Request) (*http.Response, error) { var tlsConfig *tls.Config - if prop, _ := t.discoveryClient.GetHTTPClientPropertiesByHost(context.Background(), req.URL.Host); prop != nil { //nolint:nestif + if prop, _ := t.discoveryClient.GetHTTPClientPropertiesByHost(req.Context(), req.URL.Host); prop != nil { //nolint:nestif t.logger.V(3).Info("discovered overrides", "overrides", prop) if endpointAddr, err := prop.Address(); err != nil { return nil, err @@ -73,7 +72,7 @@ func (t *istioHTTPRequestTransport) RoundTrip(req *http.Request) (*http.Response } opts := []network.DialerOption{ - network.DialerWithWrappedConnectionOptions(network.WrappedConnectionWithCloserWrapper(t.discoveryClient.NewConnectionCloseWrapper())), + network.DialerWithConnectionOptions(network.ConnectionWithCloserWrapper(t.discoveryClient.NewConnectionCloseWrapper())), network.DialerWithDialerWrapper(t.discoveryClient.NewDialWrapper()), } diff --git a/pkg/network/cert.go b/pkg/network/cert.go index fb6d4eef..01660042 100644 --- a/pkg/network/cert.go +++ b/pkg/network/cert.go @@ -20,6 +20,7 @@ import ( "crypto/x509" "errors" "fmt" + "strings" ) var ErrInvalidCertificate = errors.New("invalid tls certificate") @@ -31,6 +32,7 @@ type Certificate interface { GetFirstURI() string GetFirstDNSName() string GetFirstIP() string + String() string } type certificate struct { @@ -91,3 +93,33 @@ func (c *certificate) GetFirstIP() string { return c.IPAddresses[0].String() } + +func (c *certificate) String() string { + p := make([]string, 0) + + if s := c.GetSubject(); s != "" { + p = append(p, fmt.Sprintf("Subject=%s", s)) + } + + if u := c.URIs; len(u) > 0 { + uris := []string{} + for _, uri := range u { + uris = append(uris, uri.String()) + } + p = append(p, fmt.Sprintf("URIs=%s", strings.Join(uris, ","))) + } + + if d := c.DNSNames; len(d) > 0 { + p = append(p, fmt.Sprintf("DNSNames=%s", strings.Join(d, ","))) + } + + if i := c.IPAddresses; len(i) > 0 { + ips := []string{} + for _, ip := range i { + ips = append(ips, ip.String()) + } + p = append(p, fmt.Sprintf("IPs=%s", strings.Join(ips, ","))) + } + + return strings.Join(p, " ") +} diff --git a/pkg/network/conn.go b/pkg/network/conn.go deleted file mode 100644 index e0f02439..00000000 --- a/pkg/network/conn.go +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright (c) 2022 Cisco and/or its affiliates. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package network - -import ( - "context" - "crypto/tls" - "crypto/x509" - "errors" - "net" - "reflect" - "time" -) - -type Connection interface { //nolint:interfacebloat - net.Conn - - NetConn() net.Conn - SetNetConn(net.Conn) - GetTimeToFirstByte() time.Time - GetLocalCertificate() Certificate - GetPeerCertificate() Certificate - SetLocalCertificate(cert *tls.Certificate) - GetConnectionState() *tls.ConnectionState - SetConnWithTLSConnectionState(ConnWithTLSConnectionState) - SetOptions(opts ...WrappedConnectionOption) - GetOriginalAddress() string -} - -type ConnectionCloseWrapper interface { - AddParentCloseWrapper(ConnectionCloseWrapper) - BeforeClose(net.Conn) error - AfterClose(net.Conn) error -} - -type ConnWithTLSConnectionState interface { - ConnectionState() tls.ConnectionState -} - -type wrappedListener struct { - net.Listener -} - -func NewWrappedListener(l net.Listener) net.Listener { - return &wrappedListener{ - Listener: l, - } -} - -func (l *wrappedListener) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() - if err != nil { - return nil, err - } - - return NewWrappedConnection(c), nil -} - -type wrappedConn struct { - net.Conn - - connWithTLSConnectionState ConnWithTLSConnectionState - timeToFirstByte time.Time - localCertificate *x509.Certificate - closeWrapper ConnectionCloseWrapper - beforeCloseRun bool - originalAddress string -} - -type WrappedConnectionOption func(*wrappedConn) - -func WrappedConnectionWithCloserWrapper(closeWrapper ConnectionCloseWrapper) WrappedConnectionOption { - return func(wc *wrappedConn) { - if wc.closeWrapper != nil && !reflect.DeepEqual(wc.closeWrapper, closeWrapper) { - closeWrapper.AddParentCloseWrapper(wc.closeWrapper) - } - wc.closeWrapper = closeWrapper - } -} - -func WrappedConnectionWithOriginalAddress(address string) WrappedConnectionOption { - return func(wc *wrappedConn) { - wc.originalAddress = address - } -} - -func NewWrappedConnection(conn net.Conn, opts ...WrappedConnectionOption) Connection { - wc := &wrappedConn{ - Conn: conn, - } - - wc.SetOptions(opts...) - - return wc -} - -func NewConnectionToContext(ctx context.Context) context.Context { - return WrappedConnectionToContext(ctx, NewWrappedConnection(NewNilConn())) -} - -func (c *wrappedConn) NetConn() net.Conn { - return c.Conn -} - -func (c *wrappedConn) SetNetConn(conn net.Conn) { - c.Conn = conn -} - -func (c *wrappedConn) Read(b []byte) (n int, err error) { - if c.timeToFirstByte.IsZero() { - c.timeToFirstByte = time.Now() - } - - return c.Conn.Read(b) -} - -func (c *wrappedConn) Write(b []byte) (n int, err error) { - if c.timeToFirstByte.IsZero() { - c.timeToFirstByte = time.Now() - } - - return c.Conn.Write(b) -} - -func (c *wrappedConn) SetOptions(opts ...WrappedConnectionOption) { - for _, opt := range opts { - opt(c) - } -} - -func (c *wrappedConn) GetOriginalAddress() string { - return c.originalAddress -} - -func (c *wrappedConn) Close() error { - if c.closeWrapper != nil && !c.beforeCloseRun { - if err := c.closeWrapper.BeforeClose(c); err != nil { - return err - } - - c.beforeCloseRun = true - } - - if err := c.Conn.Close(); err != nil { - return err - } - - if c.closeWrapper != nil { - if err := c.closeWrapper.AfterClose(c); err != nil { - return err - } - } - - return nil -} - -func (c *wrappedConn) GetTimeToFirstByte() time.Time { - return c.timeToFirstByte -} - -func (c *wrappedConn) GetLocalCertificate() Certificate { - if c.localCertificate == nil { - return nil - } - - return &certificate{ - Certificate: c.localCertificate, - } -} - -func (c *wrappedConn) GetPeerCertificate() Certificate { - cs := c.GetConnectionState() - if cs == nil { - return nil - } - - if len(cs.PeerCertificates) < 1 { - return nil - } - - return &certificate{ - Certificate: cs.PeerCertificates[0], - } -} - -func (c *wrappedConn) SetLocalCertificate(cert *tls.Certificate) { - if cert, err := x509.ParseCertificate(cert.Certificate[0]); err == nil { - c.localCertificate = cert - } -} - -func (c *wrappedConn) SetConnWithTLSConnectionState(conn ConnWithTLSConnectionState) { - c.connWithTLSConnectionState = conn -} - -func (c *wrappedConn) GetConnectionState() *tls.ConnectionState { - if c.connWithTLSConnectionState != nil { - cs := c.connWithTLSConnectionState.ConnectionState() - return &cs - } - - return nil -} - -var ( - ErrNilRead = errors.New("cannot read from nilConn") - ErrNilWrite = errors.New("cannot write to nilConn") -) - -func NewNilConn() net.Conn { - return &nilConn{} -} - -type nilConn struct{} -type nilAddr struct{} - -func (a *nilAddr) String() string { - return "0.0.0.0:0" -} - -func (a *nilAddr) Network() string { - return "tcp" -} - -func (c *nilConn) Read(b []byte) (n int, err error) { - return 0, ErrNilRead -} - -func (c *nilConn) Write(b []byte) (n int, err error) { - return 0, ErrNilWrite -} - -func (c *nilConn) Close() error { - return nil -} - -func (c *nilConn) LocalAddr() net.Addr { - return &nilAddr{} -} - -func (c *nilConn) RemoteAddr() net.Addr { - return &nilAddr{} -} - -func (c *nilConn) SetDeadline(t time.Time) error { - return nil -} - -func (c *nilConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *nilConn) SetWriteDeadline(t time.Time) error { - return nil -} diff --git a/pkg/network/connection.go b/pkg/network/connection.go index 129493a5..8f213bd7 100644 --- a/pkg/network/connection.go +++ b/pkg/network/connection.go @@ -15,111 +15,187 @@ package network import ( - "context" - "fmt" + "crypto/tls" "net" + "reflect" + "sync" + "time" - "github.com/go-logr/logr" + "github.com/google/uuid" + + "github.com/cisco-open/nasp/pkg/network/listener" ) -type contextKey struct { - name string +type connection struct { + net.Conn + + id string + connectionState ConnectionState + + closeWrapper ConnectionCloseWrapper + + tlsConnectionStateSetter sync.Once + tlsConn *tls.Conn + + originalAddress string } -var connectionContextKey = contextKey{"network.connection"} -var ConnectionTrackerLogger logr.Logger = logr.Discard() +type ConnectionCloseWrapper interface { + AddParentCloseWrapper(ConnectionCloseWrapper) + BeforeClose(net.Conn) error + AfterClose(net.Conn) error +} -type ConnectionHolder interface { - SetConn(net.Conn) - NetConn() net.Conn +type ConnWithTLSConnectionState interface { + ConnectionState() tls.ConnectionState } -type connectionHolder struct { - net.Conn +type ConnectionOption func(*connection) + +// Make sure the connection implements ConnectionStateGetter interface +var _ ConnectionStateGetter = &connection{} + +func ConnectionWithCloserWrapper(closeWrapper ConnectionCloseWrapper) ConnectionOption { + return func(c *connection) { + if c.closeWrapper != nil && !reflect.DeepEqual(c.closeWrapper, closeWrapper) { + closeWrapper.AddParentCloseWrapper(c.closeWrapper) + } + + c.closeWrapper = closeWrapper + } } -func (h *connectionHolder) SetConn(conn net.Conn) { - h.Conn = conn +func ConnectionWithState(state ConnectionState) ConnectionOption { + return func(c *connection) { + c.connectionState = state + } } -func (h *connectionHolder) NetConn() net.Conn { - return h.Conn +func ConnectionWithOriginalAddress(address string) ConnectionOption { + return func(c *connection) { + c.originalAddress = address + } } -type connectionTracker struct { - logger logr.Logger +func WrapConnection(conn net.Conn, opts ...ConnectionOption) net.Conn { + if _, ok := conn.(ConnectionStateGetter); ok { + return conn + } + + c := &connection{ + Conn: conn, + } + + c.setOptions(opts...) + + if c.connectionState == nil { + c.connectionState = NewConnectionState() + } + + c.connectionState.SetLocalAddr(conn.LocalAddr()) + c.connectionState.SetRemoteAddr(conn.RemoteAddr()) - connection Connection - connWithTLSConnectionState ConnWithTLSConnectionState + c.id = uuid.NewString() + + return c } -func WrappedConnectionToContext(ctx context.Context, conn net.Conn) context.Context { - return context.WithValue(ctx, connectionContextKey, &connectionHolder{conn}) +func (c *connection) ID() string { + return c.id } -func ConnectionHolderFromContext(ctx context.Context) (ConnectionHolder, bool) { - if c, ok := ctx.Value(connectionContextKey).(ConnectionHolder); ok { - return c, true +func (c *connection) SetTLSConn(conn *tls.Conn) { + if c.tlsConn == nil { + c.tlsConn = conn } +} + +func (c *connection) GetConnectionState() ConnectionState { + return c.connectionState +} - return nil, false +func (c *connection) NetConn() net.Conn { + return c.Conn } -func SetConnectionToContextConnectionHolder(ctx context.Context, conn net.Conn) bool { - if h, ok := ConnectionHolderFromContext(ctx); ok { - h.SetConn(conn) +func (c *connection) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) - return true - } + c.connectionState.SetTimeToFirstByte(time.Now()) + + c.tlsConnectionStateSetter.Do(c.setTLSConnectionState) - return false + return } -func WrappedConnectionFromContext(ctx context.Context) (Connection, bool) { - if c, ok := ctx.Value(connectionContextKey).(ConnectionHolder); ok { - return WrappedConnectionFromNetConn(c.NetConn()) - } +func (c *connection) Write(b []byte) (n int, err error) { + n, err = c.Conn.Write(b) + + c.connectionState.SetTimeToFirstByte(time.Now()) + + c.tlsConnectionStateSetter.Do(c.setTLSConnectionState) - return nil, false + return } -func WrappedConnectionFromNetConn(conn net.Conn) (Connection, bool) { - t := &connectionTracker{ - logger: ConnectionTrackerLogger, +func (c *connection) Close() error { + if c.closeWrapper != nil { + if err := c.closeWrapper.BeforeClose(c); err != nil { + return err + } } - connection := t.getConnection(conn) + if err := c.Conn.Close(); err != nil { + return err + } - if connection != nil { - connection.SetConnWithTLSConnectionState(t.connWithTLSConnectionState) + if c.closeWrapper != nil { + if err := c.closeWrapper.AfterClose(c); err != nil { + return err + } } - return connection, connection != nil + return nil } -func (t *connectionTracker) getConnection(c net.Conn) Connection { - t.logger.Info("check connection", "type", fmt.Sprintf("%T", c)) +func (c *connection) setOptions(opts ...ConnectionOption) { + for _, opt := range opts { + opt(c) + } +} - if c, ok := c.(ConnWithTLSConnectionState); ok { - t.connWithTLSConnectionState = c - if t.connection != nil { - return t.connection - } +func (c *connection) setTLSConnectionState() { + if c.connectionState.GetTLSConnectionState().HandshakeComplete { + return + } + + if c.tlsConn != nil { + c.connectionState.SetTLSConnectionStateAsync(func() tls.ConnectionState { + cs := c.tlsConn.ConnectionState() + c.tlsConn = nil + + return cs + }) + + return } - if conn, ok := c.(Connection); ok && t.connection == nil { - t.connection = conn + t := &connectionStateTracker{ + logger: ConnectionStateTrackerLogger, } - if conn, ok := c.(interface { - NetConn() net.Conn - }); ok { - return t.getConnection(conn.NetConn()) - } else if conn, ok := c.(interface { - GetConn() net.Conn - }); ok { - return t.getConnection(conn.GetConn()) + tlsConn := t.getTLSConnection(c) + if tlsConn != nil { + if csc, ok := tlsConn.(interface { + ConnectionState() tls.ConnectionState + }); ok { + c.connectionState.SetTLSConnectionState(csc.ConnectionState()) + } } +} - return t.connection +func WrappedConnectionListener(l net.Listener) net.Listener { + return listener.NewListenerWithConnectionWrapper(l, func(c net.Conn) net.Conn { + return WrapConnection(c) + }) } diff --git a/pkg/network/connection_state.go b/pkg/network/connection_state.go new file mode 100644 index 00000000..5338f927 --- /dev/null +++ b/pkg/network/connection_state.go @@ -0,0 +1,276 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + "net/http" + "sync" + "time" + + "github.com/google/uuid" +) + +type contextKey struct { + name string +} + +var connectionStateContextKey = contextKey{"network.connection.state"} + +type ConnectionState interface { //nolint:interfacebloat + GetTimeToFirstByte() time.Time + SetTimeToFirstByte(time.Time) + ResetTimeToFirstByte() + GetLocalCertificate() Certificate + SetLocalCertificate(cert tls.Certificate) + GetPeerCertificate() Certificate + GetTLSConnectionState() tls.ConnectionState + SetTLSConnectionState(tls.ConnectionState) + SetTLSConnectionStateAsync(func() tls.ConnectionState) + LocalAddr() net.Addr + SetLocalAddr(net.Addr) + RemoteAddr() net.Addr + SetRemoteAddr(net.Addr) + GetOriginalAddress() string + ID() string +} + +type ConnectionStateGetter interface { + GetConnectionState() ConnectionState +} + +type ConnectionStateHolder interface { + Get() ConnectionState + Set(ConnectionState) + ID() string +} + +type connectionStateHolder struct { + connectionState ConnectionState + id string +} + +func (h *connectionStateHolder) ID() string { + return h.id +} + +func (h *connectionStateHolder) Get() ConnectionState { + return h.connectionState +} + +func (h *connectionStateHolder) Set(stat ConnectionState) { + h.connectionState = stat +} + +func ConnectionStateToContextFromNetConn(ctx context.Context, conn net.Conn) context.Context { + if state, ok := ConnectionStateFromNetConn(conn); ok { + state.SetTimeToFirstByte(time.Time{}) + + return ConnectionStateToContext(ctx, state) + } + + return ctx +} + +// ConnectionStateToContext creates a new context object with the provided +// ConnectionState object added to it as a value using the +// connectionStateContextKey constant as the key. +func ConnectionStateToContext(ctx context.Context, state ConnectionState) context.Context { + return context.WithValue(ctx, connectionStateContextKey, &connectionStateHolder{ + connectionState: state, + id: uuid.NewString(), + }) +} + +// ConnectionStateFromContext extracts the ConnectionState object from the +// provided context and returns it, along with a boolean indicating whether the +// extraction was successful. If the ConnectionState object is not found in the +// context or is nil, the function returns false. +func ConnectionStateFromContext(ctx context.Context) (ConnectionState, bool) { + if holder, ok := ctx.Value(connectionStateContextKey).(ConnectionStateHolder); ok && holder.Get() != nil { + return holder.Get(), true + } + + return nil, false +} + +// NewConnectionStateHolderToContext creates a new connectionStateHolder object +// with a new UUID string as the ID and adds it to the provided context using +// the connectionStateContextKey constant as the key. +func NewConnectionStateHolderToContext(ctx context.Context) context.Context { + return ConnectionStateHolderToContext(ctx, &connectionStateHolder{id: uuid.NewString()}) +} + +func ConnectionStateHolderToContext(ctx context.Context, holder ConnectionStateHolder) context.Context { + return context.WithValue(ctx, connectionStateContextKey, holder) +} + +func ConnectionStateHolderFromContext(ctx context.Context) (ConnectionStateHolder, bool) { + if holder, ok := ctx.Value(connectionStateContextKey).(ConnectionStateHolder); ok { + return holder, true + } + + return nil, false +} + +func ConnectionStateFromHTTPRequest(req *http.Request) (ConnectionState, bool) { + if state, ok := ConnectionStateFromContext(req.Context()); ok { + if req.TLS != nil { + state.SetTLSConnectionState(*req.TLS) + } + + return state, true + } + + return nil, false +} + +func NewConnectionState() ConnectionState { + return NewConnectionStateWithID(uuid.NewString()) +} + +func NewConnectionStateWithID(id string) ConnectionState { + s := &connectionState{ + id: id, + connectionState: tls.ConnectionState{ + HandshakeComplete: false, + }, + } + + return s +} + +type connectionState struct { + timeToFirstByte time.Time + localCertificate *x509.Certificate + connectionState tls.ConnectionState + localAddr net.Addr + remoteAddr net.Addr + originalAddress string + + id string + + mux sync.RWMutex + mux2 sync.Mutex +} + +func (s *connectionState) ID() string { + return s.id +} + +func (s *connectionState) GetOriginalAddress() string { + return s.originalAddress +} + +func (s *connectionState) GetTimeToFirstByte() time.Time { + s.mux.RLock() + defer s.mux.RUnlock() + + return s.timeToFirstByte +} + +func (s *connectionState) ResetTimeToFirstByte() { + s.mux.Lock() + defer s.mux.Unlock() + + s.timeToFirstByte = time.Time{} +} + +func (s *connectionState) SetTimeToFirstByte(t time.Time) { + s.mux.Lock() + defer s.mux.Unlock() + + if s.timeToFirstByte.IsZero() { + s.timeToFirstByte = t + } +} + +func (s *connectionState) GetLocalCertificate() Certificate { + s.mux.RLock() + defer s.mux.RUnlock() + + if s.localCertificate == nil { + return nil + } + + return &certificate{ + Certificate: s.localCertificate, + } +} + +func (s *connectionState) SetLocalCertificate(cert tls.Certificate) { + s.mux.Lock() + defer s.mux.Unlock() + + if cert, err := x509.ParseCertificate(cert.Certificate[0]); err == nil { + s.localCertificate = cert + } +} + +func (s *connectionState) GetPeerCertificate() Certificate { + cs := s.GetTLSConnectionState() + if len(cs.PeerCertificates) < 1 { + return nil + } + + return &certificate{ + Certificate: cs.PeerCertificates[0], + } +} + +func (s *connectionState) GetTLSConnectionState() tls.ConnectionState { + s.mux.RLock() + s.mux2.Lock() + defer func() { + s.mux.RUnlock() + s.mux2.Unlock() + }() + + return s.connectionState +} + +func (s *connectionState) SetTLSConnectionStateAsync(f func() tls.ConnectionState) { + s.mux2.Lock() + go func() { + defer s.mux2.Unlock() + s.connectionState = f() + }() +} + +func (s *connectionState) SetTLSConnectionState(state tls.ConnectionState) { + s.mux.Lock() + defer s.mux.Unlock() + + s.connectionState = state +} + +func (s *connectionState) LocalAddr() net.Addr { + return s.localAddr +} + +func (s *connectionState) SetLocalAddr(addr net.Addr) { + s.localAddr = addr +} + +func (s *connectionState) RemoteAddr() net.Addr { + return s.remoteAddr +} + +func (s *connectionState) SetRemoteAddr(addr net.Addr) { + s.remoteAddr = addr +} diff --git a/pkg/network/connection_state_tracker.go b/pkg/network/connection_state_tracker.go new file mode 100644 index 00000000..30215e76 --- /dev/null +++ b/pkg/network/connection_state_tracker.go @@ -0,0 +1,92 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "fmt" + "net" + + "github.com/go-logr/logr" +) + +var ConnectionStateTrackerLogger logr.Logger = logr.Discard() + +type connectionStateTracker struct { + logger logr.Logger + + connectionState ConnectionState + connWithTLSConnectionState ConnWithTLSConnectionState +} + +func ConnectionStateFromNetConn(conn net.Conn) (ConnectionState, bool) { + t := &connectionStateTracker{ + logger: ConnectionStateTrackerLogger, + } + + state := t.getConnectionState(conn) + + if state != nil && t.connWithTLSConnectionState != nil { + state.SetTLSConnectionState(t.connWithTLSConnectionState.ConnectionState()) + } + + return state, state != nil +} + +func (t *connectionStateTracker) getTLSConnection(c net.Conn) net.Conn { + t.logger.Info("check connection", "type", fmt.Sprintf("%T", c)) + + if _, ok := c.(ConnWithTLSConnectionState); ok { + return c + } + + if conn, ok := c.(interface { + NetConn() net.Conn + }); ok { + return t.getTLSConnection(conn.NetConn()) + } else if conn, ok := c.(interface { + GetConn() net.Conn + }); ok { + return t.getTLSConnection(conn.GetConn()) + } + + return nil +} + +func (t *connectionStateTracker) getConnectionState(c net.Conn) ConnectionState { + t.logger.Info("check connection", "type", fmt.Sprintf("%T", c)) + + if c, ok := c.(ConnWithTLSConnectionState); ok { + t.connWithTLSConnectionState = c + if t.connectionState != nil { + return t.connectionState + } + } + + if conn, ok := c.(ConnectionStateGetter); ok && t.connectionState == nil { + t.connectionState = conn.GetConnectionState() + } + + if conn, ok := c.(interface { + NetConn() net.Conn + }); ok { + return t.getConnectionState(conn.NetConn()) + } else if conn, ok := c.(interface { + GetConn() net.Conn + }); ok { + return t.getConnectionState(conn.GetConn()) + } + + return t.connectionState +} diff --git a/pkg/network/dialer.go b/pkg/network/dialer.go index d68a04e3..04c0c508 100644 --- a/pkg/network/dialer.go +++ b/pkg/network/dialer.go @@ -20,14 +20,16 @@ import ( "crypto/x509" "net" "reflect" + + "github.com/google/uuid" ) type dialer struct { netDialer *net.Dialer tlsConfig *tls.Config - dialWrapper DialWrapper - wrappedConnectionOptions []WrappedConnectionOption + dialWrapper DialWrapper + connectionOptions []ConnectionOption } type ConnectionDialer interface { @@ -52,9 +54,9 @@ func DialerWithDialerWrapper(w DialWrapper) DialerOption { } } -func DialerWithWrappedConnectionOptions(opts ...WrappedConnectionOption) DialerOption { +func DialerWithConnectionOptions(opts ...ConnectionOption) DialerOption { return func(d *dialer) { - d.wrappedConnectionOptions = opts + d.connectionOptions = opts } } @@ -129,26 +131,27 @@ func (d *dialer) dialContext(dialer connectionDialer, ctx context.Context, netwo } } - c, err := dialer.DialContext(ctx, network, addr) + var state ConnectionState + if s, ok := ConnectionStateFromContext(ctx); ok { + state = s + } else { + state = NewConnectionStateWithID("connection-" + uuid.NewString()) + } + + innerCtx := ConnectionStateToContext(ctx, state) + c, err := dialer.DialContext(innerCtx, network, addr) if err != nil { return nil, err } - wrapConnection := func() net.Conn { - opts := d.wrappedConnectionOptions - opts = append(opts, WrappedConnectionWithOriginalAddress(addr)) - - if wc, ok := WrappedConnectionFromContext(ctx); ok { - wc.SetNetConn(c) - wc.SetOptions(opts...) - - return wc - } - - return NewWrappedConnection(c, opts...) + if s, ok := ConnectionStateHolderFromContext(ctx); ok && s.Get() == nil { + s.Set(state) } - wc := wrapConnection() + opts := d.connectionOptions + opts = append(opts, ConnectionWithOriginalAddress(addr), ConnectionWithState(state)) + + wc := WrapConnection(c, opts...) if d.dialWrapper != nil { if err := d.dialWrapper.AfterDial(ctx, wc, network, addr); err != nil { diff --git a/pkg/network/http.go b/pkg/network/http.go index 3523d0bb..ae9c9312 100644 --- a/pkg/network/http.go +++ b/pkg/network/http.go @@ -21,11 +21,11 @@ import ( "net/http" "net/http/httptrace" - ltls "github.com/cisco-open/nasp/pkg/tls" + "github.com/cisco-open/nasp/pkg/network/listener" ) type Server interface { - GetUnifiedListener() ltls.UnifiedListener + GetUnifiedListener() listener.UnifiedListener Serve(l net.Listener) error ServeTLS(l net.Listener, certFile, keyFile string) error ServeWithTLSConfig(l net.Listener, config *tls.Config) error @@ -35,7 +35,7 @@ type Server interface { type server struct { *http.Server - unifiedListener ltls.UnifiedListener + unifiedListener listener.UnifiedListener } func WrapHTTPTransport(rt http.RoundTripper, dialer ConnectionDialer) http.RoundTripper { @@ -52,7 +52,14 @@ func WrapHTTPTransport(rt http.RoundTripper, dialer ConnectionDialer) http.Round } func WrapHTTPServer(srv *http.Server) Server { - srv.ConnContext = WrappedConnectionToContext + srv.ConnContext = ConnectionStateToContextFromNetConn + srv.ConnState = func(conn net.Conn, cstate http.ConnState) { + if state, ok := ConnectionStateFromNetConn(conn); ok { + if cstate != http.StateActive { + state.ResetTimeToFirstByte() + } + } + } return &server{ Server: srv, @@ -63,18 +70,28 @@ func (s *server) GetHTTPServer() *http.Server { return s.Server } -func (s *server) GetUnifiedListener() ltls.UnifiedListener { +func (s *server) GetUnifiedListener() listener.UnifiedListener { return s.unifiedListener } func (s *server) Serve(l net.Listener) error { - s.unifiedListener = ltls.NewUnifiedListener(NewWrappedListener(l), nil, ltls.TLSModeDisabled) + s.unifiedListener = listener.NewUnifiedListener(l, nil, listener.TLSModeDisabled, + listener.UnifiedListenerWithTLSConnectionCreator(CreateTLSServerConn), + listener.UnifiedListenerWithConnectionWrapper(func(c net.Conn) net.Conn { + return WrapConnection(c) + }), + ) return s.Server.Serve(s.unifiedListener) } func (s *server) ServeWithTLSConfig(l net.Listener, config *tls.Config) error { - s.unifiedListener = ltls.NewUnifiedListener(NewWrappedListener(l), WrapTLSConfig(config), ltls.TLSModePermissive) + s.unifiedListener = listener.NewUnifiedListener(l, WrapTLSConfig(config), listener.TLSModePermissive, + listener.UnifiedListenerWithTLSConnectionCreator(CreateTLSServerConn), + listener.UnifiedListenerWithConnectionWrapper(func(c net.Conn) net.Conn { + return WrapConnection(c) + }), + ) return s.Server.Serve(s.unifiedListener) } @@ -104,12 +121,21 @@ type transport struct { } func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { - ctx := NewConnectionToContext(req.Context()) - req = req.WithContext(httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + ctx := NewConnectionStateHolderToContext(req.Context()) + + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ GotConn: func(i httptrace.GotConnInfo) { - SetConnectionToContextConnectionHolder(ctx, i.Conn) + if connWithState, ok := i.Conn.(ConnectionStateGetter); ok { + if h, ok := ConnectionStateHolderFromContext(ctx); ok { + state := connWithState.GetConnectionState() + state.ResetTimeToFirstByte() + h.Set(state) + } + } }, - })) + }) + + req = req.WithContext(ctx) return t.RoundTripper.RoundTrip(req) } diff --git a/pkg/network/listener/listener.go b/pkg/network/listener/multi_listener.go similarity index 97% rename from pkg/network/listener/listener.go rename to pkg/network/listener/multi_listener.go index 87a46d6b..0c633e7c 100644 --- a/pkg/network/listener/listener.go +++ b/pkg/network/listener/multi_listener.go @@ -64,7 +64,7 @@ func MultiListenerWithLogger(logger logr.Logger) MultiListenerOption { } } -func NewListener(l net.Listener, options ...MultiListenerOption) MultiListener { +func NewMultiListener(l net.Listener, options ...MultiListenerOption) MultiListener { lstn := &multipleListeners{ listeners: sync.Map{}, diff --git a/pkg/tls/unified_tls.go b/pkg/network/listener/unified_tls_listener.go similarity index 57% rename from pkg/tls/unified_tls.go rename to pkg/network/listener/unified_tls_listener.go index 700863e4..85fe5bd6 100644 --- a/pkg/tls/unified_tls.go +++ b/pkg/network/listener/unified_tls_listener.go @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package tls +package listener import ( "bufio" "crypto/tls" "encoding/binary" + "errors" + "io" "net" ) @@ -58,6 +60,14 @@ func (c *unifiedConn) NetConn() net.Conn { return c.Conn } +func (c *unifiedConn) SetTLSConn(conn *tls.Conn) { + if c, ok := c.Conn.(interface { + SetTLSConn(conn *tls.Conn) + }); ok { + c.SetTLSConn(conn) + } +} + type UnifiedListener interface { net.Listener @@ -70,15 +80,42 @@ type unifiedListener struct { tlsConfig *tls.Config mode TLSMode + + tlsConnCreator func(net.Conn, *tls.Config) *tls.Conn + connectionWrapper ConnectionWrapper } -func NewUnifiedListener(listener net.Listener, tlsConfig *tls.Config, mode TLSMode) UnifiedListener { - return &unifiedListener{ +type UnifiedListenerOption func(*unifiedListener) + +func UnifiedListenerWithTLSConnectionCreator(f func(net.Conn, *tls.Config) *tls.Conn) UnifiedListenerOption { + return func(l *unifiedListener) { + l.tlsConnCreator = f + } +} + +func UnifiedListenerWithConnectionWrapper(wrapper ConnectionWrapper) UnifiedListenerOption { + return func(l *unifiedListener) { + l.connectionWrapper = wrapper + } +} + +func NewUnifiedListener(listener net.Listener, tlsConfig *tls.Config, mode TLSMode, opts ...UnifiedListenerOption) UnifiedListener { + l := &unifiedListener{ Listener: listener, tlsConfig: tlsConfig, mode: mode, } + + for _, o := range opts { + o(l) + } + + if l.tlsConnCreator == nil { + l.tlsConnCreator = tls.Server + } + + return l } func (l *unifiedListener) SetTLSMode(mode TLSMode) { @@ -95,30 +132,44 @@ func (l *unifiedListener) Accept() (net.Conn, error) { return nil, err } - if l.mode == TLSModeDisabled { - return c, nil - } - - if l.mode == TLSModeStrict { - return tls.Server(c, l.tlsConfig), nil + if l.mode == TLSModePermissive { + // buffer reads on our conn + var conn net.Conn + conn = &unifiedConn{ + Conn: c, + buf: bufio.NewReader(c), + } + + // inspect the first few bytes + hdr, err := conn.(*unifiedConn).buf.Peek(3) + // close connection in case of EOF without reporting error + if errors.Is(err, io.EOF) { + conn.Close() + return conn, nil + } + if err != nil { + conn.Close() + return nil, err + } + + if l.connectionWrapper != nil { + conn = l.connectionWrapper(conn) + } + + if isTLSHandhsake(hdr) { + return l.tlsConnCreator(conn, l.tlsConfig), nil + } + + return conn, nil } - // buffer reads on our conn - conn := &unifiedConn{ - Conn: c, - buf: bufio.NewReader(c), + if l.connectionWrapper != nil { + c = l.connectionWrapper(c) } - // inspect the first few bytes - hdr, err := conn.buf.Peek(3) - if err != nil { - conn.Close() - return nil, err - } - - if isTLSHandhsake(hdr) { - return tls.Server(conn, l.tlsConfig), nil + if l.mode == TLSModeDisabled { + return c, nil } - return conn, nil + return l.tlsConnCreator(c, l.tlsConfig), nil } diff --git a/pkg/network/listener/wrapped_listener.go b/pkg/network/listener/wrapped_listener.go new file mode 100644 index 00000000..0958cb7e --- /dev/null +++ b/pkg/network/listener/wrapped_listener.go @@ -0,0 +1,42 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package listener + +import "net" + +type listenerWithConnectionWrapper struct { + net.Listener + + wrapper ConnectionWrapper +} + +type ConnectionWrapper func(net.Conn) net.Conn + +func NewListenerWithConnectionWrapper(l net.Listener, wrapper ConnectionWrapper) net.Listener { + return &listenerWithConnectionWrapper{ + Listener: l, + + wrapper: wrapper, + } +} + +func (l *listenerWithConnectionWrapper) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + + return l.wrapper(c), nil +} diff --git a/pkg/network/network_test.go b/pkg/network/network_test.go index 307f3766..27c4c65a 100644 --- a/pkg/network/network_test.go +++ b/pkg/network/network_test.go @@ -26,6 +26,8 @@ import ( "math/rand" "net" "net/http" + "net/textproto" + "sync" "testing" "time" @@ -34,7 +36,8 @@ import ( "github.com/cisco-open/nasp/pkg/ca" "github.com/cisco-open/nasp/pkg/ca/selfsigned" "github.com/cisco-open/nasp/pkg/network" - ltls "github.com/cisco-open/nasp/pkg/tls" + "github.com/cisco-open/nasp/pkg/network/listener" + "github.com/cisco-open/nasp/pkg/network/pool" ) type NetworkTestSuite struct { @@ -140,7 +143,7 @@ func (s *NetworkTestSuite) wrappedTCPServer(ctx context.Context, handler func(ne return err } - l = network.NewWrappedListener(l) + l = network.WrappedConnectionListener(l) s.wrappedTCPServerRunning = true @@ -192,7 +195,12 @@ func (s *NetworkTestSuite) wrappedTLSServer(ctx context.Context, handler func(ne InsecureSkipVerify: true, } - l = ltls.NewUnifiedListener(network.NewWrappedListener(l), network.WrapTLSConfig(tlsConfig), ltls.TLSModeStrict) + l = listener.NewUnifiedListener(l, network.WrapTLSConfig(tlsConfig), listener.TLSModePermissive, + listener.UnifiedListenerWithTLSConnectionCreator(network.CreateTLSServerConn), + listener.UnifiedListenerWithConnectionWrapper(func(c net.Conn) net.Conn { + return network.WrapConnection(c) + }), + ) s.wrappedTLSServerRunning = true @@ -332,14 +340,21 @@ func (s *NetworkTestSuite) TestWrappedHTTPServerWithSimpleClient() { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - if connection, ok := network.WrappedConnectionFromContext(r.Context()); ok { - s.Require().True(ok) - s.Require().NotNil(connection) - s.Equal(s.wrappedHTTPServerAddr, connection.LocalAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) - } + startTime := time.Now() + + connectionState, ok := network.ConnectionStateFromContext(r.Context()) + s.Require().True(ok) + s.Require().NotNil(connectionState) + s.Equal(s.wrappedHTTPServerAddr, connectionState.LocalAddr().String()) + connectionState.ResetTimeToFirstByte() + _, err := w.Write(body) s.Require().Nil(err) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) }) go func() { @@ -353,15 +368,23 @@ func (s *NetworkTestSuite) TestWrappedHTTPServerWithSimpleClient() { c := http.Client{} - resp, err := c.Get("http://" + s.wrappedHTTPServerAddr) - s.Require().Nil(err) - defer resp.Body.Close() - s.Require().NotNil(resp) - s.Equal(resp.StatusCode, 200) + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := c.Get("http://" + s.wrappedHTTPServerAddr) + s.Require().Nil(err) + defer resp.Body.Close() + s.Require().NotNil(resp) + s.Equal(resp.StatusCode, 200) - b, err := io.ReadAll(resp.Body) - s.Require().Nil(err) - s.Equal(b, body) + b, err := io.ReadAll(resp.Body) + s.Require().Nil(err) + s.Equal(b, body) + }() + } + wg.Wait() } func (s *NetworkTestSuite) TestWrappedHTTPServerWithWrappedClient() { @@ -372,14 +395,21 @@ func (s *NetworkTestSuite) TestWrappedHTTPServerWithWrappedClient() { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + connectionState, ok := network.ConnectionStateFromContext(r.Context()) + s.Require().True(ok) + s.Require().NotNil(connectionState) + s.Equal(s.wrappedHTTPServerAddr, connectionState.LocalAddr().String()) + connectionState.ResetTimeToFirstByte() + _, err := w.Write(body) s.Require().Nil(err) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } - connection, ok := network.WrappedConnectionFromContext(r.Context()) - s.Require().True(ok) - s.Require().NotNil(connection) - s.Equal(s.wrappedHTTPServerAddr, connection.LocalAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) }) go func() { @@ -395,21 +425,31 @@ func (s *NetworkTestSuite) TestWrappedHTTPServerWithWrappedClient() { Transport: network.WrapHTTPTransport(http.DefaultTransport.(*http.Transport).Clone(), network.NewDialer()), } - resp, err := c.Get("http://" + s.wrappedHTTPServerAddr) - s.Require().Nil(err) - defer resp.Body.Close() - s.Require().NotNil(resp) - s.Equal(resp.StatusCode, 200) + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + startTime := time.Now() - b, err := io.ReadAll(resp.Body) - s.Require().Nil(err) - s.Equal(b, body) + defer wg.Done() + resp, err := c.Get("http://" + s.wrappedHTTPServerAddr) + s.Require().Nil(err) + defer resp.Body.Close() + s.Require().NotNil(resp) + s.Equal(resp.StatusCode, 200) + + b, err := io.ReadAll(resp.Body) + s.Require().Nil(err) + s.Equal(b, body) - connection, ok := network.WrappedConnectionFromContext(resp.Request.Context()) - s.Require().True(ok) - s.Require().NotNil(connection) - s.Equal(s.wrappedHTTPServerAddr, connection.RemoteAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + connectionState, ok := network.ConnectionStateFromContext(resp.Request.Context()) + s.Require().True(ok) + s.Require().NotNil(connectionState) + s.Equal(s.wrappedHTTPServerAddr, connectionState.RemoteAddr().String()) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) + }() + } + wg.Wait() } func (s *NetworkTestSuite) TestWrappedHTTPSServerWithWrappedClient() { @@ -419,22 +459,48 @@ func (s *NetworkTestSuite) TestWrappedHTTPSServerWithWrappedClient() { body := []byte("hello world") uri := "spiffe://acme.corp/https-wrapped-client" + type Stats struct { + Count int + HandshakeCompleteCount int + Addresses map[string]int + } + + serverStats := &Stats{ + Addresses: make(map[string]int), + } + clientStats := &Stats{ + Addresses: make(map[string]int), + } + mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + connectionState, ok := network.ConnectionStateFromContext(r.Context()) + s.Require().True(ok) + s.Require().NotNil(connectionState) + s.Equal(s.wrappedHTTPSServerAddr, connectionState.LocalAddr().String()) + connectionState.ResetTimeToFirstByte() + _, err := w.Write(body) s.Require().Nil(err) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } - connection, ok := network.WrappedConnectionFromContext(r.Context()) - s.Require().True(ok) - s.Require().NotNil(connection) - s.Equal(s.wrappedHTTPSServerAddr, connection.LocalAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) - s.Require().NotNil(connection.GetPeerCertificate()) - s.Equal(uri, connection.GetPeerCertificate().GetFirstURI()) + s.Require().NotNil(connectionState.GetPeerCertificate()) + s.Equal(uri, connectionState.GetPeerCertificate().GetFirstURI()) - s.Require().NotNil(connection.GetLocalCertificate()) - s.Equal(s.wrappedHTTPSServerCertURI, connection.GetLocalCertificate().GetFirstURI()) + s.Require().NotNil(connectionState.GetLocalCertificate()) + s.Equal(s.wrappedHTTPSServerCertURI, connectionState.GetLocalCertificate().GetFirstURI()) + + serverStats.Count++ + if connectionState.GetTLSConnectionState().HandshakeComplete { + serverStats.HandshakeCompleteCount++ + } + serverStats.Addresses[connectionState.RemoteAddr().String()]++ }) go func() { @@ -459,58 +525,117 @@ func (s *NetworkTestSuite) TestWrappedHTTPSServerWithWrappedClient() { InsecureSkipVerify: true, }) - c := http.Client{ //nolint:forcetypeassert - Transport: network.WrapHTTPTransport(http.DefaultTransport.(*http.Transport).Clone(), d), + var t *http.Transport + if tp, ok := http.DefaultTransport.(*http.Transport); ok { + t = tp.Clone() } + s.Require().NotNil(t) + t.MaxIdleConns = 10 + t.MaxConnsPerHost = 10 + t.MaxIdleConnsPerHost = 10 - resp, err := c.Get("https://" + s.wrappedHTTPSServerAddr + "/") - s.Require().Nil(err) - defer resp.Body.Close() - s.Require().NotNil(resp) - s.Equal(200, resp.StatusCode) + c := http.Client{ + Transport: network.WrapHTTPTransport(t, d), + } - b, err := io.ReadAll(resp.Body) - s.Require().Nil(err) - s.Equal(body, b) + wg := sync.WaitGroup{} + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + startTime := time.Now() + + defer wg.Done() + resp, err := c.Get("https://" + s.wrappedHTTPSServerAddr + "/") + s.Require().Nil(err) + defer resp.Body.Close() + s.Require().NotNil(resp) + s.Equal(200, resp.StatusCode) + + b, err := io.ReadAll(resp.Body) + s.Require().Nil(err) + s.Equal(body, b) + + connectionState, ok := network.ConnectionStateFromContext(resp.Request.Context()) + s.Require().True(ok) + s.Require().NotNil(connectionState) + s.Equal(s.wrappedHTTPSServerAddr, connectionState.RemoteAddr().String()) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) - connection, ok := network.WrappedConnectionFromContext(resp.Request.Context()) - s.Require().True(ok) - s.Require().NotNil(connection) - s.Equal(s.wrappedHTTPSServerAddr, connection.RemoteAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + s.Require().NotNil(connectionState.GetLocalCertificate()) + s.Equal(uri, connectionState.GetLocalCertificate().GetFirstURI()) - s.Require().NotNil(connection.GetLocalCertificate()) - s.Equal(uri, connection.GetLocalCertificate().GetFirstURI()) + s.Require().NotNil(connectionState.GetPeerCertificate()) + s.Equal(s.wrappedHTTPSServerCertURI, connectionState.GetPeerCertificate().GetFirstURI()) + + clientStats.Count++ + if connectionState.GetTLSConnectionState().HandshakeComplete { + clientStats.HandshakeCompleteCount++ + } + clientStats.Addresses[connectionState.LocalAddr().String()]++ + }() + } + wg.Wait() - s.Require().NotNil(connection.GetPeerCertificate()) - s.Equal(s.wrappedHTTPSServerCertURI, connection.GetPeerCertificate().GetFirstURI()) + s.Require().Equal(serverStats.Count, clientStats.Count) + s.Require().Equal(serverStats.HandshakeCompleteCount, clientStats.HandshakeCompleteCount) + s.Require().Equal(serverStats.Addresses, clientStats.Addresses) } -func (s *NetworkTestSuite) TestWrappedTLSServer() { +func (s *NetworkTestSuite) TestWrappedTLSServer() { //nolint:funlen ctx, cancelContext := context.WithCancel(context.Background()) uri := "spiffe://acme.corp/tls-client" + type Stats struct { + Count int + HandshakeCompleteCount int + Addresses map[string]int + } + + serverMux := sync.Mutex{} + clientMux := sync.Mutex{} + + serverStats := &Stats{ + Addresses: make(map[string]int), + } + clientStats := &Stats{ + Addresses: make(map[string]int), + } + go func() { err := s.wrappedTLSServer(ctx, func(conn net.Conn) { - b := make([]byte, 1) - n, err := conn.Read(b) - s.Require().Nil(err) - s.Equal(1, n) + for { + startTime := time.Now() - connection, ok := network.WrappedConnectionFromNetConn(conn) - s.Require().True(ok) - s.Require().NotNil(connection) + connectionState, ok := network.ConnectionStateFromNetConn(conn) + s.Require().True(ok) + s.Require().NotNil(connectionState) + connectionState.ResetTimeToFirstByte() + + r := textproto.NewReader(bufio.NewReader(conn)) + str, err := r.ReadLine() + s.Require().Nil(err) + s.Equal("hello", str) - s.Equal(s.wrappedTLSServerAddr, connection.LocalAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + s.Equal(s.wrappedTLSServerAddr, connectionState.LocalAddr().String()) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) - s.Require().NotNil(connection.GetPeerCertificate()) - s.Equal(uri, connection.GetPeerCertificate().GetFirstURI()) + s.Require().NotNil(connectionState.GetPeerCertificate()) + s.Equal(uri, connectionState.GetPeerCertificate().GetFirstURI()) - s.Require().NotNil(connection.GetLocalCertificate()) - s.Equal(s.wrappedTLServerCertURI, connection.GetLocalCertificate().GetFirstURI()) + s.Require().NotNil(connectionState.GetLocalCertificate()) + s.Equal(s.wrappedTLServerCertURI, connectionState.GetLocalCertificate().GetFirstURI()) + + serverStats.Count++ + if connectionState.GetTLSConnectionState().HandshakeComplete { + serverStats.HandshakeCompleteCount++ + } + serverMux.Lock() + serverStats.Addresses[connectionState.RemoteAddr().String()]++ + serverMux.Unlock() + } }) + s.Require().Nil(err) }() @@ -531,32 +656,71 @@ func (s *NetworkTestSuite) TestWrappedTLSServer() { InsecureSkipVerify: true, }) - ctx = network.NewConnectionToContext(ctx) - conn, err := d.DialTLSContext(ctx, "tcp", s.wrappedTLSServerAddr) - s.Require().Nil(err) + p, err := pool.NewChannelPool(func() (net.Conn, error) { + conn, err := d.DialContext(ctx, "tcp", s.wrappedTLSServerAddr) + if connectionState, ok := network.ConnectionStateFromNetConn(conn); ok { + connectionState.ResetTimeToFirstByte() + } - _, err = conn.Write([]byte("hello\n")) + return conn, err + }, pool.ChannelPoolWithInitialCap(5), pool.ChannelPoolWithMaxCap(5)) s.Require().Nil(err) - connection, ok := network.WrappedConnectionFromNetConn(conn) - s.Require().True(ok) - s.Require().NotNil(connection) + wg := sync.WaitGroup{} + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + startTime := time.Now() + defer wg.Done() + + conn, err := p.Get() + defer func() { + _ = conn.Close() + }() + s.Require().Nil(err) + + connectionState, ok := network.ConnectionStateFromNetConn(conn) + s.Require().True(ok) + s.Require().NotNil(connectionState) + connectionState.ResetTimeToFirstByte() - s.Equal(s.wrappedTLSServerAddr, connection.RemoteAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + _, err = conn.Write([]byte("hello\n")) + s.Require().Nil(err) + + s.Equal(s.wrappedTLSServerAddr, connectionState.RemoteAddr().String()) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) - s.Require().NotNil(connection.GetPeerCertificate()) - s.Equal(s.wrappedTLServerCertURI, connection.GetPeerCertificate().GetFirstURI()) + s.Require().NotNil(connectionState.GetPeerCertificate()) + s.Equal(s.wrappedTLServerCertURI, connectionState.GetPeerCertificate().GetFirstURI()) - s.Require().NotNil(connection.GetLocalCertificate()) - s.Equal(uri, connection.GetLocalCertificate().GetFirstURI()) + s.Require().NotNil(connectionState.GetLocalCertificate()) + s.Equal(uri, connectionState.GetLocalCertificate().GetFirstURI()) + + clientStats.Count++ + if connectionState.GetTLSConnectionState().HandshakeComplete { + clientStats.HandshakeCompleteCount++ + } + clientMux.Lock() + clientStats.Addresses[connectionState.LocalAddr().String()]++ + clientMux.Unlock() + }() + + if i%5 == 0 { + time.Sleep(time.Millisecond * 100) + } + } + wg.Wait() cancelContext() - // wait for tls server to shut down to give time to server connection handler to act + // wait for tls server to shut down to give time to server connectionState handler to act s.Eventually(func() bool { return !s.wrappedTLSServerRunning }, time.Second*5, time.Millisecond*10, "wrapped tls server didn't shut down") + + s.Require().Equal(serverStats.Count, clientStats.Count) + s.Require().Equal(serverStats.HandshakeCompleteCount, clientStats.HandshakeCompleteCount) + s.Require().Equal(serverStats.Addresses, clientStats.Addresses) } func (s *NetworkTestSuite) TestWrappedTCPServer() { @@ -564,18 +728,20 @@ func (s *NetworkTestSuite) TestWrappedTCPServer() { go func() { err := s.wrappedTCPServer(ctx, func(conn net.Conn) { + startTime := time.Now() + b := make([]byte, 1) n, err := conn.Read(b) s.Require().Nil(err) s.Equal(1, n) - connection, ok := network.WrappedConnectionFromNetConn(conn) + connectionState, ok := network.ConnectionStateFromNetConn(conn) s.Require().True(ok) - s.Require().NotNil(connection) + s.Require().NotNil(connectionState) - s.Equal(s.wrappedTCPServerAddr, connection.LocalAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + s.Equal(s.wrappedTCPServerAddr, connectionState.LocalAddr().String()) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) }) s.Require().Nil(err) }() @@ -584,23 +750,33 @@ func (s *NetworkTestSuite) TestWrappedTCPServer() { return s.wrappedTCPServerRunning }, time.Second*5, time.Millisecond*10, "wrapped tcp server didn't come up") - ctx = network.NewConnectionToContext(ctx) - conn, err := network.NewDialer().DialContext(ctx, "tcp", s.wrappedTCPServerAddr) - s.Require().Nil(err) + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + startTime := time.Now() - _, err = conn.Write([]byte("hello\n")) - s.Require().Nil(err) + defer wg.Done() + + conn, err := network.NewDialer().DialContext(ctx, "tcp", s.wrappedTCPServerAddr) + s.Require().Nil(err) + + _, err = conn.Write([]byte("hello\n")) + s.Require().Nil(err) - connection, ok := network.WrappedConnectionFromNetConn(conn) - s.Require().True(ok) - s.Require().NotNil(connection) + connectionState, ok := network.ConnectionStateFromNetConn(conn) + s.Require().True(ok) + s.Require().NotNil(connectionState) - s.Equal(s.wrappedTCPServerAddr, connection.RemoteAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + s.Equal(s.wrappedTCPServerAddr, connectionState.RemoteAddr().String()) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) + }() + } + wg.Wait() cancelContext() - // wait for tcp server to shut down to give time to server connection handler to act + // wait for tcp server to shut down to give time to server connectionState handler to act s.Eventually(func() bool { return !s.wrappedTCPServerRunning }, time.Second*5, time.Millisecond*10, "wrapped tcp server didn't shut down") @@ -619,25 +795,35 @@ func (s *NetworkTestSuite) TestSimpleTCPClient() { return s.simpleTCPServerRunning }, time.Second*5, time.Millisecond*10, "simple tcp server didn't come up") - ctx = network.NewConnectionToContext(ctx) - conn, err := network.NewDialer().DialContext(ctx, "tcp", s.simpleTCPServerAddr) - s.Require().Nil(err) + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + startTime := time.Now() - _, err = conn.Write([]byte("hello\n")) - s.Require().Nil(err) + defer wg.Done() - reply := make([]byte, 1024) - n, err := conn.Read(reply) - s.Require().Nil(err) + conn, err := network.NewDialer().DialContext(ctx, "tcp", s.simpleTCPServerAddr) + s.Require().Nil(err) - s.Greater(n, 0) + _, err = conn.Write([]byte("hello\n")) + s.Require().Nil(err) + + reply := make([]byte, 1024) + n, err := conn.Read(reply) + s.Require().Nil(err) - connection, ok := network.WrappedConnectionFromNetConn(conn) - s.Require().True(ok) - s.Require().NotNil(connection) + s.Greater(n, 0) - s.Equal(s.simpleTCPServerAddr, connection.RemoteAddr().String()) - s.WithinRange(connection.GetTimeToFirstByte(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) + connectionState, ok := network.ConnectionStateFromNetConn(conn) + s.Require().True(ok) + s.Require().NotNil(connectionState) + + s.Equal(s.simpleTCPServerAddr, connectionState.RemoteAddr().String()) + s.WithinRange(connectionState.GetTimeToFirstByte(), startTime, time.Now().Add(time.Second)) + }() + } + wg.Wait() } func TestNetworkTestSuite(t *testing.T) { diff --git a/pkg/network/pool/conn.go b/pkg/network/pool/conn.go index ab10b852..a72b7f58 100644 --- a/pkg/network/pool/conn.go +++ b/pkg/network/pool/conn.go @@ -42,6 +42,10 @@ func (c *poolConnection) Close() error { return c.pool.Put(c.Conn) } +func (c *poolConnection) NetConn() net.Conn { + return c.Conn +} + func (c *poolConnection) Read(b []byte) (n int, err error) { n, err = c.Conn.Read(b) if err != nil { diff --git a/pkg/network/tls.go b/pkg/network/tls.go index a20789a8..3375c069 100644 --- a/pkg/network/tls.go +++ b/pkg/network/tls.go @@ -16,9 +16,30 @@ package network import ( "crypto/tls" + "net" "strings" ) +func CreateTLSConn(conn net.Conn, tlsConfig *tls.Config, f func(net.Conn, *tls.Config) *tls.Conn) *tls.Conn { + tlsConn := f(conn, tlsConfig) + + if co, ok := conn.(interface { + SetTLSConn(*tls.Conn) + }); ok { + co.SetTLSConn(tlsConn) + } + + return tlsConn +} + +func CreateTLSServerConn(conn net.Conn, tlsConfig *tls.Config) *tls.Conn { + return CreateTLSConn(conn, tlsConfig, tls.Server) +} + +func CreateTLSClientConn(conn net.Conn, tlsConfig *tls.Config) *tls.Conn { + return CreateTLSConn(conn, tlsConfig, tls.Client) +} + func WrapTLSConfig(config *tls.Config) *tls.Config { //nolint:gocognit if config == nil { return nil @@ -82,10 +103,10 @@ func WrapTLSConfig(config *tls.Config) *tls.Config { //nolint:gocognit } if cert != nil { - if c, ok := WrappedConnectionFromContext(chi.Context()); ok { - c.SetLocalCertificate(cert) - } else if c, ok := WrappedConnectionFromNetConn(chi.Conn); ok { - c.SetLocalCertificate(cert) + if s, ok := ConnectionStateFromContext(chi.Context()); ok { + s.SetLocalCertificate(*cert) + } else if s, ok := ConnectionStateFromNetConn(chi.Conn); ok { + s.SetLocalCertificate(*cert) } } @@ -126,8 +147,8 @@ func WrapTLSConfig(config *tls.Config) *tls.Config { //nolint:gocognit } if cert != nil { - if c, ok := WrappedConnectionFromContext(cri.Context()); ok { - c.SetLocalCertificate(cert) + if s, ok := ConnectionStateHolderFromContext(cri.Context()); ok { + s.Get().SetLocalCertificate(*cert) } } diff --git a/pkg/proxywasm/api/stream.go b/pkg/proxywasm/api/stream.go index 1c9120fb..969466d7 100644 --- a/pkg/proxywasm/api/stream.go +++ b/pkg/proxywasm/api/stream.go @@ -61,10 +61,10 @@ type HTTPResponse interface { Body() io.ReadCloser SetBody(io.ReadCloser) - Connection() network.Connection - ContentLength() int64 StatusCode() int + + ConnectionState() network.ConnectionState } type HTTPRequest interface { @@ -78,5 +78,5 @@ type HTTPRequest interface { Host() string Method() string - Connection() network.Connection + ConnectionState() network.ConnectionState } diff --git a/pkg/proxywasm/grpc/dialer.go b/pkg/proxywasm/grpc/dialer.go index faa38d8d..b1f720bc 100644 --- a/pkg/proxywasm/grpc/dialer.go +++ b/pkg/proxywasm/grpc/dialer.go @@ -43,7 +43,7 @@ type GRPCDialer struct { discoveryClient discovery.DiscoveryClient logger logr.Logger - connection network.Connection + connectionState network.ConnectionState } func NewGRPCDialer(caClient ca.Client, streamHandler api.StreamHandler, discoveryClient discovery.DiscoveryClient, logger logr.Logger) *GRPCDialer { @@ -73,18 +73,20 @@ func (g *GRPCDialer) Dial(ctx context.Context, addr string) (net.Conn, error) { tlsConfig.ServerName = prop.ServerName() } - ctx = network.NewConnectionToContext(ctx) + ctx = network.NewConnectionStateHolderToContext(ctx) + opts := []network.DialerOption{ - network.DialerWithWrappedConnectionOptions(network.WrappedConnectionWithCloserWrapper(g.discoveryClient.NewConnectionCloseWrapper())), + network.DialerWithConnectionOptions(network.ConnectionWithCloserWrapper(g.discoveryClient.NewConnectionCloseWrapper())), network.DialerWithDialerWrapper(g.discoveryClient.NewDialWrapper()), } + conn, err := network.NewDialerWithTLSConfig(tlsConfig, opts...).DialTLSContext(ctx, "tcp", addr) if err != nil { return nil, err } - if connection, ok := network.WrappedConnectionFromContext(ctx); ok { - g.connection = connection + if state, ok := network.ConnectionStateFromContext(ctx); ok { + g.connectionState = state } return conn, nil @@ -93,9 +95,9 @@ func (g *GRPCDialer) Dial(ctx context.Context, addr string) (net.Conn, error) { func (g *GRPCDialer) RequestInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { g.logger.Info("intercepted request", "target", cc.Target(), "method", method, "req", req) - if g.connection != nil { - g.discoveryClient.IncrementActiveRequestsCount(g.connection.GetOriginalAddress()) - defer g.discoveryClient.DecrementActiveRequestsCount(g.connection.GetOriginalAddress()) + if g.discoveryClient != nil { + g.discoveryClient.IncrementActiveRequestsCount(cc.Target()) + defer g.discoveryClient.DecrementActiveRequestsCount(cc.Target()) } stream, err := g.streamHandler.NewStream(api.ListenerDirectionOutbound) @@ -116,7 +118,7 @@ func (g *GRPCDialer) RequestInterceptor(ctx context.Context, method string, req, var responseHeaders, responseTrailers metadata.MD opts = append(opts, grpc.Header(&responseHeaders), grpc.Trailer(&responseTrailers)) - wrappedRequest := WrapGRPCRequest(fmt.Sprintf("https://%s/%s", cc.Target(), strings.TrimLeft(method, "/")), headers, g.connection) + wrappedRequest := WrapGRPCRequest(fmt.Sprintf("https://%s/%s", cc.Target(), strings.TrimLeft(method, "/")), headers, g.connectionState) g.BeforeRequest(wrappedRequest, stream) @@ -133,7 +135,11 @@ func (g *GRPCDialer) RequestInterceptor(ctx context.Context, method string, req, return err } - wrappedResponse := WrapGRPCResponse(status.Code(err), responseHeaders, responseTrailers, g.connection) + if h, ok := network.ConnectionStateHolderFromContext(ctx); ok { + h.Set(g.connectionState) + } + + wrappedResponse := WrapGRPCResponse(status.Code(err), responseHeaders, responseTrailers, g.connectionState) stream.Set("grpc.status", status.Code(err)) g.BeforeResponse(wrappedResponse, stream) diff --git a/pkg/proxywasm/grpc/wrappers.go b/pkg/proxywasm/grpc/wrappers.go index 9707791d..e0dff595 100644 --- a/pkg/proxywasm/grpc/wrappers.go +++ b/pkg/proxywasm/grpc/wrappers.go @@ -26,24 +26,24 @@ import ( ) type GRPCRequest struct { - url *url.URL - header metadata.MD - connection network.Connection + url *url.URL + header metadata.MD + connectionState network.ConnectionState } type GRPCResponse struct { - statusCode codes.Code - header metadata.MD - trailer metadata.MD - connection network.Connection + statusCode codes.Code + header metadata.MD + trailer metadata.MD + connectionState network.ConnectionState } -func WrapGRPCResponse(statusCode codes.Code, header metadata.MD, trailer metadata.MD, connection network.Connection) api.HTTPResponse { +func WrapGRPCResponse(statusCode codes.Code, header metadata.MD, trailer metadata.MD, connectionState network.ConnectionState) api.HTTPResponse { return &GRPCResponse{ - statusCode: statusCode, - header: header, - trailer: trailer, - connection: connection, + statusCode: statusCode, + header: header, + trailer: trailer, + connectionState: connectionState, } } @@ -66,8 +66,8 @@ func (r *GRPCResponse) ContentLength() int64 { return 0 } -func (r *GRPCResponse) Connection() network.Connection { - return r.connection +func (r *GRPCResponse) ConnectionState() network.ConnectionState { + return r.connectionState } func (r *GRPCResponse) StatusCode() int { @@ -111,11 +111,11 @@ func (r *GRPCResponse) StatusCode() int { } } -func WrapGRPCRequest(rawURL string, header metadata.MD, conn network.Connection) api.HTTPRequest { +func WrapGRPCRequest(rawURL string, header metadata.MD, connectionState network.ConnectionState) api.HTTPRequest { r := &GRPCRequest{ - url: &url.URL{}, - header: header, - connection: conn, + url: &url.URL{}, + header: header, + connectionState: connectionState, } if u, err := url.Parse(rawURL); err != nil { @@ -156,8 +156,8 @@ func (r *GRPCRequest) Method() string { return "" } -func (r *GRPCRequest) Connection() network.Connection { - return r.connection +func (r *GRPCRequest) ConnectionState() network.ConnectionState { + return r.connectionState } type grpcHeaderMap struct { diff --git a/pkg/proxywasm/http/wrappers.go b/pkg/proxywasm/http/wrappers.go index b8abc353..f0f7f410 100644 --- a/pkg/proxywasm/http/wrappers.go +++ b/pkg/proxywasm/http/wrappers.go @@ -75,9 +75,9 @@ func (r *HTTPRequest) Method() string { return r.Request.Method } -func (r *HTTPRequest) Connection() network.Connection { - if connection, ok := network.WrappedConnectionFromContext(r.Request.Context()); ok { - return connection +func (r *HTTPRequest) ConnectionState() network.ConnectionState { + if connectionState, ok := network.ConnectionStateFromContext(r.Request.Context()); ok { + return connectionState } return nil @@ -111,9 +111,9 @@ func (r *HTTPResponse) StatusCode() int { return r.Response.StatusCode } -func (r *HTTPResponse) Connection() network.Connection { - if connection, ok := network.WrappedConnectionFromContext(r.Response.Request.Context()); ok { - return connection +func (r *HTTPResponse) ConnectionState() network.ConnectionState { + if connectionState, ok := network.ConnectionStateFromContext(r.Response.Request.Context()); ok { + return connectionState } return nil diff --git a/pkg/proxywasm/middleware/envoy_middleware.go b/pkg/proxywasm/middleware/envoy_middleware.go index cc14b613..d9541eca 100644 --- a/pkg/proxywasm/middleware/envoy_middleware.go +++ b/pkg/proxywasm/middleware/envoy_middleware.go @@ -95,9 +95,9 @@ func (m *envoyHttpHandlerMiddleware) setXForwardedHeaders(req api.HTTPRequest, s } func (m *envoyHttpHandlerMiddleware) beforeInboundRequest(req api.HTTPRequest, stream api.Stream) { - if connection := req.Connection(); connection != nil { - m.setRequestProperties(req, stream, connection.GetTimeToFirstByte()) - SetEnvoyConnectionInfo(connection, stream) + if connectionState := req.ConnectionState(); connectionState != nil { + m.setRequestProperties(req, stream, connectionState.GetTimeToFirstByte()) + SetEnvoyConnectionInfo(connectionState, stream) m.setXForwardedHeaders(req, stream) } } @@ -127,8 +127,8 @@ func (m *envoyHttpHandlerMiddleware) afterOutboundRequest(req api.HTTPRequest, s } func (m *envoyHttpHandlerMiddleware) beforeOutboundResponse(resp api.HTTPResponse, stream api.Stream) { - if connection := resp.Connection(); connection != nil { - SetEnvoyConnectionInfo(connection, stream) + if connectionState := resp.ConnectionState(); connectionState != nil { + SetEnvoyConnectionInfo(connectionState, stream) } } @@ -172,7 +172,7 @@ func (m *envoyHttpHandlerMiddleware) setResponseInfo(resp api.HTTPResponse, stre } } -func SetEnvoyConnectionInfo(conn network.Connection, stream api.Stream) { +func SetEnvoyConnectionInfo(connectionState network.ConnectionState, stream api.Stream) { remoteKey := "source" connectionKey := "connection" if stream.Direction() == api.ListenerDirectionOutbound { @@ -181,21 +181,21 @@ func SetEnvoyConnectionInfo(conn network.Connection, stream api.Stream) { } connection := map[string]interface{}{} - connection["mtls"] = conn.GetPeerCertificate() != nil - if peerCert := conn.GetPeerCertificate(); peerCert != nil { + connection["mtls"] = connectionState.GetPeerCertificate() != nil + if peerCert := connectionState.GetPeerCertificate(); peerCert != nil { connection["subject_peer_certificate"] = peerCert.GetSubject() connection["dns_san_peer_certificate"] = peerCert.GetFirstDNSName() connection["uri_san_peer_certificate"] = peerCert.GetFirstURI() connection["sha256_peer_certificate_digest"] = peerCert.GetSHA256Digest() } - if localCert := conn.GetLocalCertificate(); localCert != nil { + if localCert := connectionState.GetLocalCertificate(); localCert != nil { connection["subject_local_certificate"] = localCert.GetSubject() connection["dns_san_local_certificate"] = localCert.GetFirstDNSName() connection["uri_san_local_certificate"] = localCert.GetFirstURI() } - if cs := conn.GetConnectionState(); cs != nil { + if cs := connectionState.GetTLSConnectionState(); cs.HandshakeComplete { connection["requested_server_name"] = cs.ServerName connection["tls_version"] = cs.Version connection["negotiated_protocol"] = cs.NegotiatedProtocol @@ -204,7 +204,7 @@ func SetEnvoyConnectionInfo(conn network.Connection, stream api.Stream) { stream.Logger().V(3).Info("set connection key", "key", connectionKey, "value", connection) stream.Set(connectionKey, connection) - if ip, port, err := net.SplitHostPort(conn.RemoteAddr().String()); err == nil { + if ip, port, err := net.SplitHostPort(connectionState.RemoteAddr().String()); err == nil { stream.Set(remoteKey+".address", ip) if port, err := strconv.Atoi(port); err == nil { stream.Set(remoteKey+".port", port) @@ -215,7 +215,7 @@ func SetEnvoyConnectionInfo(conn network.Connection, stream api.Stream) { } } - if ip, port, err := net.SplitHostPort(conn.LocalAddr().String()); err == nil { + if ip, port, err := net.SplitHostPort(connectionState.LocalAddr().String()); err == nil { stream.Set("destination.address", ip) if port, err := strconv.Atoi(port); err == nil { stream.Set("destination.port", port) diff --git a/pkg/proxywasm/tcp/conn.go b/pkg/proxywasm/tcp/conn.go index 26596682..aa834d45 100644 --- a/pkg/proxywasm/tcp/conn.go +++ b/pkg/proxywasm/tcp/conn.go @@ -176,8 +176,8 @@ func (c *wrappedConn) Write(b []byte) (int, error) { } if !c.connectionInfoSet { - if connection, ok := network.WrappedConnectionFromNetConn(c); ok { - middleware.SetEnvoyConnectionInfo(connection, c.stream) + if connectionState, ok := network.ConnectionStateFromNetConn(c); ok { + middleware.SetEnvoyConnectionInfo(connectionState, c.stream) c.connectionInfoSet = true } } diff --git a/pkg/util/util.go b/pkg/util/util.go new file mode 100644 index 00000000..52a86850 --- /dev/null +++ b/pkg/util/util.go @@ -0,0 +1,39 @@ +// Copyright (c) 2022 Cisco and/or its affiliates. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "time" + + "github.com/go-logr/logr" + + "github.com/cisco-open/nasp/pkg/network" +) + +func PrintConnectionState(connectionState network.ConnectionState, logger logr.Logger) { + localAddr := connectionState.LocalAddr().String() + remoteAddr := connectionState.RemoteAddr().String() + var localCertID, remoteCertID string + + if cert := connectionState.GetLocalCertificate(); cert != nil { + localCertID = cert.String() + } + + if cert := connectionState.GetPeerCertificate(); cert != nil { + remoteCertID = cert.String() + } + + logger.Info("connection info", "localAddr", localAddr, "localCertID", localCertID, "remoteAddr", remoteAddr, "remoteCertID", remoteCertID, "ttfb", connectionState.GetTimeToFirstByte().Format(time.RFC3339Nano)) +}