2020from torch .utils ._pytree import tree_map
2121import triton
2222
23+ from ._compat import get_tensor_descriptor_fn_name
2324from ._utils import counters
25+ from .autotuner .benchmarking import compute_repeat
26+ from .autotuner .benchmarking import interleaved_bench
2427from .runtime .config import Config
25- import helion
26- from helion ._compat import get_tensor_descriptor_fn_name
27- from helion .autotuner .benchmarking import compute_repeat
28- from helion .autotuner .benchmarking import interleaved_bench
29- from helion .runtime .ref_mode import is_ref_mode_enabled
28+ from .runtime .ref_mode import is_ref_mode_enabled
29+ from .runtime .settings import RefMode
3030
3131if TYPE_CHECKING :
3232 import types
@@ -93,7 +93,7 @@ def track_run_ref_calls() -> Generator[list[int], None, None]:
9393 Yields:
9494 A list that will contain the count of run_ref calls.
9595 """
96- from helion .runtime .kernel import BoundKernel
96+ from .runtime .kernel import BoundKernel
9797
9898 original_run_ref = BoundKernel .run_ref
9999 run_ref_count = [0 ]
@@ -112,7 +112,7 @@ def tracked_run_ref(self: BoundKernel, *args: object) -> object:
112112
113113@contextlib .contextmanager
114114def assert_helion_ref_mode (
115- ref_mode : helion . RefMode = helion . RefMode .OFF ,
115+ ref_mode : RefMode = RefMode .OFF ,
116116) -> Generator [None , None , None ]:
117117 """Context manager that asserts Helion compilation behavior based on RefMode.
118118
@@ -122,12 +122,12 @@ def assert_helion_ref_mode(
122122 with track_run_ref_calls () as run_ref_count :
123123 yield
124124
125- if ref_mode == helion . RefMode .OFF :
125+ if ref_mode == RefMode .OFF :
126126 # In normal mode (RefMode.OFF), run_ref should not be called
127127 assert run_ref_count [0 ] == 0 , (
128128 f"Expected run_ref to not be called in normal mode (RefMode.OFF), but got: run_ref={ run_ref_count [0 ]} "
129129 )
130- elif ref_mode == helion . RefMode .EAGER :
130+ elif ref_mode == RefMode .EAGER :
131131 # In ref eager mode (RefMode.EAGER), run_ref should be called
132132 assert run_ref_count [0 ] > 0 , (
133133 f"Expected run_ref to be called in ref eager mode (RefMode.EAGER), but got: run_ref={ run_ref_count [0 ]} "
@@ -137,11 +137,11 @@ def assert_helion_ref_mode(
137137
138138
139139assert_helion_compilation = functools .partial (
140- assert_helion_ref_mode , ref_mode = helion . RefMode .OFF
140+ assert_helion_ref_mode , ref_mode = RefMode .OFF
141141)
142142
143143assert_ref_eager_mode = functools .partial (
144- assert_helion_ref_mode , ref_mode = helion . RefMode .EAGER
144+ assert_helion_ref_mode , ref_mode = RefMode .EAGER
145145)
146146
147147
0 commit comments