diff --git a/cmd/main.go b/cmd/main.go index 537ccb5..759c9f2 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -30,11 +30,13 @@ import ( func main() { var ( + lifecycle = run.NewLifecycle() configFile = &internal.LocalConfigFile{} logging = internal.NewLogSystem(log.New(), &configFile.Config) - jwks = oidc.NewJWKSProvider() + tlsPool = internal.NewTLSConfigPool(lifecycle.Context()) + jwks = oidc.NewJWKSProvider(tlsPool) sessions = oidc.NewSessionStoreFactory(&configFile.Config) - envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks, sessions) + envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, tlsPool, jwks, sessions) authzServer = server.New(&configFile.Config, envoyAuthz.Register) healthz = server.NewHealthServer(&configFile.Config) secrets = internal.NewSecretLoader(&configFile.Config) @@ -51,6 +53,7 @@ func main() { g := run.Group{Logger: internal.Logger(internal.Default)} g.Register( + lifecycle, // manage the lifecycle of the run.Services configFile, // load the configuration logging, // set up the logging system secrets, // load the secrets and update the configuration diff --git a/config/gen/go/v1/oidc/config.pb.go b/config/gen/go/v1/oidc/config.pb.go index a5063aa..51f21a7 100644 --- a/config/gen/go/v1/oidc/config.pb.go +++ b/config/gen/go/v1/oidc/config.pb.go @@ -27,6 +27,7 @@ import ( _ "github.com/envoyproxy/protoc-gen-validate/validate" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + durationpb "google.golang.org/protobuf/types/known/durationpb" structpb "google.golang.org/protobuf/types/known/structpb" ) @@ -317,6 +318,12 @@ type OIDCConfig struct { // *OIDCConfig_TrustedCertificateAuthority // *OIDCConfig_TrustedCertificateAuthorityFile TrustedCaConfig isOIDCConfig_TrustedCaConfig `protobuf_oneof:"trusted_ca_config"` + // The duration between refreshes of the trusted certificate authority if `trusted_certificate_authority_file` is set. + // Unset or 0 (the default) disables the refresh, useful is no rotation is expected. + // Is a String that ends in `s` to indicate seconds and is preceded by the number of seconds, + // with nanoseconds expressed as fractional seconds, e.g. `120.15s`. + // Optional. + TrustedCertificateAuthorityRefreshInterval *durationpb.Duration `protobuf:"bytes,22,opt,name=trusted_certificate_authority_refresh_interval,json=trustedCertificateAuthorityRefreshInterval,proto3" json:"trusted_certificate_authority_refresh_interval,omitempty"` // The Authservice makes two kinds of direct network connections directly to the OIDC Provider. // Both are POST requests to the configured `token_uri` of the OIDC Provider. // The first is to exchange the authorization code for tokens, and the other is to use the @@ -522,6 +529,13 @@ func (x *OIDCConfig) GetTrustedCertificateAuthorityFile() string { return "" } +func (x *OIDCConfig) GetTrustedCertificateAuthorityRefreshInterval() *durationpb.Duration { + if x != nil { + return x.TrustedCertificateAuthorityRefreshInterval + } + return nil +} + func (x *OIDCConfig) GetProxyUri() string { if x != nil { return x.ProxyUri @@ -754,7 +768,9 @@ var file_v1_oidc_config_proto_rawDesc = []byte{ 0x0a, 0x14, 0x76, 0x31, 0x2f, 0x6f, 0x69, 0x64, 0x63, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x1a, 0x61, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x6f, 0x69, - 0x64, 0x63, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x64, 0x63, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x4a, 0x0a, 0x0b, 0x54, 0x6f, 0x6b, @@ -771,7 +787,7 @@ var file_v1_oidc_config_proto_rawDesc = []byte{ 0x02, 0x10, 0x01, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x2a, 0x0a, 0x0c, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xfa, 0x42, 0x04, 0x72, 0x02, 0x10, 0x01, 0x52, 0x0b, 0x72, 0x65, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x55, 0x72, 0x69, 0x22, 0x96, 0x0c, 0x0a, 0x0a, 0x4f, 0x49, 0x44, 0x43, 0x43, 0x6f, + 0x63, 0x74, 0x55, 0x72, 0x69, 0x22, 0x95, 0x0d, 0x0a, 0x0a, 0x4f, 0x49, 0x44, 0x43, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2b, 0x0a, 0x11, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x13, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x72, @@ -835,56 +851,64 @@ var file_v1_oidc_config_proto_rawDesc = []byte{ 0x72, 0x69, 0x74, 0x79, 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x14, 0x20, 0x01, 0x28, 0x09, 0x48, 0x02, 0x52, 0x1f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x64, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x74, 0x79, 0x46, 0x69, - 0x6c, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x75, 0x72, 0x69, 0x18, - 0x0f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x72, 0x69, 0x12, - 0x64, 0x0a, 0x1a, 0x72, 0x65, 0x64, 0x69, 0x73, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, - 0x5f, 0x73, 0x74, 0x6f, 0x72, 0x65, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x10, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x27, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, - 0x65, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x6f, 0x69, 0x64, 0x63, - 0x2e, 0x52, 0x65, 0x64, 0x69, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x17, 0x72, 0x65, - 0x64, 0x69, 0x73, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x53, 0x74, 0x6f, 0x72, 0x65, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x49, 0x0a, 0x15, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x76, 0x65, - 0x72, 0x69, 0x66, 0x79, 0x5f, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x63, 0x65, 0x72, 0x74, 0x18, 0x12, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x12, 0x73, 0x6b, - 0x69, 0x70, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x50, 0x65, 0x65, 0x72, 0x43, 0x65, 0x72, 0x74, - 0x1a, 0xbc, 0x01, 0x0a, 0x11, 0x4a, 0x77, 0x6b, 0x73, 0x46, 0x65, 0x74, 0x63, 0x68, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x19, 0x0a, 0x08, 0x6a, 0x77, 0x6b, 0x73, 0x5f, 0x75, - 0x72, 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6a, 0x77, 0x6b, 0x73, 0x55, 0x72, - 0x69, 0x12, 0x3d, 0x0a, 0x1b, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x69, 0x63, 0x5f, 0x66, 0x65, - 0x74, 0x63, 0x68, 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x5f, 0x73, 0x65, 0x63, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x18, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x69, 0x63, - 0x46, 0x65, 0x74, 0x63, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x53, 0x65, 0x63, - 0x12, 0x4d, 0x0a, 0x15, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x5f, - 0x70, 0x65, 0x65, 0x72, 0x5f, 0x63, 0x65, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2e, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x42, 0x02, 0x18, 0x01, 0x52, 0x12, 0x73, 0x6b, 0x69, + 0x6c, 0x65, 0x12, 0x7d, 0x0a, 0x2e, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x64, 0x5f, 0x63, 0x65, + 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x74, 0x79, 0x5f, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x69, 0x6e, 0x74, 0x65, + 0x72, 0x76, 0x61, 0x6c, 0x18, 0x16, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x2a, 0x74, 0x72, 0x75, 0x73, 0x74, 0x65, 0x64, 0x43, 0x65, + 0x72, 0x74, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, + 0x74, 0x79, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, + 0x6c, 0x12, 0x1b, 0x0a, 0x09, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x0f, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x72, 0x69, 0x12, 0x64, + 0x0a, 0x1a, 0x72, 0x65, 0x64, 0x69, 0x73, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x5f, + 0x73, 0x74, 0x6f, 0x72, 0x65, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x10, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x27, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x6f, 0x69, 0x64, 0x63, 0x2e, + 0x52, 0x65, 0x64, 0x69, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x17, 0x72, 0x65, 0x64, + 0x69, 0x73, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x53, 0x74, 0x6f, 0x72, 0x65, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x49, 0x0a, 0x15, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x76, 0x65, 0x72, + 0x69, 0x66, 0x79, 0x5f, 0x70, 0x65, 0x65, 0x72, 0x5f, 0x63, 0x65, 0x72, 0x74, 0x18, 0x12, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x52, 0x12, 0x73, 0x6b, 0x69, 0x70, 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x50, 0x65, 0x65, 0x72, 0x43, 0x65, 0x72, 0x74, 0x1a, - 0x4c, 0x0a, 0x0f, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x52, 0x65, 0x66, 0x65, 0x72, 0x65, 0x6e, - 0x63, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, - 0x12, 0x1b, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, - 0xfa, 0x42, 0x04, 0x72, 0x02, 0x10, 0x01, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x42, 0x0d, 0x0a, - 0x0b, 0x6a, 0x77, 0x6b, 0x73, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x42, 0x1b, 0x0a, 0x14, - 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f, 0x63, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x12, 0x03, 0xf8, 0x42, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x74, 0x72, 0x75, - 0x73, 0x74, 0x65, 0x64, 0x5f, 0x63, 0x61, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x42, 0xf4, - 0x01, 0x0a, 0x1e, 0x63, 0x6f, 0x6d, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x6f, 0x69, 0x64, - 0x63, 0x42, 0x0b, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, - 0x5a, 0x39, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x74, 0x65, 0x74, - 0x72, 0x61, 0x74, 0x65, 0x69, 0x6f, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x2d, 0x67, 0x6f, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2f, 0x67, 0x65, 0x6e, - 0x2f, 0x67, 0x6f, 0x2f, 0x76, 0x31, 0x2f, 0x6f, 0x69, 0x64, 0x63, 0xa2, 0x02, 0x04, 0x41, 0x43, - 0x56, 0x4f, 0xaa, 0x02, 0x1a, 0x41, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x2e, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x56, 0x31, 0x2e, 0x4f, 0x69, 0x64, 0x63, 0xca, - 0x02, 0x1a, 0x41, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5c, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x5c, 0x56, 0x31, 0x5c, 0x4f, 0x69, 0x64, 0x63, 0xe2, 0x02, 0x26, 0x41, - 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5c, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x5c, 0x56, 0x31, 0x5c, 0x4f, 0x69, 0x64, 0x63, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, - 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1d, 0x41, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, - 0x69, 0x63, 0x65, 0x3a, 0x3a, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x3a, 0x3a, 0x56, 0x31, 0x3a, - 0x3a, 0x4f, 0x69, 0x64, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0xbc, 0x01, 0x0a, 0x11, 0x4a, 0x77, 0x6b, 0x73, 0x46, 0x65, 0x74, 0x63, 0x68, 0x65, 0x72, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x19, 0x0a, 0x08, 0x6a, 0x77, 0x6b, 0x73, 0x5f, 0x75, 0x72, + 0x69, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6a, 0x77, 0x6b, 0x73, 0x55, 0x72, 0x69, + 0x12, 0x3d, 0x0a, 0x1b, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x69, 0x63, 0x5f, 0x66, 0x65, 0x74, + 0x63, 0x68, 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x5f, 0x73, 0x65, 0x63, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x18, 0x70, 0x65, 0x72, 0x69, 0x6f, 0x64, 0x69, 0x63, 0x46, + 0x65, 0x74, 0x63, 0x68, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x53, 0x65, 0x63, 0x12, + 0x4d, 0x0a, 0x15, 0x73, 0x6b, 0x69, 0x70, 0x5f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x79, 0x5f, 0x70, + 0x65, 0x65, 0x72, 0x5f, 0x63, 0x65, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x42, 0x02, 0x18, 0x01, 0x52, 0x12, 0x73, 0x6b, 0x69, 0x70, + 0x56, 0x65, 0x72, 0x69, 0x66, 0x79, 0x50, 0x65, 0x65, 0x72, 0x43, 0x65, 0x72, 0x74, 0x1a, 0x4c, + 0x0a, 0x0f, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x52, 0x65, 0x66, 0x65, 0x72, 0x65, 0x6e, 0x63, + 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, + 0x1b, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xfa, + 0x42, 0x04, 0x72, 0x02, 0x10, 0x01, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x42, 0x0d, 0x0a, 0x0b, + 0x6a, 0x77, 0x6b, 0x73, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x42, 0x1b, 0x0a, 0x14, 0x63, + 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x5f, 0x63, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x12, 0x03, 0xf8, 0x42, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x74, 0x72, 0x75, 0x73, + 0x74, 0x65, 0x64, 0x5f, 0x63, 0x61, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x42, 0xf4, 0x01, + 0x0a, 0x1e, 0x63, 0x6f, 0x6d, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x76, 0x31, 0x2e, 0x6f, 0x69, 0x64, 0x63, + 0x42, 0x0b, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, + 0x39, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x74, 0x65, 0x74, 0x72, + 0x61, 0x74, 0x65, 0x69, 0x6f, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, + 0x65, 0x2d, 0x67, 0x6f, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2f, 0x67, 0x65, 0x6e, 0x2f, + 0x67, 0x6f, 0x2f, 0x76, 0x31, 0x2f, 0x6f, 0x69, 0x64, 0x63, 0xa2, 0x02, 0x04, 0x41, 0x43, 0x56, + 0x4f, 0xaa, 0x02, 0x1a, 0x41, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x56, 0x31, 0x2e, 0x4f, 0x69, 0x64, 0x63, 0xca, 0x02, + 0x1a, 0x41, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5c, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x5c, 0x56, 0x31, 0x5c, 0x4f, 0x69, 0x64, 0x63, 0xe2, 0x02, 0x26, 0x41, 0x75, + 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5c, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x5c, 0x56, 0x31, 0x5c, 0x4f, 0x69, 0x64, 0x63, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, + 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x1d, 0x41, 0x75, 0x74, 0x68, 0x73, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x3a, 0x3a, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x3a, 0x3a, 0x56, 0x31, 0x3a, 0x3a, + 0x4f, 0x69, 0x64, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -907,7 +931,8 @@ var file_v1_oidc_config_proto_goTypes = []interface{}{ (*OIDCConfig)(nil), // 3: authservice.config.v1.oidc.OIDCConfig (*OIDCConfig_JwksFetcherConfig)(nil), // 4: authservice.config.v1.oidc.OIDCConfig.JwksFetcherConfig (*OIDCConfig_SecretReference)(nil), // 5: authservice.config.v1.oidc.OIDCConfig.SecretReference - (*structpb.Value)(nil), // 6: google.protobuf.Value + (*durationpb.Duration)(nil), // 6: google.protobuf.Duration + (*structpb.Value)(nil), // 7: google.protobuf.Value } var file_v1_oidc_config_proto_depIdxs = []int32{ 4, // 0: authservice.config.v1.oidc.OIDCConfig.jwks_fetcher:type_name -> authservice.config.v1.oidc.OIDCConfig.JwksFetcherConfig @@ -915,14 +940,15 @@ var file_v1_oidc_config_proto_depIdxs = []int32{ 0, // 2: authservice.config.v1.oidc.OIDCConfig.id_token:type_name -> authservice.config.v1.oidc.TokenConfig 0, // 3: authservice.config.v1.oidc.OIDCConfig.access_token:type_name -> authservice.config.v1.oidc.TokenConfig 2, // 4: authservice.config.v1.oidc.OIDCConfig.logout:type_name -> authservice.config.v1.oidc.LogoutConfig - 1, // 5: authservice.config.v1.oidc.OIDCConfig.redis_session_store_config:type_name -> authservice.config.v1.oidc.RedisConfig - 6, // 6: authservice.config.v1.oidc.OIDCConfig.skip_verify_peer_cert:type_name -> google.protobuf.Value - 6, // 7: authservice.config.v1.oidc.OIDCConfig.JwksFetcherConfig.skip_verify_peer_cert:type_name -> google.protobuf.Value - 8, // [8:8] is the sub-list for method output_type - 8, // [8:8] is the sub-list for method input_type - 8, // [8:8] is the sub-list for extension type_name - 8, // [8:8] is the sub-list for extension extendee - 0, // [0:8] is the sub-list for field type_name + 6, // 5: authservice.config.v1.oidc.OIDCConfig.trusted_certificate_authority_refresh_interval:type_name -> google.protobuf.Duration + 1, // 6: authservice.config.v1.oidc.OIDCConfig.redis_session_store_config:type_name -> authservice.config.v1.oidc.RedisConfig + 7, // 7: authservice.config.v1.oidc.OIDCConfig.skip_verify_peer_cert:type_name -> google.protobuf.Value + 7, // 8: authservice.config.v1.oidc.OIDCConfig.JwksFetcherConfig.skip_verify_peer_cert:type_name -> google.protobuf.Value + 9, // [9:9] is the sub-list for method output_type + 9, // [9:9] is the sub-list for method input_type + 9, // [9:9] is the sub-list for extension type_name + 9, // [9:9] is the sub-list for extension extendee + 0, // [0:9] is the sub-list for field type_name } func init() { file_v1_oidc_config_proto_init() } diff --git a/config/gen/go/v1/oidc/config.pb.validate.go b/config/gen/go/v1/oidc/config.pb.validate.go index 11b363c..46e6523 100644 --- a/config/gen/go/v1/oidc/config.pb.validate.go +++ b/config/gen/go/v1/oidc/config.pb.validate.go @@ -532,6 +532,35 @@ func (m *OIDCConfig) validate(all bool) error { // no validation rules for IdleSessionTimeout + if all { + switch v := interface{}(m.GetTrustedCertificateAuthorityRefreshInterval()).(type) { + case interface{ ValidateAll() error }: + if err := v.ValidateAll(); err != nil { + errors = append(errors, OIDCConfigValidationError{ + field: "TrustedCertificateAuthorityRefreshInterval", + reason: "embedded message failed validation", + cause: err, + }) + } + case interface{ Validate() error }: + if err := v.Validate(); err != nil { + errors = append(errors, OIDCConfigValidationError{ + field: "TrustedCertificateAuthorityRefreshInterval", + reason: "embedded message failed validation", + cause: err, + }) + } + } + } else if v, ok := interface{}(m.GetTrustedCertificateAuthorityRefreshInterval()).(interface{ Validate() error }); ok { + if err := v.Validate(); err != nil { + return OIDCConfigValidationError{ + field: "TrustedCertificateAuthorityRefreshInterval", + reason: "embedded message failed validation", + cause: err, + } + } + } + // no validation rules for ProxyUri if all { diff --git a/config/v1/oidc/config.proto b/config/v1/oidc/config.proto index da00d7f..1dc9af1 100644 --- a/config/v1/oidc/config.proto +++ b/config/v1/oidc/config.proto @@ -16,6 +16,7 @@ syntax = "proto3"; package authservice.config.v1.oidc; +import "google/protobuf/duration.proto"; import "google/protobuf/struct.proto"; import "validate/validate.proto"; @@ -222,6 +223,13 @@ message OIDCConfig { string trusted_certificate_authority_file = 20; } + // The duration between refreshes of the trusted certificate authority if `trusted_certificate_authority_file` is set. + // Unset or 0 (the default) disables the refresh, useful is no rotation is expected. + // Is a String that ends in `s` to indicate seconds and is preceded by the number of seconds, + // with nanoseconds expressed as fractional seconds, e.g. `120.15s`. + // Optional. + google.protobuf.Duration trusted_certificate_authority_refresh_interval = 22; + // The Authservice makes two kinds of direct network connections directly to the OIDC Provider. // Both are POST requests to the configured `token_uri` of the OIDC Provider. // The first is to exchange the authorization code for tokens, and the other is to use the diff --git a/e2e/keycloak/authz-config.json b/e2e/keycloak/authz-config.json index 8e5a48c..a999f9c 100644 --- a/e2e/keycloak/authz-config.json +++ b/e2e/keycloak/authz-config.json @@ -28,7 +28,8 @@ "redis_session_store_config": { "server_uri": "redis://redis:6379" }, - "trusted_certificate_authority_file": "/etc/authservice/certs/ca.crt" + "trusted_certificate_authority_file": "/etc/authservice/certs/ca.crt", + "trusted_certificate_authority_refresh_interval": "60.25s" } } ] diff --git a/internal/authz/oidc.go b/internal/authz/oidc.go index ceaea1d..eac94e0 100644 --- a/internal/authz/oidc.go +++ b/internal/authz/oidc.go @@ -53,6 +53,7 @@ var ( type oidcHandler struct { log telemetry.Logger config *oidcv1.OIDCConfig + tlsPool internal.TLSConfigPool jwks oidc.JWKSProvider sessions oidc.SessionStoreFactory sessionGen oidc.SessionGenerator @@ -61,11 +62,10 @@ type oidcHandler struct { } // NewOIDCHandler creates a new OIDC implementation of the Handler interface. -func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider, - sessions oidc.SessionStoreFactory, clock oidc.Clock, - sessionGen oidc.SessionGenerator) (Handler, error) { +func NewOIDCHandler(cfg *oidcv1.OIDCConfig, tlsPool internal.TLSConfigPool, jwks oidc.JWKSProvider, + sessions oidc.SessionStoreFactory, clock oidc.Clock, sessionGen oidc.SessionGenerator) (Handler, error) { - client, err := getHTTPClient(cfg) + client, err := getHTTPClient(cfg, tlsPool) if err != nil { return nil, err } @@ -77,6 +77,7 @@ func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider, return &oidcHandler{ log: internal.Logger(internal.Authz).With("type", "oidc"), config: cfg, + tlsPool: tlsPool, jwks: jwks, sessions: sessions, clock: clock, @@ -85,11 +86,11 @@ func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider, }, nil } -func getHTTPClient(cfg *oidcv1.OIDCConfig) (*http.Client, error) { +func getHTTPClient(cfg *oidcv1.OIDCConfig, tlsPool internal.TLSConfigPool) (*http.Client, error) { transport := http.DefaultTransport.(*http.Transport).Clone() var err error - if transport.TLSClientConfig, err = internal.LoadTLSConfig(cfg); err != nil { + if transport.TLSClientConfig, err = tlsPool.LoadTLSConfig(cfg); err != nil { return nil, err } diff --git a/internal/authz/oidc_test.go b/internal/authz/oidc_test.go index 42db58b..7c8f844 100644 --- a/internal/authz/oidc_test.go +++ b/internal/authz/oidc_test.go @@ -38,6 +38,7 @@ import ( "google.golang.org/grpc/test/bufconn" oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc" + "github.com/tetrateio/authservice-go/internal" inthttp "github.com/tetrateio/authservice-go/internal/http" "github.com/tetrateio/authservice-go/internal/oidc" ) @@ -201,7 +202,8 @@ func TestOIDCProcess(t *testing.T) { clock := oidc.Clock{} sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} store := sessions.Get(basicOIDCConfig) - h, err := NewOIDCHandler(basicOIDCConfig, oidc.NewJWKSProvider(), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + tlsPool := internal.NewTLSConfigPool(context.Background()) + h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, oidc.NewJWKSProvider(tlsPool), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) require.NoError(t, err) ctx := context.Background() @@ -877,6 +879,7 @@ func TestOIDCProcess(t *testing.T) { func TestOIDCProcessWithFailingSessionStore(t *testing.T) { store := &storeMock{delegate: oidc.NewMemoryStore(&oidc.Clock{}, time.Hour, time.Hour)} sessions := &mockSessionStoreFactory{store: store} + tlsPool := internal.NewTLSConfigPool(context.Background()) jwkPriv, jwkPub := newKeyPair(t) bytes, err := json.Marshal(newKeySet(jwkPub)) @@ -885,7 +888,8 @@ func TestOIDCProcessWithFailingSessionStore(t *testing.T) { Jwks: string(bytes), } - h, err := NewOIDCHandler(basicOIDCConfig, oidc.NewJWKSProvider(), sessions, oidc.Clock{}, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, oidc.NewJWKSProvider(tlsPool), + sessions, oidc.Clock{}, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) require.NoError(t, err) ctx := context.Background() @@ -1029,7 +1033,8 @@ func TestOIDCProcessWithFailingJWKSProvider(t *testing.T) { clock := oidc.Clock{} sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} store := sessions.Get(basicOIDCConfig) - h, err := NewOIDCHandler(basicOIDCConfig, funcJWKSProvider, sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + tlsPool := internal.NewTLSConfigPool(context.Background()) + h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, funcJWKSProvider, sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) require.NoError(t, err) idpServer := newServer() @@ -1235,10 +1240,11 @@ func TestEncodeTokensToHeaders(t *testing.T) { } sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&oidc.Clock{}, time.Hour, time.Hour)} + tlsPool := internal.NewTLSConfigPool(context.Background()) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h, err := NewOIDCHandler(tt.config, nil, sessions, oidc.Clock{}, nil) + h, err := NewOIDCHandler(tt.config, tlsPool, nil, sessions, oidc.Clock{}, nil) require.NoError(t, err) tokResp := &oidc.TokenResponse{ @@ -1307,10 +1313,11 @@ func TestAreTokensExpired(t *testing.T) { } sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&oidc.Clock{}, time.Hour, time.Hour)} + tlsPool := internal.NewTLSConfigPool(context.Background()) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - h, err := NewOIDCHandler(tt.config, nil, sessions, oidc.Clock{}, nil) + h, err := NewOIDCHandler(tt.config, tlsPool, nil, sessions, oidc.Clock{}, nil) require.NoError(t, err) tokResp := &oidc.TokenResponse{ @@ -1341,12 +1348,16 @@ func TestLoadWellKnownConfig(t *testing.T) { func TestLoadWellKnownConfigError(t *testing.T) { clock := oidc.Clock{} + tlsPool := internal.NewTLSConfigPool(context.Background()) sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} - _, err := NewOIDCHandler(dynamicOIDCConfig, oidc.NewJWKSProvider(), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + _, err := NewOIDCHandler(dynamicOIDCConfig, tlsPool, oidc.NewJWKSProvider(tlsPool), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) require.Error(t, err) // Fail to retrieve the dynamic config since the test server is not running } func TestNewOIDCHandler(t *testing.T) { + clock := oidc.Clock{} + tlsPool := internal.NewTLSConfigPool(context.Background()) + sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} tests := []struct { name string @@ -1359,9 +1370,8 @@ func TestNewOIDCHandler(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - clock := oidc.Clock{} - sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} - _, err := NewOIDCHandler(tt.config, oidc.NewJWKSProvider(), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + + _, err := NewOIDCHandler(tt.config, tlsPool, oidc.NewJWKSProvider(tlsPool), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) if tt.wantErr { require.Error(t, err) } else { diff --git a/internal/file.go b/internal/file.go new file mode 100644 index 0000000..f183976 --- /dev/null +++ b/internal/file.go @@ -0,0 +1,164 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "os" + "sync" + "time" + + "github.com/tetratelabs/telemetry" +) + +type ( + // FileWatcher watches multiple files for changes and calls a callback when the file changes. + // It is safe to call WatchFile concurrently. + // To stop watching the files, cancel the context passed to NewFileWatcher. + FileWatcher struct { + ctx context.Context + log telemetry.Logger + + mu sync.Mutex + watchers map[string]*watcher + } + + // watcher watches a file for changes and calls a callback when the file changes. + watcher struct { + ctx context.Context + cancel context.CancelFunc + + log telemetry.Logger + interval time.Duration + callback func([]byte) + reader Reader + data []byte + } + + // Reader is an interface to read the content of a file. + Reader interface { + // ID returns a unique identifier for the file. + ID() string + // Read reads the content of the file. + Read() ([]byte, error) + } +) + +// NewFileWatcher creates a new FileWatcher. +func NewFileWatcher(ctx context.Context) *FileWatcher { + return &FileWatcher{ + ctx: ctx, + log: Logger(Config), + watchers: map[string]*watcher{}, + } +} + +// WatchFile watches a file for changes and calls the callback when the file changes. +// It returns the content of the file and an error if the file cannot be read. +// The callback function is called with the new content of the file. +// If the file is already being watched, the previous watcher is stopped and the new one is started. +func (f *FileWatcher) WatchFile(reader Reader, interval time.Duration, callback func([]byte)) ([]byte, error) { + id := reader.ID() + + f.mu.Lock() + if old, ok := f.watchers[id]; ok { + // stop the current watcher + old.cancel() + } + f.mu.Unlock() + + log := f.log.With("file", id) + + // Load the file data + data, err := reader.Read() + if err != nil { + log.Error("error reading file", err) + return nil, err + } + + // Non-positive interval means no watching for file. + if interval <= 0 { + return data, nil + } + + // Create a new watcher + f.mu.Lock() + ctx, cancel := context.WithCancel(f.ctx) + w := &watcher{ + ctx: ctx, + cancel: cancel, + log: log, + interval: interval, + callback: callback, + reader: reader, + data: data, + } + f.watchers[id] = w + f.mu.Unlock() + + // Start watching the file + w.start() + return data, nil +} + +func (w *watcher) start() { + go func() { + w.log.Info("start file watcher") + + ticker := time.NewTicker(w.interval) + defer ticker.Stop() + for { + select { + case <-w.ctx.Done(): + w.log.Info("stop file watcher") + return + + case <-ticker.C: + data, err := w.reader.Read() + if err != nil { + w.log.Error("error reading file", err) + continue + } + if string(data) != string(w.data) { + w.log.Info("file changed, invoking callback") + w.data = data + go w.callback(data) + } + } + } + }() +} + +var _ Reader = (*FileReader)(nil) + +// FileReader is a Reader that reads the content of a file given its path. +type FileReader struct { + filePath string +} + +// NewFileReader creates a new FileReader. +func NewFileReader(filePath string) *FileReader { + return &FileReader{filePath: filePath} +} + +// ID returns the file path. +func (f *FileReader) ID() string { + return f.filePath +} + +// Read reads the content of the file. +func (f *FileReader) Read() ([]byte, error) { + return os.ReadFile(f.filePath) +} diff --git a/internal/file_test.go b/internal/file_test.go new file mode 100644 index 0000000..9f6a961 --- /dev/null +++ b/internal/file_test.go @@ -0,0 +1,318 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFileWatcher_WatchFile(t *testing.T) { + const watcherInterval = 500 * time.Millisecond + + tests := []struct { + name string + fileReader *mockReader + genUpdates func(reader *mockReader) + interval time.Duration + + wantCallbacks int + want string + wantUpdates []string + wantErr bool + }{ + { + name: "no updates happening", + fileReader: newMockReader("test", "original", nil), + interval: watcherInterval, + wantCallbacks: 0, + want: "original", + }, + { + name: "all updates notified", + fileReader: newMockReader("test", "original", nil), + genUpdates: func(reader *mockReader) { + reader.setData([]byte("update 1")) + reader.waitForRead() + reader.setData([]byte("update 2")) + reader.waitForRead() + }, + interval: watcherInterval, + wantCallbacks: 2, + want: "original", + wantUpdates: []string{"update 1", "update 2"}, + }, + { + name: "no content changes don't notify", + fileReader: newMockReader("test", "original", nil), + genUpdates: func(reader *mockReader) { + reader.setData([]byte("update 1")) + reader.waitForRead() + reader.setData([]byte("update 2")) + reader.waitForRead() + reader.setData([]byte("update 2")) + reader.waitForRead() + reader.setData([]byte("update 2")) + reader.waitForRead() + }, + interval: watcherInterval, + wantCallbacks: 2, + want: "original", + wantUpdates: []string{"update 1", "update 2"}, + }, + { + name: "missed update due to slow interval", + fileReader: newMockReader("test", "original", nil), + genUpdates: func(reader *mockReader) { + reader.setData([]byte("update 1")) + // no waiting for the read to happen and performing next update + // reader.waitForRead() + reader.setData([]byte("update 2")) + reader.waitForRead() + }, + interval: watcherInterval, + wantCallbacks: 1, + want: "original", + wantUpdates: []string{"update 2"}, + }, + { + name: "error reading file at start", + fileReader: newMockReader("test", "original", errors.New("error reading file")), + interval: watcherInterval, + wantErr: true, + }, + { + name: "error reading file after start", + fileReader: newMockReader("test", "original", nil), + genUpdates: func(reader *mockReader) { + reader.setData([]byte("update 1")) + reader.waitForRead() + reader.err = errors.New("error reading file") + reader.waitForRead() + // stop error + reader.err = nil + // even if an error happens, next updates should be notified + reader.setData([]byte("update 2")) + reader.waitForRead() + }, + interval: watcherInterval, + wantCallbacks: 2, + want: "original", + wantUpdates: []string{"update 1", "update 2"}, + }, + { + name: "no interval", + fileReader: newMockReader("test", "original", nil), + genUpdates: func(reader *mockReader) { + reader.setData([]byte("update 1")) + reader.setData([]byte("update 2")) + }, + interval: 0, + want: "original", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + fw := NewFileWatcher(ctx) + + var gotUpdates []string + + wg := sync.WaitGroup{} + wg.Add(tt.wantCallbacks) + + got, err := fw.WatchFile(tt.fileReader, tt.interval, func(data []byte) { + defer wg.Done() + gotUpdates = append(gotUpdates, string(data)) + }) + if tt.wantErr { + require.Error(t, err) + return + } + + if tt.interval <= 0 { + // if no interval configured, the watcher shouldn't be registered + _, ok := fw.watchers[tt.fileReader.ID()] + require.False(t, ok) + } + + tt.fileReader.waitForRead() // Wait for the first read to happen, the one synchronous + require.Equal(t, tt.want, string(got)) + require.NoError(t, err) + + if tt.genUpdates != nil { + tt.genUpdates(tt.fileReader) + } + + // ensure no more updates are notified before verifying the results + cancel() + + wg.Wait() // Wait for all callbacks to be notified + require.Equal(t, tt.wantUpdates, gotUpdates) + }) + } + + // This test is to ensure that the file watcher can handle multiple files being watched and + // each callback is notified only for the file it's watching. + t.Run("multiple files watched", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + fw := NewFileWatcher(ctx) + + gotUpdates1 := make([]string, 0) + gotUpdates2 := make([]string, 0) + + wg1 := sync.WaitGroup{} + wg1.Add(2) // 2 callbacks to be notified + wg2 := sync.WaitGroup{} + wg2.Add(1) // 1 callback to be notified + + file1 := newMockReader("test1", "original1", nil) + got1, err := fw.WatchFile(file1, watcherInterval, func(data []byte) { + defer wg1.Done() + gotUpdates1 = append(gotUpdates1, string(data)) + }) + require.NoError(t, err) + file1.waitForRead() // Wait for the first read to happen + + file2 := newMockReader("test2", "original2", nil) + got2, err := fw.WatchFile(file2, watcherInterval, func(data []byte) { + defer wg2.Done() + gotUpdates2 = append(gotUpdates2, string(data)) + }) + require.NoError(t, err) + file2.waitForRead() // Wait for the first read to happen + + file1.setData([]byte("update 1-1")) + file1.waitForRead() + file2.setData([]byte("update 2-1")) + file2.waitForRead() + file1.setData([]byte("update 1-2")) + file1.waitForRead() + + // ensure no more updates are notified before verifying the results + cancel() + + wg1.Wait() // Wait for all callbacks to be notified + wg2.Wait() // Wait for all callbacks to be notified + + require.Equal(t, "original1", string(got1)) + require.Equal(t, "original2", string(got2)) + require.Equal(t, []string{"update 1-1", "update 1-2"}, gotUpdates1) + require.Equal(t, []string{"update 2-1"}, gotUpdates2) + }) + + // This test is to ensure that the callback is overridden when a new file is watched + // The first WatchFile sets a callback, that will only receive the first update happening at the WatchFile call. + // Then the second WatchFile sets a new callback, that will receive all updates happening after the WatchFile call. + t.Run("override file watcher overrides callback too", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + fw := NewFileWatcher(ctx) + + gotUpdates := make([]string, 0) + gotOverride := make([]string, 0) + + file1 := newMockReader("test1", "original", nil) + got, err := fw.WatchFile(file1, watcherInterval, func(data []byte) { + gotUpdates = append(gotUpdates, string(data)) + }) + require.NoError(t, err) + file1.waitForRead() // Wait for the first read to happen + + wg := sync.WaitGroup{} + wg.Add(2) // 2 callbacks to be notified + + file1.setData([]byte("override")) + gotOvrr, err := fw.WatchFile(file1, watcherInterval/2, func(data []byte) { + defer wg.Done() + gotOverride = append(gotOverride, string(data)) + }) + require.NoError(t, err) + file1.waitForRead() // Wait for the first read to happen again + + file1.setData([]byte("update 1")) + file1.waitForRead() + file1.setData([]byte("update 2")) + file1.waitForRead() + + // ensure no more updates are notified before verifying the results + cancel() + + wg.Wait() // Wait for all callbacks to be notified + + require.Equal(t, "original", string(got)) + require.Equal(t, "override", string(gotOvrr)) + require.Equal(t, []string{}, gotUpdates) + require.Equal(t, []string{"update 1", "update 2"}, gotOverride) + }) +} + +var _ Reader = (*mockReader)(nil) + +type mockReader struct { + id string + err error + + m sync.Mutex + fileData []byte + + // reads is used to signal that a read happened, it should be buffered to avoid deadlocks. + // It is used to know if a read happened but there's no need to block reads happening if no one is waiting for them. + reads chan struct{} +} + +func newMockReader(id, data string, err error) *mockReader { + return &mockReader{ + id: id, + fileData: []byte(data), + err: err, + reads: make(chan struct{}, 50), + } +} + +func (m *mockReader) ID() string { + return m.id +} + +func (m *mockReader) Read() ([]byte, error) { + // Notify that a read happened + defer func() { m.reads <- struct{}{} }() + + if m.err != nil { + return nil, m.err + } + + m.m.Lock() + defer m.m.Unlock() + return m.fileData, nil +} + +func (m *mockReader) setData(data []byte) { + m.m.Lock() + defer m.m.Unlock() + m.fileData = data +} + +func (m *mockReader) waitForRead() { + <-m.reads +} diff --git a/internal/oidc/jwks.go b/internal/oidc/jwks.go index ff5ffac..353738d 100644 --- a/internal/oidc/jwks.go +++ b/internal/oidc/jwks.go @@ -52,12 +52,14 @@ type DefaultJWKSProvider struct { log telemetry.Logger cache *jwk.AutoRefresh shutdown context.CancelFunc + tlsPool internal.TLSConfigPool } // NewJWKSProvider returns a new JWKSProvider. -func NewJWKSProvider() *DefaultJWKSProvider { +func NewJWKSProvider(tlsPool internal.TLSConfigPool) *DefaultJWKSProvider { return &DefaultJWKSProvider{ - log: internal.Logger(internal.JWKS), + log: internal.Logger(internal.JWKS), + tlsPool: tlsPool, } } @@ -109,7 +111,7 @@ func (j *DefaultJWKSProvider) fetchDynamic(ctx context.Context, config *oidcv1.O transport := http.DefaultTransport.(*http.Transport).Clone() var err error - if transport.TLSClientConfig, err = internal.LoadTLSConfig(config); err != nil { + if transport.TLSClientConfig, err = j.tlsPool.LoadTLSConfig(config); err != nil { return nil, fmt.Errorf("error loading TLS config: %w", err) } diff --git a/internal/oidc/jwks_test.go b/internal/oidc/jwks_test.go index b3dcb9b..f25c5b0 100644 --- a/internal/oidc/jwks_test.go +++ b/internal/oidc/jwks_test.go @@ -33,6 +33,7 @@ import ( "google.golang.org/protobuf/types/known/structpb" oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc" + "github.com/tetrateio/authservice-go/internal" ) // nolint: lll @@ -76,8 +77,10 @@ var ( ) func TestStaticJWKSProvider(t *testing.T) { + tlsPool := internal.NewTLSConfigPool(context.Background()) + t.Run("invalid", func(t *testing.T) { - cache := NewJWKSProvider() + cache := NewJWKSProvider(tlsPool) go func() { require.NoError(t, cache.Serve()) }() t.Cleanup(cache.GracefulStop) @@ -91,7 +94,7 @@ func TestStaticJWKSProvider(t *testing.T) { }) t.Run("single-key", func(t *testing.T) { - cache := NewJWKSProvider() + cache := NewJWKSProvider(tlsPool) go func() { require.NoError(t, cache.Serve()) }() t.Cleanup(cache.GracefulStop) @@ -112,7 +115,7 @@ func TestStaticJWKSProvider(t *testing.T) { }) t.Run("multiple-keys", func(t *testing.T) { - cache := NewJWKSProvider() + cache := NewJWKSProvider(tlsPool) go func() { require.NoError(t, cache.Serve()) }() t.Cleanup(cache.GracefulStop) @@ -144,8 +147,9 @@ func TestDynamicJWKSProvider(t *testing.T) { pub = newKey(t) jwks = newKeySet(pub) + tlsPool = internal.NewTLSConfigPool(context.Background()) newCache = func(t *testing.T) JWKSProvider { - cache := NewJWKSProvider() + cache := NewJWKSProvider(tlsPool) g := run.Group{Logger: telemetry.NoopLogger()} g.Register(cache) go func() { _ = g.Run() }() diff --git a/internal/server/authz.go b/internal/server/authz.go index 7603c21..5c139cd 100644 --- a/internal/server/authz.go +++ b/internal/server/authz.go @@ -59,15 +59,17 @@ var ( type ExtAuthZFilter struct { log telemetry.Logger cfg *configv1.Config + tlsPool internal.TLSConfigPool jwks oidc.JWKSProvider sessions oidc.SessionStoreFactory } // NewExtAuthZFilter creates a new ExtAuthZFilter. -func NewExtAuthZFilter(cfg *configv1.Config, jwks oidc.JWKSProvider, sessions oidc.SessionStoreFactory) *ExtAuthZFilter { +func NewExtAuthZFilter(cfg *configv1.Config, tlsPool internal.TLSConfigPool, jwks oidc.JWKSProvider, sessions oidc.SessionStoreFactory) *ExtAuthZFilter { return &ExtAuthZFilter{ log: internal.Logger(internal.Authz), cfg: cfg, + tlsPool: tlsPool, jwks: jwks, sessions: sessions, } @@ -120,7 +122,7 @@ func (e *ExtAuthZFilter) Check(ctx context.Context, req *envoy.CheckRequest) (re case *configv1.Filter_Mock: h = authz.NewMockHandler(ft.Mock) case *configv1.Filter_Oidc: - if h, err = authz.NewOIDCHandler(ft.Oidc, e.jwks, e.sessions, oidc.Clock{}, oidc.NewRandomGenerator()); err != nil { + if h, err = authz.NewOIDCHandler(ft.Oidc, e.tlsPool, e.jwks, e.sessions, oidc.Clock{}, oidc.NewRandomGenerator()); err != nil { return nil, err } } diff --git a/internal/server/authz_test.go b/internal/server/authz_test.go index 17120b0..9f87879 100644 --- a/internal/server/authz_test.go +++ b/internal/server/authz_test.go @@ -39,7 +39,7 @@ func TestUnmatchedRequests(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}, nil, nil) + e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}, nil, nil, nil) got, err := e.Check(context.Background(), &envoy.CheckRequest{}) require.NoError(t, err) require.Equal(t, int32(tt.want), got.Status.Code) @@ -61,7 +61,7 @@ func TestFiltersMatch(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cfg := &configv1.Config{Chains: []*configv1.FilterChain{{Filters: tt.filters}}} - e := NewExtAuthZFilter(cfg, nil, nil) + e := NewExtAuthZFilter(cfg, nil, nil, nil) got, err := e.Check(context.Background(), &envoy.CheckRequest{}) require.NoError(t, err) @@ -91,7 +91,7 @@ func TestUseFirstMatchingChain(t *testing.T) { }, } - e := NewExtAuthZFilter(cfg, nil, nil) + e := NewExtAuthZFilter(cfg, nil, nil, nil) got, err := e.Check(context.Background(), header("match")) require.NoError(t, err) @@ -121,7 +121,7 @@ func TestMatch(t *testing.T) { } func TestGrpcNoChainsMatched(t *testing.T) { - e := NewExtAuthZFilter(&configv1.Config{}, nil, nil) + e := NewExtAuthZFilter(&configv1.Config{}, nil, nil, nil) s := NewTestServer(e.Register) go func() { require.NoError(t, s.Start()) }() t.Cleanup(s.Stop) @@ -274,7 +274,7 @@ func TestCheckTriggerRules(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - e := NewExtAuthZFilter(tt.config, nil, nil) + e := NewExtAuthZFilter(tt.config, nil, nil, nil) req := &envoy.CheckRequest{ Attributes: &envoy.AttributeContext{ Request: &envoy.AttributeContext_Request{ diff --git a/internal/tls.go b/internal/tls.go index 3dc7f36..1d71d28 100644 --- a/internal/tls.go +++ b/internal/tls.go @@ -15,27 +15,75 @@ package internal import ( + "bytes" + "context" "crypto/tls" "crypto/x509" + "encoding/hex" + "encoding/json" "errors" "fmt" - "os" + "hash/fnv" + "sync" + "github.com/tetratelabs/telemetry" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/structpb" ) -// TLSConfig is an interface for the TLS configuration of the AuthService. -type TLSConfig interface { - // GetTrustedCertificateAuthority returns the trusted certificate authority PEM. - GetTrustedCertificateAuthority() string - // GetTrustedCertificateAuthorityFile returns the path to the trusted certificate authority file. - GetTrustedCertificateAuthorityFile() string - // GetSkipVerifyPeerCert returns whether to skip verification of the peer certificate. - GetSkipVerifyPeerCert() *structpb.Value +type ( + // TLSConfig is an interface for the TLS configuration of the AuthService. + TLSConfig interface { + // GetTrustedCertificateAuthority returns the trusted certificate authority PEM. + GetTrustedCertificateAuthority() string + // GetTrustedCertificateAuthorityFile returns the path to the trusted certificate authority file. + GetTrustedCertificateAuthorityFile() string + // GetSkipVerifyPeerCert returns whether to skip verification of the peer certificate. + GetSkipVerifyPeerCert() *structpb.Value + GetTrustedCertificateAuthorityRefreshInterval() *durationpb.Duration + } + + // TLSConfigPool is an interface for a pool of TLS configurations. + TLSConfigPool interface { + // LoadTLSConfig loads a TLS configuration from the given TLSConfig. + LoadTLSConfig(config TLSConfig) (*tls.Config, error) + } + + // tlsConfigPool is a pool of TLS configurations. + // That reloads the trusted certificate authority when there are changes. + tlsConfigPool struct { + ctx context.Context + cancel context.CancelFunc + log telemetry.Logger + + mu sync.RWMutex + configs map[string]*tls.Config + caWatcher *FileWatcher + } +) + +// NewTLSConfigPool creates a new TLSConfigPool. +func NewTLSConfigPool(ctx context.Context) TLSConfigPool { + ctx, cancel := context.WithCancel(ctx) + return &tlsConfigPool{ + ctx: ctx, + cancel: cancel, + log: Logger(Config), + configs: make(map[string]*tls.Config), + caWatcher: NewFileWatcher(ctx), + } } // LoadTLSConfig loads a TLS configuration from the given TLSConfig. -func LoadTLSConfig(config TLSConfig) (*tls.Config, error) { +func (p *tlsConfigPool) LoadTLSConfig(config TLSConfig) (*tls.Config, error) { + encConfig := encodeConfig(config) + id := encConfig.hash() + if tlsConfig, ok := p.configs[id]; ok { + return tlsConfig, nil + } + + log := p.log.With("id", id) + log.Info("loading new TLS config", "config", encConfig.JSON()) tlsConfig := &tls.Config{} // Load the trusted CA PEM from the config @@ -43,14 +91,21 @@ func LoadTLSConfig(config TLSConfig) (*tls.Config, error) { switch { case config.GetTrustedCertificateAuthority() != "": ca = []byte(config.GetTrustedCertificateAuthority()) + case config.GetTrustedCertificateAuthorityFile() != "": var err error - ca, err = os.ReadFile(config.GetTrustedCertificateAuthorityFile()) + ca, err = p.caWatcher.WatchFile( + NewFileReader(config.GetTrustedCertificateAuthorityFile()), + config.GetTrustedCertificateAuthorityRefreshInterval().AsDuration(), + func(data []byte) { p.updateCA(id, data) }, + ) if err != nil { - return nil, fmt.Errorf("error reading trusted CA file: %w", err) + return nil, fmt.Errorf("error watching trusted CA file: %w", err) } + case config.GetSkipVerifyPeerCert() != nil: tlsConfig.InsecureSkipVerify = BoolStrValue(config.GetSkipVerifyPeerCert()) + default: // No CA or skip verification, return nil TLS config return nil, nil @@ -58,6 +113,10 @@ func LoadTLSConfig(config TLSConfig) (*tls.Config, error) { // Add the loaded CA to the TLS config if len(ca) != 0 { + if BoolStrValue(config.GetSkipVerifyPeerCert()) { + log.Info("`skip_verify_peer_cert` is set to true but there's also a trusted certificate authority, ignoring `skip_verify_peer_cert`") + } + certPool, err := x509.SystemCertPool() if err != nil { return nil, fmt.Errorf("error creating system cert pool: %w", err) @@ -70,5 +129,81 @@ func LoadTLSConfig(config TLSConfig) (*tls.Config, error) { tlsConfig.RootCAs = certPool } + // Save the TLS config to the pool + p.mu.Lock() + p.configs[id] = tlsConfig + p.mu.Unlock() return tlsConfig, nil } + +func (p *tlsConfigPool) updateCA(id string, caPem []byte) { + log := p.log.With("id", id) + + // Load the TLS config + p.mu.Lock() + tlsConfig, ok := p.configs[id] + if !ok { + log.Error("couldn't update TLS config", errors.New("config not found")) + p.mu.Unlock() + return + } + p.mu.Unlock() + + // Add the loaded CA to the TLS config + certPool, err := x509.SystemCertPool() + if err != nil { + log.Error("error creating system cert pool", err) + return + } + + if ok := certPool.AppendCertsFromPEM(caPem); !ok { + log.Error("could not load trusted certificate authority", errors.New("failed to append certificate in the cert pool")) + return + } + + // Update the TLS config + tlsConfig.RootCAs = certPool + log.Info("updated TLS config with new trusted certificate authority") + + p.mu.Lock() + p.configs[id] = tlsConfig + p.mu.Unlock() +} + +// tlsConfigEncoder is the internal representation of a TLSConfig. +// It handles some useful methods for the TLSConfig. +type tlsConfigEncoder struct { + SkipVerifyPeerCert bool `json:"skipVerifyPeerCert,omitempty"` + TrustedCA string `json:"trustedCertificateAuthority,omitempty"` + TrustedCAFile string `json:"trustedCertificateAuthorityFile,omitempty"` + TrustedCARefreshInterval string `json:"trustedCertificateAuthorityRefreshInterval,omitempty"` +} + +// encodeConfig converts a TLSConfig to an tlsConfigEncoder. +func encodeConfig(config TLSConfig) tlsConfigEncoder { + return tlsConfigEncoder{ + TrustedCA: config.GetTrustedCertificateAuthority(), + TrustedCAFile: config.GetTrustedCertificateAuthorityFile(), + TrustedCARefreshInterval: config.GetTrustedCertificateAuthorityRefreshInterval().AsDuration().String(), + SkipVerifyPeerCert: BoolStrValue(config.GetSkipVerifyPeerCert()), + } +} + +// hash returns the hash of the tls config. +func (c tlsConfigEncoder) hash() string { + buff := bytes.Buffer{} + _, _ = buff.WriteString(fmt.Sprintf("%t", c.SkipVerifyPeerCert)) + _, _ = buff.WriteString(c.TrustedCA) + _, _ = buff.WriteString(c.TrustedCAFile) + _, _ = buff.WriteString(c.TrustedCARefreshInterval) + hash := fnv.New64a() + _, _ = hash.Write(buff.Bytes()) + out := hash.Sum(make([]byte, 0, 15)) + return hex.EncodeToString(out) +} + +// JSON returns the JSON representation of the tls config. +func (c tlsConfigEncoder) JSON() string { + jsonBytes, _ := json.Marshal(c) + return string(jsonBytes) +} diff --git a/internal/tls_test.go b/internal/tls_test.go index 9d8a7a6..7668a85 100644 --- a/internal/tls_test.go +++ b/internal/tls_test.go @@ -15,31 +15,83 @@ package internal import ( + "context" + "crypto/x509" + "encoding/pem" "os" "testing" + "time" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/structpb" "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc" ) const ( - smallCAPem = `-----BEGIN CERTIFICATE----- -MIIB8TCCAZugAwIBAgIJANZ3fvnlU+1IMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV -BAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRAwDgYDVQQKDAdUZXRyYXRlMRQw -EgYDVQQLDAtFbmdpbmVlcmluZzESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI0MDIx -NjE1MzExOFoXDTI0MDIxNzE1MzExOFowXjELMAkGA1UEBhMCVVMxEzARBgNVBAgM -CkNhbGlmb3JuaWExEDAOBgNVBAoMB1RldHJhdGUxFDASBgNVBAsMC0VuZ2luZWVy -aW5nMRIwEAYDVQQDDAlsb2NhbGhvc3QwXDANBgkqhkiG9w0BAQEFAANLADBIAkEA -17tRxNJNLZVu2ntW/ehw5BneJFV+o7UmpCipv0zBtMtgJw2Z04fYiipaXgwg/sVL -wnyFgbhd0OgoIEg+ND38iQIDAQABozwwOjASBgNVHRMBAf8ECDAGAQH/AgEBMA4G -A1UdDwEB/wQEAwIC5DAUBgNVHREEDTALgglsb2NhbGhvc3QwDQYJKoZIhvcNAQEL -BQADQQAnQuyYJ6FbTuwtduT1ZCDcXMqTKcLb4ex3iaowflGubQuCX41yIprFScN4 -2P5SpEcFlILZiK6vRzyPmuWEQVVr + invalidCAPem = `` + firstCertDNSName = "testing" + secondCertDNSName = "other" + + firstCAPem = `-----BEGIN CERTIFICATE----- +MIICNjCCAeCgAwIBAgIUCUxfyLHNslm/jteqHDJdiYxVo+gwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEDAOBgNVBAoM +B1RldHJhdGUxFDASBgNVBAsMC0VuZ2luZWVyaW5nMRAwDgYDVQQDDAd0ZXN0aW5n +MB4XDTI0MDIyMzEyMTYzMFoXDTI0MDIyNDEyMTYzMFowXDELMAkGA1UEBhMCVVMx +EzARBgNVBAgMCkNhbGlmb3JuaWExEDAOBgNVBAoMB1RldHJhdGUxFDASBgNVBAsM +C0VuZ2luZWVyaW5nMRAwDgYDVQQDDAd0ZXN0aW5nMFwwDQYJKoZIhvcNAQEBBQAD +SwAwSAJBAL5+wV2XPh0l6cwUS4CWqddSfKww6XD0YdKXjjKQZMNo6pZfRfmPIalk +ExNZF8rbCmpk3XJqmh9mpKKPFCNJEbECAwEAAaN6MHgwHQYDVR0OBBYEFD2aRQZN +sH7eVIv2CN+PiTYz2LV4MB8GA1UdIwQYMBaAFD2aRQZNsH7eVIv2CN+PiTYz2LV4 +MBIGA1UdEwEB/wQIMAYBAf8CAQEwDgYDVR0PAQH/BAQDAgLkMBIGA1UdEQQLMAmC +B3Rlc3RpbmcwDQYJKoZIhvcNAQELBQADQQCK9MOCDozutKvtEQ8piLVlkR5EmtWn +33SDPZXeCD4wLyULP8OFayar0rBLaGB33OeKOffQ8xiNF7MD4pOicFlU +-----END CERTIFICATE-----` + + firstCertPem = `-----BEGIN CERTIFICATE----- +MIICEjCCAbygAwIBAgIUZ92xILVsEMxFXr4DJLFpZXp1O5MwDQYJKoZIhvcNAQEL +BQAwXDELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEDAOBgNVBAoM +B1RldHJhdGUxFDASBgNVBAsMC0VuZ2luZWVyaW5nMRAwDgYDVQQDDAd0ZXN0aW5n +MB4XDTI0MDIyMzEyMTYzMFoXDTI0MDIyNDEyMTYzMFowXDELMAkGA1UEBhMCVVMx +EzARBgNVBAgMCkNhbGlmb3JuaWExEDAOBgNVBAoMB1RldHJhdGUxFDASBgNVBAsM +C0VuZ2luZWVyaW5nMRAwDgYDVQQDDAd0ZXN0aW5nMFwwDQYJKoZIhvcNAQEBBQAD +SwAwSAJBAKg+Ife6c7EHqSp2jDZqBCj8dsUvUwR3pxbZdMZOHQ8JwCRLT58TFilb +HkuBMNBAG2wIBgz1yTUQD1qcCS54s8ECAwEAAaNWMFQwEgYDVR0RBAswCYIHdGVz +dGluZzAdBgNVHQ4EFgQUd/ybkK9CxV3CNd96WzNu5nbVsCgwHwYDVR0jBBgwFoAU +PZpFBk2wft5Ui/YI34+JNjPYtXgwDQYJKoZIhvcNAQELBQADQQCTPOpJQp6E6XBf +pf8oBmK4m5qM/qbReZJaRYJFaOJlvgHXkJOLW5SC++yyHLDIphn1WLDGec/Z1JYs +k3ElQddK +-----END CERTIFICATE-----` + + secondCAPem = `-----BEGIN CERTIFICATE----- +MIICMDCCAdqgAwIBAgIUae5YysWmjLQbR6SqzETNSz7EKPwwDQYJKoZIhvcNAQEL +BQAwWjELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEDAOBgNVBAoM +B1RldHJhdGUxFDASBgNVBAsMC0VuZ2luZWVyaW5nMQ4wDAYDVQQDDAVvdGhlcjAe +Fw0yNDAyMjMxMjI0NDdaFw0yNDAyMjQxMjI0NDdaMFoxCzAJBgNVBAYTAlVTMRMw +EQYDVQQIDApDYWxpZm9ybmlhMRAwDgYDVQQKDAdUZXRyYXRlMRQwEgYDVQQLDAtF +bmdpbmVlcmluZzEOMAwGA1UEAwwFb3RoZXIwXDANBgkqhkiG9w0BAQEFAANLADBI +AkEAzGmlQyy0yq6dOLNctb1L5BiQQcfN94jBtzpWavsNt1cZai592Ej7CvQ1FBUj +poP+WUOlv1puhI/sjLK1+E/cRQIDAQABo3gwdjAdBgNVHQ4EFgQU5PTjWUjpv3Hq +0Gqh7+VKX5TMJ9kwHwYDVR0jBBgwFoAU5PTjWUjpv3Hq0Gqh7+VKX5TMJ9kwEgYD +VR0TAQH/BAgwBgEB/wIBATAOBgNVHQ8BAf8EBAMCAuQwEAYDVR0RBAkwB4IFb3Ro +ZXIwDQYJKoZIhvcNAQELBQADQQCxR+vBL0fn1MeQrmla6bDYNbAkdWSJPZASbmeJ +yUoadrfNxkMnlA94OTX0wYmQ4zwedyDWRzp4HgPOWOphe2U2 -----END CERTIFICATE-----` - invalidCAPem = `` + secondCertPem = `-----BEGIN CERTIFICATE----- +MIICDDCCAbagAwIBAgIUZ92xILVsEMxFXr4DJLFpZXp1O5QwDQYJKoZIhvcNAQEL +BQAwWjELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExEDAOBgNVBAoM +B1RldHJhdGUxFDASBgNVBAsMC0VuZ2luZWVyaW5nMQ4wDAYDVQQDDAVvdGhlcjAe +Fw0yNDAyMjMxMjI0NDdaFw0yNDAyMjQxMjI0NDdaMFoxCzAJBgNVBAYTAlVTMRMw +EQYDVQQIDApDYWxpZm9ybmlhMRAwDgYDVQQKDAdUZXRyYXRlMRQwEgYDVQQLDAtF +bmdpbmVlcmluZzEOMAwGA1UEAwwFb3RoZXIwXDANBgkqhkiG9w0BAQEFAANLADBI +AkEAnHRlTPKzGlS0xGUfgk6eQRcbc0eFlQ2QUKm55l5iBC9BP1sY5cO6jcf227l7 +sFdg+9vCBa+j5whebjlWlQ5iawIDAQABo1QwUjAQBgNVHREECTAHggVvdGhlcjAd +BgNVHQ4EFgQU7l+RGBV99RiMS6FKXStwgHuBNvwwHwYDVR0jBBgwFoAU5PTjWUjp +v3Hq0Gqh7+VKX5TMJ9kwDQYJKoZIhvcNAQELBQADQQAy/TQdxLQOxYfTUsXvhbdd +CKSSTT6gHjNgrA5r61drQqvG+69zVEuWybjPzK5uSMntof6I4XWpdfWd37d7WNyd +-----END CERTIFICATE-----` ) func TestLoadTLSConfig(t *testing.T) { @@ -48,7 +100,7 @@ func TestLoadTLSConfig(t *testing.T) { validFile = tmpDir + "/valid.pem" invalidFile = tmpDir + "/invalid.pem" ) - require.NoError(t, os.WriteFile(validFile, []byte(smallCAPem), 0644)) + require.NoError(t, os.WriteFile(validFile, []byte(firstCAPem), 0644)) require.NoError(t, os.WriteFile(invalidFile, []byte(invalidCAPem), 0644)) tests := []struct { @@ -72,7 +124,7 @@ func TestLoadTLSConfig(t *testing.T) { }, { name: "valid trusted CA string config", - config: &oidc.OIDCConfig{TrustedCaConfig: &oidc.OIDCConfig_TrustedCertificateAuthority{TrustedCertificateAuthority: smallCAPem}}, + config: &oidc.OIDCConfig{TrustedCaConfig: &oidc.OIDCConfig_TrustedCertificateAuthority{TrustedCertificateAuthority: firstCAPem}}, wantTLS: true, wantPool: true, }, @@ -97,11 +149,25 @@ func TestLoadTLSConfig(t *testing.T) { config: &oidc.OIDCConfig{TrustedCaConfig: &oidc.OIDCConfig_TrustedCertificateAuthorityFile{TrustedCertificateAuthorityFile: "non-existing.pem"}}, wantErr: true, }, + { + name: "valid trusted CA file and skip verify config", + config: &oidc.OIDCConfig{ + TrustedCaConfig: &oidc.OIDCConfig_TrustedCertificateAuthorityFile{TrustedCertificateAuthorityFile: validFile}, + SkipVerifyPeerCert: structpb.NewBoolValue(true), + }, + wantTLS: true, + wantSkip: false, // skip verify is ignored because there's a trusted CA + wantPool: true, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := LoadTLSConfig(tc.config) + ctx, cancel := context.WithCancel(context.Background()) + pool := NewTLSConfigPool(ctx) + t.Cleanup(cancel) + + got, err := pool.LoadTLSConfig(tc.config) // Check for errors if tc.wantErr { @@ -122,3 +188,165 @@ func TestLoadTLSConfig(t *testing.T) { }) } } + +func TestTLSConfigPoolUpdates(t *testing.T) { + tmpDir := t.TempDir() + var caFile1 = tmpDir + "/ca1.pem" + require.NoError(t, os.WriteFile(caFile1, []byte(firstCAPem), 0644)) + + block, _ := pem.Decode([]byte(firstCertPem)) + cert1, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + + block, _ = pem.Decode([]byte(secondCertPem)) + cert2, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + pool := NewTLSConfigPool(ctx) + t.Cleanup(cancel) + + const ( + interval = 100 * time.Millisecond + intervalAndHalf = interval + interval/2 + ) + + config := &oidc.OIDCConfig{ + TrustedCaConfig: &oidc.OIDCConfig_TrustedCertificateAuthorityFile{TrustedCertificateAuthorityFile: caFile1}, + TrustedCertificateAuthorityRefreshInterval: durationpb.New(interval), + } + + // load the TLS config + gotTLS, err := pool.LoadTLSConfig(config) + require.NoError(t, err) + require.NotNil(t, gotTLS) + + // verify the got TLS config is valid + _, err = cert1.Verify(x509.VerifyOptions{Roots: gotTLS.RootCAs, DNSName: firstCertDNSName}) + require.NoError(t, err) + + // update the CA file content + require.NoError(t, os.WriteFile(caFile1, []byte(secondCAPem), 0644)) + time.Sleep(intervalAndHalf) + + // load the TLS config again + gotTLS, err = pool.LoadTLSConfig(config) + require.NoError(t, err) + + // verify the got TLS config is not valid anymore for the old CA, + // as we updated it with CA only valid for cert2. + _, err = cert1.Verify(x509.VerifyOptions{Roots: gotTLS.RootCAs, DNSName: firstCertDNSName}) + require.Error(t, err) + + // verify the got TLS config is valid for the new CA + _, err = cert2.Verify(x509.VerifyOptions{Roots: gotTLS.RootCAs, DNSName: secondCertDNSName}) + require.NoError(t, err) + + // update the CA file content to be invalid + require.NoError(t, os.WriteFile(caFile1, []byte(invalidCAPem), 0644)) + time.Sleep(intervalAndHalf) + + // load the TLS config again + gotTLS, err = pool.LoadTLSConfig(config) + require.NoError(t, err) + + // verify the config is not updated, so the old TLS config is still valid + _, err = cert2.Verify(x509.VerifyOptions{Roots: gotTLS.RootCAs, DNSName: secondCertDNSName}) + require.NoError(t, err) + + // remove the CA file + require.NoError(t, os.Remove(caFile1)) + time.Sleep(intervalAndHalf) + + // load the TLS config again + gotTLS, err = pool.LoadTLSConfig(config) + require.NoError(t, err) + + // verify the config is not modified, so the old TLS config is still valid + _, err = cert2.Verify(x509.VerifyOptions{Roots: gotTLS.RootCAs, DNSName: secondCertDNSName}) + require.NoError(t, err) + + // update the CA file content to be valid again and verify the new CA is loaded + require.NoError(t, os.WriteFile(caFile1, []byte(firstCAPem), 0644)) + time.Sleep(intervalAndHalf) + + // load the TLS config again + gotTLS, err = pool.LoadTLSConfig(config) + require.NoError(t, err) + + _, err = cert1.Verify(x509.VerifyOptions{Roots: gotTLS.RootCAs, DNSName: firstCertDNSName}) + require.NoError(t, err) +} + +func TestTLSConfigPoolWithMultipleConfigs(t *testing.T) { + tmpDir := t.TempDir() + var ( + caFile1 = tmpDir + "/ca1.pem" + caFile2 = tmpDir + "/ca2.pem" + ) + require.NoError(t, os.WriteFile(caFile1, []byte(firstCAPem), 0644)) + require.NoError(t, os.WriteFile(caFile2, []byte(secondCAPem), 0644)) + + block, _ := pem.Decode([]byte(firstCertPem)) + cert1, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + + block, _ = pem.Decode([]byte(secondCertPem)) + cert2, err := x509.ParseCertificate(block.Bytes) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + pool := NewTLSConfigPool(ctx) + t.Cleanup(cancel) + + const ( + config1Interval = 100 * time.Millisecond + config2Interval = 200 * time.Millisecond + ) + var intervalAndHalf = func(interval time.Duration) time.Duration { return interval + interval/2 } + + config1 := &oidc.OIDCConfig{ + TrustedCaConfig: &oidc.OIDCConfig_TrustedCertificateAuthorityFile{TrustedCertificateAuthorityFile: caFile1}, + TrustedCertificateAuthorityRefreshInterval: durationpb.New(config1Interval), + } + config2 := &oidc.OIDCConfig{ + TrustedCaConfig: &oidc.OIDCConfig_TrustedCertificateAuthorityFile{TrustedCertificateAuthorityFile: caFile2}, + TrustedCertificateAuthorityRefreshInterval: durationpb.New(config2Interval), + } + + // load the TLS config for config1 + gotTLS1, err := pool.LoadTLSConfig(config1) + require.NoError(t, err) + + // load the TLS config for config2 + gotTLS2, err := pool.LoadTLSConfig(config2) + require.NoError(t, err) + + // verify the got TLS config for config1 is valid + _, err = cert1.Verify(x509.VerifyOptions{Roots: gotTLS1.RootCAs, DNSName: firstCertDNSName}) + require.NoError(t, err) + + // verify the got TLS config for config2 is valid + _, err = cert2.Verify(x509.VerifyOptions{Roots: gotTLS2.RootCAs, DNSName: secondCertDNSName}) + require.NoError(t, err) + + // update the second file to contain the first CA + require.NoError(t, os.WriteFile(caFile2, []byte(firstCAPem), 0644)) + time.Sleep(intervalAndHalf(config2Interval)) + + // load the TLS config for config2 again + gotTLS2, err = pool.LoadTLSConfig(config2) + require.NoError(t, err) + + // verify the got TLS config for config2 is valid for the first CA and not for the second + _, err = cert1.Verify(x509.VerifyOptions{Roots: gotTLS2.RootCAs, DNSName: firstCertDNSName}) + require.NoError(t, err) + _, err = cert2.Verify(x509.VerifyOptions{Roots: gotTLS2.RootCAs, DNSName: secondCertDNSName}) + require.Error(t, err) + + // verify the got TLS config for config1 is still valid + gotTLS1, err = pool.LoadTLSConfig(config1) + require.NoError(t, err) + _, err = cert1.Verify(x509.VerifyOptions{Roots: gotTLS1.RootCAs, DNSName: firstCertDNSName}) + require.NoError(t, err) +}