diff --git a/internal/authz/oidc.go b/internal/authz/oidc.go index f418f07..ceaea1d 100644 --- a/internal/authz/oidc.go +++ b/internal/authz/oidc.go @@ -731,18 +731,26 @@ func (o *oidcHandler) encodeTokensToHeaders(tokens *oidc.TokenResponse) map[stri headers := make(map[string]string) // Always add the ID token to the headers - headers[o.config.GetIdToken().GetHeader()] = o.config.IdToken.GetPreamble() + " " + oidc.EncodeToken(tokens.IDToken) + headers[o.config.GetIdToken().GetHeader()] = encodeHeaderValue(o.config.IdToken.GetPreamble(), tokens.IDToken) if o.config.GetAccessToken() == nil || tokens.AccessToken == "" { return headers } // If there is an access token and config enables it, add it to the headers - headers[o.config.GetAccessToken().GetHeader()] = o.config.GetAccessToken().GetPreamble() + " " + oidc.EncodeToken(tokens.AccessToken) + headers[o.config.GetAccessToken().GetHeader()] = encodeHeaderValue(o.config.GetAccessToken().GetPreamble(), tokens.AccessToken) return headers } +// encodeHeaderValue encodes the value with the given preamble, if any +func encodeHeaderValue(preamble string, value string) string { + if preamble != "" { + return preamble + " " + value + } + return value +} + // areRequiredTokensExpired checks if the required tokens are expired. func (o *oidcHandler) areRequiredTokensExpired(tokens *oidc.TokenResponse) (bool, error) { idToken, err := tokens.ParseIDToken() diff --git a/internal/authz/oidc_test.go b/internal/authz/oidc_test.go index 17b8d0b..6311a72 100644 --- a/internal/authz/oidc_test.go +++ b/internal/authz/oidc_test.go @@ -18,7 +18,6 @@ import ( "context" "crypto/rand" "crypto/rsa" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -1152,10 +1151,8 @@ func TestMatchesLogoutPath(t *testing.T) { func TestEncodeTokensToHeaders(t *testing.T) { const ( - idToken = "id-token" - accessToken = "access-token" - idTokenB64 = "aWQtdG9rZW4=" - accessTokenB64 = "YWNjZXNzLXRva2Vu" + idToken = "id-token" + accessToken = "access-token" ) tests := []struct { @@ -1171,7 +1168,7 @@ func TestEncodeTokensToHeaders(t *testing.T) { }, idToken: idToken, accessToken: "", want: map[string]string{ - "Authorization": "Bearer " + idTokenB64, + "Authorization": "Bearer " + idToken, }, }, { @@ -1182,8 +1179,8 @@ func TestEncodeTokensToHeaders(t *testing.T) { }, idToken: idToken, accessToken: accessToken, want: map[string]string{ - "Authorization": "Bearer " + idTokenB64, - "X-Access-Token": "Bearer " + accessTokenB64, + "Authorization": "Bearer " + idToken, + "X-Access-Token": "Bearer " + accessToken, }, }, { @@ -1194,8 +1191,8 @@ func TestEncodeTokensToHeaders(t *testing.T) { }, idToken: idToken, accessToken: accessToken, want: map[string]string{ - "X-Id-Token": "Other " + idTokenB64, - "X-Access-Token-Other": "Other " + accessTokenB64, + "X-Id-Token": "Other " + idToken, + "X-Access-Token-Other": "Other " + accessToken, }, }, { @@ -1206,7 +1203,7 @@ func TestEncodeTokensToHeaders(t *testing.T) { }, idToken: idToken, accessToken: "", want: map[string]string{ - "Authorization": "Bearer " + idTokenB64, + "Authorization": "Bearer " + idToken, }, }, { @@ -1216,7 +1213,19 @@ func TestEncodeTokensToHeaders(t *testing.T) { }, idToken: idToken, accessToken: accessToken, want: map[string]string{ - "Authorization": "Bearer " + idTokenB64, + "Authorization": "Bearer " + idToken, + }, + }, + { + name: "config with out preamble", + config: &oidcv1.OIDCConfig{ + IdToken: &oidcv1.TokenConfig{Header: "X-ID-Token"}, + AccessToken: &oidcv1.TokenConfig{Header: "X-Access-Token"}, + }, + idToken: idToken, accessToken: accessToken, + want: map[string]string{ + "X-ID-Token": idToken, + "X-Access-Token": accessToken, }, }, } @@ -1487,9 +1496,9 @@ func requireTokensInResponse(t *testing.T, resp *envoy.OkHttpResponse, cfg *oidc wantIDToken, wantAccessToken string ) - wantIDToken = cfg.GetIdToken().GetPreamble() + " " + base64.URLEncoding.EncodeToString([]byte(idToken)) + wantIDToken = encodeHeaderValue(cfg.GetIdToken().GetPreamble(), idToken) if cfg.GetAccessToken() != nil { - wantAccessToken = cfg.GetAccessToken().GetPreamble() + " " + base64.URLEncoding.EncodeToString([]byte(accessToken)) + wantAccessToken = encodeHeaderValue(cfg.GetAccessToken().GetPreamble(), accessToken) } for _, header := range resp.GetHeaders() { diff --git a/internal/oidc/token.go b/internal/oidc/token.go index a1e405a..ce25cae 100644 --- a/internal/oidc/token.go +++ b/internal/oidc/token.go @@ -15,7 +15,6 @@ package oidc import ( - "encoding/base64" "time" "github.com/lestrrat-go/jwx/jwt" @@ -36,8 +35,3 @@ func (t *TokenResponse) ParseIDToken() (jwt.Token, error) { return ParseToken(t. func ParseToken(token string) (jwt.Token, error) { return jwt.Parse([]byte(token), jwt.WithValidate(false)) } - -// EncodeToken returns the base64 encoded string representation of the token. Compatible with HTTP headers. -func EncodeToken(token string) string { - return base64.URLEncoding.EncodeToString([]byte(token)) -} diff --git a/internal/oidc/token_test.go b/internal/oidc/token_test.go index 71c6d09..ad4d304 100644 --- a/internal/oidc/token_test.go +++ b/internal/oidc/token_test.go @@ -50,7 +50,3 @@ func newToken() string { signed, _ := jwt.Sign(token, jwa.HS256, []byte("key")) return string(signed) } - -func TestEncodeToken(t *testing.T) { - require.Equal(t, "dGVzdA==", EncodeToken("test")) -}