From e5d6cd6f5b96f484179b654d7503e1f620e008d6 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 11 Oct 2025 21:21:14 -0700 Subject: [PATCH] Verify compiled kernels in subprocess This is to handle configs that hang when we run them. It should also fix the IMA issues we have been seeing. stack-info: PR: https://github.com/pytorch/helion/pull/914, branch: jansel/stack/173 --- docs/api/settings.md | 7 +- helion/autotuner/base_search.py | 417 ++++++++++++++++++--- helion/autotuner/differential_evolution.py | 2 +- helion/autotuner/logger.py | 4 + helion/exc.py | 4 + helion/runtime/__init__.py | 1 + helion/runtime/kernel.py | 2 +- helion/runtime/precompile_shim.py | 24 +- helion/runtime/settings.py | 21 +- test/test_autotuner.py | 212 +++++++---- 10 files changed, 551 insertions(+), 143 deletions(-) diff --git a/docs/api/settings.md b/docs/api/settings.md index 2e86a1e61..b4bc71aca 100644 --- a/docs/api/settings.md +++ b/docs/api/settings.md @@ -125,7 +125,12 @@ with helion.set_default_settings( .. autoattribute:: Settings.autotune_precompile - Whether to precompile kernels before autotuning. Default is ``True`` on non-Windows systems, ``False`` on Windows. + Select the autotuner precompile mode, which adds parallelism and + checks for errors/timeouts. ``"spawn"`` (default) runs kernel + warm-up in a fresh process including running to check for errors, + ``"fork"`` is faster but does not include the error check run, + or None to disables precompile checks altogether. Controlled by + ``HELION_AUTOTUNE_PRECOMPILE``. .. autoattribute:: Settings.autotune_random_seed diff --git a/helion/autotuner/base_search.py b/helion/autotuner/base_search.py index 8a696db40..bdf21f17d 100644 --- a/helion/autotuner/base_search.py +++ b/helion/autotuner/base_search.py @@ -4,36 +4,41 @@ import collections import contextlib import dataclasses +import datetime import functools +import inspect from itertools import starmap import logging import math from math import inf +import multiprocessing as mp from multiprocessing import connection import os +from pathlib import Path import random import sys +import tempfile import time +import traceback +import types from typing import TYPE_CHECKING +from typing import Any from typing import Callable from typing import NoReturn - -from .benchmarking import interleaved_bench - -if TYPE_CHECKING: - from triton.runtime.jit import JITFunction - +from typing import cast from unittest.mock import patch +import uuid import torch -import torch.multiprocessing as mp from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_map from triton.testing import do_bench from .. import exc +from ..runtime.kernel import BoundKernel from ..runtime.precompile_shim import already_compiled from ..runtime.precompile_shim import make_precompiler +from .benchmarking import interleaved_bench from .config_generation import ConfigGeneration from .config_generation import FlatConfig from .logger import LambdaLogger @@ -46,8 +51,6 @@ if TYPE_CHECKING: from collections.abc import Sequence - import triton - from ..runtime.config import Config from ..runtime.kernel import BoundKernel from ..runtime.kernel import CompiledConfig @@ -101,6 +104,8 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: self._baseline_output: object | None = None self._baseline_post_args: Sequence[object] | None = None self._kernel_mutates_args: bool = False + self._precompile_tmpdir: tempfile.TemporaryDirectory[str] | None = None + self._precompile_args_path: str | None = None if self.settings.autotune_accuracy_check: ( self._baseline_output, @@ -108,6 +113,12 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None: self._baseline_post_args, ) = self._compute_baseline() + def cleanup(self) -> None: + if self._precompile_tmpdir is not None: + self._precompile_tmpdir.cleanup() + self._precompile_tmpdir = None + self._precompile_args_path = None + def _clone_args(self, args: Sequence[object]) -> Sequence[object]: def _clone_leaf(leaf: object) -> object: if isinstance(leaf, torch.Tensor): @@ -206,6 +217,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float: self.log.debug(lambda: f"Running benchmark for {config!r}") try: # TODO(jansel): early exit with fewer trials if early runs are slow + self.log.debug(lambda: f"Running {config} at {datetime.datetime.now()}") t0 = time.perf_counter() if self._kernel_mutates_args: self.args = self._clone_args(self._original_args) @@ -251,8 +263,8 @@ def start_precompile_and_check_for_hangs( self, config: Config, fn: CompiledConfig ) -> PrecompileFuture: """ - Unfortunately, Triton can hang when compiling a kernel. This function tries to - compile the kernel with the given configuration and checks if it hangs in a subprocess. + Run the kernel in a spawned subprocess to detect hangs during compilation or execution. + We use the subprocess timeout to guard against Triton kernels that never finish. We also do this in parallel (when called from parallel_benchmark) to do faster autotuning. Note that we compile in parallel, but we benchmark one-by-one to avoid noisy results. @@ -265,44 +277,51 @@ def start_precompile_and_check_for_hangs( """ if not self.settings.autotune_precompile: return PrecompileFuture.skip(self, config, True) - ctx = mp.get_context("fork") - - def extract_launcher( - triton_kernel: triton.JITFunction, - grid: tuple[int, ...], - *args: object, - **kwargs: object, - ) -> NoReturn: - """Custom launcher that extracts arguments instead of executing.""" - raise _ExtractedLaunchArgs(triton_kernel, grid, args, kwargs) - - try: - # Call main function with extraction launcher to extract arguments - fn(*self.args, _launcher=extract_launcher) - # Should not reach here - raise RuntimeError("Expected _ExtractedLaunchArgs exception") - except _ExtractedLaunchArgs as e: - precompiler = make_precompiler( - e.kernel, - config, - self.kernel, - )(*e.args, **e.kwargs) - if precompiler is already_compiled: - return PrecompileFuture.skip(self, config, True) - except Exception: - log.warning( - "Helion autotuner precompile error for %s\n\nGenerated Triton code:\n%s", - self.kernel.format_kernel_decorator(config, self.settings), - self.kernel.to_triton_code(config), - exc_info=True, + mode = self.settings.autotune_precompile + if mode not in {"fork", "spawn"}: + raise exc.InvalidAPIUsage("autotune_precompile must be 'fork' or 'spawn'") + if self._kernel_mutates_args: + device_args = self._clone_args(self._original_args) + else: + device_args = self.args + + decorator = self.kernel.format_kernel_decorator(config, self.settings) + + if mode == "spawn": + ctx = mp.get_context("spawn") + assert self._precompile_args_path is not None + parent_conn, child_conn = ctx.Pipe() + try: + fn_spec = _serialize_compiled_fn(fn) + except RuntimeError as err: + raise exc.AutotuneError( + "Failed to serialize compiled kernel for spawn precompile." + ' Set HELION_AUTOTUNE_PRECOMPILE="fork" to fall back to fork mode.' + ) from err + process = cast( + "mp.Process", + ctx.Process( + target=_run_kernel_in_subprocess_spawn, + args=(fn_spec, self._precompile_args_path, child_conn, decorator), + ), + ) + else: + ctx = mp.get_context("fork") + parent_conn, child_conn = ctx.Pipe() + process = cast( + "mp.Process", + ctx.Process( + target=_run_kernel_in_subprocess_fork, + args=(fn, device_args, config, self.kernel, child_conn, decorator), + ), ) - raise - process: mp.Process = ctx.Process(target=precompiler) # pyright: ignore[reportAssignmentType] return PrecompileFuture( search=self, config=config, process=process, timeout=self.settings.autotune_compile_timeout, + conn=parent_conn, + child_conn=child_conn, ) def parallel_benchmark( @@ -361,8 +380,20 @@ def autotune(self) -> Config: """ start = time.perf_counter() self.log.reset() - # Autotuner triggers bugs in remote triton compile service - with patch.dict(os.environ, {"TRITON_LOCAL_BUILD": "1"}, clear=False): + exit_stack = contextlib.ExitStack() + with exit_stack: + # Autotuner triggers bugs in remote triton compile service + exit_stack.enter_context( + patch.dict(os.environ, {"TRITON_LOCAL_BUILD": "1"}, clear=False) + ) + if self.settings.autotune_precompile == "spawn": + assert self._precompile_tmpdir is None + tempdir = tempfile.TemporaryDirectory() + self._precompile_tmpdir = tempdir + args_path = os.path.join(tempdir.name, "args.pt") + torch.save(self.args, args_path) + self._precompile_args_path = args_path + exit_stack.callback(self.cleanup) best = self._autotune() end = time.perf_counter() kernel_decorator = self.kernel.format_kernel_decorator(best, self.settings) @@ -665,6 +696,11 @@ class PrecompileFuture: start_time: float | None = None end_time: float | None = None ok: bool | None = None + conn: connection.Connection | None = None + child_conn: connection.Connection | None = None + _result_received: bool = False + remote_error: RemoteError | None = None + _remote_error_handled: bool = False @property def elapsed(self) -> float: @@ -700,6 +736,10 @@ def start(self) -> None: return self.start_time = time.time() self.process.start() + if self.child_conn is not None: + with contextlib.suppress(Exception): + self.child_conn.close() + self.child_conn = None @staticmethod def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture: @@ -713,6 +753,11 @@ def skip(search: BaseSearch, config: Config, ok: bool) -> PrecompileFuture: ok=ok, start_time=ts, end_time=ts, + conn=None, + child_conn=None, + _result_received=True, + remote_error=None, + _remote_error_handled=True, ) def __call__(self) -> bool: @@ -728,6 +773,7 @@ def __call__(self) -> bool: process.join(self.seconds_left()) finally: self._mark_complete() + self._handle_remote_error(raise_on_raise=True) assert self.ok is not None return self.ok @@ -800,6 +846,7 @@ def _wait_for_all_step( continue if f.started and (not f.is_alive() or f.seconds_left() <= 0): f._mark_complete() + f._handle_remote_error(raise_on_raise=True) else: remaining.append(f) return remaining @@ -820,6 +867,8 @@ def _mark_complete(self) -> bool: self.start() if not process.is_alive(): self.ok = process.exitcode == 0 + self._recv_result(block=True) + self._handle_remote_error(raise_on_raise=False) return self.ok process.terminate() process.join(10) @@ -835,26 +884,221 @@ def _mark_complete(self) -> bool: self.search.log.warning(msg) self.ok = False + self._recv_result(block=False) + self._handle_remote_error(raise_on_raise=False) return False + def _recv_result(self, *, block: bool) -> None: + if self._result_received or self.conn is None: + return + timeout = None if block else 0.0 + try: + if self.conn.poll(timeout): + message = self.conn.recv() + if isinstance(message, dict) and message.get("status") == "ok": + if self.ok is None: + self.ok = True + elif isinstance(message, dict): + exc_args = message.get("exc_args") + if not isinstance(exc_args, (list, tuple)): + exc_args = (message.get("traceback"),) + self.remote_error = RemoteError( + exc_type=message.get("exc_type", "RemoteError"), + exc_module=message.get("exc_module"), + exc_args=tuple(exc_args), + traceback=message.get("traceback"), + classification=message.get("classification"), + ) + self.ok = False + elif block: + self.remote_error = self.remote_error or RemoteError( + exc_type="EOFError", + exc_module=__name__, + exc_args=("No result received from subprocess.",), + traceback=None, + classification="debug", + ) + except (EOFError, OSError) as exc: + if self.remote_error is None: + self.remote_error = RemoteError( + exc_type=type(exc).__name__, + exc_module=type(exc).__module__, + exc_args=(str(exc),), + traceback=None, + classification="debug", + ) + finally: + with contextlib.suppress(Exception): + self.conn.close() + self.conn = None + self._result_received = True + + def _handle_remote_error(self, *, raise_on_raise: bool) -> None: + error = self.remote_error + if error is None or self._remote_error_handled: + return + exc_obj = error.to_exception() + classification = error.classification or classify_triton_exception(exc_obj) + if classification == "raise": + if raise_on_raise: + self._remote_error_handled = True + raise exc.TritonError( + f"{type(exc_obj).__qualname__}: {exc_obj}", + self.search.kernel.format_kernel_decorator( + self.config, self.search.settings + ), + self.search.kernel.to_triton_code(self.config), + ) from exc_obj + return + + message = format_triton_compile_failure( + self.config, exc_obj, self.search.kernel + ) + if error.traceback: + message = ( + f"{message}\nRemote traceback (spawned process):\n{error.traceback}" + ) + if classification == "warn": + self.search.log.warning(message) + else: + self.search.log.debug(message) + self._remote_error_handled = True + + +def _clone_tree(tree: object) -> object: + def _clone(leaf: object) -> object: + if isinstance(leaf, torch.Tensor): + clone = leaf.detach().clone() + clone.requires_grad_(leaf.requires_grad) + return clone + return leaf + + return tree_map(_clone, tree) + + +def _assert_args_close(actual: Sequence[object], expected: Sequence[object]) -> None: + actual_flat, _ = tree_flatten(actual) + expected_flat, _ = tree_flatten(expected) + for act, exp in zip(actual_flat, expected_flat, strict=False): + if isinstance(act, torch.Tensor) and isinstance(exp, torch.Tensor): + torch.testing.assert_close(act, exp, atol=1e-2, rtol=1e-2) + + +def _run_kernel_in_subprocess_spawn( + fn_spec: SerializedCompiledFunction, + args_path: str, + conn: connection.Connection, + decorator: str, +) -> None: + status = 0 + try: + fn = _load_compiled_fn(fn_spec) + args = torch.load(args_path) + assert isinstance(args, (tuple, list)) + torch.accelerator.synchronize() + fn(*args) + torch.accelerator.synchronize() + conn.send({"status": "ok"}) + except Exception as exc: + status = 1 + with contextlib.suppress(Exception): + try: + exc_args = tuple(exc.args) + except Exception: + exc_args = (str(exc),) + try: + classification = classify_triton_exception(exc) + except Exception: + classification = None + conn.send( + { + "status": "error", + "traceback": traceback.format_exc(), + "decorator": decorator, + "exc_type": type(exc).__name__, + "exc_module": type(exc).__module__, + "exc_args": exc_args, + "classification": classification, + } + ) + finally: + with contextlib.suppress(Exception): + conn.close() + os._exit(status) + + +def _run_kernel_in_subprocess_fork( + fn: CompiledConfig, + args: Sequence[object], + config: Config, + kernel: BoundKernel, + conn: connection.Connection, + decorator: str, +) -> None: + status = 0 + try: + + def extract_launcher( + triton_kernel: object, + grid: tuple[int, ...], + *launch_args: object, + **launch_kwargs: object, + ) -> NoReturn: + raise _ExtractedLaunchArgs(triton_kernel, grid, launch_args, launch_kwargs) + + try: + fn(*tuple(args), _launcher=extract_launcher) + raise RuntimeError("Expected _ExtractedLaunchArgs to be raised") + except _ExtractedLaunchArgs as extracted: + precompiler_factory = make_precompiler( + cast("Any", extracted.kernel), + config, + kernel, + ) + precompiler = precompiler_factory(*extracted.args, **extracted.kwargs) + if precompiler is not already_compiled: + precompiler() + conn.send({"status": "ok"}) + except Exception as exc: + status = 1 + with contextlib.suppress(Exception): + try: + exc_args = tuple(exc.args) + except Exception: + exc_args = (str(exc),) + try: + classification = classify_triton_exception(exc) + except Exception: + classification = None + conn.send( + { + "status": "error", + "traceback": traceback.format_exc(), + "decorator": decorator, + "exc_type": type(exc).__name__, + "exc_module": type(exc).__module__, + "exc_args": exc_args, + "classification": classification, + } + ) + finally: + with contextlib.suppress(Exception): + conn.close() + os._exit(status) + class _ExtractedLaunchArgs(Exception): """Exception that carries kernel launch arguments for precompiler extraction.""" - kernel: JITFunction[object] - grid: object - args: tuple[object, ...] - kwargs: dict[str, object] - def __init__( self, - triton_kernel: JITFunction[object], - grid: object, + kernel: object, + grid: tuple[int, ...], args: tuple[object, ...], kwargs: dict[str, object], ) -> None: super().__init__() - self.kernel = triton_kernel + self.kernel = kernel self.grid = grid self.args = args self.kwargs = kwargs @@ -862,3 +1106,70 @@ def __init__( def _unset_fn(*args: object) -> NoReturn: raise RuntimeError("Uninitialized function") + + +@dataclasses.dataclass +class SerializedCompiledFunction: + function_name: str + source_code: str + filename: str | None + module_name: str | None + + +@dataclasses.dataclass +class RemoteError: + exc_type: str + exc_module: str | None + exc_args: tuple[object, ...] + traceback: str | None + classification: str | None + + def to_exception(self) -> Exception: + exc_cls = types.new_class(self.exc_type, (Exception,)) + exc_cls.__module__ = self.exc_module or __name__ + exc_obj = exc_cls(*self.exc_args) + exc_obj.remote_traceback = self.traceback + return exc_obj + + +def _serialize_compiled_fn(fn: CompiledConfig) -> SerializedCompiledFunction: + if "" in getattr(fn, "__qualname__", ""): + raise RuntimeError("Unable to serialize nested compiled functions") + module_name = getattr(fn, "__module__", None) + module = sys.modules.get(module_name) if module_name is not None else None + filename: str | None = None + source_code: str | None = None + if module is not None: + filename = getattr(module, "__file__", None) + if filename is not None and os.path.exists(filename): + source_code = Path(filename).read_text(encoding="utf-8") + if source_code is None: + with contextlib.suppress(OSError, TypeError): + source_code = inspect.getsource(module) + if source_code is None: + raise RuntimeError("Unable to capture source for compiled kernel") + return SerializedCompiledFunction( + function_name=fn.__name__, + source_code=source_code, + filename=filename, + module_name=module_name, + ) + + +def _load_compiled_fn(fn_spec: SerializedCompiledFunction) -> CompiledConfig: + module_name = f"_helion_autotune_subprocess_{uuid.uuid4().hex}" + module = types.ModuleType(module_name) + module.__file__ = fn_spec.filename or "" + module.__loader__ = None + module.__package__ = None + sys.modules[module_name] = module + exec( + compile(fn_spec.source_code, module.__file__, "exec"), + module.__dict__, + ) + fn = getattr(module, fn_spec.function_name, None) + if fn is None: + raise RuntimeError( + f"Unable to locate compiled kernel '{fn_spec.function_name}' in generated module" + ) + return fn diff --git a/helion/autotuner/differential_evolution.py b/helion/autotuner/differential_evolution.py index d570ae6c4..df172ecda 100644 --- a/helion/autotuner/differential_evolution.py +++ b/helion/autotuner/differential_evolution.py @@ -34,7 +34,7 @@ def __init__( ) -> None: super().__init__(kernel, args) if immediate_update is None: - immediate_update = not kernel.settings.autotune_precompile + immediate_update = not bool(kernel.settings.autotune_precompile) self.population_size = population_size self.max_generations = max_generations self.crossover_rate = crossover_rate diff --git a/helion/autotuner/logger.py b/helion/autotuner/logger.py index 6bdbd71c5..5ca803f86 100644 --- a/helion/autotuner/logger.py +++ b/helion/autotuner/logger.py @@ -121,6 +121,10 @@ def format_triton_compile_failure( "triton.compiler.errors.CompilationError", # Triton CompilationError "out of resource: shared memory", # Triton shared memory OOM "ZE_RESULT_ERROR_INVALID_KERNEL_NAME", # Level Zero compile failed + "an illegal memory access was encountered", # workaround triton bugs + "misaligned address", # workaround triton bugs + "unspecified launch failure", # workaround ptxas bugs + "exceeds triton maximum tensor numel", # needs smaller config ], ) ) diff --git a/helion/exc.py b/helion/exc.py index f9a231996..79dbb1423 100644 --- a/helion/exc.py +++ b/helion/exc.py @@ -48,6 +48,10 @@ class ClosuresNotSupported(BaseError): message = "A closure ({0!r}) was found in the kernel. Closures are not supported." +class AutotuneError(BaseError): + message = "{0}" + + class ClosureMutation(BaseError): message = "Closure mutation (of {0}) is not allowed in a function arg." diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 4052e05df..1f2c9d718 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -5,6 +5,7 @@ import torch +from .. import _compat as _compat # ensure Triton compatibility patches run from .config import Config as Config from .kernel import Kernel as Kernel from .kernel import kernel as kernel diff --git a/helion/runtime/kernel.py b/helion/runtime/kernel.py index 4655ed24d..c638ed129 100644 --- a/helion/runtime/kernel.py +++ b/helion/runtime/kernel.py @@ -501,7 +501,7 @@ def autotune( (config,) = self.kernel.configs else: # We have finite predetermined configs, no need to precompile - self.settings.autotune_precompile = False + self.settings.autotune_precompile = None from ..autotuner import FiniteSearch diff --git a/helion/runtime/precompile_shim.py b/helion/runtime/precompile_shim.py index 0e9824cf3..6d72dba78 100644 --- a/helion/runtime/precompile_shim.py +++ b/helion/runtime/precompile_shim.py @@ -3,7 +3,10 @@ import os import sys from typing import TYPE_CHECKING +from typing import cast +from .._compat import get_triton_find_paths_if +from .._compat import get_triton_iterable_path from ..autotuner.logger import classify_triton_exception from ..autotuner.logger import format_triton_compile_failure @@ -21,9 +24,6 @@ def make_precompiler( config: Config, bound_kernel: BoundKernel, ) -> Callable[..., Callable[[], None]]: - from triton.runtime.jit import find_paths_if - from triton.runtime.jit import get_iterable_path - from .kernel import _find_device def _make_precompiler(*args: object, **kwargs: object) -> Callable[[], None]: @@ -47,14 +47,24 @@ def _make_precompiler(*args: object, **kwargs: object) -> Callable[[], None]: sigkeys = [x.name for x in fn.params] sigvals = [x[0] for x in specialization] signature = dict(zip(sigkeys, sigvals, strict=False)) - constexprs = find_paths_if(sigvals, lambda _, val: val == "constexpr") + find_paths_if = get_triton_find_paths_if() + get_iterable_path = get_triton_iterable_path() + constexpr_paths = cast( + "list[tuple[int, ...]]", + find_paths_if(sigvals, lambda _, val: val == "constexpr"), + ) constexprs = { path: get_iterable_path(list(bound_args.values()), path) - for path in constexprs + for path in constexpr_paths } attrvals = [x[1] for x in specialization] - attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str)) - attrs = {k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attrs} + attr_paths = cast( + "list[tuple[int, ...]]", + find_paths_if(attrvals, lambda _, x: isinstance(x, str)), + ) + attrs = { + k: backend.parse_attr(get_iterable_path(attrvals, k)) for k in attr_paths + } def finish_it() -> None: src = fn.ASTSource(fn, signature, constexprs, attrs) diff --git a/helion/runtime/settings.py b/helion/runtime/settings.py index 9c01923d7..9362ac18a 100644 --- a/helion/runtime/settings.py +++ b/helion/runtime/settings.py @@ -3,7 +3,6 @@ import dataclasses import logging import os -import sys import threading import time from typing import TYPE_CHECKING @@ -132,6 +131,20 @@ def _get_autotune_effort() -> AutotuneEffort: return cast("AutotuneEffort", os.environ.get("HELION_AUTOTUNE_EFFORT", "full")) +def _get_autotune_precompile() -> str | None: + value = os.environ.get("HELION_AUTOTUNE_PRECOMPILE") + if value is None: + return "spawn" + mode = value.strip().lower() + if mode in {"", "0", "false", "none"}: + return None + if mode in {"spawn", "fork"}: + return mode + raise ValueError( + "HELION_AUTOTUNE_PRECOMPILE must be 'spawn', 'fork', or empty to disable precompile" + ) + + @dataclasses.dataclass class _Settings: # see __slots__ below for the doc strings that show up in help(Settings) @@ -148,7 +161,9 @@ class _Settings: autotune_compile_timeout: int = int( os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60") ) - autotune_precompile: bool = sys.platform != "win32" + autotune_precompile: str | None = dataclasses.field( + default_factory=_get_autotune_precompile + ) autotune_precompile_jobs: int | None = None autotune_random_seed: int = dataclasses.field( default_factory=_get_autotune_random_seed @@ -196,7 +211,7 @@ class Settings(_Settings): "static_shapes": "If True, use static shapes for all tensors. This is a performance optimization.", "autotune_log_level": "Log level for autotuning using Python logging levels. Default is logging.INFO. Use 0 to disable all output.", "autotune_compile_timeout": "Timeout for Triton compilation in seconds used for autotuning. Default is 60 seconds.", - "autotune_precompile": "If True, precompile the kernel before autotuning. Requires fork-safe environment.", + "autotune_precompile": "Autotuner precompile mode: 'spawn', 'fork', or falsy/None to disable. Defaults to 'spawn' on non-Windows platforms.", "autotune_precompile_jobs": "Maximum concurrent Triton precompile processes, default to cpu count.", "autotune_random_seed": "Seed used for autotuner random number generation. Defaults to HELION_AUTOTUNE_RANDOM_SEED or a time-based seed.", "autotune_accuracy_check": "If True, validate candidate configs against the baseline kernel output before accepting them during autotuning.", diff --git a/test/test_autotuner.py b/test/test_autotuner.py index 2e69a4f10..7710b246e 100644 --- a/test/test_autotuner.py +++ b/test/test_autotuner.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import contextmanager +from contextlib import nullcontext import math import os from pathlib import Path @@ -210,6 +211,7 @@ def test_random_search(self): torch.randn([512, 512], device=DEVICE), ) bound_kernel = examples_matmul.bind(args) + bound_kernel.settings.autotune_precompile = None random.seed(123) best = RandomSearch(bound_kernel, args, 20).autotune() fn = bound_kernel.compile_config(best) @@ -301,41 +303,68 @@ def add_inplace(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: b[tile] = a[tile] + b[tile] return b - a = torch.randn([32], device=DEVICE) - b = torch.randn([32], device=DEVICE) - bound_kernel = add_inplace.bind((a, b)) - - original_compile = bound_kernel.compile_config - - def make_bad_config_produce_wrong_output( - config: helion.Config, *, allow_print: bool = True - ): - fn = original_compile(config, allow_print=allow_print) - if config == bad_config: - return lambda *fn_args, **fn_kwargs: fn(*fn_args, **fn_kwargs) + 1 - return fn - - with patch.object( - bound_kernel, - "compile_config", - side_effect=make_bad_config_produce_wrong_output, - ): - search = FiniteSearch( - bound_kernel, (a, b), configs=[bad_config, good_config] - ) - _, bad_time = search.benchmark(bad_config) - assert math.isinf(bad_time) - self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) - search.counters["accuracy_mismatch"] = 0 # reset counter - - _, good_time = search.benchmark(good_config) - assert not math.isinf(good_time) - self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) - search.counters["accuracy_mismatch"] = 0 # reset counter - - best = search._autotune() - self.assertEqual(best, good_config) - self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) + def run_mode(mode: str, *, expect_error: bool) -> None: + a = torch.randn([32], device=DEVICE) + b = torch.randn([32], device=DEVICE) + bound_kernel = add_inplace.bind((a, b)) + original_compile = bound_kernel.compile_config + bound_kernel.settings.autotune_precompile = mode + + def make_bad_config_produce_wrong_output( + config: helion.Config, *, allow_print: bool = True + ): + fn = original_compile(config, allow_print=allow_print) + if config == bad_config: + return lambda *fn_args, **fn_kwargs: fn(*fn_args, **fn_kwargs) + 1 + return fn + + import helion.autotuner.base_search as base_search_module + + with patch.object( + bound_kernel, + "compile_config", + side_effect=make_bad_config_produce_wrong_output, + ): + search = FiniteSearch( + bound_kernel, (a, b), configs=[bad_config, good_config] + ) + if mode == "fork": + start_cm = patch.object( + search, + "start_precompile_and_check_for_hangs", + side_effect=lambda config, + fn: base_search_module.PrecompileFuture.skip( + search, config, True + ), + ) + else: + start_cm = nullcontext() + + with start_cm: + if expect_error: + with self.assertRaisesRegex( + helion.exc.AutotuneError, + 'Set HELION_AUTOTUNE_PRECOMPILE="fork"', + ): + search.autotune() + return + + _, bad_time = search.benchmark(bad_config) + assert math.isinf(bad_time) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) + search.counters["accuracy_mismatch"] = 0 + + _, good_time = search.benchmark(good_config) + assert not math.isinf(good_time) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) + search.counters["accuracy_mismatch"] = 0 + + best = search.autotune() + self.assertEqual(best, good_config) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) + + run_mode("fork", expect_error=False) + run_mode("spawn", expect_error=True) def test_accuracy_check_filters_bad_config_wrong_arg_mutation(self) -> None: bad_config = helion.Config(block_sizes=[1], num_warps=8) @@ -347,48 +376,77 @@ def add_inplace(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: b[tile] = a[tile] + b[tile] return b - a = torch.randn([32], device=DEVICE) - b = torch.randn([32], device=DEVICE) - bound_kernel = add_inplace.bind((a, b)) - - original_compile = bound_kernel.compile_config - - def make_bad_config_produce_wrong_input_arg_mutation( - config: helion.Config, *, allow_print: bool = True - ): - fn = original_compile(config, allow_print=allow_print) - if config == bad_config: - - def wrong_fn(*fn_args, **fn_kwargs): - result = fn(*fn_args, **fn_kwargs) - # Introduce an extra mutation so inputs differ from baseline - fn_args[1].add_(1) - return result - - return wrong_fn - return fn - - with patch.object( - bound_kernel, - "compile_config", - side_effect=make_bad_config_produce_wrong_input_arg_mutation, - ): - search = FiniteSearch( - bound_kernel, (a, b), configs=[bad_config, good_config] - ) - _, bad_time = search.benchmark(bad_config) - assert math.isinf(bad_time) - self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) - search.counters["accuracy_mismatch"] = 0 # reset counter - - _, good_time = search.benchmark(good_config) - assert not math.isinf(good_time) - self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) - search.counters["accuracy_mismatch"] = 0 # reset counter - - best = search._autotune() - self.assertEqual(best, good_config) - self.assertGreaterEqual(search.counters.get("accuracy_mismatch", 0), 1) + def run_mode(mode: str, *, expect_error: bool) -> None: + a = torch.randn([32], device=DEVICE) + b = torch.randn([32], device=DEVICE) + bound_kernel = add_inplace.bind((a, b)) + original_compile = bound_kernel.compile_config + bound_kernel.settings.autotune_precompile = mode + + def make_bad_config_produce_wrong_input_arg_mutation( + config: helion.Config, *, allow_print: bool = True + ): + fn = original_compile(config, allow_print=allow_print) + if config == bad_config: + + def wrong_fn(*fn_args, **fn_kwargs): + result = fn(*fn_args, **fn_kwargs) + # Introduce an extra mutation so inputs differ from baseline + fn_args[1].add_(1) + return result + + return wrong_fn + return fn + + import helion.autotuner.base_search as base_search_module + + with patch.object( + bound_kernel, + "compile_config", + side_effect=make_bad_config_produce_wrong_input_arg_mutation, + ): + search = FiniteSearch( + bound_kernel, (a, b), configs=[bad_config, good_config] + ) + if mode == "fork": + start_cm = patch.object( + search, + "start_precompile_and_check_for_hangs", + side_effect=lambda config, + fn: base_search_module.PrecompileFuture.skip( + search, config, True + ), + ) + else: + start_cm = nullcontext() + + with start_cm: + if expect_error: + with self.assertRaisesRegex( + helion.exc.AutotuneError, + 'Set HELION_AUTOTUNE_PRECOMPILE="fork"', + ): + search.autotune() + return + + _, bad_time = search.benchmark(bad_config) + assert math.isinf(bad_time) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 1) + search.counters["accuracy_mismatch"] = 0 + + _, good_time = search.benchmark(good_config) + assert not math.isinf(good_time) + self.assertEqual(search.counters.get("accuracy_mismatch", 0), 0) + search.counters["accuracy_mismatch"] = 0 + + best = search.autotune() + self.assertEqual(best, good_config) + self.assertGreaterEqual( + search.counters.get("accuracy_mismatch", 0), 1 + ) + + run_mode("fork", expect_error=False) + run_mode("spawn", expect_error=True) def test_max_generations(self): """Autotuner max generation respects explicit kwargs then setting override."""