Skip to content

Commit 018c599

Browse files
Adding modal support to run_and_check.py (#83)
* adding modal support to run and check, debugging tensor size error * remove a few unncessary modal dependencies, ready to merge --------- Co-authored-by: Simon Guo <[email protected]>
1 parent 2c3dbda commit 018c599

File tree

4 files changed

+203
-48
lines changed

4 files changed

+203
-48
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ To evaluate model-generated kernels, we need to check if they:
3939

4040
Check out `src/eval.py` for details on how we implement correctness check and timing.
4141

42-
We provide a convenient script `scripts/run_and_check.py` to evaluate one single sample source code against a reference source code, check correctness and compute speedup. You can use this to evaluate a model-generated kernel.
42+
We provide a convenient script `scripts/run_and_check.py` to evaluate one single sample source code against a reference source code, check correctness and compute speedup. You can use this to evaluate a kernel either locally or remotely by setting `eval_mode=local` or `eval_mode=modal`.
4343

4444
#### Overall Benchmark Metric
4545

scripts/eval_from_generations.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,13 @@
6868
"clang"
6969
)
7070
.pip_install(
71-
"anthropic",
7271
"numpy",
73-
"openai",
7472
"packaging",
7573
"pydra_config",
7674
"torch==2.5.0",
7775
"tqdm",
7876
"datasets",
7977
"transformers",
80-
"google-generativeai",
81-
"together",
8278
"pytest",
8379
"ninja",
8480
"utils",

scripts/generate_baseline_time_modal.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,14 @@
6868
"clang" # note i skip a step
6969
)
7070
.pip_install( # required to build flash-attn
71-
"anthropic",
71+
# Let's unify these dependencies somewhere
7272
"numpy",
73-
"openai",
7473
"packaging",
7574
"pydra_config",
7675
"torch==2.5.0",
7776
"tqdm",
7877
"datasets",
7978
"transformers",
80-
"google-generativeai",
81-
"together",
8279
"pytest",
8380
"ninja",
8481
"utils",

scripts/run_and_check.py

Lines changed: 201 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,57 @@
44
from pydra import REQUIRED, Config
55
import os
66
from datasets import load_dataset
7-
7+
import modal
88

99
from src import eval as kernel_eval
1010
from src import utils as kernel_utils
1111
from scripts.generate_baseline_time import measure_program_time
1212
from src.utils import read_file
1313

14+
# Modal setup
15+
app = modal.App("run_and_check")
16+
gpu_arch_mapping = {
17+
"L40S": ["Ada"],
18+
"H100": ["Hopper"],
19+
"H200": ["Hopper"],
20+
"A100": ["Ampere"],
21+
"A100-80GB": ["Ampere"],
22+
"L4": ["Ada"],
23+
"T4": ["Turing"],
24+
"A10G": ["Ampere"]
25+
}
26+
27+
REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
28+
KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench")
29+
30+
cuda_version = "12.4.0"
31+
flavor = "devel"
32+
operating_sys = "ubuntu22.04"
33+
tag = f"{cuda_version}-{flavor}-{operating_sys}"
34+
35+
image = (
36+
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
37+
.apt_install("git", "gcc-10", "g++-10", "clang")
38+
.pip_install(
39+
"numpy",
40+
"packaging",
41+
"pydra_config",
42+
"torch==2.5.0",
43+
"tqdm",
44+
"datasets",
45+
"transformers",
46+
"pytest",
47+
"ninja",
48+
"utils",
49+
"einops",
50+
"python-dotenv",
51+
"litellm[proxy]",
52+
)
53+
.add_local_dir(KERNEL_BENCH_PATH, remote_path="/root/KernelBench")
54+
.add_local_python_source("src")
55+
.add_local_python_source("scripts")
56+
)
57+
1458
"""
1559
Run a pair of KernelBench format (problem, solution) to check if solution is correct and compute speedup
1660
@@ -25,11 +69,17 @@
2569
2670
====================================================
2771
Usage:
28-
1. PyTorch reference is a local file
29-
python3 scripts/run_and_check.py ref_origin=local ref_arch_src_path=src/prompts/model_ex_add.py kernel_src_path=src/prompts/model_new_ex_add.py
72+
1. PyTorch reference is a local file (local eval)
73+
python3 scripts/run_and_check.py ref_origin=local ref_arch_src_path=src/prompts/model_ex_add.py kernel_src_path=src/prompts/model_new_ex_add.py eval_mode=local
74+
75+
2. PyTorch reference is a kernelbench problem (local eval)
76+
python3 scripts/run_and_check.py ref_origin=kernelbench level=<level> problem_id=<problem_id> kernel_src_path=<path to model-generated kernel> eval_mode=local
3077
31-
2. PyTorch refernece is a kernelbench problem
32-
python3 scripts/run_and_check.py ref_origin=kernelbench level=<level> problem_id=<problem_id> kernel_src_path=<path to model-generated kernel>
78+
3. PyTorch reference is a local file (modal eval on cloud GPU)
79+
python3 scripts/run_and_check.py ref_origin=local ref_arch_src_path=src/prompts/model_ex_add.py kernel_src_path=src/prompts/model_new_ex_add.py eval_mode=modal gpu=H100
80+
81+
4. PyTorch reference is a kernelbench problem (modal eval on cloud GPU)
82+
python3 scripts/run_and_check.py ref_origin=kernelbench level=<level> problem_id=<problem_id> kernel_src_path=<path to model-generated kernel> eval_mode=modal gpu=L40S
3383
====================================================
3484
3585
"""
@@ -51,6 +101,9 @@ def __init__(self):
51101
# Solution src definition
52102
self.kernel_src_path = ""
53103

104+
# Evaluation mode
105+
self.eval_mode = "local" # either "local" or "modal"
106+
self.gpu = "L40S" # GPU type for modal (L40S, H100, H200, A100, etc.)
54107

55108
# KernelBench Eval specific
56109
# number of trials to run for correctness
@@ -66,7 +119,7 @@ def __init__(self):
66119
self.clear_cache = False # TODO
67120

68121
# Replace with your NVIDIA GPU architecture, e.g. ["Hopper"]
69-
self.gpu_arch = ["Ada"]
122+
self.gpu_arch = ["Ada"]
70123
self.precision = "fp32"
71124
self.backend = "cuda"
72125

@@ -119,11 +172,70 @@ def evaluate_single_sample_src(ref_arch_src: str, kernel_src: str, configs: dict
119172
"hardware": torch.cuda.get_device_name(device=device),
120173
"device": str(device)
121174
}
122-
eval_result = kernel_eval.KernelExecResult(compiled=False, correctness=False,
175+
eval_result = kernel_eval.KernelExecResult(compiled=False, correctness=False,
123176
metadata=metadata)
124177
return eval_result
125178

126179

180+
# Modal evaluation class
181+
@app.cls(image=image, scaledown_window=5)
182+
class EvalFunc:
183+
184+
@modal.method()
185+
def evaluate_single_sample_src_modal(self, ref_arch_src: str, kernel_src: str, configs: dict, gpu_arch: list):
186+
"""Evaluate a single sample source code against a reference source code on Modal"""
187+
from src.utils import set_gpu_arch
188+
from src.eval import eval_kernel_against_ref, get_torch_dtype_from_string
189+
190+
set_gpu_arch(gpu_arch)
191+
device = torch.device("cuda:0")
192+
193+
num_correct_trials = configs["num_correct_trials"]
194+
num_perf_trials = configs["num_perf_trials"]
195+
verbose = configs["verbose"]
196+
measure_performance = configs["measure_performance"]
197+
198+
eval_result = eval_kernel_against_ref(
199+
original_model_src=ref_arch_src,
200+
custom_model_src=kernel_src,
201+
measure_performance=measure_performance,
202+
verbose=verbose,
203+
num_correct_trials=num_correct_trials,
204+
num_perf_trials=num_perf_trials,
205+
device=device,
206+
backend=configs["backend"],
207+
precision=get_torch_dtype_from_string(configs["precision"])
208+
)
209+
return eval_result
210+
211+
@modal.method()
212+
def measure_program_time_modal(
213+
self,
214+
ref_arch_src: str,
215+
num_trials: int,
216+
use_torch_compile: bool,
217+
torch_compile_backend: str,
218+
torch_compile_options: str,
219+
gpu_arch: list
220+
):
221+
"""Measure the execution time of a reference program on Modal"""
222+
from scripts.generate_baseline_time import measure_program_time
223+
from src.utils import set_gpu_arch
224+
225+
set_gpu_arch(gpu_arch)
226+
device = torch.device("cuda:0")
227+
228+
return measure_program_time(
229+
ref_arch_name="Reference Program",
230+
ref_arch_src=ref_arch_src,
231+
num_trials=num_trials,
232+
use_torch_compile=use_torch_compile,
233+
torch_compile_backend=torch_compile_backend,
234+
torch_compile_options=torch_compile_options,
235+
device=device
236+
)
237+
238+
127239
@pydra.main(base=ScriptConfig)
128240
def main(config: ScriptConfig):
129241

@@ -162,38 +274,88 @@ def main(config: ScriptConfig):
162274
kernel_src = read_file(config.kernel_src_path)
163275

164276
# Start Evaluation
165-
device = torch.device("cuda:0") # default device
166-
kernel_utils.set_gpu_arch(config.gpu_arch)
167-
168-
print("[INFO] Evaluating kernel against reference code")
169-
# Evaluate kernel against reference code
170-
kernel_eval_result = evaluate_single_sample_src(
171-
ref_arch_src=ref_arch_src,
172-
kernel_src=kernel_src,
173-
configs=config.to_dict(),
174-
device=device
175-
)
176-
kernel_exec_time = kernel_eval_result.runtime
177-
178-
# Measure baseline time
179-
print("[INFO] Measuring reference program time")
180-
# Default using PyTorch Eager here
181-
ref_time_eager_result = measure_program_time(ref_arch_name="Reference Program",
182-
ref_arch_src=ref_arch_src,
183-
num_trials=config.num_perf_trials,
184-
use_torch_compile=False,
185-
device=device)
186-
ref_exec_eager_time = ref_time_eager_result.get("mean", None)
187-
188-
# Measure Torch Compile time
189-
ref_time_compile_result = measure_program_time(ref_arch_name="Reference Program",
190-
ref_arch_src=ref_arch_src,
191-
num_trials=config.num_perf_trials,
192-
use_torch_compile=True,
193-
torch_compile_backend="inductor",
194-
torch_compile_options="default",
195-
device=device)
196-
ref_exec_compile_time = ref_time_compile_result.get("mean", None)
277+
assert config.eval_mode in ["local", "modal"], "eval_mode must be either 'local' or 'modal'"
278+
279+
if config.eval_mode == "local":
280+
# Local evaluation (existing code path)
281+
device = torch.device("cuda:0")
282+
kernel_utils.set_gpu_arch(config.gpu_arch)
283+
284+
print("[INFO] Evaluating kernel against reference code (LOCAL)")
285+
# Evaluate kernel against reference code
286+
kernel_eval_result = evaluate_single_sample_src(
287+
ref_arch_src=ref_arch_src,
288+
kernel_src=kernel_src,
289+
configs=config.to_dict(),
290+
device=device
291+
)
292+
kernel_exec_time = kernel_eval_result.runtime
293+
294+
# Measure baseline time
295+
print("[INFO] Measuring reference program time")
296+
# Default using PyTorch Eager here
297+
ref_time_eager_result = measure_program_time(ref_arch_name="Reference Program",
298+
ref_arch_src=ref_arch_src,
299+
num_trials=config.num_perf_trials,
300+
use_torch_compile=False,
301+
device=device)
302+
ref_exec_eager_time = ref_time_eager_result.get("mean", None)
303+
304+
# Measure Torch Compile time
305+
ref_time_compile_result = measure_program_time(ref_arch_name="Reference Program",
306+
ref_arch_src=ref_arch_src,
307+
num_trials=config.num_perf_trials,
308+
use_torch_compile=True,
309+
torch_compile_backend="inductor",
310+
torch_compile_options="default",
311+
device=device)
312+
ref_exec_compile_time = ref_time_compile_result.get("mean", None)
313+
314+
elif config.eval_mode == "modal":
315+
# Modal evaluation (remote execution)
316+
gpu_arch = gpu_arch_mapping.get(config.gpu, config.gpu_arch)
317+
print(f"[INFO] Using GPU: {config.gpu} with architecture: {gpu_arch}")
318+
319+
with app.run():
320+
print("[INFO] Evaluating kernel against reference code (MODAL)")
321+
# Evaluate kernel against reference code
322+
kernel_eval_result = EvalFunc.with_options(
323+
gpu=config.gpu
324+
)().evaluate_single_sample_src_modal.remote(
325+
ref_arch_src=ref_arch_src,
326+
kernel_src=kernel_src,
327+
configs=config.to_dict(),
328+
gpu_arch=gpu_arch
329+
)
330+
kernel_exec_time = kernel_eval_result.runtime
331+
332+
# Measure baseline time
333+
print("[INFO] Measuring reference program time (PyTorch Eager)")
334+
ref_time_eager_result = EvalFunc.with_options(
335+
gpu=config.gpu
336+
)().measure_program_time_modal.remote(
337+
ref_arch_src=ref_arch_src,
338+
num_trials=config.num_perf_trials,
339+
use_torch_compile=False,
340+
torch_compile_backend=None,
341+
torch_compile_options=None,
342+
gpu_arch=gpu_arch
343+
)
344+
ref_exec_eager_time = ref_time_eager_result.get("mean", None)
345+
346+
# Measure Torch Compile time
347+
print("[INFO] Measuring reference program time (torch.compile)")
348+
ref_time_compile_result = EvalFunc.with_options(
349+
gpu=config.gpu
350+
)().measure_program_time_modal.remote(
351+
ref_arch_src=ref_arch_src,
352+
num_trials=config.num_perf_trials,
353+
use_torch_compile=True,
354+
torch_compile_backend="inductor",
355+
torch_compile_options="default",
356+
gpu_arch=gpu_arch
357+
)
358+
ref_exec_compile_time = ref_time_compile_result.get("mean", None)
197359

198360
print("="*40)
199361
print(f"[Eval] Kernel eval result: {kernel_eval_result}")

0 commit comments

Comments
 (0)