Skip to content

Commit 1f054c7

Browse files
author
Sokserey Sun
committed
Cleaned up the toml file and added logic to use the toml
1 parent 044db0e commit 1f054c7

File tree

6 files changed

+202
-271
lines changed

6 files changed

+202
-271
lines changed

scripts/generate_and_eval_single_sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from src.dataset import construct_kernelbench_dataset
1111
from src.eval import eval_kernel_against_ref
1212
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
13-
from src.prompt_constructor_multilang import get_prompt_for_backend
13+
from src.prompt_constructor_multilang import get_prompt_for_language
1414
from src.utils import (
1515
create_inference_server_from_presets,
1616
extract_first_code,
@@ -147,7 +147,7 @@ def main(config: EvalConfig):
147147

148148
# Use appropriate prompt constructor based on backend
149149
if config.backend in ["cuda", "triton", "cute"]:
150-
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
150+
custom_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, option="few_shot")
151151
else:
152152
raise ValueError(
153153
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'."

scripts/generate_and_eval_single_sample_modal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#from src.dataset import construct_kernelbench_dataset
1717
from src.eval import eval_kernel_against_ref
1818
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
19-
from src.prompt_constructor_multilang import get_prompt_for_backend
19+
from src.prompt_constructor_multilang import get_prompt_for_language
2020
from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets
2121

2222
app = modal.App("eval_single_sample")
@@ -193,7 +193,7 @@ def main(config: EvalConfig):
193193

194194
# Use appropriate prompt constructor based on backend
195195
if config.backend in ["cuda", "triton", "cute"]:
196-
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
196+
custom_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, option="few_shot")
197197
else:
198198
raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'.")
199199

