Skip to content
This repository was archived by the owner on Apr 22, 2024. It is now read-only.

Decouple CA load from request lifecycle #50

Merged
merged 4 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ import (

func main() {
var (
lifecycle = run.NewLifecycle()
configFile = &internal.LocalConfigFile{}
logging = internal.NewLogSystem(log.New(), &configFile.Config)
jwks = oidc.NewJWKSProvider()
tlsPool = internal.NewTLSConfigPool(lifecycle.Context())
jwks = oidc.NewJWKSProvider(tlsPool)
sessions = oidc.NewSessionStoreFactory(&configFile.Config)
envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks, sessions)
envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, tlsPool, jwks, sessions)
authzServer = server.New(&configFile.Config, envoyAuthz.Register)
healthz = server.NewHealthServer(&configFile.Config)
secrets = internal.NewSecretLoader(&configFile.Config)
Expand All @@ -51,6 +53,7 @@ func main() {
g := run.Group{Logger: internal.Logger(internal.Default)}

g.Register(
lifecycle, // manage the lifecycle of the run.Services
configFile, // load the configuration
logging, // set up the logging system
secrets, // load the secrets and update the configuration
Expand Down
146 changes: 86 additions & 60 deletions config/gen/go/v1/oidc/config.pb.go

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions config/gen/go/v1/oidc/config.pb.validate.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions config/v1/oidc/config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ syntax = "proto3";

package authservice.config.v1.oidc;

import "google/protobuf/duration.proto";
import "google/protobuf/struct.proto";
import "validate/validate.proto";

Expand Down Expand Up @@ -222,6 +223,13 @@ message OIDCConfig {
string trusted_certificate_authority_file = 20;
}

// The duration between refreshes of the trusted certificate authority if `trusted_certificate_authority_file` is set.
// Unset or 0 (the default) disables the refresh, useful is no rotation is expected.
// Is a String that ends in `s` to indicate seconds and is preceded by the number of seconds,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it supports also minutes or hours?
Also, let's not mention nanos; no one will use such precision when configuring this interval

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I thought, but this is not golang time.Duration, it's the protobuf one it says only seconds and nanos are supported: https://protobuf.dev/reference/protobuf/google.protobuf/#duration

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is what happens when you marshal the duration into a JSON, that it always marshals it as seconds, but when unmarshaling you can use normal Golang Duration string values. Can you try it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already tried it:

proto: (line 32:63): invalid google.protobuf.Duration value "60m"

But as the documentation says, it only accepts seconds with decimals for nanos.

Something that could make sense is to turn this field into a string representing the golang time.Duration format and validate that manually when we're loading the config.

// with nanoseconds expressed as fractional seconds, e.g. `120.15s`.
// Optional.
google.protobuf.Duration trusted_certificate_authority_refresh_interval = 22;

// The Authservice makes two kinds of direct network connections directly to the OIDC Provider.
// Both are POST requests to the configured `token_uri` of the OIDC Provider.
// The first is to exchange the authorization code for tokens, and the other is to use the
Expand Down
3 changes: 2 additions & 1 deletion e2e/keycloak/authz-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
"redis_session_store_config": {
"server_uri": "redis://redis:6379"
},
"trusted_certificate_authority_file": "/etc/authservice/certs/ca.crt"
"trusted_certificate_authority_file": "/etc/authservice/certs/ca.crt",
"trusted_certificate_authority_refresh_interval": "60.25s"
}
}
]
Expand Down
13 changes: 7 additions & 6 deletions internal/authz/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ var (
type oidcHandler struct {
log telemetry.Logger
config *oidcv1.OIDCConfig
tlsPool internal.TLSConfigPool
jwks oidc.JWKSProvider
sessions oidc.SessionStoreFactory
sessionGen oidc.SessionGenerator
Expand All @@ -61,11 +62,10 @@ type oidcHandler struct {
}

// NewOIDCHandler creates a new OIDC implementation of the Handler interface.
func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider,
sessions oidc.SessionStoreFactory, clock oidc.Clock,
sessionGen oidc.SessionGenerator) (Handler, error) {
func NewOIDCHandler(cfg *oidcv1.OIDCConfig, tlsPool internal.TLSConfigPool, jwks oidc.JWKSProvider,
sessions oidc.SessionStoreFactory, clock oidc.Clock, sessionGen oidc.SessionGenerator) (Handler, error) {

client, err := getHTTPClient(cfg)
client, err := getHTTPClient(cfg, tlsPool)
if err != nil {
return nil, err
}
Expand All @@ -77,6 +77,7 @@ func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider,
return &oidcHandler{
log: internal.Logger(internal.Authz).With("type", "oidc"),
config: cfg,
tlsPool: tlsPool,
jwks: jwks,
sessions: sessions,
clock: clock,
Expand All @@ -85,11 +86,11 @@ func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider,
}, nil
}

func getHTTPClient(cfg *oidcv1.OIDCConfig) (*http.Client, error) {
func getHTTPClient(cfg *oidcv1.OIDCConfig, tlsPool internal.TLSConfigPool) (*http.Client, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()

var err error
if transport.TLSClientConfig, err = internal.LoadTLSConfig(cfg); err != nil {
if transport.TLSClientConfig, err = tlsPool.LoadTLSConfig(cfg); err != nil {
return nil, err
}

Expand Down
28 changes: 19 additions & 9 deletions internal/authz/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"google.golang.org/grpc/test/bufconn"

oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
"github.com/tetrateio/authservice-go/internal"
inthttp "github.com/tetrateio/authservice-go/internal/http"
"github.com/tetrateio/authservice-go/internal/oidc"
)
Expand Down Expand Up @@ -201,7 +202,8 @@ func TestOIDCProcess(t *testing.T) {
clock := oidc.Clock{}
sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)}
store := sessions.Get(basicOIDCConfig)
h, err := NewOIDCHandler(basicOIDCConfig, oidc.NewJWKSProvider(), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
tlsPool := internal.NewTLSConfigPool(context.Background())
h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, oidc.NewJWKSProvider(tlsPool), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
require.NoError(t, err)

ctx := context.Background()
Expand Down Expand Up @@ -877,6 +879,7 @@ func TestOIDCProcess(t *testing.T) {
func TestOIDCProcessWithFailingSessionStore(t *testing.T) {
store := &storeMock{delegate: oidc.NewMemoryStore(&oidc.Clock{}, time.Hour, time.Hour)}
sessions := &mockSessionStoreFactory{store: store}
tlsPool := internal.NewTLSConfigPool(context.Background())

jwkPriv, jwkPub := newKeyPair(t)
bytes, err := json.Marshal(newKeySet(jwkPub))
Expand All @@ -885,7 +888,8 @@ func TestOIDCProcessWithFailingSessionStore(t *testing.T) {
Jwks: string(bytes),
}

h, err := NewOIDCHandler(basicOIDCConfig, oidc.NewJWKSProvider(), sessions, oidc.Clock{}, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, oidc.NewJWKSProvider(tlsPool),
sessions, oidc.Clock{}, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
require.NoError(t, err)

ctx := context.Background()
Expand Down Expand Up @@ -1029,7 +1033,8 @@ func TestOIDCProcessWithFailingJWKSProvider(t *testing.T) {
clock := oidc.Clock{}
sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)}
store := sessions.Get(basicOIDCConfig)
h, err := NewOIDCHandler(basicOIDCConfig, funcJWKSProvider, sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
tlsPool := internal.NewTLSConfigPool(context.Background())
h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, funcJWKSProvider, sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
require.NoError(t, err)

idpServer := newServer()
Expand Down Expand Up @@ -1235,10 +1240,11 @@ func TestEncodeTokensToHeaders(t *testing.T) {
}

sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&oidc.Clock{}, time.Hour, time.Hour)}
tlsPool := internal.NewTLSConfigPool(context.Background())

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, err := NewOIDCHandler(tt.config, nil, sessions, oidc.Clock{}, nil)
h, err := NewOIDCHandler(tt.config, tlsPool, nil, sessions, oidc.Clock{}, nil)
require.NoError(t, err)

tokResp := &oidc.TokenResponse{
Expand Down Expand Up @@ -1307,10 +1313,11 @@ func TestAreTokensExpired(t *testing.T) {
}

sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&oidc.Clock{}, time.Hour, time.Hour)}
tlsPool := internal.NewTLSConfigPool(context.Background())

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, err := NewOIDCHandler(tt.config, nil, sessions, oidc.Clock{}, nil)
h, err := NewOIDCHandler(tt.config, tlsPool, nil, sessions, oidc.Clock{}, nil)
require.NoError(t, err)

tokResp := &oidc.TokenResponse{
Expand Down Expand Up @@ -1341,12 +1348,16 @@ func TestLoadWellKnownConfig(t *testing.T) {

func TestLoadWellKnownConfigError(t *testing.T) {
clock := oidc.Clock{}
tlsPool := internal.NewTLSConfigPool(context.Background())
sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)}
_, err := NewOIDCHandler(dynamicOIDCConfig, oidc.NewJWKSProvider(), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
_, err := NewOIDCHandler(dynamicOIDCConfig, tlsPool, oidc.NewJWKSProvider(tlsPool), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
require.Error(t, err) // Fail to retrieve the dynamic config since the test server is not running
}

func TestNewOIDCHandler(t *testing.T) {
clock := oidc.Clock{}
tlsPool := internal.NewTLSConfigPool(context.Background())
sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)}

tests := []struct {
name string
Expand All @@ -1359,9 +1370,8 @@ func TestNewOIDCHandler(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clock := oidc.Clock{}
sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)}
_, err := NewOIDCHandler(tt.config, oidc.NewJWKSProvider(), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState))

_, err := NewOIDCHandler(tt.config, tlsPool, oidc.NewJWKSProvider(tlsPool), sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState))
if tt.wantErr {
require.Error(t, err)
} else {
Expand Down
Loading