Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 90 additions & 24 deletions lib/api/oidc/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import (
"net/http"
"net/url"
"slices"
"sync"
"time"

"github.com/ether/etherpad-go/assets/login"
"github.com/ether/etherpad-go/lib/api/constants"
db "github.com/ether/etherpad-go/lib/db"
"github.com/ether/etherpad-go/lib/models/oidc"
"github.com/ether/etherpad-go/lib/security"
"github.com/ether/etherpad-go/lib/settings"
"github.com/ory/fosite"
"github.com/ory/fosite/compose"
Expand All @@ -27,10 +29,14 @@ import (
)

type Authenticator struct {
// provider is rebuilt whenever the SecretRotator rotates the global HMAC
// secret, so it is guarded by mu. Read it via currentProvider().
mu sync.RWMutex
provider fosite.OAuth2Provider
store *MemoryStore
privateKey *rsa.PrivateKey
retrievedSettings *settings.Settings
rotator *security.SecretRotator
}

func NewAuthenticator(retrievedSettings *settings.Settings, persistence db.DataStore) *Authenticator {
Expand Down Expand Up @@ -79,11 +85,6 @@ func NewAuthenticator(retrievedSettings *settings.Settings, persistence db.DataS
}
}

secret := []byte("some-cool-secret-that-is-32bytes")
config := &fosite.Config{
AccessTokenLifespan: time.Minute * 30,
GlobalSecret: secret,
}
privateKey, _ := rsa.GenerateKey(rand.Reader, 2048)
privateKey, err := loadOrCreatePrivateKey(persistence, privateKey)
if err != nil {
Expand All @@ -93,14 +94,75 @@ func NewAuthenticator(retrievedSettings *settings.Settings, persistence db.DataS
log.Fatalf("Error loading oidc store snapshot: %v", err)
}

var oauth2 = compose.ComposeAllEnabled(config, store, privateKey)

return &Authenticator{
provider: oauth2,
a := &Authenticator{
store: store,
privateKey: privateKey,
retrievedSettings: retrievedSettings,
}

// The fosite GlobalSecret (used to HMAC short-lived artifacts such as
// authorize codes) was previously hard-coded. It is now a randomly
// generated, database-persisted secret that rotates on the configured
// cookie key-rotation interval. Old secrets remain valid for verification
// for the session lifetime so in-flight artifacts keep working across a
// rotation. See lib/security/secretrotator.go.
interval := time.Duration(retrievedSettings.Cookie.KeyRotationInterval) * time.Millisecond
if interval <= 0 {
interval = 24 * time.Hour
}
lifetime := time.Duration(retrievedSettings.Cookie.SessionLifetime) * time.Millisecond
if lifetime <= 0 {
lifetime = interval
}
a.rotator = security.NewSecretRotator(persistence, "oidc_global_secret", interval, lifetime, nil, nil)
a.rotator.OnRotate(a.rebuildProvider)
if err := a.rotator.Start(); err != nil {
log.Fatalf("Error starting oidc secret rotator: %v", err)
}
// Start triggers the first update which fires OnRotate -> rebuildProvider,
// but guard against an empty provider just in case.
if a.currentProvider() == nil {
a.rebuildProvider()
}
return a
}

// rebuildProvider composes a fresh OAuth2 provider using the rotator's current
// secrets. A brand-new fosite.Config is built each time (never mutated in
// place) so that requests holding an older provider keep reading a consistent
// secret. Invoked on startup and on every rotation.
func (a *Authenticator) rebuildProvider() {
secrets := a.rotator.Secrets()
var global []byte
var rotated [][]byte
if len(secrets) > 0 {
global = secrets[0]
rotated = secrets[1:]
}
cfg := &fosite.Config{
AccessTokenLifespan: time.Minute * 30,
GlobalSecret: global,
RotatedGlobalSecrets: rotated,
}
prov := compose.ComposeAllEnabled(cfg, a.store, a.privateKey)
a.mu.Lock()
a.provider = prov
a.mu.Unlock()
}

// currentProvider returns the active OAuth2 provider. Capture it once per
// request so a concurrent rotation cannot swap the provider mid-handler.
func (a *Authenticator) currentProvider() fosite.OAuth2Provider {
a.mu.RLock()
defer a.mu.RUnlock()
return a.provider
}

// Stop halts the background secret rotation. Call during server shutdown.
func (a *Authenticator) Stop() {
if a.rotator != nil {
a.rotator.Stop()
}
}

func (a *Authenticator) ValidateAdminToken(tokenString string, adminClient *settings.SSOClient) (bool, error) {
Expand Down Expand Up @@ -179,26 +241,28 @@ func (a *Authenticator) JwksEndpoint(rw http.ResponseWriter, req *http.Request)

func (a *Authenticator) IntrospectionEndpoint(rw http.ResponseWriter, req *http.Request) {
ctx := req.Context()
provider := a.currentProvider()
mySessionData := a.newSession(nil, "")
ir, err := a.provider.NewIntrospectionRequest(ctx, req, mySessionData)
ir, err := provider.NewIntrospectionRequest(ctx, req, mySessionData)
if err != nil {
log.Printf("Error occurred in NewIntrospectionRequest: %+v", err)
a.provider.WriteIntrospectionError(ctx, rw, err)
provider.WriteIntrospectionError(ctx, rw, err)
return
}
a.provider.WriteIntrospectionResponse(ctx, rw, ir)
provider.WriteIntrospectionResponse(ctx, rw, ir)
}

func (a *Authenticator) TokenEndpoint(rw http.ResponseWriter, req *http.Request) {
ctx := req.Context()
provider := a.currentProvider()
clientId := req.Form.Get("client_id")

mySessionData := a.newSession(nil, clientId)

accessRequest, err := a.provider.NewAccessRequest(ctx, req, mySessionData)
accessRequest, err := provider.NewAccessRequest(ctx, req, mySessionData)
if err != nil {
log.Printf("Error occurred in NewAccessRequest: %+v", err)
a.provider.WriteAccessError(ctx, rw, accessRequest, err)
provider.WriteAccessError(ctx, rw, accessRequest, err)
return
}

Expand All @@ -208,25 +272,26 @@ func (a *Authenticator) TokenEndpoint(rw http.ResponseWriter, req *http.Request)
}
}

response, err := a.provider.NewAccessResponse(ctx, accessRequest)
response, err := provider.NewAccessResponse(ctx, accessRequest)
if err != nil {
log.Printf("Error occurred in NewAccessResponse: %+v", err)
a.provider.WriteAccessError(ctx, rw, accessRequest, err)
provider.WriteAccessError(ctx, rw, accessRequest, err)
return
}

a.provider.WriteAccessResponse(ctx, rw, accessRequest, response)
provider.WriteAccessResponse(ctx, rw, accessRequest, response)
}

func (a *Authenticator) RevokeEndpoint(rw http.ResponseWriter, req *http.Request) {
// This context will be passed to all methods.
ctx := req.Context()
provider := a.currentProvider()

// This will accept the token revocation request and validate various parameters.
err := a.provider.NewRevocationRequest(ctx, req)
err := provider.NewRevocationRequest(ctx, req)

// All done, send the response.
a.provider.WriteRevocationResponse(ctx, rw, err)
provider.WriteRevocationResponse(ctx, rw, err)
}

func (a *Authenticator) OicWellKnown(rw http.ResponseWriter, req *http.Request, retrievedSettings *settings.Settings) {
Expand Down Expand Up @@ -320,12 +385,13 @@ func renderLoginPage(rw http.ResponseWriter, req *http.Request, clients []settin

func (a *Authenticator) AuthEndpoint(rw http.ResponseWriter, req *http.Request, setupLogger *zap.SugaredLogger, retrievedSettings *settings.Settings) {
ctx := req.Context()
provider := a.currentProvider()
req.ParseForm()

ar, err := a.provider.NewAuthorizeRequest(ctx, req)
ar, err := provider.NewAuthorizeRequest(ctx, req)
if err != nil {
setupLogger.Error("Error occurred in NewAuthorizeRequest: ", err)
a.provider.WriteAuthorizeError(ctx, rw, ar, err)
provider.WriteAuthorizeError(ctx, rw, ar, err)
return
}
hydrateAuthorizeRequestForm(req, ar)
Expand All @@ -352,13 +418,13 @@ func (a *Authenticator) AuthEndpoint(rw http.ResponseWriter, req *http.Request,
}

mySessionData := a.newSession(&user, clientId)
response, err := a.provider.NewAuthorizeResponse(ctx, ar, mySessionData)
response, err := provider.NewAuthorizeResponse(ctx, ar, mySessionData)
if err != nil {
log.Printf("Error occurred in NewAuthorizeResponse: %+v", err)
a.provider.WriteAuthorizeError(ctx, rw, ar, err)
provider.WriteAuthorizeError(ctx, rw, ar, err)
return
}
a.provider.WriteAuthorizeResponse(ctx, rw, ar, response)
provider.WriteAuthorizeResponse(ctx, rw, ar, response)
}

func (a *Authenticator) newSession(user *MemoryUserRelation, clientId string) *openid.DefaultSession {
Expand Down
14 changes: 14 additions & 0 deletions lib/db/DataStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,19 @@ type OIDCMethods interface {
DeleteOIDCSession(signature string) error
}

// SecretMethods back the SecretRotator. Each published parameter set is
// addressed by a random id within a rotator namespace (prefix), allowing
// several rotators (and several Etherpad instances) to coexist in one table.
type SecretMethods interface {
// SaveSecretParams upserts one published parameter set.
SaveSecretParams(id string, prefix string, payload string) error
// ListSecretParams returns all parameter sets for the given prefix as a
// map of id -> payload.
ListSecretParams(prefix string) (map[string]string, error)
// DeleteSecretParams removes a single parameter set by id.
DeleteSecretParams(id string) error
}

type DataStore interface {
PadMethods
AuthorMethods
Expand All @@ -128,6 +141,7 @@ type DataStore interface {
ChatMethods
ServerMethods
OIDCMethods
SecretMethods
Close() error
Ping() error
}
29 changes: 29 additions & 0 deletions lib/db/MemoryDataStore.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type MemoryDataStore struct {
groupStore map[string]string
serverVersion *db.ServerVersion
oidcStorage map[string]string
secretParams map[string]memorySecretRow

// oidc
accessTokens map[string]fosite.Requester
Expand Down Expand Up @@ -671,6 +672,33 @@ func (m *MemoryDataStore) DeleteOIDCStorageValue(key string) error {
return nil
}

// ============== SECRET ROTATION ==============

type memorySecretRow struct {
prefix string
payload string
}

func (m *MemoryDataStore) SaveSecretParams(id string, prefix string, payload string) error {
m.secretParams[id] = memorySecretRow{prefix: prefix, payload: payload}
return nil
}

func (m *MemoryDataStore) ListSecretParams(prefix string) (map[string]string, error) {
result := make(map[string]string)
for id, row := range m.secretParams {
if row.prefix == prefix {
result[id] = row.payload
}
}
return result, nil
}

func (m *MemoryDataStore) DeleteSecretParams(id string) error {
delete(m.secretParams, id)
return nil
}

// ============== OAUTH TOKEN TABLE METHODS ==============

// Access tokens
Expand Down Expand Up @@ -892,6 +920,7 @@ func NewMemoryDataStore() *MemoryDataStore {
sessionStore: make(map[string]session2.Session),
groupStore: make(map[string]string),
oidcStorage: make(map[string]string),
secretParams: make(map[string]memorySecretRow),
accessTokens: make(map[string]fosite.Requester),
accessTokenRequestIDs: make(map[string]string),
refreshTokens: make(map[string]db.StoreRefreshToken),
Expand Down
53 changes: 53 additions & 0 deletions lib/db/MySQLDB.go
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,59 @@ func (d MysqlDB) DeleteOIDCStorageValue(key string) error {
return err
}

// ============== SECRET ROTATION TABLE METHODS ==============

func (d MysqlDB) SaveSecretParams(id string, prefix string, payload string) error {
resultedSQL, args, err := mysql.
Insert("secret_rotation").
Columns("id", "prefix", "payload").
Values(id, prefix, payload).
Suffix("ON DUPLICATE KEY UPDATE prefix = VALUES(prefix), payload = VALUES(payload)").
ToSql()
if err != nil {
return err
}
_, err = d.sqlDB.Exec(resultedSQL, args...)
return err
}

func (d MysqlDB) ListSecretParams(prefix string) (map[string]string, error) {
resultedSQL, args, err := mysql.
Select("id", "payload").
From("secret_rotation").
Where(sq.Eq{"prefix": prefix}).
ToSql()
if err != nil {
return nil, err
}
rows, err := d.sqlDB.Query(resultedSQL, args...)
if err != nil {
return nil, err
}
defer rows.Close()
result := make(map[string]string)
for rows.Next() {
var id, payload string
if err := rows.Scan(&id, &payload); err != nil {
return nil, err
}
result[id] = payload
}
return result, rows.Err()
}

func (d MysqlDB) DeleteSecretParams(id string) error {
resultedSQL, args, err := mysql.
Delete("secret_rotation").
Where(sq.Eq{"id": id}).
ToSql()
if err != nil {
return err
}
_, err = d.sqlDB.Exec(resultedSQL, args...)
return err
}

// ============== OAUTH TOKEN TABLE METHODS ==============

// Access tokens
Expand Down
Loading
Loading