Skip to content

Commit f557f34

Browse files
committed
wip
1 parent 36c491e commit f557f34

File tree

8 files changed

+500
-82
lines changed

8 files changed

+500
-82
lines changed

helion/_compiler/ast_extension.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from typing import TYPE_CHECKING
1212
from typing import TypeVar
1313

14+
import torch
15+
1416
from .. import exc
1517
from .output_lines import OutputLines
1618
from .source_location import SourceLocation
@@ -87,10 +89,29 @@ def __repr__(self) -> str:
8789

8890
def update_type_info(self, type_info: TypeInfo) -> TypeInfo:
8991
if self._type_info is not None and type_info != self._type_info:
92+
prev_rank = self._tensor_rank(self._type_info)
93+
new_rank = self._tensor_rank(type_info)
94+
if (
95+
prev_rank is not None
96+
and new_rank is not None
97+
and prev_rank != new_rank
98+
):
99+
self._type_info = type_info
100+
return self._type_info
90101
type_info = self._type_info.merge(type_info)
91102
self._type_info = type_info
92103
return self._type_info
93104

105+
@staticmethod
106+
def _tensor_rank(type_info: "TypeInfo") -> int | None:
107+
for attr in ["fake_value", "tensor"]:
108+
obj = getattr(type_info, attr, None)
109+
if attr == "tensor" and obj is not None:
110+
obj = getattr(obj, "fake_value", None)
111+
if isinstance(obj, torch.Tensor):
112+
return obj.dim()
113+
return None
114+
94115
def debug_annotations(self) -> list[str]:
95116
result = []
96117
if self._type_info:

helion/_compiler/compile_environment.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
194194
if rdim.reduction and rdim.size == size:
195195
return rdim
196196

197+
# Check if size matches any tile dimension for symbolic equality.
198+
# When building expressions that mix sizes derived from tiles (e.g. via
199+
# slicing) with sizes coming directly from tile block vars, we want them
200+
# to share the same SymInt variable whenever they are equal by
201+
# construction. This preserves equality in the shape environment and
202+
# avoids spurious "size mismatch" issues during fake-tensor broadcasting
203+
# and arithmetic in type propagation.
204+
if isinstance(size, torch.SymInt):
205+
block_idx = self.get_block_id(size)
206+
if block_idx is not None and not self.block_sizes[block_idx].reduction:
207+
return self._clone_block_size_as_reduction(block_idx, size)
208+
209+
sym = size._sympy_()
210+
for block_idx, block_info in enumerate(self.block_sizes):
211+
if not block_info.reduction and sym == block_info.symbol():
212+
return self._clone_block_size_as_reduction(block_idx, size)
213+
197214
# Allocate a new reduction dimension
215+
return self._allocate_new_reduction(size)
216+
217+
def _clone_block_size_as_reduction(
218+
self, block_idx: int, size: torch.SymInt | int
219+
) -> BlockSizeInfo:
220+
rdim = self._allocate_new_reduction(size)
221+
rdim.var = self.block_sizes[block_idx].var
222+
return rdim
223+
224+
def _allocate_new_reduction(self, size: torch.SymInt | int) -> BlockSizeInfo:
198225
rdim_idx = self.allocate_block_size(
199226
size,
200227
reduction=True,
201228
source=ReductionLoopBlockSizeSource(
202-
sum([int(bs.reduction) for bs in self.block_sizes])
229+
self._next_reduction_loop_index()
203230
),
204231
hint=next_power_of_2(self.size_hint(size)),
205232
)
206233
return self.block_sizes[rdim_idx]
207234

