|
46 | 46 | from .benchmarking import interleaved_bench |
47 | 47 | from .config_generation import ConfigGeneration |
48 | 48 | from .config_generation import FlatConfig |
| 49 | +from .logger import SUPPRESSED_TRITON_CODE_MSG |
49 | 50 | from .logger import LambdaLogger |
50 | 51 | from .logger import classify_triton_exception |
51 | 52 | from .logger import format_triton_compile_failure |
| 53 | +from .logger import log_generated_triton_code_debug |
52 | 54 | from .logger import match_unrecoverable_runtime_error |
53 | 55 | from .progress_bar import iter_with_progress |
54 | 56 |
|
@@ -163,11 +165,16 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]: |
163 | 165 | decorator = self.kernel.format_kernel_decorator( |
164 | 166 | baseline_config, self.settings |
165 | 167 | ) |
166 | | - triton_code = self.kernel.to_triton_code(baseline_config) |
| 168 | + log_generated_triton_code_debug( |
| 169 | + self.log, |
| 170 | + self.kernel, |
| 171 | + baseline_config, |
| 172 | + prefix=f"Generated Triton code for {decorator}:", |
| 173 | + ) |
167 | 174 | raise exc.InvalidConfig( |
168 | 175 | "Default config failed while computing baseline.\n" |
169 | 176 | f"Default config: {decorator}\n" |
170 | | - f"\nGenerated Triton code:\n{triton_code}\n" |
| 177 | + f"{SUPPRESSED_TRITON_CODE_MSG}\n" |
171 | 178 | ) from e |
172 | 179 | original_args_flat, _ = tree_flatten(self._original_args) |
173 | 180 | new_args_flat, _ = tree_flatten(new_args) |
@@ -326,16 +333,35 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: |
326 | 333 | if self.settings.autotune_ignore_errors: |
327 | 334 | pass |
328 | 335 | elif action == "raise": |
| 336 | + decorator = self.kernel.format_kernel_decorator(config, self.settings) |
| 337 | + log_generated_triton_code_debug( |
| 338 | + self.log, |
| 339 | + self.kernel, |
| 340 | + config, |
| 341 | + prefix=f"Generated Triton code for {decorator}:", |
| 342 | + ) |
329 | 343 | raise exc.TritonError( |
330 | 344 | error=f"{type(e).__qualname__}: {e}", |
331 | | - decorator=self.kernel.format_kernel_decorator( |
332 | | - config, self.settings |
333 | | - ), |
334 | | - code=self.kernel.to_triton_code(config), |
| 345 | + decorator=decorator, |
| 346 | + code=SUPPRESSED_TRITON_CODE_MSG, |
335 | 347 | ) from e |
336 | 348 | elif action == "warn": |
| 349 | + decorator = self.kernel.format_kernel_decorator(config, self.settings) |
| 350 | + log_generated_triton_code_debug( |
| 351 | + self.log, |
| 352 | + self.kernel, |
| 353 | + config, |
| 354 | + prefix=f"Generated Triton code for {decorator}:", |
| 355 | + ) |
337 | 356 | self.log.warning(format_triton_compile_failure(config, e, self.kernel)) |
338 | 357 | else: |
| 358 | + decorator = self.kernel.format_kernel_decorator(config, self.settings) |
| 359 | + log_generated_triton_code_debug( |
| 360 | + self.log, |
| 361 | + self.kernel, |
| 362 | + config, |
| 363 | + prefix=f"Generated Triton code for {decorator}:", |
| 364 | + ) |
339 | 365 | self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}") |
340 | 366 | return inf |
341 | 367 |
|
@@ -1143,15 +1169,31 @@ def _consume_result(self, *, raise_on_raise: bool) -> None: |
1143 | 1169 | if classification == "raise": |
1144 | 1170 | if raise_on_raise: |
1145 | 1171 | self._remote_error_handled = True |
| 1172 | + decorator = self.search.kernel.format_kernel_decorator( |
| 1173 | + self.config, self.search.settings |
| 1174 | + ) |
| 1175 | + log_generated_triton_code_debug( |
| 1176 | + self.search.log, |
| 1177 | + self.search.kernel, |
| 1178 | + self.config, |
| 1179 | + prefix=f"Generated Triton code for {decorator}:", |
| 1180 | + ) |
1146 | 1181 | raise exc.TritonError( |
1147 | 1182 | error=f"{type(exc_obj).__qualname__}: {exc_obj}", |
1148 | | - decorator=self.search.kernel.format_kernel_decorator( |
1149 | | - self.config, self.search.settings |
1150 | | - ), |
1151 | | - code=self.search.kernel.to_triton_code(self.config), |
| 1183 | + decorator=decorator, |
| 1184 | + code=SUPPRESSED_TRITON_CODE_MSG, |
1152 | 1185 | ) from exc_obj |
1153 | 1186 | return |
1154 | 1187 |
|
| 1188 | + decorator = self.search.kernel.format_kernel_decorator( |
| 1189 | + self.config, self.search.settings |
| 1190 | + ) |
| 1191 | + log_generated_triton_code_debug( |
| 1192 | + self.search.log, |
| 1193 | + self.search.kernel, |
| 1194 | + self.config, |
| 1195 | + prefix=f"Generated Triton code for {decorator}:", |
| 1196 | + ) |
1155 | 1197 | formatted = format_triton_compile_failure( |
1156 | 1198 | self.config, exc_obj, self.search.kernel |
1157 | 1199 | ) |
@@ -1265,10 +1307,16 @@ def extract_launcher( |
1265 | 1307 | return None |
1266 | 1308 | return precompiler |
1267 | 1309 | except Exception: |
| 1310 | + log_generated_triton_code_debug( |
| 1311 | + log, |
| 1312 | + kernel, |
| 1313 | + config, |
| 1314 | + prefix=f"Generated Triton code for {decorator}:", |
| 1315 | + ) |
1268 | 1316 | log.warning( |
1269 | | - "Helion autotuner precompile error for %s\n\nGenerated Triton code:\n%s", |
| 1317 | + "Helion autotuner precompile error for %s. %s", |
1270 | 1318 | decorator, |
1271 | | - kernel.to_triton_code(config), |
| 1319 | + SUPPRESSED_TRITON_CODE_MSG, |
1272 | 1320 | exc_info=True, |
1273 | 1321 | ) |
1274 | 1322 | raise |
|
0 commit comments