Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions scripts/generate_and_eval_single_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor_multilang import get_prompt_for_backend
from src.prompt_constructor_multilang import get_prompt_for_language
from src.loader import get_hardware_architecture
from src.utils import (
create_inference_server_from_presets,
extract_first_code,
Expand Down Expand Up @@ -45,9 +45,14 @@ def __init__(self):
# Evaluation
# local (requires a GPU), modal (cloud GPU) coming soon
self.eval_mode = "local"

self.option = "few_shot"
self.hardware_type = "GPU"
self.hardware_name = "L40S"

# Construct this from mapping from architecture name to torch cuda arch list in the future
# you can either specify SM version or just use the name
self.gpu_arch = ["Ada"]
self.arch = get_hardware_architecture(hardware_type=self.hardware_type, hardware_name=self.hardware_name)

# Inference config
self.server_type = "deepseek"
Expand Down Expand Up @@ -91,8 +96,8 @@ def main(config: EvalConfig):
elif config.dataset_src == "local":
curr_level_dataset = construct_kernelbench_dataset(config.level)

if config.gpu_arch:
set_gpu_arch(config.gpu_arch) # otherwise build for all architectures
if config.hardware_type == "GPU" and config.arch:
set_gpu_arch(config.arch) # otherwise build for all architectures

if config.log:
os.makedirs(config.logdir, exist_ok=True)
Expand Down Expand Up @@ -146,14 +151,9 @@ def main(config: EvalConfig):
)

# Use appropriate prompt constructor based on backend
if config.backend == "cuda":
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
elif config.backend in ["triton", "cute"]: # removed "tilelang"
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
else:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'."
)
custom_prompt = get_prompt_for_language(ref_arch_src, language=config.backend,
option=config.option,
hardware_name=config.hardware_name, hardware_type=config.hardware_type)

if config.log_prompt:
with open(
Expand Down
10 changes: 2 additions & 8 deletions scripts/generate_and_eval_single_sample_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@

#from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor_multilang import get_prompt_for_backend
from src.prompt_constructor_multilang import get_prompt_for_language
from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets

app = modal.App("eval_single_sample")
Expand Down Expand Up @@ -192,12 +191,7 @@ def main(config: EvalConfig):


# Use appropriate prompt constructor based on backend
if config.backend == "cuda":
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
elif config.backend in ["triton", "cute"]: # removed "tilelang"
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
else:
raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'.")
custom_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, option="few_shot")

if config.log_prompt:
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:
Expand Down
14 changes: 2 additions & 12 deletions scripts/generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
from src.prompt_constructor_multilang import get_prompt_for_backend
from src.prompt_constructor_multilang import get_prompt_for_language
from src.utils import (
create_inference_server_from_presets,
extract_first_code,
Expand Down Expand Up @@ -120,16 +119,7 @@ def generate_sample_single(
), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"

# Construct Prompt
if config.backend == "cuda":
custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(
ref_arch_src
)
elif config.backend in ["triton", "cute"]: # removed "tilelang"
custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
else:
raise ValueError(
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'."
)
custom_cuda_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, option="few_shot")
if config.log_prompt:
prompt_path = os.path.join(
run_dir,
Expand Down
226 changes: 226 additions & 0 deletions src/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# src/loader.py
import os
import runpy
import tomli # pip install tomli
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from .utils import read_file # your existing util

REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))

def _abs_path(rel: str) -> str:
if os.path.isabs(rel):
return rel
return os.path.join(REPO_TOP_PATH, rel)

PROMPTS_TOML = _abs_path("src/prompts/prompts.toml")
HARDWARE_SPECS_TOML = _abs_path("src/prompts/hardware/hardware_specs.toml")

@dataclass
class TomlConfig:
data: Dict[str, Any]

@classmethod
def from_toml(cls, path: str) -> "TomlConfig":
"""Load a TOML file and return a TomlConfig instance."""
with open(path, "rb") as f:
data = tomli.load(f)
return cls(data)

def compose_blocks(self, keys: List[str]) -> str:
text_parts = []
for key in keys:
node: Any = self.data
for part in key.split("."):
if part not in node:
raise KeyError(f"compose key not found: {key}")
node = node[part]
if not isinstance(node, str):
raise TypeError(f"compose key must resolve to string: {key}")
text_parts.append(node.strip() + "\n")
return "\n".join(text_parts).strip() + "\n"

prompt_cfg = TomlConfig.from_toml(PROMPTS_TOML)
hardware_cfg = TomlConfig.from_toml(HARDWARE_SPECS_TOML)

def _hardware_context_from_path(hardware_type: str, hardware_name: str) -> Dict[str, str]:
"""
Load hardware spec dicts from a TOML file (.toml).
- HARDWARE_SPEC_INFO = { type: { name: { ... } } }

For TOML files we expect structure like [hardware_type.hardware_name], plus optional
[hardware_type.definitions], [hardware_type.best_practices], etc.
"""
hw_section = hardware_cfg.data.get(hardware_type)
if not isinstance(hw_section, dict):
raise KeyError(f"Hardware type '{hardware_type}' not found in specs TOML")

