Skip to content

Commit

Permalink
Update tests to match current implementation, change how additional p…
Browse files Browse the repository at this point in the history
…arameters are handled

Signed-off-by: Adam Nych <[email protected]>
  • Loading branch information
a-nych committed Nov 7, 2024
1 parent f3d0250 commit 553beb0
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 158 deletions.
97 changes: 43 additions & 54 deletions connector/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"golang.org/x/exp/slices"
"log/slog"
"net/http"
"net/url"
Expand Down Expand Up @@ -106,8 +107,7 @@ type Config struct {
FilterGroupClaims FilterGroupClaims `json:"filterGroupClaims"`
} `json:"claimModifications"`

// Add additional authorization request parameters to acceess IdP specific features.
// Take care not to override standard OICD authorization requests parameters.
// AdditionalAuthRequestParams allows to add additional authorization request parameters to access IdP specific features.
AdditionalAuthRequestParams map[string]string `json:"additionalAuthRequestParams"`
}

Expand Down Expand Up @@ -306,23 +306,23 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
ctx, // Pass our ctx with customized http.Client
&oidc.Config{ClientID: clientID},
),
logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)),
cancel: cancel,
httpClient: httpClient,
insecureSkipEmailVerified: c.InsecureSkipEmailVerified,
insecureEnableGroups: c.InsecureEnableGroups,
allowedGroups: c.AllowedGroups,
acrValues: c.AcrValues,
getUserInfo: c.GetUserInfo,
promptType: promptType,
userIDKey: c.UserIDKey,
userNameKey: c.UserNameKey,
overrideClaimMapping: c.OverrideClaimMapping,
preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey,
emailKey: c.ClaimMapping.EmailKey,
groupsKey: c.ClaimMapping.GroupsKey,
newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims,
groupsFilter: groupsFilter,
logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)),
cancel: cancel,
httpClient: httpClient,
insecureSkipEmailVerified: c.InsecureSkipEmailVerified,
insecureEnableGroups: c.InsecureEnableGroups,
allowedGroups: c.AllowedGroups,
acrValues: c.AcrValues,
getUserInfo: c.GetUserInfo,
promptType: promptType,
userIDKey: c.UserIDKey,
userNameKey: c.UserNameKey,
overrideClaimMapping: c.OverrideClaimMapping,
preferredUsernameKey: c.ClaimMapping.PreferredUsernameKey,
emailKey: c.ClaimMapping.EmailKey,
groupsKey: c.ClaimMapping.GroupsKey,
newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims,
groupsFilter: groupsFilter,
additionalAuthRequestParams: c.AdditionalAuthRequestParams,
}, nil
}
Expand All @@ -333,27 +333,27 @@ var (
)

type oidcConnector struct {
provider *oidc.Provider
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
cancel context.CancelFunc
logger *slog.Logger
httpClient *http.Client
insecureSkipEmailVerified bool
insecureEnableGroups bool
allowedGroups []string
acrValues []string
getUserInfo bool
promptType string
userIDKey string
userNameKey string
overrideClaimMapping bool
preferredUsernameKey string
emailKey string
groupsKey string
newGroupFromClaims []NewGroupFromClaims
groupsFilter *regexp.Regexp
provider *oidc.Provider
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
cancel context.CancelFunc
logger *slog.Logger
httpClient *http.Client
insecureSkipEmailVerified bool
insecureEnableGroups bool
allowedGroups []string
acrValues []string
getUserInfo bool
promptType string
userIDKey string
userNameKey string
overrideClaimMapping bool
preferredUsernameKey string
emailKey string
groupsKey string
newGroupFromClaims []NewGroupFromClaims
groupsFilter *regexp.Regexp
additionalAuthRequestParams map[string]string
}

Expand All @@ -378,14 +378,12 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", c.promptType))
}

if len(c.additionalAuthRequestParams) > 0 {
for k, v := range c.additionalAuthRequestParams {
if contains(managedAuthParams, k) {
return "", fmt.Errorf("parameter '%s' is already managed by this connector", k)
}
for k, v := range c.additionalAuthRequestParams {
if !slices.Contains(managedAuthParams, k) {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
}

return c.oauth2Config.AuthCodeURL(state, opts...), nil
}

Expand Down Expand Up @@ -653,12 +651,3 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I

return identity, nil
}

func contains(slice []string, value string) bool {
for _, v := range slice {
if v == value {
return true
}
}
return false
}
184 changes: 80 additions & 104 deletions connector/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package oidc

import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
Expand Down Expand Up @@ -773,6 +774,7 @@ func testLoginURL(t *testing.T, config Config, state string) (url.Values, error)
}
return values, nil
}

