Skip to content

Commit f51b457

Browse files
committed
Import updates
stack-info: PR: #953, branch: jansel/stack/198
1 parent cd3304d commit f51b457

File tree

9 files changed

+30
-30
lines changed

9 files changed

+30
-30
lines changed

helion/_compiler/reduction_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@
2626
from .tile_strategy import PersistentReductionState
2727
from .tile_strategy import TileStrategy
2828

29-
ARG_REDUCE_MAP = {"argmax": ("max", "maximum"), "argmin": ("min", "minimum")}
30-
3129
if TYPE_CHECKING:
3230
from .device_function import DeviceFunction
3331
from .inductor_lowering import CodegenState
3432

33+
ARG_REDUCE_MAP = {"argmax": ("max", "maximum"), "argmin": ("min", "minimum")}
34+
3535

3636
class ReductionStrategy(TileStrategy):
3737
def __init__(

helion/_compiler/type_propagation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torch.utils._pytree import tree_map_only
2323

2424
from .. import exc
25+
from .. import language as language_module
2526
from ..autotuner.config_fragment import ConfigSpecFragment
2627
from ..autotuner.config_spec import BlockSizeSpec
2728
from ..language._decorators import get_device_func_replacement
@@ -54,7 +55,6 @@
5455
from .variable_origin import Origin
5556
from .variable_origin import SourceOrigin
5657
from .variable_origin import TensorSizeOrigin
57-
import helion
5858

5959
if TYPE_CHECKING:
6060
from collections.abc import Callable
@@ -97,7 +97,7 @@ def _get(self, name: str) -> TypeInfo:
9797
else:
9898
raise exc.UndefinedVariable(name) from None
9999
else:
100-
if value is helion.language:
100+
if value is language_module:
101101
origin = GlobalOrigin(name="hl", function=self.function)
102102
return TypeInfo.from_example(value, origin)
103103
if name in library_imports:

helion/_testing.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
from torch.utils._pytree import tree_map
2121
import triton
2222

23+
from ._compat import get_tensor_descriptor_fn_name
2324
from ._utils import counters
25+
from .autotuner.benchmarking import compute_repeat
26+
from .autotuner.benchmarking import interleaved_bench
2427
from .runtime.config import Config
25-
import helion
26-
from helion._compat import get_tensor_descriptor_fn_name
27-
from helion.autotuner.benchmarking import compute_repeat
28-
from helion.autotuner.benchmarking import interleaved_bench
29-
from helion.runtime.ref_mode import is_ref_mode_enabled
28+
from .runtime.ref_mode import is_ref_mode_enabled
29+
from .runtime.settings import RefMode
3030

3131
if TYPE_CHECKING:
3232
import types
@@ -104,7 +104,7 @@ def track_run_ref_calls() -> Generator[list[int], None, None]:
104104
Yields:
105105
A list that will contain the count of run_ref calls.
106106
"""
107-
from helion.runtime.kernel import BoundKernel
107+
from .runtime.kernel import BoundKernel
108108

109109
original_run_ref = BoundKernel.run_ref
110110
run_ref_count = [0]
@@ -123,7 +123,7 @@ def tracked_run_ref(self: BoundKernel, *args: object) -> object:
123123

124124
@contextlib.contextmanager
125125
def assert_helion_ref_mode(
126-
ref_mode: helion.RefMode = helion.RefMode.OFF,
126+
ref_mode: RefMode = RefMode.OFF,
127127
) -> Generator[None, None, None]:
128128
"""Context manager that asserts Helion compilation behavior based on RefMode.
129129
@@ -133,12 +133,12 @@ def assert_helion_ref_mode(
133133
with track_run_ref_calls() as run_ref_count:
134134
yield
135135

136-
if ref_mode == helion.RefMode.OFF:
136+
if ref_mode == RefMode.OFF:
137137
# In normal mode (RefMode.OFF), run_ref should not be called
138138
assert run_ref_count[0] == 0, (
139139
f"Expected run_ref to not be called in normal mode (RefMode.OFF), but got: run_ref={run_ref_count[0]}"
140140
)
141-
elif ref_mode == helion.RefMode.EAGER:
141+
elif ref_mode == RefMode.EAGER:
142142
# In ref eager mode (RefMode.EAGER), run_ref should be called
143143
assert run_ref_count[0] > 0, (
144144
f"Expected run_ref to be called in ref eager mode (RefMode.EAGER), but got: run_ref={run_ref_count[0]}"
@@ -148,11 +148,11 @@ def assert_helion_ref_mode(
148148

149149

150150
assert_helion_compilation = functools.partial(
151-
assert_helion_ref_mode, ref_mode=helion.RefMode.OFF
151+
assert_helion_ref_mode, ref_mode=RefMode.OFF
152152
)
153153

154154
assert_ref_eager_mode = functools.partial(
155-
assert_helion_ref_mode, ref_mode=helion.RefMode.EAGER
155+
assert_helion_ref_mode, ref_mode=RefMode.EAGER
156156
)
157157

158158

helion/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def convert_size_arg(size: object) -> object:
3535
- Other values -> unchanged
3636
"""
3737
# Import here to avoid circular dependency
38-
from helion.language.ref_tile import RefTile
38+
from .language.ref_tile import RefTile
3939

4040
if isinstance(size, (list, tuple)):
4141
return [convert_size_arg(item) for item in size]
@@ -54,7 +54,7 @@ def convert_tile_indices_to_slices(index: object) -> object:
5454
Index with RefTile objects replaced by their slice objects
5555
"""
5656
# Import here to avoid circular dependency
57-
from helion.language.ref_tile import RefTile
57+
from .language.ref_tile import RefTile
5858

5959
def _extract_slice(obj: object) -> object:
6060
return obj._slice if isinstance(obj, RefTile) else obj

helion/autotuner/base_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747
from .logger import format_triton_compile_failure
4848
from .progress_bar import iter_with_progress
4949

50-
log = logging.getLogger(__name__)
51-
5250
if TYPE_CHECKING:
5351
from collections.abc import Sequence
5452

@@ -58,6 +56,8 @@
5856
from ..runtime.settings import Settings
5957
from . import ConfigSpec
6058

59+
log = logging.getLogger(__name__)
60+
6161

6262
class BaseAutotuner(abc.ABC):
6363
"""

helion/language/memory_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .. import exc
1010
from .._compiler.indexing_strategy import SubscriptIndexing
1111
from . import _decorators
12-
from helion.language.stack_tensor import StackTensor
12+
from .stack_tensor import StackTensor
1313

1414
if TYPE_CHECKING:
1515
from .._compiler.inductor_lowering import CodegenState

helion/language/signal_wait.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .. import exc
1010
from .._compiler.indexing_strategy import SubscriptIndexing
1111
from . import _decorators
12-
from helion.language.stack_tensor import StackTensor
12+
from .stack_tensor import StackTensor
1313

1414
if TYPE_CHECKING:
1515
import ast

helion/runtime/ref_mode.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
import torch
1212
from torch.overrides import BaseTorchFunctionMode
1313

14-
from helion._compiler.compile_environment import CompileEnvironment
15-
from helion._compiler.compile_environment import NoCurrentEnvironment
16-
from helion._compiler.compile_environment import tls as ce_tls
17-
from helion._utils import convert_size_arg
18-
from helion._utils import create_shape_matching_slices
14+
from .._compiler.compile_environment import CompileEnvironment
15+
from .._compiler.compile_environment import NoCurrentEnvironment
16+
from .._compiler.compile_environment import tls as ce_tls
17+
from .._utils import convert_size_arg
18+
from .._utils import create_shape_matching_slices
1919

2020
if TYPE_CHECKING:
2121
from typing_extensions import Self

helion/runtime/settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
import torch
1515
from torch._environment import is_fbcode
1616

17-
from helion import exc
18-
from helion.autotuner.effort_profile import AutotuneEffort
19-
from helion.autotuner.effort_profile import get_effort_profile
20-
from helion.runtime.ref_mode import RefMode
17+
from .. import exc
18+
from ..autotuner.effort_profile import AutotuneEffort
19+
from ..autotuner.effort_profile import get_effort_profile
20+
from .ref_mode import RefMode
2121

2222
if TYPE_CHECKING:
2323
from contextlib import AbstractContextManager

0 commit comments

Comments
 (0)