Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ TINYAUTH_OAUTH_PROVIDERS_name_CLIENTID=
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRET=
# Path to the file containing the OAuth client secret.
TINYAUTH_OAUTH_PROVIDERS_name_CLIENTSECRETFILE=
# Comma-separated list of allowed OAuth domains for this provider.
TINYAUTH_OAUTH_PROVIDERS_name_WHITELIST=
# Path to the OAuth whitelist file for this provider.
TINYAUTH_OAUTH_PROVIDERS_name_WHITELISTFILE=
# OAuth scopes.
TINYAUTH_OAUTH_PROVIDERS_name_SCOPES=
# OAuth redirect URL.
Expand Down
7 changes: 7 additions & 0 deletions internal/bootstrap/app_bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthProviders = app.config.OAuth.Providers

for id, provider := range app.runtime.OAuthProviders {
providerWhitelist, err := utils.GetStringList(provider.Whitelist, provider.WhitelistFile)
if err != nil {
return fmt.Errorf("failed to load oauth whitelist for provider %s: %w", id, err)
}

provider.Whitelist = providerWhitelist

secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile)
provider.ClientSecret = secret
provider.ClientSecretFile = ""
Expand Down
32 changes: 16 additions & 16 deletions internal/controller/oauth_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,23 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return
}

if !controller.auth.IsEmailWhitelisted(user.Email) {
svc, err := controller.auth.GetOAuthService(sessionIdCookie)

if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}

if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}

if !controller.auth.IsEmailWhitelisted(svc.ID(), user.Email) {
controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access")
controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted")
controller.log.AuditLoginFailure(user.Email, svc.ID(), c.ClientIP(), "email not whitelisted")

queries, err := query.Values(UnauthorizedQuery{
Username: user.Email,
Expand Down Expand Up @@ -226,20 +240,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
username = strings.Replace(user.Email, "@", "_", 1)
}

svc, err := controller.auth.GetOAuthService(sessionIdCookie)

if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}

if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}

sessionCookie := repository.Session{
Username: username,
Name: name,
Expand Down
2 changes: 1 addition & 1 deletion internal/middleware/context_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string, ip stri
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
}

if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
if !m.auth.IsEmailWhitelisted(userContext.OAuth.ID, userContext.OAuth.Email) {
m.auth.DeleteSession(ctx, uuid)
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
}
Expand Down
2 changes: 2 additions & 0 deletions internal/model/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
ClientSecretFile string `description:"Path to the file containing the OAuth client secret." yaml:"clientSecretFile"`
Whitelist []string `description:"Comma-separated list of allowed OAuth domains for this provider." yaml:"whitelist"`
WhitelistFile string `description:"Path to the OAuth whitelist file for this provider." yaml:"whitelistFile"`
Scopes []string `description:"OAuth scopes." yaml:"scopes"`
RedirectURL string `description:"OAuth redirect URL." yaml:"redirectUrl"`
AuthURL string `description:"OAuth authorization URL." yaml:"authUrl"`
Expand Down
11 changes: 8 additions & 3 deletions internal/service/auth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,15 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
}
}

func (auth *AuthService) IsEmailWhitelisted(email string) bool {
match, err := utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email)
func (auth *AuthService) IsEmailWhitelisted(provider string, email string) bool {
whitelist := auth.runtime.OAuthWhitelist
if providerConfig, ok := auth.runtime.OAuthProviders[provider]; ok && len(providerConfig.Whitelist) > 0 {
whitelist = providerConfig.Whitelist
}

match, err := utils.CheckFilter(strings.Join(whitelist, ","), email)
if err != nil {
auth.log.App.Warn().Err(err).Str("email", email).Msg("Invalid email filter pattern")
auth.log.App.Warn().Err(err).Str("provider", provider).Str("email", email).Msg("Invalid email filter pattern")
return false
}
return match
Expand Down
39 changes: 39 additions & 0 deletions internal/service/auth_service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package service

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
)

func TestIsEmailWhitelistedUsesProviderSpecificList(t *testing.T) {
log := logger.NewLogger().WithTestConfig()
log.Init()

auth := &AuthService{
log: log,
runtime: model.RuntimeConfig{
OAuthWhitelist: []string{"global@example.com"},
OAuthProviders: map[string]model.OAuthServiceConfig{
"github": {
Whitelist: []string{"github@example.com"},
},
"pocketid": {
Whitelist: []string{"pocket@example.com"},
},
"gitlab": {
Whitelist: []string{},
},
},
},
}

assert.True(t, auth.IsEmailWhitelisted("github", "github@example.com"))
assert.False(t, auth.IsEmailWhitelisted("github", "pocket@example.com"))
assert.True(t, auth.IsEmailWhitelisted("pocketid", "pocket@example.com"))
assert.True(t, auth.IsEmailWhitelisted("google", "global@example.com"))
assert.True(t, auth.IsEmailWhitelisted("gitlab", "global@example.com"))
assert.False(t, auth.IsEmailWhitelisted("gitlab", "unknown@example.com"))
}
Loading