curr = None
if isinstance(hw_section.get(hardware_name), dict):
curr = hw_section[hardware_name]
if curr is None:
raise KeyError(
f"Hardware '{hardware_name}' not found under type '{hardware_type}' in {HARDWARE_SPECS_TOML}"
)

# definitions
definitions = {}
if isinstance(hw_section.get("definitions"), dict):
definitions.update(hw_section.get("definitions"))

# best practices
best_list: List[str] = []
if isinstance(hw_section.get("best_practices"), dict):
best_list.extend(hw_section.get("best_practices", {}).get("items", []))

# Derive architecture name from common keys
hardware_architecture = curr.get("architecture") or "Unknown"

# Build human-readable bullets for specs/definitions/best practices
specs_bullets = "\n".join(
[f"- {k}: {v}" for k, v in curr.items() if k != "architecture"]
)
defs_bullets = "\n".join([f"- {k}: {v}" for k, v in definitions.items()])
best_bullets = "\n".join([f"- {x}" for x in best_list])

return {
"hardware_type": hardware_type,
"hardware_name": hardware_name,
"hardware_architecture": hardware_architecture,
"hardware_specs_bullets": specs_bullets,
"hardware_definitions_bullets": defs_bullets,
"hardware_best_practices_bullets": best_bullets,
}

def get_hardware_architecture(
hardware_type: str, hardware_name: str
) -> str:
"""
Convenience helper: return the architecture string for a given hardware.

Args:
hardware_type: Hardware type (e.g., 'GPU', 'TT').
hardware_name: Hardware name (e.g., 'A100', 'L4', 'Grayskull').

Returns:
The architecture name as a string (e.g., 'Ampere', 'Ada', 'Grayskull').

Raises:
KeyError / ValueError propagated from the underlying loader if the
hardware entry isn't found or the file is invalid. This keeps the
behavior explicit for callers.
"""

# Use the existing loader to build the context and extract architecture
ctx = _hardware_context_from_path(hardware_type, hardware_name)
arch = ctx.get("hardware_architecture")
return arch if arch != "Unknown" else None


def render_prompt_by_option(
*,
language: str,
option: str,
context: Dict[str, str],
hardware_type: Optional[str] = None,
hardware_name: Optional[str] = None,
) -> str:
"""
New function that uses languages.X and options.Y structure from prompts.toml

Args:
prompts_toml: Path to the prompts.toml file
language: The kernel language (triton, cuda, cute)
option: The prompt option (basic, few_shot, hardware_info, fix_compile, fix_correctness)
context: Variables to fill in the prompt template
hardware_specs_py: Optional path to hardware specs file (.py or .toml)
hardware_type: Hardware type (e.g., "GPU", "Tenstorrent")
hardware_name: Hardware name (e.g., "A100", "H100")
"""
# Get language-specific content
try:
lang_data = prompt_cfg.data["languages"][language]
except KeyError:
raise KeyError(f"Unknown language: {language}")

# Get option configuration
try:
option_data = prompt_cfg.data["options"][option]
except KeyError:
raise KeyError(f"Unknown option: {option}")

# Get shared templates
shared = prompt_cfg.data.get("shared", {})
language_display = lang_data.get("language_display", language.upper())

# Fill in shared templates with language-specific terms
problem_statement = shared.get("problem_statement", "").format(
language_display=language_display
)
instruction = shared.get("instruction", "").format(
language_display=language_display
)

# Add language-specific content to context
context = {
**context,
"language": (
language.upper() if language in ["cuda", "cute"] else language.capitalize()
),
"language_display": language_display,
"problem_statement": problem_statement,
"instruction": instruction,
}

# Load example files if requested
if option_data.get("requires_example"):
# Use language-specific example arch, or fall back to shared one
ex_arch_path = _abs_path(
lang_data.get("few_shot_example_arch")
or shared.get("few_shot_example_arch")
)
ex_new_path = _abs_path(lang_data["few_shot_new_arch"])
context.update(
{
"example_arch_src": read_file(ex_arch_path),
"example_new_arch_src": read_file(ex_new_path),
}
)

# Load hardware details if requested
if option_data.get("requires_hardware"):
if not (hardware_type and hardware_name):
raise ValueError(
f"Option '{option}' requires hardware info; provide hardware_type, and hardware_name"
)
context.update(
**_hardware_context_from_path(hardware_type, hardware_name),
)

# Build the prompt from components
prompt_parts = []
for component in option_data["components"]:
if component == "problem_statement":
# Use the already-formatted problem_statement from context
prompt_parts.append(context["problem_statement"])
elif component == "instruction":
# Use the already-formatted instruction from context
prompt_parts.append(context["instruction"])
elif component.startswith("hardware_"):
# Hardware components from templates.hardware
template_key = f"templates.hardware.{component}"
prompt_parts.append(prompt_cfg.compose_blocks([template_key]))
else:
# Other components from templates.common
template_key = f"templates.common.{component}"
prompt_parts.append(prompt_cfg.compose_blocks([template_key]))

prompt_text = "\n".join(prompt_parts).strip() + "\n"

try:
return prompt_text.format(**context).strip() + "\n"
except KeyError as e:
raise KeyError(
f"Missing placeholder in context: {e.args[0]}. Available: {list(context.keys())}"
) from e
Loading