Skip to content

Commit 9e6848b

Browse files
committed
wip
1 parent 44d624f commit 9e6848b

File tree

8 files changed

+500
-82
lines changed

8 files changed

+500
-82
lines changed

helion/_compiler/ast_extension.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from typing import TYPE_CHECKING
1212
from typing import TypeVar
1313

14+
import torch
15+
1416
from .. import exc
1517
from .output_lines import OutputLines
1618
from .source_location import SourceLocation
@@ -87,10 +89,29 @@ def __repr__(self) -> str:
8789

8890
def update_type_info(self, type_info: TypeInfo) -> TypeInfo:
8991
if self._type_info is not None and type_info != self._type_info:
92+
prev_rank = self._tensor_rank(self._type_info)
93+
new_rank = self._tensor_rank(type_info)
94+
if (
95+
prev_rank is not None
96+
and new_rank is not None
97+
and prev_rank != new_rank
98+
):
99+
self._type_info = type_info
100+
return self._type_info
90101
type_info = self._type_info.merge(type_info)
91102
self._type_info = type_info
92103
return self._type_info
93104

105+
@staticmethod
106+
def _tensor_rank(type_info: "TypeInfo") -> int | None:
107+
for attr in ["fake_value", "tensor"]:
108+
obj = getattr(type_info, attr, None)
109+
if attr == "tensor" and obj is not None:
110+
obj = getattr(obj, "fake_value", None)
111+
if isinstance(obj, torch.Tensor):
112+
return obj.dim()
113+
return None
114+
94115
def debug_annotations(self) -> list[str]:
95116
result = []
96117
if self._type_info:

helion/_compiler/compile_environment.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,17 +197,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
197197
if rdim.reduction and rdim.size == size:
198198
return rdim
199199

200+
# Check if size matches any tile dimension for symbolic equality.
201+
# When building expressions that mix sizes derived from tiles (e.g. via
202+
# slicing) with sizes coming directly from tile block vars, we want them
203+
# to share the same SymInt variable whenever they are equal by
204+
# construction. This preserves equality in the shape environment and
205+
# avoids spurious "size mismatch" issues during fake-tensor broadcasting
206+
# and arithmetic in type propagation.
207+
if isinstance(size, torch.SymInt):
208+
block_idx = self.get_block_id(size)
209+
if block_idx is not None and not self.block_sizes[block_idx].reduction:
210+
return self._clone_block_size_as_reduction(block_idx, size)
211+
212+
sym = size._sympy_()
213+
for block_idx, block_info in enumerate(self.block_sizes):
214+
if not block_info.reduction and sym == block_info.symbol():
215+
return self._clone_block_size_as_reduction(block_idx, size)
216+
200217
# Allocate a new reduction dimension
218+
return self._allocate_new_reduction(size)
219+
220+
def _clone_block_size_as_reduction(
221+
self, block_idx: int, size: torch.SymInt | int
222+
) -> BlockSizeInfo:
223+
rdim = self._allocate_new_reduction(size)
224+
rdim.var = self.block_sizes[block_idx].var
225+
return rdim
226+
227+
def _allocate_new_reduction(self, size: torch.SymInt | int) -> BlockSizeInfo:
201228
rdim_idx = self.allocate_block_size(
202229
size,
203230
reduction=True,
204231
source=ReductionLoopBlockSizeSource(
205-
sum([int(bs.reduction) for bs in self.block_sizes])
232+
self._next_reduction_loop_index()
206233
),
207234
hint=next_power_of_2(self.size_hint(size)),
208235
)
209236
return self.block_sizes[rdim_idx]
210237

