Skip to content

Commit

Permalink
Merge pull request #3 from cloudfoundry-community/main
Browse files Browse the repository at this point in the history
Performance and auth improvements
  • Loading branch information
wayneeseguin authored Jun 14, 2024
2 parents b0a7655 + 064a24f commit 810e867
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 61 deletions.
92 changes: 89 additions & 3 deletions cf/auth_service.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cf

import (
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"os"
Expand All @@ -11,66 +13,105 @@ import (
"golang.org/x/oauth2"
)

type Logger interface {
Info(tag, message string)
Error(tag, message string)
}

type AuthService struct {
client *cf.Client
logger Logger
}

func NewAuthService(client *cf.Client) *AuthService {
func NewAuthService(client *cf.Client, logger Logger) *AuthService {
return &AuthService{
client: client,
logger: logger,
}
}

func (service *AuthService) Verify(auth string) error {
tag := "AuthService.Verify"
service.logger.Info(tag, "Starting verification process")
username, err := getUsername(auth)
if err != nil {
service.logger.Error(tag, fmt.Sprintf("Error getting username: %v", err))
return err
}
service.logger.Info(tag, fmt.Sprintf("Username obtained: %s", username))

user, err := service.getUser(username)
if err != nil {
service.logger.Error(tag, fmt.Sprintf("Error getting user: %v", err))
return err
}
// Debugging = noisy
// service.logger.Info(tag, fmt.Sprintf("User obtained: %v", user))

roles, err := service.getUserRoles(user)
if err != nil {
service.logger.Error(tag, fmt.Sprintf("Error getting user roles: %v", err))
return err
}
// Debugging = noisy
// service.logger.Info(tag, fmt.Sprintf("User roles obtained: %v", roles))

tokenScopes, err := getTokenScopes(auth, service.logger)
if err != nil {
service.logger.Error(tag, fmt.Sprintf("Error getting token scopes: %v", err))
return err
}
service.logger.Info(tag, fmt.Sprintf("Token scopes obtained: %v", tokenScopes))

// Check all the roles, but return good early if we find one that works.

// Check token scopes for cloud_controller.admin
for _, scope := range tokenScopes {
if scope == "cloud_controller.admin" {
service.logger.Info(tag, "User has cloud_controller.admin scope")
return nil
}
}

// Check CF roles for space_manager or space_developer
for _, role := range roles {
// NOTE: we should definitely be checking space IDs, too, but that's tomorrow
// guy's problem.
if role.Type == "space_manager" || role.Type == "space_developer" {
service.logger.Info(tag, fmt.Sprintf("User has role: %s", role.Type))
return nil
}
}

service.logger.Error(tag, "User does not have sufficient permissions")
return fmt.Errorf("insufficient permissions")
}

func (service *AuthService) getUser(username string) (cf.User, error) {
tag := "AuthService.getUser"
query := url.Values{}
query.Add("username", username)

users, err := service.client.ListUsersByQuery(query)
if err != nil {
service.logger.Error(tag, fmt.Sprintf("Error listing users by query: %v", err))
return cf.User{}, err
}

user := users.GetUserByUsername(username)
if len(user.Guid) == 0 {
service.logger.Error(tag, "No such user found")
return cf.User{}, fmt.Errorf("no such user")
}

return user, nil
}

func (service *AuthService) getUserRoles(user cf.User) ([]cf.V3Role, error) {
tag := "AuthService.getUserRoles"
roleQuery := url.Values{}
roleQuery.Add("user_guids", user.Guid)
roles, err := service.client.ListV3RolesByQuery(roleQuery)
if err != nil {
service.logger.Error(tag, fmt.Sprintf("Error listing V3 roles by query: %v", err))
return nil, err
}

Expand Down Expand Up @@ -125,3 +166,48 @@ func getBearer(auth string) (string, error) {

return parts[bearerLoc+1], nil
}

// JWTClaims represents the claims in the JWT token
type JWTClaims struct {
Scope []string `json:"scope"`
}

// DecodeJWT decodes the JWT token and extracts the claims
func DecodeJWT(token string) (*JWTClaims, error) {
parts := strings.Split(token, ".")
if len(parts) < 3 {
return nil, fmt.Errorf("invalid token format")
}

payload := parts[1]
payloadDecoded, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
return nil, fmt.Errorf("failed to decode payload: %v", err)
}

var claims JWTClaims
err = json.Unmarshal(payloadDecoded, &claims)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal claims: %v", err)
}

return &claims, nil
}

func getTokenScopes(auth string, logger Logger) ([]string, error) {
tag := "AuthService.getTokenScopes"
bearer, err := getBearer(auth)
if err != nil {
logger.Error(tag, fmt.Sprintf("Error getting bearer token: %v", err))
return nil, err
}

claims, err := DecodeJWT(bearer)
if err != nil {
logger.Error(tag, fmt.Sprintf("Error decoding JWT token: %v", err))
return nil, err
}

logger.Info(tag, fmt.Sprintf("Scopes found in token: %v", claims.Scope))
return claims.Scope, nil
}
42 changes: 38 additions & 4 deletions cmd/scheduler/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func main() {

log.Info(tag, "got the cf client set up")

auth := cf.NewAuthService(cfclient)
auth := cf.NewAuthService(cfclient, log)
jobs := postgres.NewJobService(db)
calls := postgres.NewCallService(db)
info := cf.NewInfoService(cfclient)
Expand Down Expand Up @@ -168,6 +168,38 @@ func main() {
}
}

// Schedule cleanup tasks
retentionPeriod := 5 * 30 * 24 * time.Hour // 30 days
ticker := time.NewTicker(24 * time.Hour) // run cleanup every day
quit := make(chan struct{})
signalChan := make(chan os.Signal, 1) // New channel for OS signals

go func() {
for {
select {
case <-ticker.C:
log.Info(tag, "Starting cleanup of old executions")

deletedExecutions, err := executions.CleanupOldExecutions(retentionPeriod)
if err != nil {
log.Error(tag, fmt.Sprintf("Error cleaning up old executions: %s", err.Error()))
} else {
if len(deletedExecutions) == 0 {
log.Info(tag, "No executions were deleted.")
} else {
for _, execution := range deletedExecutions {
log.Info(tag, fmt.Sprintf("Deleted execution: %s, Last Updated: %s", execution["guid"], execution["execution_end_time"]))
}
}
}

case <-quit:
ticker.Stop()
return
}
}
}()

server := http.Server(fmt.Sprintf("0.0.0.0:%d", port), services)

go func() {
Expand All @@ -178,10 +210,12 @@ func main() {

log.Info(tag, fmt.Sprintf("listening for connections on %s", server.Addr))

quit := make(chan os.Signal)
signal.Notify(quit, os.Interrupt)
signal.Notify(signalChan, os.Interrupt) // Use signalChan for signal notification

<-signalChan // Wait for signal

<-quit
// Stop the cleanup ticker
close(quit)

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
Expand Down
2 changes: 1 addition & 1 deletion core/jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type JobService interface {
Delete(*Job) error
Named(string) (*Job, error)
Persist(*Job) (*Job, error)
InSpace(string) []*Job
InSpace(string) ([]*Job, error)
Success(*Job) (*Job, error)
Fail(*Job) (*Job, error)
}
19 changes: 16 additions & 3 deletions cron/cron_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ type CronService struct {

func NewCronService(log core.LogService) *CronService {
return &CronService{
cron.New(),
log,
make(map[string]cron.EntryID)}
Cron: cron.New(),
log: log,
mapping: make(map[string]cron.EntryID),
}
}

func (service *CronService) Add(runnable core.Runnable) error {
Expand Down Expand Up @@ -60,6 +61,7 @@ func (service *CronService) Add(runnable core.Runnable) error {
}

service.mapping[schedule.GUID] = id
service.logMappingSize("Added job to cron service")

return nil
}
Expand All @@ -71,6 +73,8 @@ func (service *CronService) Delete(runnable core.Runnable) error {
}

service.Remove(id)
delete(service.mapping, runnable.Schedule().GUID)
service.logMappingSize("Deleted job from cron service")

return nil
}
Expand All @@ -84,3 +88,12 @@ func (service *CronService) Validate(expression string) error {

return err
}

func (service *CronService) MappingSize() int {
return len(service.mapping)
}

func (service *CronService) logMappingSize(action string) {
size := service.MappingSize()
service.log.Info("cron-service", fmt.Sprintf("%s: current mapping size is %d", action, size))
}
Loading

0 comments on commit 810e867

Please sign in to comment.