diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index ff71e4bc..15519d91 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -9,8 +9,8 @@ from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.prompt_constructor_multilang import get_prompt_for_backend +from src.prompt_constructor_multilang import get_prompt_for_language +from src.loader import get_hardware_architecture from src.utils import ( create_inference_server_from_presets, extract_first_code, @@ -45,9 +45,14 @@ def __init__(self): # Evaluation # local (requires a GPU), modal (cloud GPU) coming soon self.eval_mode = "local" + + self.option = "few_shot" + self.hardware_type = "GPU" + self.hardware_name = "L40S" + # Construct this from mapping from architecture name to torch cuda arch list in the future # you can either specify SM version or just use the name - self.gpu_arch = ["Ada"] + self.arch = get_hardware_architecture(hardware_type=self.hardware_type, hardware_name=self.hardware_name) # Inference config self.server_type = "deepseek" @@ -91,8 +96,8 @@ def main(config: EvalConfig): elif config.dataset_src == "local": curr_level_dataset = construct_kernelbench_dataset(config.level) - if config.gpu_arch: - set_gpu_arch(config.gpu_arch) # otherwise build for all architectures + if config.hardware_type == "GPU" and config.arch: + set_gpu_arch(config.arch) # otherwise build for all architectures if config.log: os.makedirs(config.logdir, exist_ok=True) @@ -146,14 +151,9 @@ def main(config: EvalConfig): ) # Use appropriate prompt constructor based on backend - if config.backend == "cuda": - custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) - elif config.backend in ["triton", "cute"]: # removed "tilelang" - custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) - else: - raise ValueError( - f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'." - ) + custom_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, + option=config.option, + hardware_name=config.hardware_name, hardware_type=config.hardware_type) if config.log_prompt: with open( diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index e9e0866a..82b7b545 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -15,8 +15,7 @@ #from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.prompt_constructor_multilang import get_prompt_for_backend +from src.prompt_constructor_multilang import get_prompt_for_language from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets app = modal.App("eval_single_sample") @@ -192,12 +191,7 @@ def main(config: EvalConfig): # Use appropriate prompt constructor based on backend - if config.backend == "cuda": - custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src) - elif config.backend in ["triton", "cute"]: # removed "tilelang" - custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend) - else: - raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'.") + custom_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, option="few_shot") if config.log_prompt: with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f: diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 5ee217cf..cb6b8aa9 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -10,8 +10,7 @@ from src.dataset import construct_kernelbench_dataset from src.eval import eval_kernel_against_ref -from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template -from src.prompt_constructor_multilang import get_prompt_for_backend +from src.prompt_constructor_multilang import get_prompt_for_language from src.utils import ( create_inference_server_from_presets, extract_first_code, @@ -120,16 +119,7 @@ def generate_sample_single( ), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" # Construct Prompt - if config.backend == "cuda": - custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template( - ref_arch_src - ) - elif config.backend in ["triton", "cute"]: # removed "tilelang" - custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend) - else: - raise ValueError( - f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', or 'cute'." - ) + custom_cuda_prompt = get_prompt_for_language(ref_arch_src, language=config.backend, option="few_shot") if config.log_prompt: prompt_path = os.path.join( run_dir, diff --git a/src/loader.py b/src/loader.py new file mode 100644 index 00000000..28f7ace4 --- /dev/null +++ b/src/loader.py @@ -0,0 +1,226 @@ +# src/loader.py +import os +import runpy +import tomli # pip install tomli +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from .utils import read_file # your existing util + +REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + +def _abs_path(rel: str) -> str: + if os.path.isabs(rel): + return rel + return os.path.join(REPO_TOP_PATH, rel) + +PROMPTS_TOML = _abs_path("src/prompts/prompts.toml") +HARDWARE_SPECS_TOML = _abs_path("src/prompts/hardware/hardware_specs.toml") + +@dataclass +class TomlConfig: + data: Dict[str, Any] + + @classmethod + def from_toml(cls, path: str) -> "TomlConfig": + """Load a TOML file and return a TomlConfig instance.""" + with open(path, "rb") as f: + data = tomli.load(f) + return cls(data) + + def compose_blocks(self, keys: List[str]) -> str: + text_parts = [] + for key in keys: + node: Any = self.data + for part in key.split("."): + if part not in node: + raise KeyError(f"compose key not found: {key}") + node = node[part] + if not isinstance(node, str): + raise TypeError(f"compose key must resolve to string: {key}") + text_parts.append(node.strip() + "\n") + return "\n".join(text_parts).strip() + "\n" + +prompt_cfg = TomlConfig.from_toml(PROMPTS_TOML) +hardware_cfg = TomlConfig.from_toml(HARDWARE_SPECS_TOML) + +def _hardware_context_from_path(hardware_type: str, hardware_name: str) -> Dict[str, str]: + """ + Load hardware spec dicts from a TOML file (.toml). + - HARDWARE_SPEC_INFO = { type: { name: { ... } } } + + For TOML files we expect structure like [hardware_type.hardware_name], plus optional + [hardware_type.definitions], [hardware_type.best_practices], etc. + """ + hw_section = hardware_cfg.data.get(hardware_type) + if not isinstance(hw_section, dict): + raise KeyError(f"Hardware type '{hardware_type}' not found in specs TOML") + + curr = None + if isinstance(hw_section.get(hardware_name), dict): + curr = hw_section[hardware_name] + if curr is None: + raise KeyError( + f"Hardware '{hardware_name}' not found under type '{hardware_type}' in {HARDWARE_SPECS_TOML}" + ) + + # definitions + definitions = {} + if isinstance(hw_section.get("definitions"), dict): + definitions.update(hw_section.get("definitions")) + + # best practices + best_list: List[str] = [] + if isinstance(hw_section.get("best_practices"), dict): + best_list.extend(hw_section.get("best_practices", {}).get("items", [])) + + # Derive architecture name from common keys + hardware_architecture = curr.get("architecture") or "Unknown" + + # Build human-readable bullets for specs/definitions/best practices + specs_bullets = "\n".join( + [f"- {k}: {v}" for k, v in curr.items() if k != "architecture"] + ) + defs_bullets = "\n".join([f"- {k}: {v}" for k, v in definitions.items()]) + best_bullets = "\n".join([f"- {x}" for x in best_list]) + + return { + "hardware_type": hardware_type, + "hardware_name": hardware_name, + "hardware_architecture": hardware_architecture, + "hardware_specs_bullets": specs_bullets, + "hardware_definitions_bullets": defs_bullets, + "hardware_best_practices_bullets": best_bullets, + } + +def get_hardware_architecture( + hardware_type: str, hardware_name: str +) -> str: + """ + Convenience helper: return the architecture string for a given hardware. + + Args: + hardware_type: Hardware type (e.g., 'GPU', 'TT'). + hardware_name: Hardware name (e.g., 'A100', 'L4', 'Grayskull'). + + Returns: + The architecture name as a string (e.g., 'Ampere', 'Ada', 'Grayskull'). + + Raises: + KeyError / ValueError propagated from the underlying loader if the + hardware entry isn't found or the file is invalid. This keeps the + behavior explicit for callers. + """ + + # Use the existing loader to build the context and extract architecture + ctx = _hardware_context_from_path(hardware_type, hardware_name) + arch = ctx.get("hardware_architecture") + return arch if arch != "Unknown" else None + + +def render_prompt_by_option( + *, + language: str, + option: str, + context: Dict[str, str], + hardware_type: Optional[str] = None, + hardware_name: Optional[str] = None, +) -> str: + """ + New function that uses languages.X and options.Y structure from prompts.toml + + Args: + prompts_toml: Path to the prompts.toml file + language: The kernel language (triton, cuda, cute) + option: The prompt option (basic, few_shot, hardware_info, fix_compile, fix_correctness) + context: Variables to fill in the prompt template + hardware_specs_py: Optional path to hardware specs file (.py or .toml) + hardware_type: Hardware type (e.g., "GPU", "Tenstorrent") + hardware_name: Hardware name (e.g., "A100", "H100") + """ + # Get language-specific content + try: + lang_data = prompt_cfg.data["languages"][language] + except KeyError: + raise KeyError(f"Unknown language: {language}") + + # Get option configuration + try: + option_data = prompt_cfg.data["options"][option] + except KeyError: + raise KeyError(f"Unknown option: {option}") + + # Get shared templates + shared = prompt_cfg.data.get("shared", {}) + language_display = lang_data.get("language_display", language.upper()) + + # Fill in shared templates with language-specific terms + problem_statement = shared.get("problem_statement", "").format( + language_display=language_display + ) + instruction = shared.get("instruction", "").format( + language_display=language_display + ) + + # Add language-specific content to context + context = { + **context, + "language": ( + language.upper() if language in ["cuda", "cute"] else language.capitalize() + ), + "language_display": language_display, + "problem_statement": problem_statement, + "instruction": instruction, + } + + # Load example files if requested + if option_data.get("requires_example"): + # Use language-specific example arch, or fall back to shared one + ex_arch_path = _abs_path( + lang_data.get("few_shot_example_arch") + or shared.get("few_shot_example_arch") + ) + ex_new_path = _abs_path(lang_data["few_shot_new_arch"]) + context.update( + { + "example_arch_src": read_file(ex_arch_path), + "example_new_arch_src": read_file(ex_new_path), + } + ) + + # Load hardware details if requested + if option_data.get("requires_hardware"): + if not (hardware_type and hardware_name): + raise ValueError( + f"Option '{option}' requires hardware info; provide hardware_type, and hardware_name" + ) + context.update( + **_hardware_context_from_path(hardware_type, hardware_name), + ) + + # Build the prompt from components + prompt_parts = [] + for component in option_data["components"]: + if component == "problem_statement": + # Use the already-formatted problem_statement from context + prompt_parts.append(context["problem_statement"]) + elif component == "instruction": + # Use the already-formatted instruction from context + prompt_parts.append(context["instruction"]) + elif component.startswith("hardware_"): + # Hardware components from templates.hardware + template_key = f"templates.hardware.{component}" + prompt_parts.append(prompt_cfg.compose_blocks([template_key])) + else: + # Other components from templates.common + template_key = f"templates.common.{component}" + prompt_parts.append(prompt_cfg.compose_blocks([template_key])) + + prompt_text = "\n".join(prompt_parts).strip() + "\n" + + try: + return prompt_text.format(**context).strip() + "\n" + except KeyError as e: + raise KeyError( + f"Missing placeholder in context: {e.args[0]}. Available: {list(context.keys())}" + ) from e diff --git a/src/prompt_constructor_multilang.py b/src/prompt_constructor_multilang.py index 39d16243..6b3ed136 100644 --- a/src/prompt_constructor_multilang.py +++ b/src/prompt_constructor_multilang.py @@ -1,553 +1,73 @@ +# src/prompt_constructor_multilang.py (new option-based prompt constructor) import os -from .utils import read_file +from .loader import render_prompt_by_option -""" -Multi-Language Prompt Constructor - -Supports: Triton, CuTe (TileLang currently disabled/commented out) - -Design principles: -- To evaluate base model performance on KernelBench, we use the simplest prompt possible to guide model output to generated desired output format. -- However, we do not do extensive prompt engineering or few-shot examples in the LLM to steer behaviour. -""" - -REPO_TOP_PATH = os.path.abspath( - os.path.join( - os.path.dirname(__file__), - "..", - ) -) -KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") - - -def get_arch_definition_from_file(arch_path): - arch_src = read_file(arch_path) - return get_arch_definition(arch_src) - - -def get_arch_definition(arch_src): - """ - Construct torch definition from original torch nn.Module definition - """ - prompt = f"Here is a pytorch defintion of a neural network architecture in the file model.py: ```{arch_src}```\n" - return prompt - - -################################################################################ -# Triton Backend -################################################################################ - -TRITON_PROBLEM_STATEMENT = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups. \n - You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -TRITON_PROBLEM_INSTRUCTION = """ -Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - -TRITON_PROBLEM_STATEMENT_CLEANED = """You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -TRITON_PROBLEM_INSTRUCTION_CLEANED = """ -Optimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - - -def prompt_generate_custom_triton( - arc_src: str, example_arch_src: str, example_new_arch_src: str -) -> str: - prompt = TRITON_PROBLEM_STATEMENT - - assert ( - "@triton.jit" in example_new_arch_src - ), "Example new arch must contain Triton kernel" - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom Triton kernels looks like this: \n - ``` - {example_new_arch_src} - ``` \n - """ - - prompt += f""" - You are given the following architecture: \n - ``` - {arc_src} - ``` - """ - prompt += TRITON_PROBLEM_INSTRUCTION - return prompt - - -def prompt_generate_custom_triton_fewshot_and_template( - ref_arch_src: str, shots: list -) -> str: - raise NotImplementedError("This function has not been implemented yet") - - -def prompt_generate_ex_with_CoT_template_triton(ref_arch_src: str, cot_example: str) -> str: - raise NotImplementedError("This function has not been implemented yet") - - -def prompt_generate_custom_triton_from_prompt_template(ref_arch_src: str) -> str: +def get_prompt_for_language(ref_arch_src: str, + language: str = "triton", + option: str = "few_shot", + hardware_name: str = None, + hardware_type: str = None) -> str: """ - Using prompt example (an element-wise addition) for prompt templates - The most basic form of example just to show LLM the task and the expected output format - """ - arch = ref_arch_src - - # path to prompt template, show an example of Model (torch specifications) and ModelNew (torch + custom Triton kernels) - example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" - ) - - if not os.path.exists(example_arch_path): - raise FileNotFoundError( - f"Example architecture file not found: {example_arch_path}" - ) - if not os.path.exists(example_new_arch_path): - raise FileNotFoundError( - f"Example new architecture file not found: {example_new_arch_path}" - ) - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - - return prompt_generate_custom_triton(arch, example_arch, example_new_arch) - - -def prompt_generate_prompt_with_hardware_info_from_template_triton( - ref_arch_src: str, gpu_name: str -) -> str: - """ - Similar to prompt_generate_custom_triton_from_prompt_template, - but with hardware information for the given GPU - """ - arch = ref_arch_src - - example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add_triton.py" - ) - gpu_spec_file_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/hardware/gpu_specs.py" - ) - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - gpu_spec_info = read_file(gpu_spec_file_path) - - return prompt_generate_prompt_with_hardware_info_triton( - ref_arch_src=arch, - gpu_name=gpu_name, - example_arch_src=example_arch, - example_new_arch_src=example_new_arch, - gpu_spec_info_src=gpu_spec_info, + Generate a prompt for a specific language and option. + + Args: + ref_arch_src: The reference architecture source code + language: The kernel language (triton, cuda, cute) + option: The prompt option (basic, few_shot, hardware_info) + """ + return render_prompt_by_option( + language=language.lower(), + option=option, + context={"ref_arch_src": ref_arch_src}, + hardware_type=hardware_type, + hardware_name=hardware_name, ) - -def prompt_generate_prompt_with_hardware_info_triton( - ref_arch_src: str, - gpu_name: str, - example_arch_src: str, - example_new_arch_src: str, - gpu_spec_info_src: str, -) -> str: +def prompt_fix_compile(language: str, ref_arch_src: str, custom_kernel: str, metadata: str) -> str: """ - Generate a prompt with hardware information for the given GPU - gpu_spec_info_src: str of the gpu spec src file - """ - local_dict = {} - exec(gpu_spec_info_src, {}, local_dict) - - GPU_SPEC_INFO = local_dict.get("GPU_SPEC_INFO") - GPU_DEFINITIONS = local_dict.get("GPU_DEFINITIONS") - GPU_BEST_PRACTICES = local_dict.get("GPU_BEST_PRACTICES") - - if not GPU_SPEC_INFO or not GPU_DEFINITIONS or not GPU_BEST_PRACTICES: - raise ValueError( - "GPU_SPEC_INFO or GPU_DEFINITIONS or GPU_BEST_PRACTICES not found in gpu_spec_info_src" - ) - - assert gpu_name in GPU_SPEC_INFO, f"GPU name {gpu_name} not found in GPU_SPEC_INFO" - - prompt = TRITON_PROBLEM_STATEMENT - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom Triton kernels in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom Triton kernels looks like this: - ``` - {example_new_arch_src} - ``` \n - """ - - curr_gpu_spec_info = GPU_SPEC_INFO[gpu_name] - gpu_architecture = curr_gpu_spec_info.get("GPU Architecture") - prompt += f""" - Here is some information about the underlying hardware that you should keep in mind. \n\n -The GPU that will run the kernel is NVIDIA {gpu_name}, {gpu_architecture} architecture.\n\n""" - - for key, value in curr_gpu_spec_info.items(): - if key == "GPU Architecture": - continue - prompt += f"""- We have {value} of {key}.\n""" - - prompt += f"""\n\n -Here are some concepts about the GPU architecture that could be helpful: \n\n""" - for key, value in GPU_DEFINITIONS.items(): - prompt += f"""- {key}: {value}\n""" - - prompt += f"""\n\n -Here are some best practices for writing Triton kernels on GPU: \n\n""" - for best_practice in GPU_BEST_PRACTICES: - prompt += f"""- {best_practice}\n""" - - prompt += f""" - You are given the following architecture: \n - ``` - {ref_arch_src} - ``` - """ - - prompt += TRITON_PROBLEM_INSTRUCTION - return prompt - - -def prompt_fix_compile_triton(ref_arch_src, custom_kernel, metadata): - prompt = TRITON_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed to compile: - ``` - {custom_kernel} - ``` - Here's the metadata of the compilation error: - ``` - {metadata} - ``` + Generate a prompt to fix compilation errors. - Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -def prompt_fix_correctness_triton(ref_arch_src, custom_kernel, metadata): - prompt = TRITON_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed correctness: - ``` - {custom_kernel} - ``` - Here's the metadata of the correctness error: - ``` - {metadata} - ``` - Please consider how your custom Triton kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -################################################################################ -# TileLang Backend - COMMENTED OUT (not working currently) -################################################################################ - -# TILELANG_PROBLEM_STATEMENT = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups. \n -# You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -# """ -# -# TILELANG_PROBLEM_INSTRUCTION = """ -# Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -# """ -# -# TILELANG_PROBLEM_STATEMENT_CLEANED = """You write custom TileLang kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom TileLang kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -# """ -# -# TILELANG_PROBLEM_INSTRUCTION_CLEANED = """ -# Optimize the architecture named Model with custom TileLang kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -# """ -# -# -# def prompt_generate_custom_tilelang( -# arc_src: str, example_arch_src: str, example_new_arch_src: str -# ) -> str: -# prompt = TILELANG_PROBLEM_STATEMENT -# -# if example_arch_src != "" and example_new_arch_src != "": -# prompt += f""" -# Here's an example to show you the syntax of inline embedding custom TileLang kernels in torch: The example given architecture is: \n -# ``` \n -# {example_arch_src} -# ``` \n -# The example new arch with custom TileLang kernels looks like this: \n -# ``` -# {example_new_arch_src} -# ``` \n -# """ -# -# prompt += f""" -# You are given the following architecture: \n -# ``` -# {arc_src} -# ``` -# """ -# prompt += TILELANG_PROBLEM_INSTRUCTION -# return prompt -# -# -# def prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src: str) -> str: -# """ -# Using prompt example for TileLang -# Note: You'll need to create a TileLang example file similar to the Triton one -# """ -# arch = ref_arch_src -# -# # TODO: Create model_new_ex_add_tilelang.py example file -# example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") -# example_new_arch_path = os.path.join( -# REPO_TOP_PATH, f"src/prompts/model_new_ex_add_tilelang.py" -# ) -# -# if not os.path.exists(example_arch_path): -# raise FileNotFoundError( -# f"Example architecture file not found: {example_arch_path}" -# ) -# if not os.path.exists(example_new_arch_path): -# # For now, use a basic template without examples if file doesn't exist -# return prompt_generate_custom_tilelang(arch, "", "") -# -# example_arch = read_file(example_arch_path) -# example_new_arch = read_file(example_new_arch_path) -# -# return prompt_generate_custom_tilelang(arch, example_arch, example_new_arch) -# -# -# def prompt_fix_compile_tilelang(ref_arch_src, custom_kernel, metadata): -# prompt = TILELANG_PROBLEM_STATEMENT -# prompt += f""" -# With the following architecture: -# ``` -# {ref_arch_src} -# ``` -# You generated the following solution and it failed to compile: -# ``` -# {custom_kernel} -# ``` -# Here's the metadata of the compilation error: -# ``` -# {metadata} -# ``` -# -# Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. -# """ -# return prompt -# -# -# def prompt_fix_correctness_tilelang(ref_arch_src, custom_kernel, metadata): -# prompt = TILELANG_PROBLEM_STATEMENT -# prompt += f""" -# With the following architecture: -# ``` -# {ref_arch_src} -# ``` -# You generated the following solution and it failed correctness: -# ``` -# {custom_kernel} -# ``` -# Here's the metadata of the correctness error: -# ``` -# {metadata} -# ``` -# Please consider how your custom TileLang kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. -# """ -# return prompt - - -################################################################################ -# CuTe Backend -################################################################################ - -CUTE_PROBLEM_STATEMENT = """You write custom CuTe (CUTLASS) kernels to replace the pytorch operators in the given architecture to get speedups. \n - You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CuTe kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -CUTE_PROBLEM_INSTRUCTION = """ -Optimize the architecture named Model with custom CuTe (CUTLASS) kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - -CUTE_PROBLEM_STATEMENT_CLEANED = """You write custom CuTe (CUTLASS) kernels to replace the pytorch operators in the given architecture to get speedups.\n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom CuTe kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n -""" - -CUTE_PROBLEM_INSTRUCTION_CLEANED = """ -Optimize the architecture named Model with custom CuTe (CUTLASS) kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n -""" - - -def prompt_generate_custom_cute( - arc_src: str, example_arch_src: str, example_new_arch_src: str -) -> str: - prompt = CUTE_PROBLEM_STATEMENT - - if example_arch_src != "" and example_new_arch_src != "": - prompt += f""" - Here's an example to show you the syntax of inline embedding custom CuTe (CUTLASS) kernels in torch: The example given architecture is: \n - ``` \n - {example_arch_src} - ``` \n - The example new arch with custom CuTe kernels looks like this: \n - ``` - {example_new_arch_src} - ``` \n - """ - - prompt += f""" - You are given the following architecture: \n - ``` - {arc_src} - ``` - """ - prompt += CUTE_PROBLEM_INSTRUCTION - return prompt - - -def prompt_generate_custom_cute_from_prompt_template(ref_arch_src: str) -> str: - """ - Using prompt example for CuTe - Note: You'll need to create a CuTe example file - """ - arch = ref_arch_src - - # TODO: Create model_new_ex_add_cute.py example file - example_arch_path = os.path.join(REPO_TOP_PATH, f"src/prompts/model_ex_add.py") - example_new_arch_path = os.path.join( - REPO_TOP_PATH, f"src/prompts/model_new_ex_add_cute.py" + Args: + language: The kernel language (triton, cuda, cute) + ref_arch_src: The reference architecture source code + custom_kernel: The custom kernel code that failed + metadata: Compilation error metadata + """ + return render_prompt_by_option( + language=language.lower(), + option="fix_compile", + context={ + "ref_arch_src": ref_arch_src, + "custom_kernel": custom_kernel, + "metadata": metadata, + "failure_type": "to compile", + }, ) - if not os.path.exists(example_arch_path): - raise FileNotFoundError( - f"Example architecture file not found: {example_arch_path}" - ) - if not os.path.exists(example_new_arch_path): - # For now, use a basic template without examples if file doesn't exist - return prompt_generate_custom_cute(arch, "", "") - - example_arch = read_file(example_arch_path) - example_new_arch = read_file(example_new_arch_path) - - return prompt_generate_custom_cute(arch, example_arch, example_new_arch) - - -def prompt_fix_compile_cute(ref_arch_src, custom_kernel, metadata): - prompt = CUTE_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed to compile: - ``` - {custom_kernel} - ``` - Here's the metadata of the compilation error: - ``` - {metadata} - ``` - - Please fix the compilation error in the new model code. Please output the corrected code in codeblocks. - """ - return prompt - - -def prompt_fix_correctness_cute(ref_arch_src, custom_kernel, metadata): - prompt = CUTE_PROBLEM_STATEMENT - prompt += f""" - With the following architecture: - ``` - {ref_arch_src} - ``` - You generated the following solution and it failed correctness: - ``` - {custom_kernel} - ``` - Here's the metadata of the correctness error: - ``` - {metadata} - ``` - Please consider how your custom CuTe kernels are implemented, how it is different from the reference implementation, and fix the correctness error in the new model code. Please output the corrected code in codeblocks. +def prompt_fix_correctness(language: str, ref_arch_src: str, custom_kernel: str, metadata: str) -> str: """ - return prompt - - -################################################################################ -# Unified API -################################################################################ - -def get_prompt_for_backend(ref_arch_src: str, backend: str = "triton") -> str: - """ - Unified API to get prompt for any supported backend + Generate a prompt to fix correctness errors. Args: - ref_arch_src: Reference architecture source code - backend: One of 'triton', 'cute' (tilelang removed - not working) - - Returns: - Prompt string for the specified backend - """ - backend_lower = backend.lower() - - if backend_lower == "triton": - return prompt_generate_custom_triton_from_prompt_template(ref_arch_src) - # elif backend_lower == "tilelang": - # return prompt_generate_custom_tilelang_from_prompt_template(ref_arch_src) - elif backend_lower == "cute": - return prompt_generate_custom_cute_from_prompt_template(ref_arch_src) - else: - raise ValueError( - f"Unsupported backend: {backend}. Must be one of: 'triton', 'cute'" - ) - - -################################################################################ -# Main (for testing) -################################################################################ - -def main(): - gpu_name = "L40S" - backend = "triton" # Change this to test different backends - - ref_arch_src = read_file(os.path.join(KERNEL_BENCH_PATH, f"level1/19_ReLU.py")) - assert len(ref_arch_src) > 0, "ref_arch_src is empty" - - prompt = get_prompt_for_backend(ref_arch_src, backend) - print(f"\n{'='*80}\n{backend.upper()} PROMPT:\n{'='*80}\n") - print(prompt) - - # Write prompt to temp file - temp_file_path = os.path.join(REPO_TOP_PATH, "scratch", f"prompt_{backend}_draft.txt") - os.makedirs(os.path.dirname(temp_file_path), exist_ok=True) - with open(temp_file_path, "w") as f: - f.write(prompt) - print(f"\nPrompt written to: {temp_file_path}") - - -if __name__ == "__main__": - main() - - + language: The kernel language (triton, cuda, cute) + ref_arch_src: The reference architecture source code + custom_kernel: The custom kernel code that failed + metadata: Correctness error metadata + """ + return render_prompt_by_option( + language=language.lower(), + option="fix_correctness", + context={ + "ref_arch_src": ref_arch_src, + "custom_kernel": custom_kernel, + "metadata": metadata, + "failure_type": "correctness", + }, + ) +__all__ = [ + "get_prompt_for_language", + "get_prompt_with_hardware", + "prompt_fix_compile", + "prompt_fix_correctness", +] diff --git a/src/prompts/hardware/hardware_specs.toml b/src/prompts/hardware/hardware_specs.toml new file mode 100644 index 00000000..fe43c6c4 --- /dev/null +++ b/src/prompts/hardware/hardware_specs.toml @@ -0,0 +1,198 @@ +# ------------------------------------------------------------------------- +# Hardware Types: GPU, Tenstorrent, etc. +# ------------------------------------------------------------------------- +[GPU] +description = "NVIDIA GPU specifications and best practices" + +[GPU.L40S] +architecture = "Ada" +GPU_Memory = "48GB GDDR6 with ECC" +Memory_Bandwidth = "864 GB/s" +RT_Core_Performance_TFLOPS = "212" +FP32_TFLOPS = "91.6" +TF32_Tensor_Core_TFLOPS = "183.2 (366 with sparsity)" +FP16_Tensor_Core_TFLOPS = "362.05 (733 with sparsity)" +FP8_Tensor_Core_TFLOPS = "733 (1466 with sparsity)" +Peak_INT8_Tensor_TOPS = "733 (1466 with sparsity)" +Peak_INT4_Tensor_TOPS = "733 (1466 with sparsity)" +Register_File_Size = "64K 32-bit registers per SM" +Maximum_registers_per_thread = "255" +Maximum_thread_blocks_per_SM = "24" +Shared_memory_capacity_per_SM = "100 KB" +Maximum_shared_memory_per_thread_block = "99 KB" + +[GPU.H100] +architecture = "Hopper" +GPU_Memory = "80GB" +Memory_Bandwidth = "3.35 TB/s" +FP64_TFLOPS = "34" +FP64_Tensor_Core_TFLOPS = "67" +FP32_TFLOPS = "67" +TF32_Tensor_Core_TFLOPS = "989 with sparsity" +BFLOAT16_Tensor_Core_TFLOPS = "1979 with sparsity" +FP16_Tensor_Core_TFLOPS = "1979 with sparsity" +FP8_Tensor_Core_TFLOPS = "3958 with sparsity" +INT8_Tensor_Core_TOPS = "3958 with sparsity" +Register_File_Size = "64K 32-bit registers per SM" +Maximum_registers_per_thread = "255" +Maximum_thread_blocks_per_SM = "32" +Shared_memory_capacity_per_SM = "228 KB" +Maximum_shared_memory_per_thread_block = "227 KB" + +[GPU.A100] +architecture = "Ampere" +GPU_Memory = "40GB" +Memory_Bandwidth = "1935 GB/s" +FP64_TFLOPS = "9.7" +FP64_Tensor_Core_TFLOPS = "19.5" +FP32_TFLOPS = "19.5" +TF32_Tensor_Core_TFLOPS = "156 (312 with sparsity)" +BFLOAT16_Tensor_Core_TFLOPS = "312 (624 with sparsity)" +FP16_Tensor_Core_TFLOPS = "312 (624 with sparsity)" +INT8_Tensor_Core_TOPS = "624 (1248 with sparsity)" +Register_File_Size = "64K 32-bit registers per SM" +Maximum_registers_per_thread = "255" +Maximum_thread_blocks_per_SM = "32" +Shared_memory_capacity_per_SM = "164 KB" +Maximum_shared_memory_per_thread_block = "163 KB" + +[GPU.A100-80GB] +architecture = "Ampere" +GPU_Memory = "80GB" +Memory_Bandwidth = "1935 GB/s" +FP64_TFLOPS = "9.7" +FP64_Tensor_Core_TFLOPS = "19.5" +FP32_TFLOPS = "19.5" +TF32_Tensor_Core_TFLOPS = "156 (312 with sparsity)" +BFLOAT16_Tensor_Core_TFLOPS = "312 (624 with sparsity)" +FP16_Tensor_Core_TFLOPS = "312 (624 with sparsity)" +INT8_Tensor_Core_TOPS = "624 (1248 with sparsity)" +Register_File_Size = "64K 32-bit registers per SM" +Maximum_registers_per_thread = "255" +Maximum_thread_blocks_per_SM = "32" +Shared_memory_capacity_per_SM = "164 KB" +Maximum_shared_memory_per_thread_block = "163 KB" + +[GPU.L4] +architecture = "Ada" +GPU_Memory = "24GB" +Memory_Bandwidth = "300 GB/s" +FP32_TFLOPS = "30.3" +TF32_Tensor_Core_TFLOPS = "120 with sparsity" +BFLOAT16_Tensor_Core_TFLOPS = "242 with sparsity" +FP8_Tensor_Core_TFLOPS = "485 with sparsity" +INT8_Tensor_Core_TOPS = "485 with sparsity" +Register_File_Size = "64K 32-bit registers per SM" +Maximum_registers_per_thread = "255" +Maximum_thread_blocks_per_SM = "24" +Shared_memory_capacity_per_SM = "100 KB" +Maximum_shared_memory_per_thread_block = "99 KB" + +[GPU.T4] +architecture = "Turing" +GPU_Memory = "16 GB GDDR6" +Memory_Bandwidth = "300 GB/s" +Single_Precision_TFLOPS = "8.1" +Mixed_Precision_FP16_FP32_TFLOPS = "65" +INT8_TOPS = "130" +INT4_TOPS = "260" +Register_File_Size = "64K 32-bit registers per SM" +Maximum_registers_per_thread = "255" +Maximum_thread_blocks_per_SM = "16" +Shared_memory_capacity_per_SM = "64 KB" + +[GPU.A10G] +architecture = "Ampere" +GPU_Memory = "24GB GDDR6" +Memory_Bandwidth = "600 GB/s" +FP32_TFLOPS = "31.2" +TF32_Tensor_Core_TFLOPS = "62.5 (125 with sparsity)" +BFLOAT16_Tensor_Core_TFLOPS = "125 (250 with sparsity)" +FP16_Tensor_Core_TFLOPS = "125 (250 with sparsity)" +INT8_Tensor_Core_TOPS = "250 (500 with sparsity)" +INT4_Tensor_Core_TOPS = "500 (1000 with sparsity)" +Register_File_Size = "64K 32-bit registers per SM" +Maximum_registers_per_thread = "255" +Maximum_thread_blocks_per_SM = "32" +Shared_memory_capacity_per_SM = "164 KB" +Maximum_shared_memory_per_thread_block = "163 KB" + +# ------------------------------------------------------------------------- +# GPU-specific Definitions and Best Practices +# ------------------------------------------------------------------------- + +[GPU.definitions] +Thread = "A thread is a single execution unit that can run a single instruction at a time." +Thread_Block = "A thread block is a group of threads that can cooperate with each other." +Warp = "A warp is a group of threads that are scheduled together and execute in parallel." +Shared_Memory = "Shared memory is a memory space that can be accessed by all threads in a thread block." +Register = "A register is a small memory space that can be accessed by a single thread." +Memory_Hierarchy = "Memory hierarchy is a pyramid of memory types with different speeds and sizes." +Memory_Bandwidth = "Memory bandwidth is the rate at which data can be read from or stored into memory." +Cache = "Cache is a small memory space that stores frequently accessed data." +HBM = "HBM is a high-bandwidth memory technology that uses 3D-stacked DRAM." + +[GPU.best_practices] +items = [ + "Find ways to parallelize sequential code.", + "Minimize data transfers between the host and the device.", + "Adjust kernel launch configuration to maximize device utilization.", + "Ensure that global memory accesses are coalesced.", + "Minimize redundant accesses to global memory whenever possible.", + "Avoid long sequences of diverged execution by threads within the same warp.", + "Use specialized instructions based on the specific GPU architecture", +] + +# ------------------------------------------------------------------------- +# Tenstorrent Hardware Type (example) +# ------------------------------------------------------------------------- + +[TT] +description = "Tenstorrent accelerator specifications and best practices" + +[TT.Wormhole] +architecture = "Wormhole" +compute_units = "144 Tensix cores" +memory_capacity = "16GB LPDDR4x" +memory_bandwidth = "100 GB/s" +mesh_size = "12x12 mesh" +peak_bfloat16_tops = "150 TOPS" +peak_int8_tops = "300 TOPS" +on_chip_network_bandwidth = "6 TB/s" + +[TT.Blackhole] +architecture = "Blackhole" +compute_units = "288 Tensix cores" +memory_capacity = "32GB LPDDR5" +memory_bandwidth = "200 GB/s" +mesh_size = "18x16 mesh" +peak_bfloat16_tops = "600 TOPS" +peak_int8_tops = "1200 TOPS" +on_chip_network_bandwidth = "12 TB/s" + +[TT.Blackhole_Pro] +architecture = "Blackhole Pro" +compute_units = "512 Tensix cores" +memory_capacity = "64GB LPDDR5" +memory_bandwidth = "400 GB/s" +mesh_size = "24x24 mesh" +peak_bfloat16_tops = "900 TOPS" +peak_int8_tops = "1800 TOPS" +on_chip_network_bandwidth = "20 TB/s" + +# ------------------------------------------------------------------------- +# TT-specific Definitions and Best Practices +# ------------------------------------------------------------------------- + +[TT.definitions] +Thread = "A thread is a compute unit in the Tenstorrent mesh architecture." +Tensor_Core = "A compute unit optimized for tensor operations." +Mesh_Grid = "A 2D grid of compute units in Tenstorrent architecture." + +[TT.best_practices] +items = [ + "Optimize for tensor-parallel execution.", + "Maximize data locality in mesh architecture.", + "Use host-device data transfer efficiently.", + "Leverage mesh interconnect bandwidth.", +] diff --git a/src/prompts/model_new_ex_add_pykernel.py b/src/prompts/model_new_ex_add_pykernel.py new file mode 100644 index 00000000..ea139fec --- /dev/null +++ b/src/prompts/model_new_ex_add_pykernel.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from pykernel import ( + PyKernelOp, + reader_thread, + writer_thread, + compute_thread, + CircularBuffer, +) + +from math import ceil + +import ttnn +import torch +import torch.nn as nn + + +class VecAddMulticorePyKernelOp(PyKernelOp): + def __init__(self, max_core_ranges=None): + super().__init__() + self.max_core_ranges = max_core_ranges + + # KERNEL DEFINITIONS + @compute_thread() + def add_multicore( + cb_in0: CircularBuffer, + cb_in1: CircularBuffer, + cb_out: CircularBuffer, + num_tiles, + start_tile_id, + ): + binary_op_init_common(cb_in0, cb_in1, cb_out) + add_tiles_init(cb_in0, cb_in1) + + end_tile_id = start_tile_id + num_tiles + dst_reg = 0 + + for i in range(start_tile_id, end_tile_id, 1): + cb_wait_front(cb_in0, 1) + cb_wait_front(cb_in1, 1) + tile_regs_acquire() + add_tiles(cb_in0, cb_in1, 0, 0, dst_reg) + tile_regs_commit() + + cb_reserve_back(cb_out, 1) + tile_regs_wait() + pack_tile(dst_reg, cb_out, 0) + tile_regs_release() + + cb_push_back(cb_out, 1) + cb_pop_front(cb_in0, 1) + cb_pop_front(cb_in1, 1) + tile_regs_release() + return + + @writer_thread() + def writer_multicore( + cb_out: CircularBuffer, + dst_addr, + num_tiles, + start_id, + ): + onetile = 1 + + tile_bytes = get_tile_size(cb_out) + tensor_accessor_args = TensorAccessorArgs(1, 0) + s0 = TensorAccessor(tensor_accessor_args, dst_addr, tile_bytes) + + end_id = start_id + num_tiles + for i in range(start_id, end_id, onetile): + cb_wait_front(cb_out, onetile) + l1_read_addr = get_read_ptr(cb_out) + noc_async_write_tile(i, s0, l1_read_addr) + noc_async_write_barrier() + cb_pop_front(cb_out, onetile) + return + + @reader_thread() + def reader_binary_interleaved( + cb_in0: CircularBuffer, + cb_in1: CircularBuffer, + src_addr0, + src_addr1, + num_tiles, + start_id, + ): + onetile = 1 + + tile_bytes0 = get_tile_size(cb_in0) + tensor_accessor_args = TensorAccessorArgs(2, 0) + s0 = TensorAccessor(tensor_accessor_args, src_addr0, tile_bytes0) + + tile_bytes1 = get_tile_size(cb_in1) + tensor_accessor_args = TensorAccessorArgs(2, 0) + s1 = TensorAccessor(tensor_accessor_args, src_addr1, tile_bytes1) + + end_id = start_id + num_tiles + for i in range(start_id, end_id, onetile): + cb_reserve_back(cb_in0, onetile) + cb_reserve_back(cb_in1, onetile) + + src0_write_addr = get_write_ptr(cb_in0) + src1_write_addr = get_write_ptr(cb_in1) + + noc_async_read_tile(i, s0, src0_write_addr) + noc_async_read_tile(i, s1, src1_write_addr) + + noc_async_read_barrier() + cb_push_back(cb_in0, onetile) + cb_push_back(cb_in1, onetile) + return + + def define_core_ranges(self, tensors, options): + core_0 = ttnn.CoreCoord(0, 0) + if self.max_core_ranges is None: + core_1 = ttnn.CoreCoord(1, 1) + else: + core_1 = self.max_core_ranges + return ttnn.CoreRangeSet([ttnn.CoreRange(core_0, core_1)]) + + def invoke( + self, # super() has invoke signature (*tensors, **options) + a_tensor, + b_tensor, + out_tensor, # Tensor Definitions are positional args + ): + cb_in0 = self.create_cb(a_tensor, 0) + cb_in1 = self.create_cb(b_tensor, 1) + cb_out = self.create_cb(out_tensor, 2) + start_id = 0 + + self.set_tensor_accessor_config(a_tensor) + + num_tiles = ceil( + max(map(lambda t: t.volume(), [a_tensor, b_tensor, out_tensor])) / 1024 + ) + + num_cores = self.get_core_ranges().num_cores() + num_tiles_per_core = num_tiles / num_cores + + if num_tiles_per_core % 1 != 0: + # uneven distro of work, just break down and cry. + raise Exception(":sad_hamster:") + + num_tiles_per_core = int(num_tiles_per_core) + + # Define the multicore runtime arguments + start_id_multicore = [] + + # Go row-wise + bb = self.get_core_ranges().bounding_box() + for i in range(bb.start.x, bb.end.x + 1): + start_id_multicore.append([]) + for j in range(bb.start.y, bb.end.y + 1): + # Set for each core + start_id_multicore[-1].append([start_id]) + start_id += 1 + + kernels = [ + self.create_kernel( + VecAddMulticorePyKernelOp.add_multicore, + cb_in0, + cb_in1, + cb_out, + num_tiles_per_core, + start_id_multicore, + ), + self.create_kernel( + VecAddMulticorePyKernelOp.writer_multicore, + cb_out, + out_tensor.buffer_address(), + num_tiles_per_core, + start_id_multicore, + ), + self.create_kernel( + VecAddMulticorePyKernelOp.reader_binary_interleaved, + cb_in0, + cb_in1, + a_tensor.buffer_address(), + b_tensor.buffer_address(), + num_tiles_per_core, + start_id_multicore, + ), + ] + + return self.create_program(kernels, [cb_in0, cb_in1, cb_out]) + + +class ModelNew(nn.Module): + def __init__(self): + super().__init__() + self.vecadd_op = VecAddMulticorePyKernelOp() + + def forward(self, A: ttnn.Tensor, B: ttnn.Tensor, device: ttnn.MeshDevice) -> ttnn.Tensor: + shape = list(A.shape) + + output_tensor = ttnn.allocate_tensor_on_device( + ttnn.Shape(shape), + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + ) + + output = self.vecadd_op(A, B, output_tensor) + return output \ No newline at end of file diff --git a/src/prompts/prompts.toml b/src/prompts/prompts.toml new file mode 100644 index 00000000..6b4f3c7d --- /dev/null +++ b/src/prompts/prompts.toml @@ -0,0 +1,165 @@ +[meta] +version = "1.0" +default_language = "cuda" + +# ------------------------------------------------------------------------- +# Shared Templates: Used by all languages with placeholders +# ------------------------------------------------------------------------- +[shared] +problem_statement = """ +You write custom {language_display} to replace the pytorch operators in the given architecture to get speedups. + +You have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom {language_display} and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination. +""" + +instruction = """ +Optimize the architecture named Model with custom {language_display}! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! +""" + +# Shared example architecture (same for all languages) +few_shot_example_arch = "src/prompts/model_ex_add.py" + +# ------------------------------------------------------------------------- +# Languages: Language-specific configuration (minimal, just what varies) +# ------------------------------------------------------------------------- +[languages.triton] +language_display = "Triton kernels" +few_shot_new_arch = "src/prompts/model_new_ex_add_triton.py" + +[languages.cuda] +language_display = "CUDA operators" +few_shot_new_arch = "src/prompts/model_new_ex_add.py" + +[languages.cute] +language_display = "CuTe (CUTLASS) kernels" +few_shot_new_arch = "src/prompts/model_new_ex_add_cute.py" + +[languages.pykernel] +language_display = "Tenstorrent PyKernel kernels" +few_shot_new_arch = "src/prompts/model_new_ex_add_pykernel.py" + +# ------------------------------------------------------------------------- +# Options: Different prompt construction modes +# ------------------------------------------------------------------------- +[options.basic] +# Basic prompt: problem statement + architecture + instruction +description = "Minimal prompt with just problem statement and architecture" +components = ["problem_statement", "arch_block", "instruction"] + +[options.few_shot] +# With few-shot examples +description = "Includes few-shot examples to demonstrate the task" +components = ["problem_statement", "few_shot_block", "arch_block", "instruction"] +requires_example = true + +[options.hardware_info] +# Hardware-aware prompt +description = "Includes hardware specifications and best practices" +components = ["problem_statement", "few_shot_block", "hardware_header", "hardware_specs", "hardware_definitions", "hardware_best_practices", "arch_block", "instruction"] +requires_hardware = true +requires_example = true + +[options.fix_compile] +# For fixing compilation errors +description = "Prompt for fixing compilation errors" +components = ["problem_statement", "arch_with_context", "failed_kernel", "compile_metadata", "fix_footer"] + +[options.fix_correctness] +# For fixing correctness errors +description = "Prompt for fixing correctness errors" +components = ["problem_statement", "arch_with_context", "failed_kernel", "correctness_metadata", "fix_footer"] + +# ------------------------------------------------------------------------- +# Templates: Reusable text blocks with placeholders +# ------------------------------------------------------------------------- +[templates.common] + +# --- Architecture Presentation --- +# Used in prompts to present the reference architecture that needs optimization +arch_block = """ +You are given the following architecture: + + +{ref_arch_src} + +""" + +# Used in fix prompts to reference the architecture with contextual phrasing +arch_with_context = """ +With the following architecture: + + +{ref_arch_src} + +""" + +# --- Few-Shot Learning --- +# Shows an example of input architecture and its optimized version +few_shot_block = """ +Here's an example to show you the syntax of inline embedding custom {language_display} in torch: The example given architecture is: + +{example_arch_src} + + +The example new arch with custom {language_display} looks like this: + + +{example_new_arch_src} + +""" + +# --- Error Fix Templates --- +# Presents a kernel that failed (used in fix_compile and fix_correctness options) +failed_kernel = """ +You generated the following solution and it failed {failure_type}: + + +{custom_kernel} + +""" + +compile_metadata = """ +Here's the metadata of the compilation error: + + +{metadata} + +""" + +correctness_metadata = """ +Here's the metadata of the correctness error: + + +{metadata} + +""" + +fix_footer = """ +Please fix the {failure_type} in the new model code. Please output the corrected code in codeblocks. +""" + +# ------------------------------------------------------------------------- +# Hardware Templates: GPU-specific information blocks +# ------------------------------------------------------------------------- +[templates.hardware] +hardware_header = """ +Here is some information about the underlying hardware that you should keep in mind. +""" + +hardware_specs = """ +The {hardware_type} that will run the kernel is {hardware_name}, {hardware_architecture} architecture. + +{hardware_specs_bullets} +""" + +hardware_definitions = """ +Here are some concepts about the hardware architecture that could be helpful: + +{hardware_definitions_bullets} +""" + +hardware_best_practices = """ +Here are some best practices for writing kernels on this hardware: + +{hardware_best_practices_bullets} +"""