Skip to content

Commit e395e91

Browse files
PaliCAffectionateCurrynathanjpaeksimonguozirui
authored
Add Triton + CuTe Backend, enable more DSL support (#35)
* triton_backend_v2 * fix eval bugs * fix issues * revert eval * remove traceback * remove cot * improve eval * looked over pr and added future support for other languages * updated requirements * added back requirements.txt * add cute one shot addition example * remove unncessary files and redo requirements * let's see if that fixes it * fix config in file suggested soksoerey * move natalia's old file into change log --------- Co-authored-by: AffectionateCurry <[email protected]> Co-authored-by: nathanjpaek <[email protected]> Co-authored-by: Simon Guo <[email protected]>
1 parent 94ab208 commit e395e91

12 files changed

+1120
-135
lines changed
File renamed without changes.

KernelBench/test.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

README.md

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ We construct KernelBench to have 4 Levels of categories:
2626
- **Level 4 🤗**: Level Hugging Face
2727
Optimize whole model architectures from HuggingFace
2828

29+
We are actively extending KernelBench to other DSLs beyond `cuda` as well.
30+
2931
## ⚖️ Evaluation
3032
#### Methodology
3133
To evaluate model-generated kernels, we need to check if they:
@@ -47,6 +49,7 @@ Some examples to illustrate this metric that filters based on speedups:
4749

4850
You can increase speedup threshold `p` to make the task more challenging.
4951

52+
5053
#### Compute Overall Benchmark Performance
5154

5255
We provide a script `scripts/greedy_analysis.py` to compute the overall benchmark performance.
@@ -95,6 +98,8 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev
9598
# add .verbose_logging for more visbility
9699
```
97100

101+
We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support (`cuda`, `triton`, `cute`).
102+
98103
### Run on all problems
99104

100105
```
@@ -120,25 +125,10 @@ We provide some reference baseline times a variety of NVIDIA GPUs across generat
120125
We have also releaed the test-time framework [Caesar](https://github.com/simonguozirui/caesar) that are used in the multi-turn / iterative refinement experiments in our paper. You can use or modify this framework for high-throughput test-time scaling (both sequential and parallel) targeting KernelBench problems.
121126

122127
## 🛣️ Upcoming Roadmap
123-
- [ ] Triton Variant (To be merged)
124-
- [ ] Easy to use CoLab Notebook Example
125-
- [ ] Push button flow on Modal / Cloud Provider
126-
- [ ] Integrate with more frameworks, such as [ThunderKittens](https://github.com/HazyResearch/ThunderKittens)
127-
- [ ] Add backward pass
128-
- [ ] Integrate with toolchains such as NCU
129-
See Issues for the ongoing roadmap and directions.
130-
131-
128+
Check out our [roadmap](https://github.com/ScalingIntelligence/KernelBench/issues/74) for what we plan to add as features. We welcome community contirbutions in these directions.
132129

133130
## 🔍 Known Usage
134-
- [NVIDIA](https://developer.nvidia.com/blog/automating-gpu-kernel-generation-with-deepseek-r1-and-inference-time-scaling/) - Automating GPU Kernel Generation with DeepSeek-R1 and Inference Time Scaling
135-
- [METR](https://metr.org/blog/2025-02-14-measuring-automated-kernel-engineering/) - Measuring Automated Kernel Engineering
136-
- [Sakana AI](https://sakana.ai/ai-cuda-engineer/) - AI Cuda Engineer
137-
- [Project Popcorn](https://www.youtube.com/watch?v=mdDVkBeFy9A) - Triton Support for KernelBench, Data Scaling + SFT'd Kernel LLM
138-
- [Kevin](https://cognition.ai/blog/kevin-32b) - Kevin-32B: Multi-Turn RL for Writing CUDA Kernels
139-
- [Simple Test-Time Search](https://scalingintelligence.stanford.edu/blogs/fastkernels/) - by @anneouyang
140-
141-
If you are using KernelBench, we love to hear more about it!
131+
Since release, we have gotten a lot of interest from researchers, research labs, and companies that use KernelBench to explore this direction. We have documented [known usage](https://docs.google.com/document/d/e/2PACX-1vTjS-UMH1HB5n_PENq2k-3YRfXIXkqKIKeNC2zcWMyLPdl4Jrwvdk4dNDVSsM8ybKrCxZB7GJq1slZF/pub) of KernelBench and related efforts towards automated kernel generations. If you are using KernelBench, we love to hear more about it!
142132

143133
## 🪪 License
144134
MIT. Check `LICENSE.md` for more details.

requirements.txt

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
1-
anthropic
1+
# Frameworks
2+
torch==2.5.0
3+
# we shall upgrade torch for blackwell when it is stable
4+
transformers
5+
datasets
26
modal
3-
numpy
4-
openai
7+
8+
# DSLs
9+
nvidia-cutlass-dsl
10+
11+
# helper
12+
tqdm
513
packaging
614
pydra_config
7-
torch==2.5.0
8-
tqdm
9-
datasets
10-
transformers
11-
google-generativeai
12-
together
1315
pytest
1416
ninja
15-
archon-ai
17+
18+
# Numerics
1619
einops
17-
dotenv
20+
dotenv
21+
numpy
22+
23+
# to deprecate with litellm
24+
google-generativeai
25+
together
26+
openai
27+
anthropic
28+

scripts/eval_from_generations.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
import shutil
55
import time
6+
from dataclasses import dataclass
7+
68
from collections import defaultdict
79
from dataclasses import dataclass
810

@@ -12,15 +14,19 @@
1214

1315
from datasets import load_dataset
1416
from pydra import Config, REQUIRED
17+
18+
# Import only what we need
1519
from src import compile, eval, utils
1620

1721
from src.dataset import construct_kernelbench_dataset
1822
from src.eval import (
1923
build_compile_cache,
24+
get_error_name,
2025
check_metadata_serializable_all_types,
2126
eval_kernel_against_ref,
2227
KernelExecResult,
2328
)
29+
2430
from src.utils import read_file, set_gpu_arch
2531
from tqdm import tqdm
2632

@@ -137,6 +143,8 @@ def __init__(self):
137143
# number of GPUs to do batch evaluation
138144
self.num_gpu_devices = 1
139145

146+
# Backend to use for kernel implementation (cuda or triton)
147+
self.backend = "cuda"
140148
# Number of samples per problem to evaluate for pass@k analysis
141149
self.num_samples_per_problem = 1 # Default to 1 sample per problem
142150

@@ -312,6 +320,7 @@ def evaluate_single_sample(
312320
num_perf_trials=configs.num_perf_trials,
313321
build_dir=build_dir,
314322
device=device,
323+
backend=configs.backend,
315324
)
316325
return eval_result
317326
except Exception as e:
@@ -322,6 +331,7 @@ def evaluate_single_sample(
322331
# NOTE: count this as compilation failure as it is not runnable code
323332
metadata = {
324333
"cuda_error": f"CUDA Error: {str(e)}",
334+
"cuda_error_name": get_error_name(e),
325335
"hardware": torch.cuda.get_device_name(device=device),
326336
"device": str(device),
327337
} # log this for debugging as this usually signifies illegal memory access
@@ -332,6 +342,7 @@ def evaluate_single_sample(
332342
else:
333343
metadata = {
334344
"other_error": f"error: {str(e)}",
345+
"other_error_name": get_error_name(e),
335346
"hardware": torch.cuda.get_device_name(device=device),
336347
"device": str(device),
337348
} # for debugging
@@ -387,10 +398,9 @@ def cuda_single_eval_wrapper(curr_work: WorkArgs, configs: dict, dataset, run_di
387398
pool.terminate()
388399
pool.join()
389400
raise
390-
except mp.TimeoutError:
401+
except mp.TimeoutError as e:
391402
print(
392-
f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id},"
393-
f" Sample ID: {curr_work.sample_id}"
403+
f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}\nException: {e}"
394404
)
395405

396406
print(
@@ -691,7 +701,7 @@ def add_to_eval_results_file(
691701
os.makedirs(os.path.dirname(eval_file_path), exist_ok=True)
692702

693703
with open(eval_file_path, "w") as f:
694-
json.dump(eval_results, f)
704+
json.dump(eval_results, f, indent=4)
695705

696706

697707
def single_eval_example(

scripts/generate_and_eval_single_sample.py

Lines changed: 91 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,21 @@
33
import os, sys
44
import torch
55
import json
6+
import modal
67

78
from datasets import load_dataset
89

910
from src.dataset import construct_kernelbench_dataset
1011
from src.eval import eval_kernel_against_ref
1112
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
12-
from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets
13+
from src.prompt_constructor_multilang import get_prompt_for_backend
14+
from src.utils import (
15+
create_inference_server_from_presets,
16+
extract_first_code,
17+
query_server,
18+
read_file,
19+
set_gpu_arch,
20+
)
1321

1422
"""
1523
Generate and evaluate a single sample
@@ -20,15 +28,15 @@
2028

2129
torch.set_printoptions(precision=4, threshold=10)
2230

31+
2332
class EvalConfig(Config):
2433
def __init__(self):
25-
26-
self.dataset_src = REQUIRED # either huggingface or local
34+
35+
self.dataset_src = REQUIRED # either huggingface or local
2736

2837
# name of dataset name on Hugging Face
2938
self.dataset_name = "ScalingIntelligence/KernelBench"
3039

31-
3240
# Problem Specification
3341
self.level = REQUIRED
3442
# NOTE: this is the logical index (problem id the problem_name)\
@@ -56,6 +64,8 @@ def __init__(self):
5664
self.log_generated_kernel = False
5765
self.log_eval_result = False
5866

67+
self.backend = "cuda"
68+
5969
def verbose_logging(self):
6070
self.log = True
6171
self.log_prompt = True
@@ -86,24 +96,31 @@ def main(config: EvalConfig):
8696

8797
if config.log:
8898
os.makedirs(config.logdir, exist_ok=True)
89-
99+
90100
# Problem Checks
91101
num_problems = len(curr_level_dataset)
92102
print(f"Number of problems in Level {config.level}: {num_problems}")
93-
print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}")
94-
95-
assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}"
103+
print(
104+
f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}"
105+
)
96106

107+
assert (
108+
config.problem_id <= num_problems
109+
), f"Problem ID {config.problem_id} out of range for Level {config.level}"
97110

98111
# 1. Fetch Problem
99112
if config.dataset_src == "huggingface":
100113

101-
curr_problem_row = curr_level_dataset.filter(lambda x: x["problem_id"] == config.problem_id)
114+
curr_problem_row = curr_level_dataset.filter(
115+
lambda x: x["problem_id"] == config.problem_id
116+
)
102117
ref_arch_src = curr_problem_row["code"][0]
103118
problem_name = curr_problem_row["name"][0]
104119

105120
elif config.dataset_src == "local":
106-
problem_idx_in_dataset = config.problem_id - 1 # due to dataset list being 0-indexed locally
121+
problem_idx_in_dataset = (
122+
config.problem_id - 1
123+
) # due to dataset list being 0-indexed locally
107124
ref_arch_path = curr_level_dataset[problem_idx_in_dataset]
108125

109126
problem_name = os.path.basename(ref_arch_path)
@@ -112,52 +129,90 @@ def main(config: EvalConfig):
112129

113130
# Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py")
114131
problem_number = int(problem_name.split("_")[0])
115-
assert problem_number == config.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"
116-
117-
132+
assert (
133+
problem_number == config.problem_id
134+
), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"
135+
118136
# 2. Generate Sample
119137
# Create inference function with config parameters
120138
# We provide some presets in utils but you can also pass in your own, see query_server for more details
121-
inference_server = create_inference_server_from_presets(server_type=config.server_type,
122-
model_name=config.model_name,
123-
temperature=config.temperature,
124-
max_tokens=config.max_tokens,
125-
verbose=config.verbose,
126-
time_generation=True)
127-
139+
inference_server = create_inference_server_from_presets(
140+
server_type=config.server_type,
141+
model_name=config.model_name,
142+
temperature=config.temperature,
143+
max_tokens=config.max_tokens,
144+
verbose=config.verbose,
145+
time_generation=True,
146+
)
128147

148+
# Use appropriate prompt constructor based on backend
149+
if config.backend == "cuda":
150+
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
151+
elif config.backend in ["triton", "cute"]: # removed "tilelang"
152+
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
153+
else:
154+
raise ValueError(
155+
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'."
156+
)
129157

130-
custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
131158
if config.log_prompt:
132-
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:
133-
f.write(custom_cuda_prompt)
159+
with open(
160+
os.path.join(
161+
config.logdir,
162+
f"prompt_level_{config.level}_problem_{config.problem_id}.txt",
163+
),
164+
"w",
165+
) as f:
166+
f.write(custom_prompt)
134167

135168
# Query server with constructed prompt
136-
custom_cuda = inference_server(custom_cuda_prompt)
137-
custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"])
138-
# check LLM is able to generate custom CUDA code
139-
assert custom_cuda is not None, "Custom CUDA code generation failed"
140-
169+
custom_kernel = inference_server(custom_prompt)
170+
custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"])
171+
172+
# check LLM is able to generate custom kernel code
173+
assert (
174+
custom_kernel is not None
175+
), f"Custom {config.backend} kernel code generation failed"
176+
141177
# this should be optional
142178
if config.log:
143-
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f:
144-
f.write(custom_cuda)
179+
with open(
180+
os.path.join(
181+
config.logdir,
182+
f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py",
183+
),
184+
"w",
185+
) as f:
186+
f.write(custom_kernel)
145187

146188
# 3. Evaluate Kernel
147189
# NOTE: no need to wrap around process here as only a single sample
148190
# see batch eval for examples of process isolation
149191
kernel_exec_result = eval_kernel_against_ref(
150-
ref_arch_src, custom_cuda, verbose=config.verbose, measure_performance=True, num_correct_trials=5, num_perf_trials=100
192+
ref_arch_src,
193+
custom_kernel,
194+
verbose=config.verbose,
195+
measure_performance=True,
196+
num_correct_trials=5,
197+
num_perf_trials=100,
198+
backend=config.backend,
199+
)
200+
201+
print(
202+
f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}"
151203
)
152-
153-
print(f"Evaluation result for level {config.level} problem {config.problem_id}:\n{kernel_exec_result}")
154204

155205
if config.log:
156-
with open(os.path.join(config.logdir, f"eval_result_level_{config.level}_problem_{config.problem_id}.txt"), "a") as f:
206+
with open(
207+
os.path.join(
208+
config.logdir,
209+
f"eval_result_level_{config.level}_problem_{config.problem_id}.txt",
210+
),
211+
"a",
212+
) as f:
157213
f.write(f"Problem Name: {problem_name}\n")
158214
f.write(str(kernel_exec_result))
159215

160216

161217
if __name__ == "__main__":
162-
main()
163-
218+
main()

0 commit comments

Comments
 (0)