From d669d381c8ef5102e565b3ad6bc20fe35272b961 Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Wed, 15 Oct 2025 11:30:06 -0700 Subject: [PATCH 1/6] filtering TLS connections based on the subject name from Caller --- cns/configuration/cns_config.json | 3 +- cns/configuration/configuration.go | 1 + cns/configuration/configuration_test.go | 9 +- cns/service.go | 26 +++++- cns/service/main.go | 1 + cns/service_test.go | 110 +++++++++++++++--------- server/tls/tlscertificate_retriever.go | 1 + 7 files changed, 104 insertions(+), 47 deletions(-) diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index 81ef6c9b05..1b17fdc4ab 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -35,5 +35,6 @@ "AZRSettings": { "PopulateHomeAzCacheRetryIntervalSecs": 60 }, - "MinTLSVersion": "TLS 1.2" + "MinTLSVersion": "TLS 1.2", + "AllowedClientSubjectName": "" } diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index 9ec5f8664f..c7bbe0ee48 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -59,6 +59,7 @@ type CNSConfig struct { WireserverIP string GRPCSettings GRPCSettings MinTLSVersion string + AllowedClientSubjectName string } type TelemetrySettings struct { diff --git a/cns/configuration/configuration_test.go b/cns/configuration/configuration_test.go index 186c92c376..7aa230d788 100644 --- a/cns/configuration/configuration_test.go +++ b/cns/configuration/configuration_test.go @@ -222,7 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "localhost", Port: 8080, }, - MinTLSVersion: "TLS 1.2", + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "", }, }, { @@ -253,7 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", + MinTLSVersion: "TLS 1.3", + AllowedClientSubjectName: "example.com", }, want: CNSConfig{ ChannelMode: "Other", @@ -283,7 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", + MinTLSVersion: "TLS 1.3", + AllowedClientSubjectName: "example.com", }, }, } diff --git a/cns/service.go b/cns/service.go index ab7a0be3c3..5e8403cc22 100644 --- a/cns/service.go +++ b/cns/service.go @@ -156,6 +156,25 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. return nil, errors.Errorf("invalid tls settings: %+v", tlsSettings) } +// verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. +func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error { + // no client subject name provided, skip verification + if clientSubjectName == "" { + return nil + } + + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return errors.Errorf("failed to parse certificate: %v", err) + } + + err = cert.VerifyHostname(clientSubjectName) + if err != nil { + return errors.Errorf("failed to verify client certificate hostname: %v", err) + } + return nil +} + func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) { tlsCertRetriever, err := localtls.GetTlsCertificateRetriever(tlsSettings) if err != nil { @@ -202,8 +221,10 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName) + } } - logger.Debugf("TLS configured successfully from file: %+v", tlsSettings) return tlsConfig, nil @@ -254,6 +275,9 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName) + } } logger.Debugf("TLS configured successfully from KV: %+v", tlsSettings) diff --git a/cns/service/main.go b/cns/service/main.go index 67f7872f44..e4f207f6e0 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -810,6 +810,7 @@ func main() { KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour, UseMTLS: cnsconfig.UseMTLS, MinTLSVersion: cnsconfig.MinTLSVersion, + AllowedClientSubjectName: cnsconfig.AllowedClientSubjectName, } } diff --git a/cns/service_test.go b/cns/service_test.go index d20c2ef11a..8da4333e00 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -12,6 +12,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" "math/big" "net/http" "os" @@ -133,57 +134,82 @@ func TestNewService(t *testing.T) { t.Run("NewServiceWithMutualTLS", func(t *testing.T) { testCertFilePath := createTestCertificate(t) - config.TLSSettings = serverTLS.TlsSettings{ - TLSPort: "10091", - TLSSubjectName: "localhost", - TLSCertificatePath: testCertFilePath, - UseMTLS: true, - MinTLSVersion: "TLS 1.2", + TLSSetting := serverTLS.TlsSettings{ + TLSPort: "10091", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "example.com", } - svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) - require.NoError(t, err) - require.IsType(t, &Service{}, svc) + TLSSettingWithDisallowedClientSN := serverTLS.TlsSettings{ + TLSPort: "10092", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + AllowedClientSubjectName: "random.com", + } - svc.SetOption(acn.OptCnsURL, "") - svc.SetOption(acn.OptCnsPort, "") + runMutualTLSTest := func(tlsSettings serverTLS.TlsSettings, handshakeFailureExpected bool) { + config.TLSSettings = tlsSettings + svc, err := NewService(config.Name, config.Version, config.ChannelMode, config.Store) + require.NoError(t, err) + require.IsType(t, &Service{}, svc) - err = svc.Initialize(config) - t.Cleanup(func() { - svc.Uninitialize() - }) - require.NoError(t, err) + svc.SetOption(acn.OptCnsURL, "") + svc.SetOption(acn.OptCnsPort, "") - err = svc.StartListener(config) - require.NoError(t, err) + err = svc.Initialize(config) + require.NoError(t, err) - mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) - require.NoError(t, err) + err = svc.StartListener(config) + require.NoError(t, err) - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: mTLSConfig, - }, - } + mTLSConfig, err := getTLSConfigFromFile(config.TLSSettings) + require.NoError(t, err) - // TLS listener - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, "https://localhost:10091", http.NoBody) - require.NoError(t, err) - resp, err := client.Do(req) - t.Cleanup(func() { - resp.Body.Close() - }) - require.NoError(t, err) + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: mTLSConfig, + }, + } - // HTTP listener - httpClient := &http.Client{} - req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) - require.NoError(t, err) - resp, err = httpClient.Do(req) - t.Cleanup(func() { - resp.Body.Close() - }) - require.NoError(t, err) + tlsUrl := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort) + // TLS listener + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsUrl, http.NoBody) + require.NoError(t, err) + resp, err := client.Do(req) + t.Cleanup(func() { + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + }) + if handshakeFailureExpected { + require.Error(t, err) + require.ErrorContains(t, err, "failed to verify client certificate hostname") + + } else { + require.NoError(t, err) + } + + // HTTP listener + httpClient := &http.Client{} + req, err = http.NewRequestWithContext(context.TODO(), http.MethodGet, "http://localhost:10090", http.NoBody) + require.NoError(t, err) + resp, err = httpClient.Do(req) + t.Cleanup(func() { + resp.Body.Close() + }) + require.NoError(t, err) + + // Cleanup + svc.Uninitialize() + + } + runMutualTLSTest(TLSSetting, false) + runMutualTLSTest(TLSSettingWithDisallowedClientSN, true) }) } diff --git a/server/tls/tlscertificate_retriever.go b/server/tls/tlscertificate_retriever.go index a22a7336b7..dbd65af6ca 100644 --- a/server/tls/tlscertificate_retriever.go +++ b/server/tls/tlscertificate_retriever.go @@ -15,6 +15,7 @@ type TlsSettings struct { KeyVaultCertificateRefreshInterval time.Duration UseMTLS bool MinTLSVersion string + AllowedClientSubjectName string } func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) { From fddda40063e1c1e0e95e6e430315aff2aa9df5b4 Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Wed, 15 Oct 2025 12:44:18 -0700 Subject: [PATCH 2/6] add validation to client rawCerts --- cns/service.go | 4 ++++ cns/service_test.go | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cns/service.go b/cns/service.go index 5e8403cc22..f28a8fa7b3 100644 --- a/cns/service.go +++ b/cns/service.go @@ -158,6 +158,10 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. // verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error { + + if len(rawCerts) == 0 { + return errors.New("no client certificate provided") + } // no client subject name provided, skip verification if clientSubjectName == "" { return nil diff --git a/cns/service_test.go b/cns/service_test.go index 8da4333e00..58f1cd6b01 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -176,9 +176,9 @@ func TestNewService(t *testing.T) { }, } - tlsUrl := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort) + tlsURL := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort) // TLS listener - req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsUrl, http.NoBody) + req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody) require.NoError(t, err) resp, err := client.Do(req) t.Cleanup(func() { From eb0bbdd0c46afb8bfd01922662ec5b2571ad0a43 Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Wed, 15 Oct 2025 14:28:31 -0700 Subject: [PATCH 3/6] fix lint --- cns/service.go | 1 - cns/service_test.go | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/cns/service.go b/cns/service.go index f28a8fa7b3..d879ef0c4b 100644 --- a/cns/service.go +++ b/cns/service.go @@ -158,7 +158,6 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. // verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error { - if len(rawCerts) == 0 { return errors.New("no client certificate provided") } diff --git a/cns/service_test.go b/cns/service_test.go index 58f1cd6b01..0b8b362d2d 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -12,7 +12,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "fmt" "math/big" "net/http" "os" @@ -176,7 +175,7 @@ func TestNewService(t *testing.T) { }, } - tlsURL := fmt.Sprintf("https://localhost:%s", tlsSettings.TLSPort) + tlsURL := "https://localhost:" + tlsSettings.TLSPort // TLS listener req, err := http.NewRequestWithContext(context.TODO(), http.MethodGet, tlsURL, http.NoBody) require.NoError(t, err) @@ -206,7 +205,6 @@ func TestNewService(t *testing.T) { // Cleanup svc.Uninitialize() - } runMutualTLSTest(TLSSetting, false) runMutualTLSTest(TLSSettingWithDisallowedClientSN, true) From 6133cf01c35994bf5b5a684e1c444597436ac8f5 Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Thu, 23 Oct 2025 14:13:08 -0700 Subject: [PATCH 4/6] update config name and error msgs --- cns/configuration/cns_config.json | 2 +- cns/configuration/configuration.go | 2 +- cns/configuration/configuration_test.go | 12 ++++++------ cns/service.go | 11 ++++++----- cns/service/main.go | 2 +- cns/service_test.go | 2 +- 6 files changed, 16 insertions(+), 15 deletions(-) diff --git a/cns/configuration/cns_config.json b/cns/configuration/cns_config.json index 1b17fdc4ab..967089d142 100644 --- a/cns/configuration/cns_config.json +++ b/cns/configuration/cns_config.json @@ -36,5 +36,5 @@ "PopulateHomeAzCacheRetryIntervalSecs": 60 }, "MinTLSVersion": "TLS 1.2", - "AllowedClientSubjectName": "" + "MtlsClientCertSubjectName": "" } diff --git a/cns/configuration/configuration.go b/cns/configuration/configuration.go index c7bbe0ee48..b5fc0e4114 100644 --- a/cns/configuration/configuration.go +++ b/cns/configuration/configuration.go @@ -59,7 +59,7 @@ type CNSConfig struct { WireserverIP string GRPCSettings GRPCSettings MinTLSVersion string - AllowedClientSubjectName string + MtlsClientCertSubjectName string } type TelemetrySettings struct { diff --git a/cns/configuration/configuration_test.go b/cns/configuration/configuration_test.go index 7aa230d788..ab3d93ebd1 100644 --- a/cns/configuration/configuration_test.go +++ b/cns/configuration/configuration_test.go @@ -222,8 +222,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "localhost", Port: 8080, }, - MinTLSVersion: "TLS 1.2", - AllowedClientSubjectName: "", + MinTLSVersion: "TLS 1.2", + MtlsClientCertSubjectName: "", }, }, { @@ -254,8 +254,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", - AllowedClientSubjectName: "example.com", + MinTLSVersion: "TLS 1.3", + MtlsClientCertSubjectName: "example.com", }, want: CNSConfig{ ChannelMode: "Other", @@ -285,8 +285,8 @@ func TestSetCNSConfigDefaults(t *testing.T) { IPAddress: "192.168.1.1", Port: 9090, }, - MinTLSVersion: "TLS 1.3", - AllowedClientSubjectName: "example.com", + MinTLSVersion: "TLS 1.3", + MtlsClientCertSubjectName: "example.com", }, }, } diff --git a/cns/service.go b/cns/service.go index d879ef0c4b..81fe62b157 100644 --- a/cns/service.go +++ b/cns/service.go @@ -158,22 +158,23 @@ func getTLSConfig(tlsSettings localtls.TlsSettings, errChan chan<- error) (*tls. // verifyPeerCertificate verifies the client certificate's subject name matches the expected subject name. func verifyPeerCertificate(rawCerts [][]byte, clientSubjectName string) error { - if len(rawCerts) == 0 { - return errors.New("no client certificate provided") - } // no client subject name provided, skip verification if clientSubjectName == "" { return nil } + if len(rawCerts) == 0 { + return errors.New("no client certificate provided during mTLS") + } + cert, err := x509.ParseCertificate(rawCerts[0]) if err != nil { - return errors.Errorf("failed to parse certificate: %v", err) + return errors.Errorf("Failed to parse client certificate during mTLS: %v", err) } err = cert.VerifyHostname(clientSubjectName) if err != nil { - return errors.Errorf("failed to verify client certificate hostname: %v", err) + return errors.Errorf("Failed to verify client certificate subject name during mTLS: %v", err) } return nil } diff --git a/cns/service/main.go b/cns/service/main.go index e4f207f6e0..35fedf9655 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -810,7 +810,7 @@ func main() { KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour, UseMTLS: cnsconfig.UseMTLS, MinTLSVersion: cnsconfig.MinTLSVersion, - AllowedClientSubjectName: cnsconfig.AllowedClientSubjectName, + AllowedClientSubjectName: cnsconfig.MtlsClientCertSubjectName, } } diff --git a/cns/service_test.go b/cns/service_test.go index 0b8b362d2d..668fe10073 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -187,7 +187,7 @@ func TestNewService(t *testing.T) { }) if handshakeFailureExpected { require.Error(t, err) - require.ErrorContains(t, err, "failed to verify client certificate hostname") + require.ErrorContains(t, err, "Failed to verify client certificate subject name during mTLS") } else { require.NoError(t, err) From 414b760d39fee8357328cbf7007e9295ec8364d1 Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Thu, 23 Oct 2025 14:19:04 -0700 Subject: [PATCH 5/6] update variable name in tlsconfig --- cns/service/main.go | 2 +- server/tls/tlscertificate_retriever.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cns/service/main.go b/cns/service/main.go index 35fedf9655..d7b9a526d5 100644 --- a/cns/service/main.go +++ b/cns/service/main.go @@ -810,7 +810,7 @@ func main() { KeyVaultCertificateRefreshInterval: time.Duration(cnsconfig.KeyVaultSettings.RefreshIntervalInHrs) * time.Hour, UseMTLS: cnsconfig.UseMTLS, MinTLSVersion: cnsconfig.MinTLSVersion, - AllowedClientSubjectName: cnsconfig.MtlsClientCertSubjectName, + MtlsClientCertSubjectName: cnsconfig.MtlsClientCertSubjectName, } } diff --git a/server/tls/tlscertificate_retriever.go b/server/tls/tlscertificate_retriever.go index dbd65af6ca..b6a0d11099 100644 --- a/server/tls/tlscertificate_retriever.go +++ b/server/tls/tlscertificate_retriever.go @@ -15,7 +15,7 @@ type TlsSettings struct { KeyVaultCertificateRefreshInterval time.Duration UseMTLS bool MinTLSVersion string - AllowedClientSubjectName string + MtlsClientCertSubjectName string } func GetTlsCertificateRetriever(settings TlsSettings) (TlsCertificateRetriever, error) { From abed7201e5187948b6a26a42c5e2d7818a76dfde Mon Sep 17 00:00:00 2001 From: Zetao Zhuang Date: Thu, 23 Oct 2025 14:36:04 -0700 Subject: [PATCH 6/6] renaming var in tlssetting --- cns/service.go | 4 ++-- cns/service_test.go | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cns/service.go b/cns/service.go index 81fe62b157..6329e678fa 100644 --- a/cns/service.go +++ b/cns/service.go @@ -226,7 +226,7 @@ func getTLSConfigFromFile(tlsSettings localtls.TlsSettings) (*tls.Config, error) tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName) + return verifyPeerCertificate(rawCerts, tlsSettings.MtlsClientCertSubjectName) } } logger.Debugf("TLS configured successfully from file: %+v", tlsSettings) @@ -280,7 +280,7 @@ func getTLSConfigFromKeyVault(tlsSettings localtls.TlsSettings, errChan chan<- e tlsConfig.ClientCAs = rootCAs tlsConfig.RootCAs = rootCAs tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - return verifyPeerCertificate(rawCerts, tlsSettings.AllowedClientSubjectName) + return verifyPeerCertificate(rawCerts, tlsSettings.MtlsClientCertSubjectName) } } diff --git a/cns/service_test.go b/cns/service_test.go index 668fe10073..71a48fadf3 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -134,21 +134,21 @@ func TestNewService(t *testing.T) { testCertFilePath := createTestCertificate(t) TLSSetting := serverTLS.TlsSettings{ - TLSPort: "10091", - TLSSubjectName: "localhost", - TLSCertificatePath: testCertFilePath, - UseMTLS: true, - MinTLSVersion: "TLS 1.2", - AllowedClientSubjectName: "example.com", + TLSPort: "10091", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + MtlsClientCertSubjectName: "example.com", } TLSSettingWithDisallowedClientSN := serverTLS.TlsSettings{ - TLSPort: "10092", - TLSSubjectName: "localhost", - TLSCertificatePath: testCertFilePath, - UseMTLS: true, - MinTLSVersion: "TLS 1.2", - AllowedClientSubjectName: "random.com", + TLSPort: "10092", + TLSSubjectName: "localhost", + TLSCertificatePath: testCertFilePath, + UseMTLS: true, + MinTLSVersion: "TLS 1.2", + MtlsClientCertSubjectName: "random.com", } runMutualTLSTest := func(tlsSettings serverTLS.TlsSettings, handshakeFailureExpected bool) {