1- # src/prompts/prompt_constructor .py (public facade; keep old imports working )
1+ # src/prompt_constructor_multilang .py (new option-based prompt constructor )
22import os
3- from .loader import render_prompt , _abs_path
3+ from .loader import render_prompt_by_option , _abs_path
44
55REPO_TOP_PATH = os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , ".." ))
66PROMPTS_TOML = _abs_path ("src/prompts/prompts.toml" )
77GPU_SPECS_PY = "src/prompts/hardware/gpu_specs.py" # still a Python file
88
9- def get_prompt_for_backend (ref_arch_src : str , backend : str = "triton" ) -> str :
10- return render_prompt (
9+ def get_prompt_for_language (ref_arch_src : str , language : str = "triton" , option : str = "few_shot" ) -> str :
10+ """
11+ Generate a prompt for a specific language and option.
12+
13+ Args:
14+ ref_arch_src: The reference architecture source code
15+ language: The kernel language (triton, cuda, cute)
16+ option: The prompt option (basic, few_shot, hardware_info)
17+ """
18+ return render_prompt_by_option (
1119 prompts_toml = PROMPTS_TOML ,
12- backend = backend .lower (),
13- template = "default" ,
20+ language = language .lower (),
21+ option = option ,
1422 context = {"ref_arch_src" : ref_arch_src },
1523 )
1624
17- def get_prompt_with_hardware (ref_arch_src : str , backend : str , gpu_name : str ) -> str :
18- return render_prompt (
25+ def get_prompt_with_hardware (ref_arch_src : str , language : str , gpu_name : str ) -> str :
26+ """
27+ Generate a hardware-aware prompt for a specific language.
28+
29+ Args:
30+ ref_arch_src: The reference architecture source code
31+ language: The kernel language (triton, cuda, cute)
32+ gpu_name: The name of the GPU (e.g., "A100", "H100")
33+ """
34+ return render_prompt_by_option (
1935 prompts_toml = PROMPTS_TOML ,
20- backend = backend .lower (),
21- template = "with_hardware " ,
36+ language = language .lower (),
37+ option = "hardware_info " ,
2238 context = {"ref_arch_src" : ref_arch_src },
23- gpu_specs_py = GPU_SPECS_PY , # <-- python file, not TOML
39+ gpu_specs_py = GPU_SPECS_PY ,
2440 gpu_name = gpu_name ,
2541 )
2642
27- def prompt_fix_compile (backend : str , ref_arch_src : str , custom_kernel : str , metadata : str ) -> str :
28- return render_prompt (
43+ def prompt_fix_compile (language : str , ref_arch_src : str , custom_kernel : str , metadata : str ) -> str :
44+ """
45+ Generate a prompt to fix compilation errors.
46+
47+ Args:
48+ language: The kernel language (triton, cuda, cute)
49+ ref_arch_src: The reference architecture source code
50+ custom_kernel: The custom kernel code that failed
51+ metadata: Compilation error metadata
52+ """
53+ return render_prompt_by_option (
2954 prompts_toml = PROMPTS_TOML ,
30- backend = backend .lower (),
31- template = "fix_compile" ,
55+ language = language .lower (),
56+ option = "fix_compile" ,
3257 context = {
3358 "ref_arch_src" : ref_arch_src ,
3459 "custom_kernel" : custom_kernel ,
@@ -37,11 +62,20 @@ def prompt_fix_compile(backend: str, ref_arch_src: str, custom_kernel: str, meta
3762 },
3863 )
3964
40- def prompt_fix_correctness (backend : str , ref_arch_src : str , custom_kernel : str , metadata : str ) -> str :
41- return render_prompt (
65+ def prompt_fix_correctness (language : str , ref_arch_src : str , custom_kernel : str , metadata : str ) -> str :
66+ """
67+ Generate a prompt to fix correctness errors.
68+
69+ Args:
70+ language: The kernel language (triton, cuda, cute)
71+ ref_arch_src: The reference architecture source code
72+ custom_kernel: The custom kernel code that failed
73+ metadata: Correctness error metadata
74+ """
75+ return render_prompt_by_option (
4276 prompts_toml = PROMPTS_TOML ,
43- backend = backend .lower (),
44- template = "fix_correctness" ,
77+ language = language .lower (),
78+ option = "fix_correctness" ,
4579 context = {
4680 "ref_arch_src" : ref_arch_src ,
4781 "custom_kernel" : custom_kernel ,
@@ -50,34 +84,9 @@ def prompt_fix_correctness(backend: str, ref_arch_src: str, custom_kernel: str,
5084 },
5185 )
5286
53- # Optional legacy convenience wrappers (if callers use backend-specific names)
54- def prompt_fix_compile_triton (ref_arch_src , custom_kernel , metadata ):
55- return prompt_fix_compile ("triton" , ref_arch_src , custom_kernel , metadata )
56-
57- def prompt_fix_correctness_triton (ref_arch_src , custom_kernel , metadata ):
58- return prompt_fix_correctness ("triton" , ref_arch_src , custom_kernel , metadata )
59-
60- def prompt_fix_compile_cute (ref_arch_src , custom_kernel , metadata ):
61- return prompt_fix_compile ("cute" , ref_arch_src , custom_kernel , metadata )
62-
63- def prompt_fix_correctness_cute (ref_arch_src , custom_kernel , metadata ):
64- return prompt_fix_correctness ("cute" , ref_arch_src , custom_kernel , metadata )
65-
66- def prompt_fix_compile_cuda (ref_arch_src , custom_kernel , metadata ):
67- return prompt_fix_compile ("cuda" , ref_arch_src , custom_kernel , metadata )
68-
69- def prompt_fix_correctness_cuda (ref_arch_src , custom_kernel , metadata ):
70- return prompt_fix_correctness ("cuda" , ref_arch_src , custom_kernel , metadata )
71-
7287__all__ = [
73- "get_prompt_for_backend " ,
88+ "get_prompt_for_language " ,
7489 "get_prompt_with_hardware" ,
7590 "prompt_fix_compile" ,
7691 "prompt_fix_correctness" ,
77- "prompt_fix_compile_triton" ,
78- "prompt_fix_correctness_triton" ,
79- "prompt_fix_compile_cute" ,
80- "prompt_fix_correctness_cute" ,
81- "prompt_fix_compile_cuda" ,
82- "prompt_fix_correctness_cuda" ,
8392]
0 commit comments