2020from torch .utils ._pytree import tree_map
2121import triton
2222
23+ from ._compat import get_tensor_descriptor_fn_name
2324from ._utils import counters
25+ from .autotuner .benchmarking import compute_repeat
26+ from .autotuner .benchmarking import interleaved_bench
2427from .runtime .config import Config
25- import helion
26- from helion ._compat import get_tensor_descriptor_fn_name
27- from helion .autotuner .benchmarking import compute_repeat
28- from helion .autotuner .benchmarking import interleaved_bench
29- from helion .runtime .ref_mode import is_ref_mode_enabled
28+ from .runtime .ref_mode import is_ref_mode_enabled
29+ from .runtime .settings import RefMode
3030
3131if TYPE_CHECKING :
3232 import types
@@ -104,7 +104,7 @@ def track_run_ref_calls() -> Generator[list[int], None, None]:
104104 Yields:
105105 A list that will contain the count of run_ref calls.
106106 """
107- from helion .runtime .kernel import BoundKernel
107+ from .runtime .kernel import BoundKernel
108108
109109 original_run_ref = BoundKernel .run_ref
110110 run_ref_count = [0 ]
@@ -123,7 +123,7 @@ def tracked_run_ref(self: BoundKernel, *args: object) -> object:
123123
124124@contextlib .contextmanager
125125def assert_helion_ref_mode (
126- ref_mode : helion . RefMode = helion . RefMode .OFF ,
126+ ref_mode : RefMode = RefMode .OFF ,
127127) -> Generator [None , None , None ]:
128128 """Context manager that asserts Helion compilation behavior based on RefMode.
129129
@@ -133,12 +133,12 @@ def assert_helion_ref_mode(
133133 with track_run_ref_calls () as run_ref_count :
134134 yield
135135
136- if ref_mode == helion . RefMode .OFF :
136+ if ref_mode == RefMode .OFF :
137137 # In normal mode (RefMode.OFF), run_ref should not be called
138138 assert run_ref_count [0 ] == 0 , (
139139 f"Expected run_ref to not be called in normal mode (RefMode.OFF), but got: run_ref={ run_ref_count [0 ]} "
140140 )
141- elif ref_mode == helion . RefMode .EAGER :
141+ elif ref_mode == RefMode .EAGER :
142142 # In ref eager mode (RefMode.EAGER), run_ref should be called
143143 assert run_ref_count [0 ] > 0 , (
144144 f"Expected run_ref to be called in ref eager mode (RefMode.EAGER), but got: run_ref={ run_ref_count [0 ]} "
@@ -148,11 +148,11 @@ def assert_helion_ref_mode(
148148
149149
150150assert_helion_compilation = functools .partial (
151- assert_helion_ref_mode , ref_mode = helion . RefMode .OFF
151+ assert_helion_ref_mode , ref_mode = RefMode .OFF
152152)
153153
154154assert_ref_eager_mode = functools .partial (
155- assert_helion_ref_mode , ref_mode = helion . RefMode .EAGER
155+ assert_helion_ref_mode , ref_mode = RefMode .EAGER
156156)
157157
158158
0 commit comments