Skip to content

Commit 2c1b62f

Browse files
authored
feat: preserve oidc params in oauth flow (#772)
1 parent 646e24d commit 2c1b62f

3 files changed

Lines changed: 105 additions & 59 deletions

File tree

frontend/src/pages/login-page.tsx

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,14 @@ export const LoginPage = () => {
7676
isPending: oauthIsPending,
7777
variables: oauthVariables,
7878
} = useMutation({
79-
mutationFn: (provider: string) =>
80-
axios.get(
81-
`/api/oauth/url/${provider}${props.redirect_uri ? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}` : ""}`,
82-
),
79+
mutationFn: (provider: string) => {
80+
const params = isOidc
81+
? `?${compiledOIDCParams}`
82+
: props.redirect_uri
83+
? `?redirect_uri=${encodeURIComponent(props.redirect_uri)}`
84+
: "";
85+
return axios.get(`/api/oauth/url/${provider}${params}`);
86+
},
8387
mutationKey: ["oauth"],
8488
onSuccess: (data) => {
8589
toast.info(t("loginOauthSuccessTitle"), {

internal/controller/oauth_controller.go

Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,29 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
6262
return
6363
}
6464

65-
sessionId, session, err := controller.auth.NewOAuthSession(req.Provider)
65+
var reqParams service.OAuthURLParams
66+
67+
err = c.BindQuery(&reqParams)
68+
69+
if err != nil {
70+
tlog.App.Error().Err(err).Msg("Failed to bind query parameters")
71+
c.JSON(400, gin.H{
72+
"status": 400,
73+
"message": "Bad Request",
74+
})
75+
return
76+
}
77+
78+
if !controller.isOidcRequest(reqParams) {
79+
isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain)
80+
81+
if !isRedirectSafe {
82+
tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring")
83+
reqParams.RedirectURI = ""
84+
}
85+
}
86+
87+
sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams)
6688

6789
if err != nil {
6890
tlog.App.Error().Err(err).Msg("Failed to create OAuth session")
@@ -85,20 +107,6 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) {
85107
}
86108

87109
c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
88-
c.SetCookie(controller.config.CSRFCookieName, session.State, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
89-
90-
redirectURI := c.Query("redirect_uri")
91-
isRedirectSafe := utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain)
92-
93-
if !isRedirectSafe {
94-
tlog.App.Warn().Str("redirect_uri", redirectURI).Msg("Unsafe redirect URI detected, ignoring")
95-
redirectURI = ""
96-
}
97-
98-
if redirectURI != "" && isRedirectSafe {
99-
tlog.App.Debug().Msg("Setting redirect URI cookie")
100-
c.SetCookie(controller.config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
101-
}
102110

103111
c.JSON(200, gin.H{
104112
"status": 200,
@@ -129,19 +137,23 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
129137
}
130138

131139
c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
132-
defer controller.auth.EndOAuthSession(sessionIdCookie)
133140

134-
state := c.Query("state")
135-
csrfCookie, err := c.Cookie(controller.config.CSRFCookieName)
141+
oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie)
136142

137-
if err != nil || state != csrfCookie {
138-
tlog.App.Warn().Err(err).Msg("CSRF token mismatch or cookie missing")
139-
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
143+
if err != nil {
144+
tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session")
140145
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
141146
return
142147
}
143148

144-
c.SetCookie(controller.config.CSRFCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
149+
defer controller.auth.EndOAuthSession(sessionIdCookie)
150+
151+
state := c.Query("state")
152+
if state != oauthPendingSession.State {
153+
tlog.App.Warn().Err(err).Msg("CSRF token mismatch")
154+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
155+
return
156+
}
145157

146158
code := c.Query("code")
147159
_, err = controller.auth.GetOAuthToken(sessionIdCookie, code)
@@ -198,16 +210,16 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
198210
username = strings.Replace(user.Email, "@", "_", 1)
199211
}
200212

201-
service, err := controller.auth.GetOAuthService(sessionIdCookie)
213+
svc, err := controller.auth.GetOAuthService(sessionIdCookie)
202214

203215
if err != nil {
204216
tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session")
205217
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
206218
return
207219
}
208220

209-
if service.ID() != req.Provider {
210-
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", service.ID(), req.Provider)
221+
if svc.ID() != req.Provider {
222+
tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider)
211223
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
212224
return
213225
}
@@ -216,9 +228,9 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
216228
Username: username,
217229
Name: name,
218230
Email: user.Email,
219-
Provider: service.ID(),
231+
Provider: svc.ID(),
220232
OAuthGroups: utils.CoalesceToString(user.Groups),
221-
OAuthName: service.Name(),
233+
OAuthName: svc.Name(),
222234
OAuthSub: user.Sub,
223235
}
224236

@@ -234,24 +246,39 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
234246

235247
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
236248

237-
redirectURI, err := c.Cookie(controller.config.RedirectCookieName)
238-
239-
if err != nil || !utils.IsRedirectSafe(redirectURI, controller.config.CookieDomain) {
240-
tlog.App.Debug().Msg("No redirect URI cookie found, redirecting to app root")
241-
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
249+
if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
250+
tlog.App.Debug().Msg("OIDC request, redirecting to authorize page")
251+
queries, err := query.Values(oauthPendingSession.CallbackParams)
252+
if err != nil {
253+
tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
254+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
255+
return
256+
}
257+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
242258
return
243259
}
244260

245-
queries, err := query.Values(config.RedirectQuery{
246-
RedirectURI: redirectURI,
247-
})
261+
if oauthPendingSession.CallbackParams.RedirectURI != "" {
262+
queries, err := query.Values(config.RedirectQuery{
263+
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
264+
})
248265

249-
if err != nil {
250-
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
251-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
266+
if err != nil {
267+
tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query")
268+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
269+
return
270+
}
271+
272+
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
252273
return
253274
}
254275

255-
c.SetCookie(controller.config.RedirectCookieName, "", -1, "/", fmt.Sprintf(".%s", controller.config.CookieDomain), controller.config.SecureCookie, true)
256-
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
276+
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
277+
}
278+
279+
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
280+
return params.Scope != "" &&
281+
params.ResponseType != "" &&
282+
params.ClientID != "" &&
283+
params.RedirectURI != ""
257284
}

internal/service/auth_service.go

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,26 @@ const MaxOAuthPendingSessions = 256
2828
const OAuthCleanupCount = 16
2929
const MaxLoginAttemptRecords = 256
3030

31+
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
32+
// parameters and pass them to the authorize page if needed
33+
type OAuthURLParams struct {
34+
Scope string `form:"scope" url:"scope"`
35+
ResponseType string `form:"response_type" url:"response_type"`
36+
ClientID string `form:"client_id" url:"client_id"`
37+
RedirectURI string `form:"redirect_uri" url:"redirect_uri"`
38+
State string `form:"state" url:"state"`
39+
Nonce string `form:"nonce" url:"nonce"`
40+
CodeChallenge string `form:"code_challenge" url:"code_challenge"`
41+
CodeChallengeMethod string `form:"code_challenge_method" url:"code_challenge_method"`
42+
}
43+
3144
type OAuthPendingSession struct {
32-
State string
33-
Verifier string
34-
Token *oauth2.Token
35-
Service *OAuthServiceImpl
36-
ExpiresAt time.Time
45+
State string
46+
Verifier string
47+
Token *oauth2.Token
48+
Service *OAuthServiceImpl
49+
ExpiresAt time.Time
50+
CallbackParams OAuthURLParams
3751
}
3852

3953
type LdapGroupsCache struct {
@@ -598,7 +612,7 @@ func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
598612
return false
599613
}
600614

601-
func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendingSession, error) {
615+
func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthURLParams) (string, OAuthPendingSession, error) {
602616
auth.ensureOAuthSessionLimit()
603617

604618
service, ok := auth.oauthBroker.GetService(serviceName)
@@ -617,10 +631,11 @@ func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendi
617631
verifier := service.NewRandom()
618632

619633
session := OAuthPendingSession{
620-
State: state,
621-
Verifier: verifier,
622-
Service: &service,
623-
ExpiresAt: time.Now().Add(1 * time.Hour),
634+
State: state,
635+
Verifier: verifier,
636+
Service: &service,
637+
ExpiresAt: time.Now().Add(1 * time.Hour),
638+
CallbackParams: params,
624639
}
625640

626641
auth.oauthMutex.Lock()
@@ -631,7 +646,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string) (string, OAuthPendi
631646
}
632647

633648
func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
634-
session, err := auth.getOAuthPendingSession(sessionId)
649+
session, err := auth.GetOAuthPendingSession(sessionId)
635650

636651
if err != nil {
637652
return "", err
@@ -641,7 +656,7 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) {
641656
}
642657

643658
func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) {
644-
session, err := auth.getOAuthPendingSession(sessionId)
659+
session, err := auth.GetOAuthPendingSession(sessionId)
645660

646661
if err != nil {
647662
return nil, err
@@ -661,7 +676,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
661676
}
662677

663678
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
664-
session, err := auth.getOAuthPendingSession(sessionId)
679+
session, err := auth.GetOAuthPendingSession(sessionId)
665680

666681
if err != nil {
667682
return config.Claims{}, err
@@ -681,7 +696,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, erro
681696
}
682697

683698
func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) {
684-
session, err := auth.getOAuthPendingSession(sessionId)
699+
session, err := auth.GetOAuthPendingSession(sessionId)
685700

686701
if err != nil {
687702
return nil, err
@@ -715,7 +730,7 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() {
715730
}
716731
}
717732

718-
func (auth *AuthService) getOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
733+
func (auth *AuthService) GetOAuthPendingSession(sessionId string) (*OAuthPendingSession, error) {
719734
auth.ensureOAuthSessionLimit()
720735

721736
auth.oauthMutex.RLock()

0 commit comments

Comments
 (0)