238+
def _next_reduction_loop_index(self) -> int:
239+
return sum(int(info.reduction) for info in self.block_sizes)
240+
211241
def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
212242
source = _current_symbol_source()
213243
with self.shape_env.ignore_fresh_unbacked_symbols():
@@ -260,6 +290,90 @@ def cached_create_unbacked_symint(
260290
self._symint_cache[key] = result
261291
return result
262292

293+
294+
def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
295+
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
296+
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]
297+
298+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
299+
"""Return the originating ``tile.index`` block id if present."""
300+
return getattr(tensor, "_tile_index_block_id", None)
301+
302+
def get_indexer_output_dims(
303+
self,
304+
indexer_tensor: torch.Tensor,
305+
base_dim_size: int | torch.SymInt | None,
306+
) -> list[int | torch.SymInt]:
307+
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""
308+
309+
dims = list(indexer_tensor.size())
310+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
311+
312+
# Multi-dimensional indexer - return full shape
313+
if len(non_broadcast_dims) > 1:
314+
return dims
315+
316+
# Try to find block_id from various sources
317+
block_id = (
318+
self.get_tile_index_tensor_block_id(indexer_tensor)
319+
or (self.get_block_id(base_dim_size) if base_dim_size is not None else None)
320+
or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None)
321+
)
322+
323+
if block_id is not None:
324+
return [self.block_sizes[block_id].var]
325+
return [non_broadcast_dims[0]] if non_broadcast_dims else [1]
326+
327+
def tensor_indexer_broadcast_shape(
328+
self, tensors: typing.Sequence[torch.Tensor]
329+
) -> list[int | torch.SymInt] | None:
330+
"""Compute a shared broadcast shape for tensor indexers when needed."""
331+
332+
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
333+
if not tensor_list:
334+
return None
335+
336+
if all(self.get_tile_index_tensor_block_id(t) is not None for t in tensor_list):
337+
return None
338+
339+
shapes = [list(t.size()) for t in tensor_list]
340+
return compute_broadcast_shape_for_tensor_indexers(shapes, self)
341+
342+
def resolve_tile_index_shape(
343+
self, input_tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
344+
) -> tuple[list[int | torch.SymInt], int | None]:
345+
"""Resolve the symbolic shape for tensors derived from ``tile.index``.
346+
347+
Returns a copy of ``output_shape`` where the single non-broadcast
348+
dimension is replaced with the canonical block-symbol and the associated
349+
block_id to register on the new tensor. If the tensor is not a tile
350+
indexer or it introduces more than one non-broadcast dimension, the
351+
original shape and ``None`` are returned.
352+
"""
353+
354+
block_id = self.get_tile_index_tensor_block_id(input_tensor)
355+
if block_id is None:
356+
return list(output_shape), None
357+
358+
resolved = list(output_shape)
359+
non_broadcast = [i for i, s in enumerate(resolved) if self.size_hint(s) != 1]
360+
if len(non_broadcast) <= 1:
361+
if non_broadcast:
362+
resolved[non_broadcast[0]] = self.block_sizes[block_id].var
363+
return resolved, block_id
364+
return resolved, None
365+
366+
def new_index_result(
367+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
368+
) -> torch.Tensor:
369+
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""
370+
371+
resolved_shape, block_id = self.resolve_tile_index_shape(tensor, output_shape)
372+
result = tensor.new_empty(resolved_shape)
373+
if block_id is not None:
374+
self.register_tile_index_tensor_block_id(result, block_id)
375+
return result
376+
263377
def to_fake(self, obj: object, origin: Origin) -> object:
264378
if obj is None:
265379
return None
@@ -342,6 +456,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
342456
self.fake_mode, tensor, shape_env=self.shape_env, source=source
343457
)
344458
self.input_sources[result] = source
459+
if hasattr(tensor, "_tile_index_block_id"):
460+
self.register_tile_index_tensor_block_id(
461+
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
462+
)
345463
if isinstance(source, LocalSource):
346464
for i, s in enumerate(result.size()):
347465
if isinstance(s, torch.SymInt) and isinstance(
@@ -634,6 +752,34 @@ def _has_unbacked(expr: sympy.Expr) -> bool:
634752
return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue]
635753

636754

755+
def compute_broadcast_shape_for_tensor_indexers(
756+
shapes: list[list[int | torch.SymInt]],
757+
env: "CompileEnvironment"
758+
) -> list[int | torch.SymInt]:
759+
"""Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
760+
761+
For multiple 1D tensors, this should return a shape that represents their Cartesian product.
762+
For example, two tensors of shape [8] and [8] should broadcast to shape [8, 8].
763+
"""
764+
if not shapes:
765+
return []
766+
767+
# Special case: multiple 1D tensors form a Cartesian product
768+
all_1d = all(len(shape) == 1 for shape in shapes)
769+
if all_1d and len(shapes) > 1:
770+
# Return the Cartesian product shape
771+
return [shape[0] for shape in shapes]
772+
773+
# General broadcasting case
774+
max_ndim = max(len(s) for s in shapes)
775+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
776+
777+
return [
778+
next((d for d in dims if env.size_hint(d) != 1), 1)
779+
for dims in zip(*padded, strict=True)
780+
]
781+
782+
637783
def format_shape(shape: tuple[object, ...]) -> str:
638784
def _format_dim(dim: object) -> str:
639785
if isinstance(dim, torch.SymInt):

0 commit comments

Comments
 (0)