Skip to content

Commit 4069033

Browse files
committed
create simple worker pool and add AsyncGetJobInfo
Signed-off-by: You-Cheng Lin (Owen) <[email protected]>
1 parent 97425fa commit 4069033

File tree

11 files changed

+122
-21
lines changed

11 files changed

+122
-21
lines changed

apiserver/pkg/server/ray_job_submission_service_server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type RayJobSubmissionServiceServer struct {
4141
// Create RayJobSubmissionServiceServer
4242
func NewRayJobSubmissionServiceServer(clusterServer *ClusterServer, options *RayJobSubmissionServiceServerOptions) *RayJobSubmissionServiceServer {
4343
zl := zerolog.New(os.Stdout).Level(zerolog.DebugLevel)
44-
return &RayJobSubmissionServiceServer{clusterServer: clusterServer, options: options, log: zerologr.New(&zl).WithName("jobsubmissionservice"), dashboardClientFunc: utils.GetRayDashboardClientFunc(nil, false)}
44+
return &RayJobSubmissionServiceServer{clusterServer: clusterServer, options: options, log: zerologr.New(&zl).WithName("jobsubmissionservice"), dashboardClientFunc: utils.GetRayDashboardClientFunc(nil, false, nil, nil)}
4545
}
4646

4747
// Submit Ray job

ray-operator/apis/config/v1alpha1/configuration_types.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package v1alpha1
22

33
import (
4+
"sync"
5+
46
corev1 "k8s.io/api/core/v1"
57
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
68
"sigs.k8s.io/controller-runtime/pkg/manager"
@@ -85,8 +87,8 @@ type Configuration struct {
8587
EnableMetrics bool `json:"enableMetrics,omitempty"`
8688
}
8789

88-
func (config Configuration) GetDashboardClient(mgr manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
89-
return utils.GetRayDashboardClientFunc(mgr, config.UseKubernetesProxy)
90+
func (config Configuration) GetDashboardClient(mgr manager.Manager, taskQueue chan func(), jobInfoMap *sync.Map) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
91+
return utils.GetRayDashboardClientFunc(mgr, config.UseKubernetesProxy, taskQueue, jobInfoMap)
9092
}
9193

9294
func (config Configuration) GetHttpProxyClient(mgr manager.Manager) func(hostIp, podNamespace, podName string, port int) utils.RayHttpProxyClientInterface {

ray-operator/controllers/ray/rayjob_controller.go

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"os"
77
"strconv"
88
"strings"
9+
"sync"
910
"time"
1011

1112
"github.com/go-logr/logr"
@@ -29,6 +30,7 @@ import (
2930
"github.com/ray-project/kuberay/ray-operator/controllers/ray/metrics"
3031
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
3132
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils/dashboardclient"
33+
utiltypes "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils/types"
3234
"github.com/ray-project/kuberay/ray-operator/pkg/features"
3335
)
3436

@@ -40,11 +42,12 @@ const (
4042
// RayJobReconciler reconciles a RayJob object
4143
type RayJobReconciler struct {
4244
client.Client
43-
Scheme *runtime.Scheme
44-
Recorder record.EventRecorder
45-
45+
Scheme *runtime.Scheme
46+
Recorder record.EventRecorder
47+
JobInfoMap *sync.Map
4648
dashboardClientFunc func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error)
4749
options RayJobReconcilerOptions
50+
workerPool *dashboardclient.WorkerPool
4851
}
4952

5053
type RayJobReconcilerOptions struct {
@@ -53,13 +56,18 @@ type RayJobReconcilerOptions struct {
5356

5457
// NewRayJobReconciler returns a new reconcile.Reconciler
5558
func NewRayJobReconciler(_ context.Context, mgr manager.Manager, options RayJobReconcilerOptions, provider utils.ClientProvider) *RayJobReconciler {
56-
dashboardClientFunc := provider.GetDashboardClient(mgr)
59+
taskQueue := make(chan func(), 1000)
60+
JobInfoMap := &sync.Map{}
61+
workerPool := dashboardclient.NewWorkerPool(taskQueue)
62+
dashboardClientFunc := provider.GetDashboardClient(mgr, taskQueue, JobInfoMap)
5763
return &RayJobReconciler{
5864
Client: mgr.GetClient(),
5965
Scheme: mgr.GetScheme(),
6066
Recorder: mgr.GetEventRecorderFor("rayjob-controller"),
67+
JobInfoMap: JobInfoMap,
6168
dashboardClientFunc: dashboardClientFunc,
6269
options: options,
70+
workerPool: workerPool,
6371
}
6472
}
6573

@@ -263,9 +271,12 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request)
263271
if err != nil {
264272
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err
265273
}
266-
267-
jobInfo, err := rayDashboardClient.GetJobInfo(ctx, rayJobInstance.Status.JobId)
268-
if err != nil {
274+
var jobInfo *utiltypes.RayJobInfo
275+
if loadedJobInfo, ok := r.JobInfoMap.Load(rayJobInstance.Status.JobId); ok {
276+
logger.Info("Found jobInfo in map", "JobId", rayJobInstance.Status.JobId, "jobInfo", loadedJobInfo)
277+
jobInfo = loadedJobInfo.(*utiltypes.RayJobInfo)
278+
logger.Info("Casted jobInfo", "JobId", rayJobInstance.Status.JobId, "jobInfo", jobInfo)
279+
} else {
269280
// If the Ray job was not found, GetJobInfo returns a BadRequest error.
270281
if rayJobInstance.Spec.SubmissionMode == rayv1.HTTPMode && errors.IsBadRequest(err) {
271282
logger.Info("The Ray job was not found. Submit a Ray job via an HTTP request.", "JobId", rayJobInstance.Status.JobId)
@@ -275,10 +286,16 @@ func (r *RayJobReconciler) Reconcile(ctx context.Context, request ctrl.Request)
275286
}
276287
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, nil
277288
}
278-
logger.Error(err, "Failed to get job info", "JobId", rayJobInstance.Status.JobId)
289+
logger.Info("Job info not found in map", "JobId", rayJobInstance.Status.JobId)
290+
rayDashboardClient.AsyncGetJobInfo(ctx, rayJobInstance.Status.JobId)
279291
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err
280292
}
281293

294+
rayDashboardClient.AsyncGetJobInfo(ctx, rayJobInstance.Status.JobId)
295+
if jobInfo == nil {
296+
logger.Error(err, "Failed to get job info", "JobId", rayJobInstance.Status.JobId)
297+
return ctrl.Result{RequeueAfter: RayJobDefaultRequeueDuration}, err
298+
}
282299
// If the JobStatus is in a terminal status, such as SUCCEEDED, FAILED, or STOPPED, it is impossible for the Ray job
283300
// to transition to any other. Additionally, RayJob does not currently support retries. Hence, we can mark the RayJob
284301
// as "Complete" or "Failed" to avoid unnecessary reconciliation.

ray-operator/controllers/ray/rayservice_controller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ type RayServiceReconciler struct {
6060

6161
// NewRayServiceReconciler returns a new reconcile.Reconciler
6262
func NewRayServiceReconciler(_ context.Context, mgr manager.Manager, provider utils.ClientProvider) *RayServiceReconciler {
63-
dashboardClientFunc := provider.GetDashboardClient(mgr)
63+
dashboardClientFunc := provider.GetDashboardClient(mgr, nil, nil)
6464
httpProxyClientFunc := provider.GetHttpProxyClient(mgr)
6565
return &RayServiceReconciler{
6666
Client: mgr.GetClient(),

ray-operator/controllers/ray/suite_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package ray
1818
import (
1919
"os"
2020
"path/filepath"
21+
"sync"
2122
"testing"
2223

2324
. "github.com/onsi/ginkgo/v2"
@@ -52,7 +53,7 @@ var (
5253

5354
type TestClientProvider struct{}
5455

55-
func (testProvider TestClientProvider) GetDashboardClient(_ manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
56+
func (testProvider TestClientProvider) GetDashboardClient(_ manager.Manager, _ chan func(), _ *sync.Map) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
5657
return func(_ *rayv1.RayCluster, _ string) (dashboardclient.RayDashboardClientInterface, error) {
5758
return fakeRayDashboardClient, nil
5859
}

ray-operator/controllers/ray/utils/dashboardclient/dashboard_httpclient.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/http"
99
"strings"
10+
"sync"
1011

1112
"k8s.io/apimachinery/pkg/api/errors"
1213
"k8s.io/apimachinery/pkg/util/json"
@@ -25,12 +26,13 @@ var (
2526
)
2627

2728
type RayDashboardClientInterface interface {
28-
InitClient(client *http.Client, dashboardURL string)
29+
InitClient(client *http.Client, dashboardURL string, taskQueue chan func(), jobInfoMap *sync.Map)
2930
UpdateDeployments(ctx context.Context, configJson []byte) error
3031
// V2/multi-app Rest API
3132
GetServeDetails(ctx context.Context) (*utiltypes.ServeDetails, error)
3233
GetMultiApplicationStatus(context.Context) (map[string]*utiltypes.ServeApplicationStatus, error)
3334
GetJobInfo(ctx context.Context, jobId string) (*utiltypes.RayJobInfo, error)
35+
AsyncGetJobInfo(ctx context.Context, jobId string)
3436
ListJobs(ctx context.Context) (*[]utiltypes.RayJobInfo, error)
3537
SubmitJob(ctx context.Context, rayJob *rayv1.RayJob) (string, error)
3638
SubmitJobReq(ctx context.Context, request *utiltypes.RayJobRequest) (string, error)
@@ -41,12 +43,16 @@ type RayDashboardClientInterface interface {
4143

4244
type RayDashboardClient struct {
4345
client *http.Client
46+
taskQueue chan func()
47+
jobInfoMap *sync.Map
4448
dashboardURL string
4549
}
4650

47-
func (r *RayDashboardClient) InitClient(client *http.Client, dashboardURL string) {
51+
func (r *RayDashboardClient) InitClient(client *http.Client, dashboardURL string, taskQueue chan func(), jobInfoMap *sync.Map) {
4852
r.client = client
4953
r.dashboardURL = dashboardURL
54+
r.taskQueue = taskQueue
55+
r.jobInfoMap = jobInfoMap
5056
}
5157

5258
// UpdateDeployments update the deployments in the Ray cluster.
@@ -161,6 +167,19 @@ func (r *RayDashboardClient) GetJobInfo(ctx context.Context, jobId string) (*uti
161167
return &jobInfo, nil
162168
}
163169

170+
func (r *RayDashboardClient) AsyncGetJobInfo(ctx context.Context, jobId string) {
171+
r.taskQueue <- func() {
172+
jobInfo, err := r.GetJobInfo(ctx, jobId)
173+
if err != nil {
174+
fmt.Printf("AsyncGetJobInfo: error: %v\n", err)
175+
}
176+
fmt.Printf("AsyncGetJobInfo: jobInfo: %v\n", jobInfo)
177+
if jobInfo != nil {
178+
r.jobInfoMap.Store(jobId, jobInfo)
179+
}
180+
}
181+
}
182+
164183
func (r *RayDashboardClient) ListJobs(ctx context.Context) (*[]utiltypes.RayJobInfo, error) {
165184
req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.dashboardURL+JobPath, nil)
166185
if err != nil {
@@ -211,6 +230,7 @@ func (r *RayDashboardClient) SubmitJobReq(ctx context.Context, request *utiltype
211230
}
212231

213232
req.Header.Set("Content-Type", "application/json")
233+
214234
resp, err := r.client.Do(req)
215235
if err != nil {
216236
return
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package dashboardclient
2+
3+
import (
4+
"sync"
5+
)
6+
7+
type WorkerPool struct {
8+
taskQueue chan func()
9+
stop chan struct{}
10+
wg sync.WaitGroup
11+
workers int
12+
}
13+
14+
func NewWorkerPool(taskQueue chan func()) *WorkerPool {
15+
wp := &WorkerPool{
16+
taskQueue: taskQueue,
17+
workers: 10,
18+
stop: make(chan struct{}),
19+
}
20+
21+
// Start workers immediately
22+
wp.Start()
23+
return wp
24+
}
25+
26+
// Start launches worker goroutines to consume from queue
27+
func (wp *WorkerPool) Start() {
28+
for i := 0; i < wp.workers; i++ {
29+
wp.wg.Add(1)
30+
go wp.worker()
31+
}
32+
}
33+
34+
// worker consumes and executes tasks from the queue
35+
func (wp *WorkerPool) worker() {
36+
defer wp.wg.Done()
37+
38+
for {
39+
select {
40+
case <-wp.stop:
41+
return
42+
case task := <-wp.taskQueue:
43+
if task != nil {
44+
task() // Execute the job
45+
}
46+
}
47+
}
48+
}
49+
50+
// Stop shuts down all workers
51+
func (wp *WorkerPool) Stop() {
52+
close(wp.stop)
53+
wp.wg.Wait()
54+
}

ray-operator/controllers/ray/utils/fake_serve_httpclient.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7+
"sync"
78
"sync/atomic"
89

910
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
@@ -19,7 +20,7 @@ type FakeRayDashboardClient struct {
1920

2021
var _ dashboardclient.RayDashboardClientInterface = (*FakeRayDashboardClient)(nil)
2122

22-
func (r *FakeRayDashboardClient) InitClient(_ *http.Client, _ string) {
23+
func (r *FakeRayDashboardClient) InitClient(_ *http.Client, _ string, _ chan func(), _ *sync.Map) {
2324
}
2425

2526
func (r *FakeRayDashboardClient) UpdateDeployments(_ context.Context, _ []byte) error {
@@ -46,6 +47,9 @@ func (r *FakeRayDashboardClient) GetJobInfo(ctx context.Context, jobId string) (
4647
return &utiltypes.RayJobInfo{JobStatus: rayv1.JobStatusRunning}, nil
4748
}
4849

50+
func (r *FakeRayDashboardClient) AsyncGetJobInfo(_ context.Context, _ string) {
51+
}
52+
4953
func (r *FakeRayDashboardClient) ListJobs(ctx context.Context) (*[]utiltypes.RayJobInfo, error) {
5054
if mock := r.GetJobInfoMock.Load(); mock != nil {
5155
info, err := (*mock)(ctx, "job_id")

ray-operator/controllers/ray/utils/util.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"reflect"
1212
"strconv"
1313
"strings"
14+
"sync"
1415
"time"
1516
"unicode"
1617

@@ -641,7 +642,7 @@ func EnvVarByName(envName string, envVars []corev1.EnvVar) (corev1.EnvVar, bool)
641642
}
642643

643644
type ClientProvider interface {
644-
GetDashboardClient(mgr manager.Manager) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error)
645+
GetDashboardClient(mgr manager.Manager, taskQueue chan func(), jobInfoMap *sync.Map) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error)
645646
GetHttpProxyClient(mgr manager.Manager) func(hostIp, podNamespace, podName string, port int) RayHttpProxyClientInterface
646647
}
647648

@@ -758,7 +759,7 @@ func FetchHeadServiceURL(ctx context.Context, cli client.Client, rayCluster *ray
758759
return headServiceURL, nil
759760
}
760761

761-
func GetRayDashboardClientFunc(mgr manager.Manager, useKubernetesProxy bool) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
762+
func GetRayDashboardClientFunc(mgr manager.Manager, useKubernetesProxy bool, taskQueue chan func(), jobInfoMap *sync.Map) func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
762763
return func(rayCluster *rayv1.RayCluster, url string) (dashboardclient.RayDashboardClientInterface, error) {
763764
dashboardClient := &dashboardclient.RayDashboardClient{}
764765
if useKubernetesProxy {
@@ -777,13 +778,15 @@ func GetRayDashboardClientFunc(mgr manager.Manager, useKubernetesProxy bool) fun
777778
// configured to communicate with the Kubernetes API server.
778779
mgr.GetHTTPClient(),
779780
fmt.Sprintf("%s/api/v1/namespaces/%s/services/%s:dashboard/proxy", mgr.GetConfig().Host, rayCluster.Namespace, headSvcName),
781+
taskQueue,
782+
jobInfoMap,
780783
)
781784
return dashboardClient, nil
782785
}
783786

784787
dashboardClient.InitClient(&http.Client{
785788
Timeout: 2 * time.Second,
786-
}, "http://"+url)
789+
}, "http://"+url, taskQueue, jobInfoMap)
787790
return dashboardClient, nil
788791
}
789792
}

ray-operator/rayjob-submitter/cmd/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func main() {
6464
}
6565
rayDashboardClient := &dashboardclient.RayDashboardClient{}
6666
address = rayjobsubmitter.JobSubmissionURL(address)
67-
rayDashboardClient.InitClient(&http.Client{Timeout: time.Second * 10}, address)
67+
rayDashboardClient.InitClient(&http.Client{Timeout: time.Second * 10}, address, nil, nil)
6868
submissionId, err := rayDashboardClient.SubmitJobReq(context.Background(), &req)
6969
if err != nil {
7070
if strings.Contains(err.Error(), "Please use a different submission_id") {

0 commit comments

Comments
 (0)