Skip to content

Commit be5b81a

Browse files
committed
no-jira: add remote offline batch inference with vllm example
1 parent 973d320 commit be5b81a

File tree

3 files changed

+293
-0
lines changed

3 files changed

+293
-0
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Remote Offline Batch Inference with Ray Data & vLLM Example\n",
8+
"\n",
9+
"This notebook presumes:\n",
10+
"- You are working on Openshift AI\n",
11+
"- You have a Ray Cluster URL given to you to run workloads on\n"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 4,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"from codeflare_sdk import RayJobClient\n",
21+
"\n",
22+
"# Setup Authentication Configuration\n",
23+
"auth_token = \"XXXX\"\n",
24+
"header = {\"Authorization\": f\"Bearer {auth_token}\"}"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": 6,
30+
"metadata": {},
31+
"outputs": [],
32+
"source": [
33+
"# Gather the dashboard URL (provided by the creator of the RayCluster)\n",
34+
"ray_dashboard = \"XXXX\" # Replace with the Ray dashboard URL\n",
35+
"\n",
36+
"# Initialize the RayJobClient\n",
37+
"client = RayJobClient(address=ray_dashboard, headers=header, verify=True)"
38+
]
39+
},
40+
{
41+
"cell_type": "markdown",
42+
"metadata": {},
43+
"source": [
44+
"### Simple Example Explanation\n",
45+
"\n",
46+
"With the RayJobClient instantiated, lets run some batch inference. The following code is stored in `simple_batch_inf.py`, and is used as the entrypoint for the RayJob.\n",
47+
"\n",
48+
"What this processor configuration does:\n",
49+
"- Set up a vLLM engine with your model\n",
50+
"- Configure some settings for GPU processing\n",
51+
"- Defines batch processing parameters (8 requests per batch, 2 GPU workers)\n",
52+
"\n",
53+
"```python\n",
54+
"import ray\n",
55+
"from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig\n",
56+
"\n",
57+
"processor_config = vLLMEngineProcessorConfig(\n",
58+
" model_source=\"replace-me\",\n",
59+
" engine_kwargs=dict(\n",
60+
" enable_lora=False,\n",
61+
" dtype=\"half\",\n",
62+
" max_model_len=1024,\n",
63+
" ),\n",
64+
" batch_size=8,\n",
65+
" concurrency=2,\n",
66+
")\n",
67+
"```"
68+
]
69+
},
70+
{
71+
"cell_type": "markdown",
72+
"metadata": {},
73+
"source": [
74+
"With the config defined, we can instantiate the processor. This enables batch inference by processing multiple requests through the vLLM engine, with two key steps:\n",
75+
"- **Preprocess**: Converts each row into a structured chat format with system instructions and user queries, preparing the input for the LLM\n",
76+
"- **Postprocess**: Extracts only the generated text from the model response, cleaning up the output\n",
77+
"\n",
78+
"The processor defines the pipeline that will be applied to each row in the dataset, enabling efficient batch processing through Ray Data's distributed execution framework.\n",
79+
"\n",
80+
"```python\n",
81+
"processor = build_llm_processor(\n",
82+
" processor_config,\n",
83+
" preprocess=lambda row: dict(\n",
84+
" messages=[\n",
85+
" {\n",
86+
" \"role\": \"system\",\n",
87+
" \"content\": \"You are a calculator. Please only output the answer \"\n",
88+
" \"of the given equation.\",\n",
89+
" },\n",
90+
" {\"role\": \"user\", \"content\": f\"{row['id']} ** 3 = ?\"},\n",
91+
" ],\n",
92+
" sampling_params=dict(\n",
93+
" temperature=0.3,\n",
94+
" max_tokens=20,\n",
95+
" detokenize=False,\n",
96+
" ),\n",
97+
" ),\n",
98+
" postprocess=lambda row: {\n",
99+
" \"resp\": row[\"generated_text\"],\n",
100+
" },\n",
101+
")\n",
102+
"```"
103+
]
104+
},
105+
{
106+
"cell_type": "markdown",
107+
"metadata": {},
108+
"source": [
109+
"Now we can run the batch inference pipeline on our data, it will:\n",
110+
"- In the background, the processor will download the model into memory where vLLM serves it locally (on Ray Cluster) for use in inference\n",
111+
"- Generate a sample Ray Dataset with 32 rows (0-31) to process\n",
112+
"- Run the LLM processor on the dataset, triggering the preprocessing, inference, and postprocessing steps\n",
113+
"- Execute the lazy pipeline and loads results into memory\n",
114+
"- Iterate through all outputs and print each response \n",
115+
"\n",
116+
"```python\n",
117+
"ds = ray.data.range(30)\n",
118+
"ds = processor(ds)\n",
119+
"ds = ds.materialize()\n",
120+
"\n",
121+
"for out in ds.take_all():\n",
122+
" print(out)\n",
123+
" print(\"==========\")\n",
124+
"```\n",
125+
"\n",
126+
"### Job Submission\n",
127+
"\n",
128+
"Now we can submit this job against the Ray Cluster using the `RayJobClient` from earlier "
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": null,
134+
"metadata": {},
135+
"outputs": [
136+
{
137+
"name": "stderr",
138+
"output_type": "stream",
139+
"text": [
140+
"2025-06-23 16:56:53,008\tINFO dashboard_sdk.py:338 -- Uploading package gcs://_ray_pkg_d3badb03645503e8.zip.\n",
141+
"2025-06-23 16:56:53,010\tINFO packaging.py:576 -- Creating a file package for local module './'.\n"
142+
]
143+
},
144+
{
145+
"name": "stdout",
146+
"output_type": "stream",
147+
"text": [
148+
"raysubmit_AJhmqzWsvHu6SqZD successfully submitted\n"
149+
]
150+
}
151+
],
152+
"source": [
153+
"entrypoint_command = \"python simple_batch_inf.py\"\n",
154+
"\n",
155+
"submission_id = client.submit_job(\n",
156+
" entrypoint=entrypoint_command,\n",
157+
" runtime_env={\"working_dir\": \"./\", \"pip\": \"requirements.txt\"},\n",
158+
")\n",
159+
"\n",
160+
"print(submission_id + \" successfully submitted\")"
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": 12,
166+
"metadata": {},
167+
"outputs": [
168+
{
169+
"data": {
170+
"text/plain": [
171+
"<JobStatus.PENDING: 'PENDING'>"
172+
]
173+
},
174+
"execution_count": 12,
175+
"metadata": {},
176+
"output_type": "execute_result"
177+
}
178+
],
179+
"source": [
180+
"# Get the job's status\n",
181+
"client.get_job_status(submission_id)"
182+
]
183+
},
184+
{
185+
"cell_type": "code",
186+
"execution_count": null,
187+
"metadata": {},
188+
"outputs": [
189+
{
190+
"data": {
191+
"text/plain": [
192+
"'2025-06-23 15:47:22,272\\tINFO job_manager.py:531 -- Runtime env is setting up.\\nINFO 06-23 15:53:36 [__init__.py:244] Automatically detected platform cuda.\\n2025-06-23 15:53:54,307\\tINFO worker.py:1554 -- Using address 10.128.2.45:6379 set in the environment variable RAY_ADDRESS\\n2025-06-23 15:53:54,308\\tINFO worker.py:1694 -- Connecting to existing Ray cluster at address: 10.128.2.45:6379...\\n2025-06-23 15:53:54,406\\tINFO worker.py:1879 -- Connected to Ray cluster. View the dashboard at \\x1b[1m\\x1b[32mhttp://10.128.2.45:8265 \\x1b[39m\\x1b[22m\\nNo cloud storage mirror configured\\n2025-06-23 15:53:57,501\\tWARNING util.py:589 -- The argument ``compute`` is deprecated in Ray 2.9. Please specify argument ``concurrency`` instead. For more information, see https://docs.ray.io/en/master/data/transforming-data.html#stateful-transforms.\\n2025-06-23 15:53:58,095\\tINFO logging.py:290 -- Registered dataset logger for dataset dataset_33_0\\n2025-06-23 15:53:59,702\\tINFO streaming_executor.py:117 -- Starting execution of Dataset dataset_33_0. Full logs are in /tmp/ray/session_2025-06-23_10-53-41_019757_1/logs/ray-data\\n2025-06-23 15:53:59,702\\tINFO streaming_executor.py:118 -- Execution plan of Dataset dataset_33_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadRange->Map(_preprocess)] -> ActorPoolMapOperator[MapBatches(ChatTemplateUDF)] -> ActorPoolMapOperator[MapBatches(TokenizeUDF)] -> ActorPoolMapOperator[MapBatches(vLLMEngineStageUDF)] -> ActorPoolMapOperator[MapBatches(DetokenizeUDF)] -> TaskPoolMapOperator[Map(_postprocess)]\\n\\nRunning 0: 0.00 row [00:00, ? row/s]\\n \\n\\x1b[33m(raylet)\\x1b[0m [2025-06-23 15:54:00,800 E 829 829] (raylet) node_manager.cc:3287: 2 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: b72a45799ac9496bf52347fb9f9ef218722683d7bd8dd14702e821f0, IP: 10.128.2.45) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 10.128.2.45`\\n\\nRunning 0: 0.00 row [00:01, ? row/s]\\n \\n\\x1b[33m(raylet)\\x1b[0m \\n\\nRunning 0: 0.00 row [00:01, ? row/s]\\n \\n\\x1b[33m(raylet)\\x1b[0m Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero.\\n\\nRunning 0: 0.00 row [00:01, ? row/s]\\n \\n\\x1b[33m(raylet)\\x1b[0m \\n\\nRunning 0: 0.00 row [01:01, ? row/s]\\n \\n\\x1b[33m(raylet)\\x1b[0m [2025-06-23 15:55:00,824 E 829 829] (raylet) node_manager.cc:3287: 1 Workers (tasks / actors) killed due to memory pressure (OOM), 0 Workers crashed due to other reasons at node (ID: b72a45799ac9496bf52347fb9f9ef218722683d7bd8dd14702e821f0, IP: 10.128.2.45) over the last time period. To see more information about the Workers killed on this node, use `ray logs raylet.out -ip 10.128.2.45`\\n\\nRunning 0: 0.00 row [01:01, ? row/s]\\n \\n\\x1b[33m(raylet)\\x1b[0m Refer to the documentation on how to address the out of memory issue: https://docs.ray.io/en/latest/ray-core/scheduling/ray-oom-prevention.html. Consider provisioning more memory on this node or reducing task parallelism by requesting more CPUs per task. To adjust the kill threshold, set the environment variable `RAY_memory_usage_threshold` when starting Ray. To disable worker killing, set the environment variable `RAY_memory_monitor_refresh_ms` to zero.\\n\\nRunning 0: 0.00 row [01:01, ? row/s]'"
193+
]
194+
},
195+
"execution_count": 15,
196+
"metadata": {},
197+
"output_type": "execute_result"
198+
}
199+
],
200+
"source": [
201+
"# Get the job's logs\n",
202+
"client.get_job_logs(submission_id)"
203+
]
204+
}
205+
],
206+
"metadata": {
207+
"kernelspec": {
208+
"display_name": ".venv",
209+
"language": "python",
210+
"name": "python3"
211+
},
212+
"language_info": {
213+
"codemirror_mode": {
214+
"name": "ipython",
215+
"version": 3
216+
},
217+
"file_extension": ".py",
218+
"mimetype": "text/x-python",
219+
"name": "python",
220+
"nbconvert_exporter": "python",
221+
"pygments_lexer": "ipython3",
222+
"version": "3.11.12"
223+
}
224+
},
225+
"nbformat": 4,
226+
"nbformat_minor": 2
227+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
vllm
2+
transformers
3+
triton>=2.0.0
4+
torch>=2.0.0
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import ray
2+
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
3+
4+
5+
# 1. Construct a vLLM processor config.
6+
processor_config = vLLMEngineProcessorConfig(
7+
# The base model.
8+
model_source="unsloth/Llama-3.2-1B-Instruct",
9+
# vLLM engine config.
10+
engine_kwargs=dict(
11+
enable_lora=False,
12+
# # Older GPUs (e.g. T4) don't support bfloat16. You should remove
13+
# # this line if you're using later GPUs.
14+
dtype="half",
15+
# Reduce the model length to fit small GPUs. You should remove
16+
# this line if you're using large GPUs.
17+
max_model_len=1024,
18+
),
19+
# The batch size used in Ray Data.
20+
batch_size=8,
21+
# Use one GPU in this example.
22+
concurrency=1,
23+
# If you save the LoRA adapter in S3, you can set the following path.
24+
# dynamic_lora_loading_path="s3://your-lora-bucket/",
25+
)
26+
27+
# 2. Construct a processor using the processor config.
28+
processor = build_llm_processor(
29+
processor_config,
30+
preprocess=lambda row: dict(
31+
# Remove the LoRA model specification
32+
messages=[
33+
{
34+
"role": "system",
35+
"content": "You are a calculator. Please only output the answer "
36+
"of the given equation.",
37+
},
38+
{"role": "user", "content": f"{row['id']} ** 3 = ?"},
39+
],
40+
sampling_params=dict(
41+
temperature=0.3,
42+
max_tokens=20,
43+
detokenize=False,
44+
),
45+
),
46+
postprocess=lambda row: {
47+
"resp": row["generated_text"],
48+
},
49+
)
50+
51+
# 3. Synthesize a dataset with 32 rows.
52+
ds = ray.data.range(32)
53+
# 4. Apply the processor to the dataset. Note that this line won't kick off
54+
# anything because processor is execution lazily.
55+
ds = processor(ds)
56+
# Materialization kicks off the pipeline execution.
57+
ds = ds.materialize()
58+
59+
# 5. Print all outputs.
60+
for out in ds.take_all():
61+
print(out)
62+
print("==========")

0 commit comments

Comments
 (0)