Skip to content

Commit 044db0e

Browse files
author
Sokserey Sun
committed
preliminary toml file change for prompt constructor. Will add tilelang
1 parent e395e91 commit 044db0e

File tree

6 files changed

+476
-546
lines changed

6 files changed

+476
-546
lines changed

scripts/generate_and_eval_single_sample.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,7 @@ def main(config: EvalConfig):
146146
)
147147

148148
# Use appropriate prompt constructor based on backend
149-
if config.backend == "cuda":
150-
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
151-
elif config.backend in ["triton", "cute"]: # removed "tilelang"
149+
if config.backend in ["cuda", "triton", "cute"]:
152150
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
153151
else:
154152
raise ValueError(

scripts/generate_and_eval_single_sample_modal.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,7 @@ def main(config: EvalConfig):
192192

193193

194194
# Use appropriate prompt constructor based on backend
195-
if config.backend == "cuda":
196-
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
197-
elif config.backend in ["triton", "cute"]: # removed "tilelang"
195+
if config.backend in ["cuda", "triton", "cute"]:
198196
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
199197
else:
200198
raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'.")

scripts/generate_samples.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,7 @@ def generate_sample_single(
120120
), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"
121121

122122
# Construct Prompt
123-
if config.backend == "cuda":
124-
custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(
125-
ref_arch_src
126-
)
127-
elif config.backend in ["triton", "cute"]: # removed "tilelang"
123+
if config.backend in ["cuda", "triton", "cute"]:
128124
custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
129125
else:
130126
raise ValueError(

src/loader.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# src/loader.py
2+
import os
3+
import runpy
4+
import tomli # pip install tomli
5+
from dataclasses import dataclass
6+
from typing import Any, Dict, List, Optional
7+
8+
from .utils import read_file # your existing util
9+
10+
REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
11+
12+
def _abs_path(rel: str) -> str:
13+
if os.path.isabs(rel):
14+
return rel
15+
return os.path.join(REPO_TOP_PATH, rel)
16+
17+
@dataclass
18+
class PromptConfig:
19+
data: Dict[str, Any]
20+
21+
@classmethod
22+
def from_toml(cls, path: str) -> "PromptConfig":
23+
with open(path, "rb") as f:
24+
data = tomli.load(f)
25+
return cls(data)
26+
27+
def compose_blocks(self, keys: List[str]) -> str:
28+
text_parts = []
29+
for key in keys:
30+
node: Any = self.data
31+
for part in key.split("."):
32+
if part not in node:
33+
raise KeyError(f"compose key not found: {key}")
34+
node = node[part]
35+
if not isinstance(node, str):
36+
raise TypeError(f"compose key must resolve to string: {key}")
37+
text_parts.append(node.strip() + "\n")
38+
return "\n".join(text_parts).strip() + "\n"
39+
40+
def get_template_node(self, backend: str, template: str) -> Dict[str, Any]:
41+
try:
42+
return self.data["backends"][backend]["templates"][template]
43+
except KeyError as e:
44+
raise KeyError(f"Unknown backend/template: {backend}/{template}") from e
45+
46+
def _gpu_context_from_py(py_path: str, gpu_name: str) -> Dict[str, str]:
47+
"""
48+
Load GPU_* dicts from a Python file (no exec of raw strings; use runpy).
49+
Expected globals:
50+
- GPU_SPEC_INFO: dict[str, dict]
51+
- GPU_DEFINITIONS: dict[str, str]
52+
- GPU_BEST_PRACTICES: list[str] OR {"list": [...]} for compatibility
53+
"""
54+
mod = runpy.run_path(py_path)
55+
spec_info = mod.get("GPU_SPEC_INFO", {})
56+
definitions = mod.get("GPU_DEFINITIONS", {})
57+
best = mod.get("GPU_BEST_PRACTICES", [])
58+
59+
if not spec_info or not definitions or best is None:
60+
raise ValueError("GPU_SPEC_INFO / GPU_DEFINITIONS / GPU_BEST_PRACTICES missing in gpu specs .py")
61+
62+
if isinstance(best, dict) and "list" in best:
63+
best = best["list"]
64+
65+
if gpu_name not in spec_info:
66+
raise KeyError(f"GPU name {gpu_name} not found in GPU_SPEC_INFO")
67+
68+
curr = spec_info[gpu_name]
69+
gpu_architecture = curr.get("GPU Architecture", "Unknown")
70+
specs_bullets = "\n".join([f"- We have {v} of {k}." for k, v in curr.items() if k != "GPU Architecture"])
71+
defs_bullets = "\n".join([f"- {k}: {v}" for k, v in definitions.items()])
72+
best_bullets = "\n".join([f"- {x}" for x in (best or [])])
73+
74+
return {
75+
"gpu_name": gpu_name,
76+
"gpu_architecture": gpu_architecture,
77+
"gpu_specs_bullets": specs_bullets,
78+
"gpu_definitions_bullets": defs_bullets,
79+
"gpu_best_practices_bullets": best_bullets,
80+
}
81+
82+
def render_prompt(
83+
*,
84+
prompts_toml: str,
85+
backend: str,
86+
template: str,
87+
context: Dict[str, str],
88+
gpu_specs_py: Optional[str] = None,
89+
gpu_name: Optional[str] = None,
90+
) -> str:
91+
cfg = PromptConfig.from_toml(prompts_toml)
92+
node = cfg.get_template_node(backend, template)
93+
94+
# Load example files if requested
95+
if node.get("requires_example"):
96+
ex_arch_path = _abs_path(node["example_arch_path"])
97+
ex_new_path = _abs_path(node["example_new_arch_path"])
98+
context = {
99+
**context,
100+
"example_arch_src": read_file(ex_arch_path),
101+
"example_new_arch_src": read_file(ex_new_path),
102+
}
103+
104+
# Load GPU details (from .py) if requested
105+
if node.get("requires_gpu"):
106+
if not (gpu_specs_py and gpu_name):
107+
raise ValueError("Template requires GPU info; provide gpu_specs_py and gpu_name")
108+
context = {**context, **_gpu_context_from_py(_abs_path(gpu_specs_py), gpu_name)}
109+
110+
# Compose & fill
111+
compose_keys = node["compose"]
112+
prompt_text = cfg.compose_blocks(compose_keys)
113+
114+
try:
115+
return prompt_text.format(**context).strip() + "\n"
116+
except KeyError as e:
117+
raise KeyError(f"Missing placeholder in context: {e.args[0]}") from e

0 commit comments

Comments
 (0)