func TestLoginURLCustomerParam(t *testing.T) {
cfg := Config{
ClientID: "client",
Expand All @@ -783,14 +785,14 @@ func TestLoginURLCustomerParam(t *testing.T) {
}
values, err := testLoginURL(t, cfg, "1234")

assert.Nil(t, err)
assert.Len(t, values, 6)
assertParamValue(t, values, "organization", "myorg")
assertParamValue(t, values, "client_id", "client")
assertParamValue(t, values, "redirect_uri", "callback")
assertParamValue(t, values, "state", "1234")
assertParamValue(t, values, "response_type", "code")
assertParamValue(t, values, "scope", "openid profile email")
require.NoError(t, err)
require.Len(t, values, 6)
expectEquals(t, values.Get("organization"), "myorg")
expectEquals(t, values.Get("client_id"), "client")
expectEquals(t, values.Get("redirect_uri"), "callback")
expectEquals(t, values.Get("state"), "1234")
expectEquals(t, values.Get("response_type"), "code")
expectEquals(t, values.Get("scope"), "openid profile email")
}

func TestCustomLoginURLEmptyParams(t *testing.T) {
Expand All @@ -801,114 +803,88 @@ func TestCustomLoginURLEmptyParams(t *testing.T) {
}
values, err := testLoginURL(t, cfg, "1234")

assert.Nil(t, err)
assert.Len(t, values, 5)
assertParamValue(t, values, "client_id", "client")
assertParamValue(t, values, "redirect_uri", "callback")
assertParamValue(t, values, "state", "1234")
assertParamValue(t, values, "response_type", "code")
assertParamValue(t, values, "scope", "openid profile email")
require.NoError(t, err)
require.Len(t, values, 5)
expectEquals(t, values.Get("client_id"), "client")
expectEquals(t, values.Get("redirect_uri"), "callback")
expectEquals(t, values.Get("state"), "1234")
expectEquals(t, values.Get("response_type"), "code")
expectEquals(t, values.Get("scope"), "openid profile email")
}

func TestLoginURLClientIdError(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
"client_id": "not-so-fast",
func TestLoginURLParameterProtection(t *testing.T) {
tests := []struct {
name string
paramToOverride string
expectedValue string
}{
{
name: "client_id cannot be overridden",
paramToOverride: "client_id",
expectedValue: "client", // Should remain the config value
},
}
_, err := testLoginURL(t, cfg, "1234")
assert.EqualErrorf(t, err, "parameter 'client_id' is already managed by this connector", "")
}

func TestLoginURLRedirectURIError(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
"redirect_uri": "not-so-fast",
{
name: "redirect_uri cannot be overridden",
paramToOverride: "redirect_uri",
expectedValue: "callback", // Should remain the config value
},
}
_, err := testLoginURL(t, cfg, "1234")
assert.EqualErrorf(t, err, "parameter 'redirect_uri' is already managed by this connector", "")
}

func TestLoginURLStateError(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
"state": "not-so-fast",
{
name: "state cannot be overridden",
paramToOverride: "state",
expectedValue: "1234", // Should remain the state parameter value
},
}
_, err := testLoginURL(t, cfg, "1234")
assert.EqualErrorf(t, err, "parameter 'state' is already managed by this connector", "")
}

func TestLoginURLHostedDomainsError(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
"hd": "not-so-fast",
{
name: "response_type cannot be overridden",
paramToOverride: "response_type",
expectedValue: "code", // Should remain the default OAuth2 value
},
}
_, err := testLoginURL(t, cfg, "1234")
assert.EqualErrorf(t, err, "parameter 'hd' is already managed by this connector", "")
}

func TestLoginURLResponseTypeError(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
"response_type": "not-so-fast",
{
name: "scope cannot be overridden",
paramToOverride: "scope",
expectedValue: "openid profile email", // Should remain the default scopes
},
}
_, err := testLoginURL(t, cfg, "1234")
assert.EqualErrorf(t, err, "parameter 'response_type' is already managed by this connector", "")
}

func TestLoginURLScopeTypeError(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
"scope": "not-so-fast",
{
name: "prompt cannot be overridden",
paramToOverride: "prompt",
expectedValue: "", // Should not be set unless offline access is requested
},
}
_, err := testLoginURL(t, cfg, "1234")
assert.EqualErrorf(t, err, "parameter 'scope' is already managed by this connector", "")
}

func TestLoginURLPromptError(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
"prompt": "not-so-fast",
{
name: "hd cannot be overridden",
paramToOverride: "hd",
expectedValue: "", // Should not be set as hosted domains are not configured
},
}
_, err := testLoginURL(t, cfg, "1234")
assert.EqualErrorf(t, err, "parameter 'prompt' is already managed by this connector", "")
}

func TestLoginURLAcrValuesError(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
"acr_values": "not-so-fast",
{
name: "acr_values cannot be overridden",
paramToOverride: "acr_values",
expectedValue: "", // Should not be set as acr_values are not configured
},
}
_, err := testLoginURL(t, cfg, "1234")
assert.EqualErrorf(t, err, "parameter 'acr_values' is already managed by this connector", "")
}

func assertParamValue(t *testing.T, values url.Values, queryParam string, expectedValue string) {
assert.NotNil(t, values[queryParam])
assert.Equal(t, expectedValue, values[queryParam][0])
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cfg := Config{
ClientID: "client",
RedirectURI: "callback",
AdditionalAuthRequestParams: map[string]string{
tc.paramToOverride: "not-so-fast",
},
}

values, err := testLoginURL(t, cfg, "1234")
require.NoError(t, err)

// Check that the parameter contains the expected value, not the override attempt
gotValue := values.Get(tc.paramToOverride)
if tc.expectedValue == "" {
// If we expect no value, the parameter should not be present
require.Empty(t, gotValue, "parameter %s should not be present", tc.paramToOverride)
} else {
require.Equal(t, tc.expectedValue, gotValue,
"parameter %s should be %q but got %q",
tc.paramToOverride, tc.expectedValue, gotValue)
}
})
}
}

func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) {
Expand Down

0 comments on commit 553beb0

Please sign in to comment.