Skip to content

Commit 006c55b

Browse files
feat(RHOAIENG-26480): Run RayJobs against existing RayClusters
1 parent 848b565 commit 006c55b

16 files changed

+942
-665
lines changed

poetry.lock

Lines changed: 701 additions & 661 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ cryptography = "43.0.3"
2929
executing = "1.2.0"
3030
pydantic = "< 2"
3131
ipywidgets = "8.1.2"
32+
odh-kuberay-client = {version = "0.0.0.dev40", source = "testpypi"}
33+
34+
[[tool.poetry.source]]
35+
name = "pypi"
36+
37+
[[tool.poetry.source]]
38+
name = "testpypi"
39+
url = "https://test.pypi.org/simple/"
3240

3341
[tool.poetry.group.docs]
3442
optional = true

src/codeflare_sdk/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
AWManager,
1111
AppWrapperStatus,
1212
RayJobClient,
13+
RayJob,
1314
)
1415

1516
from .common.widgets import view_clusters

src/codeflare_sdk/ray/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
RayJobClient,
55
)
66

7+
from .rayjobs import (
8+
RayJob,
9+
)
10+
711
from .cluster import (
812
Cluster,
913
ClusterConfiguration,

src/codeflare_sdk/ray/cluster/build_ray_cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
136136
"enableIngress": False,
137137
"rayStartParams": {
138138
"dashboard-host": "0.0.0.0",
139+
"dashboard-port": "8265",
139140
"block": "true",
140141
"num-gpus": str(head_gpu_count),
141142
"resources": head_resources,
@@ -245,6 +246,7 @@ def get_labels(cluster: "codeflare_sdk.ray.cluster.Cluster"):
245246
"""
246247
labels = {
247248
"controller-tools.k8s.io": "1.0",
249+
"ray.io/cluster": cluster.config.name, # Enforced label always present
248250
}
249251
if cluster.config.labels != {}:
250252
labels.update(cluster.config.labels)

src/codeflare_sdk/ray/cluster/cluster.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@
2020

2121
from time import sleep
2222
from typing import List, Optional, Tuple, Dict
23+
import copy
2324

24-
from ray.job_submission import JobSubmissionClient
25+
from ray.job_submission import JobSubmissionClient, JobStatus
26+
import time
27+
import uuid
28+
import warnings
2529

2630
from ...common.kubernetes_cluster.auth import (
2731
config_check,
@@ -57,7 +61,6 @@
5761
from kubernetes.client.rest import ApiException
5862

5963
from kubernetes.client.rest import ApiException
60-
import warnings
6164

6265
CF_SDK_FIELD_MANAGER = "codeflare-sdk"
6366

@@ -760,6 +763,7 @@ def get_cluster(
760763
head_extended_resource_requests=head_extended_resources,
761764
worker_extended_resource_requests=worker_extended_resources,
762765
)
766+
763767
# Ignore the warning here for the lack of a ClusterConfiguration
764768
with warnings.catch_warnings():
765769
warnings.filterwarnings(

src/codeflare_sdk/ray/cluster/test_cluster.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,5 +758,11 @@ def custom_side_effect(group, version, namespace, plural, **kwargs):
758758

759759
# Make sure to always keep this function last
760760
def test_cleanup():
761-
os.remove(f"{aw_dir}test-all-params.yaml")
762-
os.remove(f"{aw_dir}aw-all-params.yaml")
761+
# Remove files only if they exist
762+
test_file = f"{aw_dir}test-all-params.yaml"
763+
if os.path.exists(test_file):
764+
os.remove(test_file)
765+
766+
aw_file = f"{aw_dir}aw-all-params.yaml"
767+
if os.path.exists(aw_file):
768+
os.remove(aw_file)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .rayjob import RayJob
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
RayJob client for submitting and managing Ray jobs using the odh-kuberay-client.
3+
"""
4+
5+
import logging
6+
from typing import Dict, Any, Optional
7+
from odh_kuberay_client.kuberay_job_api import RayjobApi
8+
9+
# Set up logging
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class RayJob:
14+
"""
15+
A client for managing Ray jobs using the KubeRay operator.
16+
17+
This class provides a simplified interface for submitting and managing
18+
Ray jobs in a Kubernetes cluster with the KubeRay operator installed.
19+
"""
20+
21+
def __init__(
22+
self,
23+
job_name: str,
24+
cluster_name: str,
25+
namespace: str = "default",
26+
entrypoint: str = "None",
27+
runtime_env: Optional[Dict[str, Any]] = None,
28+
):
29+
"""
30+
Initialize a RayJob instance.
31+
32+
Args:
33+
name: The name for the Ray job
34+
namespace: The Kubernetes namespace to submit the job to (default: "default")
35+
cluster_name: The name of the Ray cluster to submit the job to
36+
**kwargs: Additional configuration options
37+
"""
38+
self.name = job_name
39+
self.namespace = namespace
40+
self.cluster_name = cluster_name
41+
self.entrypoint = entrypoint
42+
self.runtime_env = runtime_env
43+
44+
# Initialize the KubeRay job API client
45+
self._api = RayjobApi()
46+
47+
logger.info(f"Initialized RayJob: {self.name} in namespace: {self.namespace}")
48+
49+
def submit(
50+
self,
51+
) -> str:
52+
"""
53+
Submit the Ray job to the Kubernetes cluster.
54+
55+
Args:
56+
entrypoint: The Python script or command to run
57+
runtime_env: Ray runtime environment configuration (optional)
58+
59+
Returns:
60+
The job ID/name if submission was successful
61+
62+
Raises:
63+
RuntimeError: If the job has already been submitted or submission fails
64+
"""
65+
# Build the RayJob custom resource
66+
rayjob_cr = self._build_rayjob_cr(
67+
entrypoint=self.entrypoint,
68+
runtime_env=self.runtime_env,
69+
)
70+
71+
# Submit the job
72+
logger.info(
73+
f"Submitting RayJob {self.name} to RayCluster {self.cluster_name} in namespace {self.namespace}"
74+
)
75+
result = self._api.submit_job(k8s_namespace=self.namespace, job=rayjob_cr)
76+
77+
if result:
78+
logger.info(f"Successfully submitted RayJob {self.name}")
79+
return self.name
80+
else:
81+
raise RuntimeError(f"Failed to submit RayJob {self.name}")
82+
83+
def _build_rayjob_cr(
84+
self,
85+
entrypoint: str,
86+
runtime_env: Optional[Dict[str, Any]] = None,
87+
) -> Dict[str, Any]:
88+
"""
89+
Build the RayJob custom resource specification.
90+
91+
This creates a minimal RayJob CR that can be extended later.
92+
"""
93+
# Basic RayJob custom resource structure
94+
rayjob_cr = {
95+
"apiVersion": "ray.io/v1",
96+
"kind": "RayJob",
97+
"metadata": {
98+
"name": self.name,
99+
"namespace": self.namespace,
100+
},
101+
"spec": {
102+
"entrypoint": entrypoint,
103+
"clusterSelector": {"ray.io/cluster": self.cluster_name},
104+
},
105+
}
106+
107+
# Add runtime environment if specified
108+
if runtime_env:
109+
rayjob_cr["spec"]["runtimeEnvYAML"] = str(runtime_env)
110+
111+
return rayjob_cr
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2024 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from unittest.mock import MagicMock
17+
from codeflare_sdk.ray.rayjobs.rayjob import RayJob
18+
19+
20+
def test_rayjob_submit_success(mocker):
21+
"""Test successful RayJob submission."""
22+
# Mock kubernetes config loading
23+
mocker.patch("kubernetes.config.load_kube_config")
24+
25+
# Mock the RayjobApi class entirely
26+
mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi")
27+
mock_api_instance = MagicMock()
28+
mock_api_class.return_value = mock_api_instance
29+
30+
# Configure the mock to return success when submit is called
31+
mock_api_instance.submit.return_value = {"metadata": {"name": "test-rayjob"}}
32+
33+
# Create RayJob instance
34+
rayjob = RayJob(
35+
job_name="test-rayjob",
36+
cluster_name="test-ray-cluster",
37+
namespace="test-namespace",
38+
entrypoint="python -c 'print(\"hello world\")'",
39+
runtime_env={"pip": ["requests"]},
40+
)
41+
42+
# Submit the job
43+
job_id = rayjob.submit()
44+
45+
# Assertions
46+
assert job_id == "test-rayjob"
47+
48+
# Verify the API was called with correct parameters
49+
mock_api_instance.submit_job.assert_called_once()
50+
call_args = mock_api_instance.submit_job.call_args
51+
52+
# Check the namespace parameter
53+
assert call_args.kwargs["k8s_namespace"] == "test-namespace"
54+
55+
# Check the job custom resource
56+
job_cr = call_args.kwargs["job"]
57+
assert job_cr["metadata"]["name"] == "test-rayjob"
58+
assert job_cr["metadata"]["namespace"] == "test-namespace"
59+
assert job_cr["spec"]["entrypoint"] == "python -c 'print(\"hello world\")'"
60+
assert job_cr["spec"]["clusterSelector"]["ray.io/cluster"] == "test-ray-cluster"
61+
assert job_cr["spec"]["runtimeEnvYAML"] == "{'pip': ['requests']}"
62+
63+
64+
def test_rayjob_submit_failure(mocker):
65+
"""Test RayJob submission failure."""
66+
# Mock kubernetes config loading
67+
mocker.patch("kubernetes.config.load_kube_config")
68+
69+
# Mock the RayjobApi class entirely
70+
mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi")
71+
mock_api_instance = MagicMock()
72+
mock_api_class.return_value = mock_api_instance
73+
74+
# Configure the mock to return failure (False/None) when submit_job is called
75+
mock_api_instance.submit_job.return_value = None
76+
77+
# Create a RayJob instance
78+
rayjob = RayJob(
79+
job_name="test-rayjob",
80+
cluster_name="test-ray-cluster",
81+
namespace="default",
82+
entrypoint="python script.py",
83+
runtime_env={"pip": ["numpy"]},
84+
)
85+
86+
# Test that RuntimeError is raised on failure
87+
with pytest.raises(RuntimeError, match="Failed to submit RayJob test-rayjob"):
88+
rayjob.submit()

0 commit comments

Comments
 (0)