scripts/generate_samples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from src.dataset import construct_kernelbench_dataset
1212
from src.eval import eval_kernel_against_ref
1313
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
14-
from src.prompt_constructor_multilang import get_prompt_for_backend
14+
from src.prompt_constructor_multilang import get_prompt_for_language
1515
from src.utils import (
1616
create_inference_server_from_presets,
1717
extract_first_code,
@@ -121,7 +121,7 @@ def generate_sample_single(
121121

122122
# Construct Prompt
123123
if config.backend in ["cuda", "triton", "cute"]:
124-
custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
124+
custom_cuda_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, option="few_shot")
125125
else:
126126
raise ValueError(
127127
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'."

src/loader.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@ def compose_blocks(self, keys: List[str]) -> str:
3737
text_parts.append(node.strip() + "\n")
3838
return "\n".join(text_parts).strip() + "\n"
3939

40-
def get_template_node(self, backend: str, template: str) -> Dict[str, Any]:
41-
try:
42-
return self.data["backends"][backend]["templates"][template]
43-
except KeyError as e:
44-
raise KeyError(f"Unknown backend/template: {backend}/{template}") from e
45-
4640
def _gpu_context_from_py(py_path: str, gpu_name: str) -> Dict[str, str]:
4741
"""
4842
Load GPU_* dicts from a Python file (no exec of raw strings; use runpy).
@@ -79,39 +73,97 @@ def _gpu_context_from_py(py_path: str, gpu_name: str) -> Dict[str, str]:
7973
"gpu_best_practices_bullets": best_bullets,
8074
}
8175

82-
def render_prompt(
76+
def render_prompt_by_option(
8377
*,
8478
prompts_toml: str,
85-
backend: str,
86-
template: str,
79+
language: str,
80+
option: str,
8781
context: Dict[str, str],
8882
gpu_specs_py: Optional[str] = None,
8983
gpu_name: Optional[str] = None,
9084
) -> str:
85+
"""
86+
New function that uses languages.X and options.Y structure
87+
88+
Args:
89+
prompts_toml: Path to the prompts.toml file
90+
language: The kernel language (triton, cuda, cute)
91+
option: The prompt option (basic, few_shot, hardware_info, fix_compile, fix_correctness)
92+
context: Variables to fill in the prompt template
93+
gpu_specs_py: Optional path to GPU specs Python file
94+
gpu_name: Optional GPU name (required if option requires_gpu)
95+
"""
9196
cfg = PromptConfig.from_toml(prompts_toml)
92-
node = cfg.get_template_node(backend, template)
93-
97+
98+
# Get language-specific content
99+
try:
100+
lang_data = cfg.data["languages"][language]
101+
except KeyError:
102+
raise KeyError(f"Unknown language: {language}")
103+
104+
# Get option configuration
105+
try:
106+
option_data = cfg.data["options"][option]
107+
except KeyError:
108+
raise KeyError(f"Unknown option: {option}")
109+
110+
# Get shared templates
111+
shared = cfg.data.get("shared", {})
112+
language_display = lang_data.get("language_display", language.upper())
113+
114+
# Fill in shared templates with language-specific terms
115+
problem_statement = shared.get("problem_statement", "").format(language_display=language_display)
116+
instruction = shared.get("instruction", "").format(language_display=language_display)
117+
118+
# Add language-specific content to context
119+
context = {
120+
**context,
121+
"language": language.upper() if language in ["cuda", "cute"] else language.capitalize(),
122+
"language_display": language_display,
123+
"problem_statement": problem_statement,
124+
"instruction": instruction,
125+
}
126+
94127
# Load example files if requested
95-
if node.get("requires_example"):
96-
ex_arch_path = _abs_path(node["example_arch_path"])
97-
ex_new_path = _abs_path(node["example_new_arch_path"])
128+
if option_data.get("requires_example"):
129+
# Use language-specific example arch, or fall back to shared one
130+
ex_arch_path = _abs_path(
131+
lang_data.get("few_shot_example_arch") or shared.get("few_shot_example_arch")
132+
)
133+
ex_new_path = _abs_path(lang_data["few_shot_new_arch"])
98134
context = {
99135
**context,
100136
"example_arch_src": read_file(ex_arch_path),
101137
"example_new_arch_src": read_file(ex_new_path),
102138
}
103-
104-
# Load GPU details (from .py) if requested
105-
if node.get("requires_gpu"):
139+
140+
# Load GPU details if requested
141+
if option_data.get("requires_gpu"):
106142
if not (gpu_specs_py and gpu_name):
107-
raise ValueError("Template requires GPU info; provide gpu_specs_py and gpu_name")
143+
raise ValueError(f"Option '{option}' requires GPU info; provide gpu_specs_py and gpu_name")
108144
context = {**context, **_gpu_context_from_py(_abs_path(gpu_specs_py), gpu_name)}
109-
110-
# Compose & fill
111-
compose_keys = node["compose"]
112-
prompt_text = cfg.compose_blocks(compose_keys)
113-
145+
146+
# Build the prompt from components
147+
prompt_parts = []
148+
for component in option_data["components"]:
149+
if component == "problem_statement":
150+
# Use the already-formatted problem_statement from context
151+
prompt_parts.append(context["problem_statement"])
152+
elif component == "instruction":
153+
# Use the already-formatted instruction from context
154+
prompt_parts.append(context["instruction"])
155+
elif component.startswith("hardware_"):
156+
# Hardware components from templates.hardware
157+
template_key = f"templates.hardware.{component}"
158+
prompt_parts.append(cfg.compose_blocks([template_key]))
159+
else:
160+
# Other components from templates.common
161+
template_key = f"templates.common.{component}"
162+
prompt_parts.append(cfg.compose_blocks([template_key]))
163+
164+
prompt_text = "\n".join(prompt_parts).strip() + "\n"
165+
114166
try:
115167
return prompt_text.format(**context).strip() + "\n"
116168
except KeyError as e:
117-
raise KeyError(f"Missing placeholder in context: {e.args[0]}") from e
169+
raise KeyError(f"Missing placeholder in context: {e.args[0]}. Available: {list(context.keys())}") from e
Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,59 @@
1-
# src/prompts/prompt_constructor.py (public facade; keep old imports working)
1+
# src/prompt_constructor_multilang.py (new option-based prompt constructor)
22
import os
3-
from .loader import render_prompt, _abs_path
3+
from .loader import render_prompt_by_option, _abs_path
44

55
REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
66
PROMPTS_TOML = _abs_path("src/prompts/prompts.toml")
77
GPU_SPECS_PY = "src/prompts/hardware/gpu_specs.py" # still a Python file
88

9-
def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str:
10-
return render_prompt(
9+
def get_prompt_for_language(ref_arch_src: str, language: str = "triton", option: str = "few_shot") -> str:
10+
"""
11+
Generate a prompt for a specific language and option.
12+
13+
Args:
14+
ref_arch_src: The reference architecture source code
15+
language: The kernel language (triton, cuda, cute)
16+
option: The prompt option (basic, few_shot, hardware_info)
17+
"""
18+
return render_prompt_by_option(
1119
prompts_toml=PROMPTS_TOML,
12-
backend=backend.lower(),
13-
template="default",
20+
language=language.lower(),
21+
option=option,
1422
context={"ref_arch_src": ref_arch_src},
1523
)
1624

17-
def get_prompt_with_hardware(ref_arch_src: str, backend: str, gpu_name: str) -> str:
18-
return render_prompt(
25+
def get_prompt_with_hardware(ref_arch_src: str, language: str, gpu_name: str) -> str:
26+
"""
27+
Generate a hardware-aware prompt for a specific language.
28+
29+
Args:
30+
ref_arch_src: The reference architecture source code
31+
language: The kernel language (triton, cuda, cute)
32+
gpu_name: The name of the GPU (e.g., "A100", "H100")
33+
"""
34+
return render_prompt_by_option(
1935
prompts_toml=PROMPTS_TOML,
20-
backend=backend.lower(),
21-
template="with_hardware",
36+
language=language.lower(),
37+
option="hardware_info",
2238
context={"ref_arch_src": ref_arch_src},
23-
gpu_specs_py=GPU_SPECS_PY, # <-- python file, not TOML
39+
gpu_specs_py=GPU_SPECS_PY,
2440
gpu_name=gpu_name,
2541
)
2642

27-
def prompt_fix_compile(backend: str, ref_arch_src: str, custom_kernel: str, metadata: str) -> str:
28-
return render_prompt(
43+
def prompt_fix_compile(language: str, ref_arch_src: str, custom_kernel: str, metadata: str) -> str:
44+
"""
45+
Generate a prompt to fix compilation errors.
46+
47+
Args:
48+
language: The kernel language (triton, cuda, cute)
49+
ref_arch_src: The reference architecture source code
50+
custom_kernel: The custom kernel code that failed
51+
metadata: Compilation error metadata
52+
"""
53+
return render_prompt_by_option(
2954
prompts_toml=PROMPTS_TOML,
30-
backend=backend.lower(),
31-
template="fix_compile",
55+
language=language.lower(),
56+
option="fix_compile",
3257
context={
3358
"ref_arch_src": ref_arch_src,
3459
"custom_kernel": custom_kernel,
@@ -37,11 +62,20 @@ def prompt_fix_compile(backend: str, ref_arch_src: str, custom_kernel: str, meta
3762
},
3863
)
3964

40-
def prompt_fix_correctness(backend: str, ref_arch_src: str, custom_kernel: str, metadata: str) -> str:
41-
return render_prompt(
65+
def prompt_fix_correctness(language: str, ref_arch_src: str, custom_kernel: str, metadata: str) -> str:
66+
"""
67+
Generate a prompt to fix correctness errors.
68+
69+
Args:
70+
language: The kernel language (triton, cuda, cute)
71+
ref_arch_src: The reference architecture source code
72+
custom_kernel: The custom kernel code that failed
73+
metadata: Correctness error metadata
74+
"""
75+
return render_prompt_by_option(
4276
prompts_toml=PROMPTS_TOML,
43-
backend=backend.lower(),
44-
template="fix_correctness",
77+
language=language.lower(),
78+
option="fix_correctness",
4579
context={
4680
"ref_arch_src": ref_arch_src,
4781
"custom_kernel": custom_kernel,
@@ -50,34 +84,9 @@ def prompt_fix_correctness(backend: str, ref_arch_src: str, custom_kernel: str,
5084
},
5185
)
5286

53-
# Optional legacy convenience wrappers (if callers use backend-specific names)
54-
def prompt_fix_compile_triton(ref_arch_src, custom_kernel, metadata):
55-
return prompt_fix_compile("triton", ref_arch_src, custom_kernel, metadata)
56-
57-
def prompt_fix_correctness_triton(ref_arch_src, custom_kernel, metadata):
58-
return prompt_fix_correctness("triton", ref_arch_src, custom_kernel, metadata)
59-
60-
def prompt_fix_compile_cute(ref_arch_src, custom_kernel, metadata):
61-
return prompt_fix_compile("cute", ref_arch_src, custom_kernel, metadata)
62-
63-
def prompt_fix_correctness_cute(ref_arch_src, custom_kernel, metadata):
64-
return prompt_fix_correctness("cute", ref_arch_src, custom_kernel, metadata)
65-
66-
def prompt_fix_compile_cuda(ref_arch_src, custom_kernel, metadata):
67-
return prompt_fix_compile("cuda", ref_arch_src, custom_kernel, metadata)
68-
69-
def prompt_fix_correctness_cuda(ref_arch_src, custom_kernel, metadata):
70-
return prompt_fix_correctness("cuda", ref_arch_src, custom_kernel, metadata)
71-
7287
__all__ = [
73-
"get_prompt_for_backend",
88+
"get_prompt_for_language",
7489
"get_prompt_with_hardware",
7590
"prompt_fix_compile",
7691
"prompt_fix_correctness",
77-
"prompt_fix_compile_triton",
78-
"prompt_fix_correctness_triton",
79-
"prompt_fix_compile_cute",
80-
"prompt_fix_correctness_cute",
81-
"prompt_fix_compile_cuda",
82-
"prompt_fix_correctness_cuda",
8392
]

0 commit comments

Comments
 (0)