Skip to content

Commit f6646e9

Browse files
authored
Log generated triton code at the DEBUG level rather than INFO (#986)
1 parent f805e5c commit f6646e9

File tree

3 files changed

+103
-15
lines changed

3 files changed

+103
-15
lines changed

.github/workflows/benchmark.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ jobs:
3737
benchmark:
3838
name: benchmark-${{ inputs.runtime-version }}-${{ inputs.kernels }}-py${{ inputs.python-version }}-${{ inputs.alias }}
3939

40+
env:
41+
HELION_AUTOTUNE_LOG_LEVEL: DEBUG
42+
4043
container:
4144
image: ${{ inputs.image }}
4245
options: ${{ inputs.container-options }}

helion/autotuner/base_search.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@
4646
from .benchmarking import interleaved_bench
4747
from .config_generation import ConfigGeneration
4848
from .config_generation import FlatConfig
49+
from .logger import SUPPRESSED_TRITON_CODE_MSG
4950
from .logger import LambdaLogger
5051
from .logger import classify_triton_exception
5152
from .logger import format_triton_compile_failure
53+
from .logger import log_generated_triton_code_debug
5254
from .logger import match_unrecoverable_runtime_error
5355
from .progress_bar import iter_with_progress
5456

@@ -163,11 +165,16 @@ def _compute_baseline(self) -> tuple[object, bool, Sequence[object] | None]:
163165
decorator = self.kernel.format_kernel_decorator(
164166
baseline_config, self.settings
165167
)
166-
triton_code = self.kernel.to_triton_code(baseline_config)
168+
log_generated_triton_code_debug(
169+
self.log,
170+
self.kernel,
171+
baseline_config,
172+
prefix=f"Generated Triton code for {decorator}:",
173+
)
167174
raise exc.InvalidConfig(
168175
"Default config failed while computing baseline.\n"
169176
f"Default config: {decorator}\n"
170-
f"\nGenerated Triton code:\n{triton_code}\n"
177+
f"{SUPPRESSED_TRITON_CODE_MSG}\n"
171178
) from e
172179
original_args_flat, _ = tree_flatten(self._original_args)
173180
new_args_flat, _ = tree_flatten(new_args)
@@ -326,16 +333,35 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
326333
if self.settings.autotune_ignore_errors:
327334
pass
328335
elif action == "raise":
336+
decorator = self.kernel.format_kernel_decorator(config, self.settings)
337+
log_generated_triton_code_debug(
338+
self.log,
339+
self.kernel,
340+
config,
341+
prefix=f"Generated Triton code for {decorator}:",
342+
)
329343
raise exc.TritonError(
330344
error=f"{type(e).__qualname__}: {e}",
331-
decorator=self.kernel.format_kernel_decorator(
332-
config, self.settings
333-
),
334-
code=self.kernel.to_triton_code(config),
345+
decorator=decorator,
346+
code=SUPPRESSED_TRITON_CODE_MSG,
335347
) from e
336348
elif action == "warn":
349+
decorator = self.kernel.format_kernel_decorator(config, self.settings)
350+
log_generated_triton_code_debug(
351+
self.log,
352+
self.kernel,
353+
config,
354+
prefix=f"Generated Triton code for {decorator}:",
355+
)
337356
self.log.warning(format_triton_compile_failure(config, e, self.kernel))
338357
else:
358+
decorator = self.kernel.format_kernel_decorator(config, self.settings)
359+
log_generated_triton_code_debug(
360+
self.log,
361+
self.kernel,
362+
config,
363+
prefix=f"Generated Triton code for {decorator}:",
364+
)
339365
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
340366
return inf
341367

@@ -1143,15 +1169,31 @@ def _consume_result(self, *, raise_on_raise: bool) -> None:
11431169
if classification == "raise":
11441170
if raise_on_raise:
11451171
self._remote_error_handled = True
1172+
decorator = self.search.kernel.format_kernel_decorator(
1173+
self.config, self.search.settings
1174+
)
1175+
log_generated_triton_code_debug(
1176+
self.search.log,
1177+
self.search.kernel,
1178+
self.config,
1179+
prefix=f"Generated Triton code for {decorator}:",
1180+
)
11461181
raise exc.TritonError(
11471182
error=f"{type(exc_obj).__qualname__}: {exc_obj}",
1148-
decorator=self.search.kernel.format_kernel_decorator(
1149-
self.config, self.search.settings
1150-
),
1151-
code=self.search.kernel.to_triton_code(self.config),
1183+
decorator=decorator,
1184+
code=SUPPRESSED_TRITON_CODE_MSG,
11521185
) from exc_obj
11531186
return
11541187

1188+
decorator = self.search.kernel.format_kernel_decorator(
1189+
self.config, self.search.settings
1190+
)
1191+
log_generated_triton_code_debug(
1192+
self.search.log,
1193+
self.search.kernel,
1194+
self.config,
1195+
prefix=f"Generated Triton code for {decorator}:",
1196+
)
11551197
formatted = format_triton_compile_failure(
11561198
self.config, exc_obj, self.search.kernel
11571199
)
@@ -1265,10 +1307,16 @@ def extract_launcher(
12651307
return None
12661308
return precompiler
12671309
except Exception:
1310+
log_generated_triton_code_debug(
1311+
log,
1312+
kernel,
1313+
config,
1314+
prefix=f"Generated Triton code for {decorator}:",
1315+
)
12681316
log.warning(
1269-
"Helion autotuner precompile error for %s\n\nGenerated Triton code:\n%s",
1317+
"Helion autotuner precompile error for %s. %s",
12701318
decorator,
1271-
kernel.to_triton_code(config),
1319+
SUPPRESSED_TRITON_CODE_MSG,
12721320
exc_info=True,
12731321
)
12741322
raise

helion/autotuner/logger.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,56 @@ def _maybe_call(fn: Callable[[], str] | str) -> str:
9393
return fn
9494

9595

96+
SUPPRESSED_TRITON_CODE_MSG = (
97+
"Enable HELION_AUTOTUNE_LOG_LEVEL=DEBUG to log generated Triton code."
98+
)
99+
100+
101+
def log_generated_triton_code_debug(
102+
logger: logging.Logger | LambdaLogger,
103+
bound_kernel: BoundKernel,
104+
config: Config,
105+
*,
106+
prefix: str | None = None,
107+
) -> None:
108+
"""
109+
Emit the generated Triton code at debug level if the logger allows it.
110+
111+
Args:
112+
logger: Logger that should receive the message.
113+
bound_kernel: Kernel whose Triton code should be logged.
114+
config: Config used to generate the Triton code.
115+
prefix: Optional prefix for the log message.
116+
"""
117+
message_prefix = prefix or "Generated Triton code:"
118+
if isinstance(logger, LambdaLogger):
119+
logger.debug(lambda: _format_triton_code(bound_kernel, config, message_prefix))
120+
return
121+
if logger.isEnabledFor(logging.DEBUG):
122+
logger.debug(
123+
"%s\n%s",
124+
message_prefix,
125+
bound_kernel.to_triton_code(config),
126+
)
127+
128+
129+
def _format_triton_code(bound_kernel: BoundKernel, config: Config, prefix: str) -> str:
130+
code = bound_kernel.to_triton_code(config)
131+
return f"{prefix}\n{code}"
132+
133+
96134
def format_triton_compile_failure(
97135
config: Config, err: BaseException, bound_kernel: BoundKernel
98136
) -> str:
99137
kernel_decorator = bound_kernel.format_kernel_decorator(
100138
config, bound_kernel.settings
101139
)
102-
triton_code = bound_kernel.to_triton_code(config)
103140
return (
104141
"Triton compile failed. This likely indicates a bug in Triton. "
105142
"Skipping failing config.\n"
106143
f"Config: {kernel_decorator}\n"
107-
f"Error: {type(err).__name__}: {err}\n\n"
108-
f"Generated Triton code:\n{triton_code}"
144+
f"Error: {type(err).__name__}: {err}\n"
145+
f"{SUPPRESSED_TRITON_CODE_MSG}"
109146
)
110147

111148

0 commit comments

Comments
 (0)