235+
def _next_reduction_loop_index(self) -> int:
236+
return sum(int(info.reduction) for info in self.block_sizes)
237+
208238
def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
209239
source = _current_symbol_source()
210240
with self.shape_env.ignore_fresh_unbacked_symbols():
@@ -257,6 +287,90 @@ def cached_create_unbacked_symint(
257287
self._symint_cache[key] = result
258288
return result
259289

290+
291+
def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
292+
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
293+
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]
294+
295+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
296+
"""Return the originating ``tile.index`` block id if present."""
297+
return getattr(tensor, "_tile_index_block_id", None)
298+
299+
def get_indexer_output_dims(
300+
self,
301+
indexer_tensor: torch.Tensor,
302+
base_dim_size: int | torch.SymInt | None,
303+
) -> list[int | torch.SymInt]:
304+
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""
305+
306+
dims = list(indexer_tensor.size())
307+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
308+
309+
# Multi-dimensional indexer - return full shape
310+
if len(non_broadcast_dims) > 1:
311+
return dims
312+
313+
# Try to find block_id from various sources
314+
block_id = (
315+
self.get_tile_index_tensor_block_id(indexer_tensor)
316+
or (self.get_block_id(base_dim_size) if base_dim_size is not None else None)
317+
or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None)
318+
)
319+
320+
if block_id is not None:
321+
return [self.block_sizes[block_id].var]
322+
return [non_broadcast_dims[0]] if non_broadcast_dims else [1]
323+
324+
def tensor_indexer_broadcast_shape(
325+
self, tensors: typing.Sequence[torch.Tensor]
326+
) -> list[int | torch.SymInt] | None:
327+
"""Compute a shared broadcast shape for tensor indexers when needed."""
328+
329+
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
330+
if not tensor_list:
331+
return None
332+
333+
if all(self.get_tile_index_tensor_block_id(t) is not None for t in tensor_list):
334+
return None
335+
336+
shapes = [list(t.size()) for t in tensor_list]
337+
return compute_broadcast_shape_for_tensor_indexers(shapes, self)
338+
339+
def resolve_tile_index_shape(
340+
self, input_tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
341+
) -> tuple[list[int | torch.SymInt], int | None]:
342+
"""Resolve the symbolic shape for tensors derived from ``tile.index``.
343+
344+
Returns a copy of ``output_shape`` where the single non-broadcast
345+
dimension is replaced with the canonical block-symbol and the associated
346+
block_id to register on the new tensor. If the tensor is not a tile
347+
indexer or it introduces more than one non-broadcast dimension, the
348+
original shape and ``None`` are returned.
349+
"""
350+
351+
block_id = self.get_tile_index_tensor_block_id(input_tensor)
352+
if block_id is None:
353+
return list(output_shape), None
354+
355+
resolved = list(output_shape)
356+
non_broadcast = [i for i, s in enumerate(resolved) if self.size_hint(s) != 1]
357+
if len(non_broadcast) <= 1:
358+
if non_broadcast:
359+
resolved[non_broadcast[0]] = self.block_sizes[block_id].var
360+
return resolved, block_id
361+
return resolved, None
362+
363+
def new_index_result(
364+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
365+
) -> torch.Tensor:
366+
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""
367+
368+
resolved_shape, block_id = self.resolve_tile_index_shape(tensor, output_shape)
369+
result = tensor.new_empty(resolved_shape)
370+
if block_id is not None:
371+
self.register_tile_index_tensor_block_id(result, block_id)
372+
return result
373+
260374
def to_fake(self, obj: object, origin: Origin) -> object:
261375
if obj is None:
262376
return None
@@ -339,6 +453,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
339453
self.fake_mode, tensor, shape_env=self.shape_env, source=source
340454
)
341455
self.input_sources[result] = source
456+
if hasattr(tensor, "_tile_index_block_id"):
457+
self.register_tile_index_tensor_block_id(
458+
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
459+
)
342460
if isinstance(source, LocalSource):
343461
for i, s in enumerate(result.size()):
344462
if isinstance(s, torch.SymInt) and isinstance(
@@ -631,6 +749,34 @@ def _has_unbacked(expr: sympy.Expr) -> bool:
631749
return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue]
632750

633751

752+
def compute_broadcast_shape_for_tensor_indexers(
753+
shapes: list[list[int | torch.SymInt]],
754+
env: "CompileEnvironment"
755+
) -> list[int | torch.SymInt]:
756+
"""Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
757+
758+
For multiple 1D tensors, this should return a shape that represents their Cartesian product.
759+
For example, two tensors of shape [8] and [8] should broadcast to shape [8, 8].
760+
"""
761+
if not shapes:
762+
return []
763+
764+
# Special case: multiple 1D tensors form a Cartesian product
765+
all_1d = all(len(shape) == 1 for shape in shapes)
766+
if all_1d and len(shapes) > 1:
767+
# Return the Cartesian product shape
768+
return [shape[0] for shape in shapes]
769+
770+
# General broadcasting case
771+
max_ndim = max(len(s) for s in shapes)
772+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
773+
774+
return [
775+
next((d for d in dims if env.size_hint(d) != 1), 1)
776+
for dims in zip(*padded, strict=True)
777+
]
778+
779+
634780
def format_shape(shape: tuple[object, ...]) -> str:
635781
def _format_dim(dim: object) -> str:
636782
if isinstance(dim, torch.SymInt):

0 commit comments

Comments
 (0)