Skip to content

Commit 3e91550

Browse files
merged updated main
2 parents 1f054c7 + df143a8 commit 3e91550

13 files changed

+771
-680
lines changed

.env.example

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# API Keys for LLM Providers
2+
# Copy this file to .env and fill in your actual API keys
3+
# DO NOT commit your .env file with real keys!
4+
5+
# OpenAI (for GPT models and o1/o3 reasoning models)
6+
OPENAI_API_KEY=sk-...
7+
8+
# Anthropic (for Claude models)
9+
ANTHROPIC_API_KEY=sk-ant-api03-...
10+
11+
# Google Gemini
12+
GEMINI_API_KEY=...
13+
14+
# DeepSeek
15+
DEEPSEEK_API_KEY=sk-...
16+
17+
# Together AI
18+
TOGETHER_API_KEY=...
19+
20+
# Fireworks AI
21+
FIREWORKS_AI_API_KEY=...
22+
23+
# Local Server Deployment (SGLang, vLLM, Tokasaurus)
24+
SGLANG_API_KEY=...

README.md

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# KernelBench: Can LLMs Write Efficient GPU Kernels? [ICML '25]
2-
[arXiv](https://arxiv.org/html/2502.10517v1) | [blog post](https://scalingintelligence.stanford.edu/blogs/kernelbench/) | [HuggingFace Dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench) |
2+
A benchmark for evaluating LLMs' ability to generate efficient GPU kernels
3+
4+
[arXiv](https://arxiv.org/html/2502.10517v1) | [blog post](https://scalingintelligence.stanford.edu/blogs/kernelbench/) | [HuggingFace Dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench)
5+
6+
<img src="./assets/figures/KernelBenchMascot.png" width="200">
37

48
## Versions
5-
The huggingface dataset is updated to v0.1.
6-
- [v0.1](https://github.com/ScalingIntelligence/KernelBench/tree/v0.1) - Latest version (also main branch)
9+
The latest stable version will be on `main` branch. We continue to update and improve the repo.
10+
- [v0.1](https://github.com/ScalingIntelligence/KernelBench/tree/v0.1) - See [blog](https://scalingintelligence.stanford.edu/blogs/kernelbenchv01/)
711
- [v0](https://github.com/ScalingIntelligence/KernelBench/tree/v0) - Original Release
812

9-
A benchmark for evaluating LLMs' ability to generate efficient GPU kernels
1013

11-
<img src="./assets/figures/KernelBenchMascot.png" width="200">
14+
The Huggingface [dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench) is updated to v0.1.
1215

13-
<!-- See [blog post](https://scalingintelligence.stanford.edu/blogs/kernelbench/) and [arXiv paper](https://arxiv.org/html/2502.10517v1) for more details. -->
16+
This repo provides core functionality for KernelBench and an easy-to-use set of scripts for evaluation. It is not intended to provide complex agentic scaffolds that solve this task; we recommend cloning and modifying this repo for your experiment, or using it as a git submodule.
1417

1518
## 👋 Task Description
1619
We structure the problem for LLM to transpile operators described in PyTorch to CUDA kernels, at whatever level of granularity it desires to.
@@ -26,7 +29,7 @@ We construct KernelBench to have 4 Levels of categories:
2629
- **Level 4 🤗**: Level Hugging Face
2730
Optimize whole model architectures from HuggingFace
2831

29-
We are actively extending KernelBench to other DSLs beyond `cuda` as well.
32+
We are actively extending KernelBench to other DSLs beyond `cuda` as well (see below).
3033

3134
## ⚖️ Evaluation
3235
#### Methodology
@@ -36,7 +39,7 @@ To evaluate model-generated kernels, we need to check if they:
3639

3740
Check out `src/eval.py` for details on how we implement correctness check and timing.
3841

39-
We provide a convenient script `scripts/run_and_check.py` to evaluate one single sample source code against a reference source code, check correctness and compute speedup. You can use this to evaluate a model-generated kernel.
42+
We provide a convenient script `scripts/run_and_check.py` to evaluate one single sample source code against a reference source code, check correctness and compute speedup. You can use this to evaluate a kernel either locally or remotely by setting `eval_mode=local` or `eval_mode=modal`.
4043

4144
#### Overall Benchmark Metric
4245

@@ -80,7 +83,7 @@ pip install -r requirements.txt
8083
pip install -e .
8184
```
8285

83-
To call LLM API providers, set your `{INFERENCE_SERVER_PROVIDER}_API_KEY` API key.
86+
We use `litellm` for API calls. Please set your keys by creating a `.env` following our `.env.example`.
8487

8588
Running and profiling kernels require a GPU.
8689
If you don't have GPU available locally, you can set up [Modal](https://modal.com/). Set up your modal token after creating an account by running `modal token new`. Then, use the `generate_and_eval_single_sample_modal.py` script.
@@ -98,7 +101,12 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev
98101
# add .verbose_logging for more visbility
99102
```
100103

101-
We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support (`cuda`, `triton`, `cute`).
104+
**What you might need to modify**
105+
* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware.
106+
* **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`.
107+
* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`.
108+
109+
Check the config fields for comprehensive set of options.
102110

103111
### Run on all problems
104112

@@ -122,7 +130,7 @@ If you are using a different hardware, you can generate the baseline time with `
122130
We provide some reference baseline times a variety of NVIDIA GPUs across generations in `results/timing`, but we recommend you to generate your own baseline time for more accurate results (cluster power, software version, all affects timing result). See `results/timing/README.md` for more details.
123131

124132
### Multi-Turn Framework
125-
We have also releaed the test-time framework [Caesar](https://github.com/simonguozirui/caesar) that are used in the multi-turn / iterative refinement experiments in our paper. You can use or modify this framework for high-throughput test-time scaling (both sequential and parallel) targeting KernelBench problems.
133+
We have also releaed the test-time framework [Caesar](https://github.com/ScalingIntelligence/caesar) that are used in the multi-turn / iterative refinement experiments in our paper. You can use or modify this framework for high-throughput test-time scaling (both sequential and parallel) targeting KernelBench problems.
126134

127135
## 🛣️ Upcoming Roadmap
128136
Check out our [roadmap](https://github.com/ScalingIntelligence/KernelBench/issues/74) for what we plan to add as features. We welcome community contirbutions in these directions.

requirements.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# Frameworks
2-
torch==2.5.0
2+
# we use latest PyTorch stable release
3+
torch==2.9.0
4+
35
# we shall upgrade torch for blackwell when it is stable
46
transformers
57
datasets
68
modal
79

810
# DSLs
911
nvidia-cutlass-dsl
12+
tilelang
1013

1114
# helper
1215
tqdm
@@ -20,9 +23,7 @@ einops
2023
dotenv
2124
numpy
2225

23-
# to deprecate with litellm
24-
google-generativeai
25-
together
26-
openai
27-
anthropic
26+
# use litellm for cloud providers and openai for local
27+
openai
28+
litellm[proxy]
2829

results/timing/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ This folder contains a set of baseline timing results for the KernelBench proble
66
Since KernelBench measures the speedup between Runtime(refernece architecture) and Runtime(LLM-generated architecture), it is important to measure the baseline reference module runtime.
77

88
We have provided a set of baseline results for the KernelBench problems on a variety of hardware as well as various PyTorch configurations.
9-
All baseline are ran with PyTorch `2.5.0+cu124` and CUDA `12.4`.
9+
All (current) baseline are ran with PyTorch `2.5.0+cu124` and CUDA `12.4`.
10+
11+
Note: we will update it soon with PyTorch `2.9.0` and CUDA `12.8`
1012

1113
For timing, we measure wall clock time. We warm up 3 times and collect runtime statistics for 100 trials.
1214

scripts/eval_from_generations.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
app = modal.App("eval_from_generations_modal")
5656
gpu_arch_mapping = {"L40S": ["Ada"], "H100": ["Hopper"], "A100": ["Ampere"], "L4": ["Ada"], "T4": ["Turing"], "A10G": ["Ampere"]}
5757

58-
cuda_version = "12.4.0" # should be no greater than host CUDA version
58+
cuda_version = "12.8.0" # should be no greater than host CUDA version
5959
flavor = "devel" # includes full CUDA toolkit
6060
operating_sys = "ubuntu22.04"
6161
tag = f"{cuda_version}-{flavor}-{operating_sys}"
@@ -67,23 +67,7 @@
6767
"g++-10",
6868
"clang"
6969
)
70-
.pip_install(
71-
"anthropic",
72-
"numpy",
73-
"openai",
74-
"packaging",
75-
"pydra_config",
76-
"torch==2.5.0",
77-
"tqdm",
78-
"datasets",
79-
"transformers",
80-
"google-generativeai",
81-
"together",
82-
"pytest",
83-
"ninja",
84-
"utils",
85-
"python-dotenv",
86-
)
70+
.pip_install_from_requirements(os.path.join(REPO_TOP_DIR, "requirements.txt"))
8771
.add_local_dir(
8872
KERNEL_BENCH_PATH,
8973
remote_path="/root/KernelBench"
@@ -145,6 +129,10 @@ def __init__(self):
145129

146130
# Backend to use for kernel implementation (cuda or triton)
147131
self.backend = "cuda"
132+
133+
# Precision for computation: "fp32", "fp16", "bf16"
134+
self.precision = "fp32"
135+
148136
# Number of samples per problem to evaluate for pass@k analysis
149137
self.num_samples_per_problem = 1 # Default to 1 sample per problem
150138

@@ -165,17 +153,18 @@ class WorkArgs:
165153
# Modal Evaluation Class
166154
# GPU must be specified here for all instances
167155
# Retries are configured at the class level to handle GPU attachment failures
168-
# @modal.concurrent: Each container handles exactly ONE evaluation at a time - prevents memory leaks
156+
# scaledown_window=5 kills idle containers after 5 seconds
157+
# Combined with 10s sleep between batches, this prevents container reuse and GPU corruption spread
169158
@app.cls(
170-
image=image,
159+
image=image,
171160
gpu="A10G",
161+
scaledown_window=5, # Kill idle containers after 5 seconds
172162
retries=modal.Retries(
173163
max_retries=3,
174164
backoff_coefficient=2.0,
175165
initial_delay=1.0,
176166
)
177167
)
178-
@modal.concurrent(max_inputs=1) # One input per container - prevents GPU memory leaks
179168
class ModalEvaluator:
180169

181170
@modal.method()
@@ -188,11 +177,13 @@ def evaluate_single_sample_modal(
188177
num_perf_trials: int = 100,
189178
measure_performance: bool = True,
190179
verbose: bool = False,
180+
backend: str = "cuda",
181+
precision: str = "fp32",
191182
):
192183
"""
193184
Evaluate a single sample on Modal GPU with automatic retries for GPU attachment failures
194185
"""
195-
from src.eval import eval_kernel_against_ref
186+
from src.eval import eval_kernel_against_ref, get_torch_dtype_from_string
196187
from src.utils import set_gpu_arch
197188
import torch
198189
import time
@@ -225,12 +216,14 @@ def evaluate_single_sample_modal(
225216
num_perf_trials=num_perf_trials,
226217
build_dir=None, # Modal doesn't need persistent build dir
227218
device=torch.device("cuda:0"), # Modal has one GPU per container
219+
backend=backend,
220+
precision=get_torch_dtype_from_string(precision),
228221
)
229-
230-
# Force cleanup and exit to prevent container reuse and memory leaks
222+
223+
# Cleanup GPU cache before returning
231224
torch.cuda.empty_cache()
232-
233-
return result # Never reached, but needed for type checking
225+
226+
return result
234227

235228

236229
def fetch_ref_arch_from_problem_id(
@@ -321,6 +314,7 @@ def evaluate_single_sample(
321314
build_dir=build_dir,
322315
device=device,
323316
backend=configs.backend,
317+
precision=eval.get_torch_dtype_from_string(configs.precision),
324318
)
325319
return eval_result
326320
except Exception as e:
@@ -477,7 +471,8 @@ def batch_eval_modal(
477471
evaluator_cls = ModalEvaluator.with_options(gpu=config.gpu) if config.gpu != "A10G" else ModalEvaluator
478472

479473
# Spawn all tasks in parallel
480-
# Each spawn creates a NEW container instance with a GPU
474+
# Modal assigns these to available containers (may reuse warm containers from previous batches)
475+
# To prevent GPU corruption spread, we sleep between batches to ensure containers scale down
481476
futures = []
482477
for item in work_items:
483478
if item is None:
@@ -491,6 +486,8 @@ def batch_eval_modal(
491486
num_perf_trials=config.num_perf_trials,
492487
measure_performance=config.measure_performance,
493488
verbose=config.verbose,
489+
backend=config.backend,
490+
precision=config.precision,
494491
)
495492
futures.append(future)
496493

@@ -531,7 +528,14 @@ def batch_eval_modal(
531528

532529
print("-" * 128)
533530
print(f"[Modal Batch] Evaluation took {end_time - start_time:.2f} seconds")
534-
531+
532+
# Wait for containers to scale down before next batch
533+
# This prevents container reuse and GPU corruption from spreading between batches
534+
if len(total_work) > 0: # Only sleep if there are more batches
535+
scaledown_wait = 10 # Wait 10 seconds (2x the scaledown_window) to ensure containers are killed
536+
print(f"[Modal] Waiting {scaledown_wait}s for containers to scale down before next batch...")
537+
time.sleep(scaledown_wait)
538+
535539
pbar.update(len(curr_work_batch))
536540

537541

scripts/generate_and_eval_single_sample.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
read_file,
1919
set_gpu_arch,
2020
)
21-
21+
from src.eval import get_torch_dtype_from_string
2222
"""
2323
Generate and evaluate a single sample
2424
Easiest way to get started, to test a single problem for experimentation or debugging
@@ -48,12 +48,18 @@ def __init__(self):
4848
# Construct this from mapping from architecture name to torch cuda arch list in the future
4949
# you can either specify SM version or just use the name
5050
self.gpu_arch = ["Ada"]
51+
self.precision = "fp32" # options ["fp32", "fp16", "bf16"]
5152

5253
# Inference config
53-
self.server_type = "deepseek"
54-
self.model_name = "deepseek-coder"
55-
self.max_tokens = 4096
56-
self.temperature = 0.0
54+
self.server_type = None
55+
self.model_name = None
56+
self.max_tokens = None
57+
self.temperature = None
58+
59+
# Reasoning model specific parameters
60+
self.is_reasoning_model = False # set to True for o1, o3, Gemini 2.5 thinking, etc.
61+
self.reasoning_effort = None # for o1/o3: "low", "medium", "high"
62+
self.budget_tokens = 0 # for Claude extended thinking mode
5763

5864
# Logging
5965
self.logdir = os.path.join(REPO_TOP_DIR, "results/eval_logs")
@@ -81,6 +87,21 @@ def main(config: EvalConfig):
8187
"""
8288
Keep it simple: Generate and evaluate a single sample
8389
"""
90+
from src.utils import SERVER_PRESETS
91+
92+
if config.server_type and config.server_type in SERVER_PRESETS:
93+
preset = SERVER_PRESETS[config.server_type]
94+
if config.model_name is None or config.model_name == "None":
95+
config.model_name = preset.get("model_name", "None")
96+
if config.max_tokens is None or config.max_tokens == "None":
97+
config.max_tokens = preset.get("max_tokens", "None")
98+
if config.temperature is None or config.temperature == "None":
99+
config.temperature = preset.get("temperature", "None")
100+
101+
# Convert string boolean to actual boolean for reasoning model flag
102+
if isinstance(config.is_reasoning_model, str):
103+
config.is_reasoning_model = config.is_reasoning_model.lower() in ['true', '1', 'yes']
104+
84105
print(f"Starting Eval with config: {config}")
85106

86107
# Configurations
@@ -143,14 +164,19 @@ def main(config: EvalConfig):
143164
max_tokens=config.max_tokens,
144165
verbose=config.verbose,
145166
time_generation=True,
167+
is_reasoning_model=config.is_reasoning_model,
168+
reasoning_effort=config.reasoning_effort,
169+
budget_tokens=config.budget_tokens,
146170
)
147171

148172
# Use appropriate prompt constructor based on backend
149-
if config.backend in ["cuda", "triton", "cute"]:
150-
custom_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, option="few_shot")
173+
if config.backend == "cuda":
174+
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
175+
elif config.backend in ["triton", "tilelang", "cute"]:
176+
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
151177
else:
152178
raise ValueError(
153-
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'."
179+
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'."
154180
)
155181

156182
if config.log_prompt:
@@ -194,6 +220,7 @@ def main(config: EvalConfig):
194220
num_correct_trials=5,
195221
num_perf_trials=100,
196222
backend=config.backend,
223+
precision=get_torch_dtype_from_string(config.precision),
197224
)
198225

199226
print(

0 commit comments

Comments
 (0)