You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The Huggingface [dataset](https://huggingface.co/datasets/ScalingIntelligence/KernelBench) is updated to v0.1.
12
15
13
-
<!-- See [blog post](https://scalingintelligence.stanford.edu/blogs/kernelbench/) and [arXiv paper](https://arxiv.org/html/2502.10517v1) for more details. -->
16
+
This repo provides core functionality for KernelBench and an easy-to-use set of scripts for evaluation. It is not intended to provide complex agentic scaffolds that solve this task; we recommend cloning and modifying this repo for your experiment, or using it as a git submodule.
14
17
15
18
## 👋 Task Description
16
19
We structure the problem for LLM to transpile operators described in PyTorch to CUDA kernels, at whatever level of granularity it desires to.
@@ -26,7 +29,7 @@ We construct KernelBench to have 4 Levels of categories:
26
29
-**Level 4 🤗**: Level Hugging Face
27
30
Optimize whole model architectures from HuggingFace
28
31
29
-
We are actively extending KernelBench to other DSLs beyond `cuda` as well.
32
+
We are actively extending KernelBench to other DSLs beyond `cuda` as well (see below).
30
33
31
34
## ⚖️ Evaluation
32
35
#### Methodology
@@ -36,7 +39,7 @@ To evaluate model-generated kernels, we need to check if they:
36
39
37
40
Check out `src/eval.py` for details on how we implement correctness check and timing.
38
41
39
-
We provide a convenient script `scripts/run_and_check.py` to evaluate one single sample source code against a reference source code, check correctness and compute speedup. You can use this to evaluate a model-generated kernel.
42
+
We provide a convenient script `scripts/run_and_check.py` to evaluate one single sample source code against a reference source code, check correctness and compute speedup. You can use this to evaluate a kernel either locally or remotely by setting `eval_mode=local` or `eval_mode=modal`.
40
43
41
44
#### Overall Benchmark Metric
42
45
@@ -80,7 +83,7 @@ pip install -r requirements.txt
80
83
pip install -e .
81
84
```
82
85
83
-
To call LLM API providers, set your `{INFERENCE_SERVER_PROVIDER}_API_KEY` API key.
86
+
We use `litellm` for API calls. Please set your keys by creating a `.env` following our `.env.example`.
84
87
85
88
Running and profiling kernels require a GPU.
86
89
If you don't have GPU available locally, you can set up [Modal](https://modal.com/). Set up your modal token after creating an account by running `modal token new`. Then, use the `generate_and_eval_single_sample_modal.py` script.
@@ -98,7 +101,12 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev
98
101
# add .verbose_logging for more visbility
99
102
```
100
103
101
-
We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support (`cuda`, `triton`, `cute`).
104
+
**What you might need to modify**
105
+
***`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware.
106
+
***`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`.
107
+
***`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`.
108
+
109
+
Check the config fields for comprehensive set of options.
102
110
103
111
### Run on all problems
104
112
@@ -122,7 +130,7 @@ If you are using a different hardware, you can generate the baseline time with `
122
130
We provide some reference baseline times a variety of NVIDIA GPUs across generations in `results/timing`, but we recommend you to generate your own baseline time for more accurate results (cluster power, software version, all affects timing result). See `results/timing/README.md` for more details.
123
131
124
132
### Multi-Turn Framework
125
-
We have also releaed the test-time framework [Caesar](https://github.com/simonguozirui/caesar) that are used in the multi-turn / iterative refinement experiments in our paper. You can use or modify this framework for high-throughput test-time scaling (both sequential and parallel) targeting KernelBench problems.
133
+
We have also releaed the test-time framework [Caesar](https://github.com/ScalingIntelligence/caesar) that are used in the multi-turn / iterative refinement experiments in our paper. You can use or modify this framework for high-throughput test-time scaling (both sequential and parallel) targeting KernelBench problems.
126
134
127
135
## 🛣️ Upcoming Roadmap
128
136
Check out our [roadmap](https://github.com/ScalingIntelligence/KernelBench/issues/74) for what we plan to add as features. We welcome community contirbutions in these directions.
Copy file name to clipboardExpand all lines: results/timing/README.md
+3-1Lines changed: 3 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -6,7 +6,9 @@ This folder contains a set of baseline timing results for the KernelBench proble
6
6
Since KernelBench measures the speedup between Runtime(refernece architecture) and Runtime(LLM-generated architecture), it is important to measure the baseline reference module runtime.
7
7
8
8
We have provided a set of baseline results for the KernelBench problems on a variety of hardware as well as various PyTorch configurations.
9
-
All baseline are ran with PyTorch `2.5.0+cu124` and CUDA `12.4`.
9
+
All (current) baseline are ran with PyTorch `2.5.0+cu124` and CUDA `12.4`.
10
+
11
+
Note: we will update it soon with PyTorch `2.9.0` and CUDA `12.8`
10
12
11
13
For timing, we measure wall clock time. We warm up 3 times and collect runtime statistics for 100 trials.
0 commit comments