@@ -194,17 +194,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
194194 if rdim .reduction and rdim .size == size :
195195 return rdim
196196
197+ # Check if size matches any tile dimension for symbolic equality.
198+ # When building expressions that mix sizes derived from tiles (e.g. via
199+ # slicing) with sizes coming directly from tile block vars, we want them
200+ # to share the same SymInt variable whenever they are equal by
201+ # construction. This preserves equality in the shape environment and
202+ # avoids spurious "size mismatch" issues during fake-tensor broadcasting
203+ # and arithmetic in type propagation.
204+ if isinstance (size , torch .SymInt ):
205+ block_idx = self .get_block_id (size )
206+ if block_idx is not None and not self .block_sizes [block_idx ].reduction :
207+ return self ._clone_block_size_as_reduction (block_idx , size )
208+
209+ sym = size ._sympy_ ()
210+ for block_idx , block_info in enumerate (self .block_sizes ):
211+ if not block_info .reduction and sym == block_info .symbol ():
212+ return self ._clone_block_size_as_reduction (block_idx , size )
213+
197214 # Allocate a new reduction dimension
215+ return self ._allocate_new_reduction (size )
216+
217+ def _clone_block_size_as_reduction (
218+ self , block_idx : int , size : torch .SymInt | int
219+ ) -> BlockSizeInfo :
220+ rdim = self ._allocate_new_reduction (size )
221+ rdim .var = self .block_sizes [block_idx ].var
222+ return rdim
223+
224+ def _allocate_new_reduction (self , size : torch .SymInt | int ) -> BlockSizeInfo :
198225 rdim_idx = self .allocate_block_size (
199226 size ,
200227 reduction = True ,
201228 source = ReductionLoopBlockSizeSource (
202- sum ([ int ( bs . reduction ) for bs in self .block_sizes ] )
229+ self ._next_reduction_loop_index ( )
203230 ),
204231 hint = next_power_of_2 (self .size_hint (size )),
205232 )
206233 return self .block_sizes [rdim_idx ]
207234
235+ def _next_reduction_loop_index (self ) -> int :
236+ return sum (int (info .reduction ) for info in self .block_sizes )
237+
208238 def create_block_var (self , debug_name : str , hint : int = 64 ) -> torch .SymInt :
209239 source = _current_symbol_source ()
210240 with self .shape_env .ignore_fresh_unbacked_symbols ():
@@ -257,6 +287,90 @@ def cached_create_unbacked_symint(
257287 self ._symint_cache [key ] = result
258288 return result
259289
290+
291+ def register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) -> None :
292+ """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
293+ tensor ._tile_index_block_id = block_id # type: ignore[attr-defined]
294+
295+ def get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) -> int | None :
296+ """Return the originating ``tile.index`` block id if present."""
297+ return getattr (tensor , "_tile_index_block_id" , None )
298+
299+ def get_indexer_output_dims (
300+ self ,
301+ indexer_tensor : torch .Tensor ,
302+ base_dim_size : int | torch .SymInt | None ,
303+ ) -> list [int | torch .SymInt ]:
304+ """Map a tensor indexer's shape to the output dimensions for advanced indexing."""
305+
306+ dims = list (indexer_tensor .size ())
307+ non_broadcast_dims = [d for d in dims if self .size_hint (d ) != 1 ]
308+
309+ # Multi-dimensional indexer - return full shape
310+ if len (non_broadcast_dims ) > 1 :
311+ return dims
312+
313+ # Try to find block_id from various sources
314+ block_id = (
315+ self .get_tile_index_tensor_block_id (indexer_tensor )
316+ or (self .get_block_id (base_dim_size ) if base_dim_size is not None else None )
317+ or (self .get_block_id (non_broadcast_dims [0 ]) if non_broadcast_dims else None )
318+ )
319+
320+ if block_id is not None :
321+ return [self .block_sizes [block_id ].var ]
322+ return [non_broadcast_dims [0 ]] if non_broadcast_dims else [1 ]
323+
324+ def tensor_indexer_broadcast_shape (
325+ self , tensors : typing .Sequence [torch .Tensor ]
326+ ) -> list [int | torch .SymInt ] | None :
327+ """Compute a shared broadcast shape for tensor indexers when needed."""
328+
329+ tensor_list = [t for t in tensors if isinstance (t , torch .Tensor )]
330+ if not tensor_list :
331+ return None
332+
333+ if all (self .get_tile_index_tensor_block_id (t ) is not None for t in tensor_list ):
334+ return None
335+
336+ shapes = [list (t .size ()) for t in tensor_list ]
337+ return compute_broadcast_shape_for_tensor_indexers (shapes , self )
338+
339+ def resolve_tile_index_shape (
340+ self , input_tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
341+ ) -> tuple [list [int | torch .SymInt ], int | None ]:
342+ """Resolve the symbolic shape for tensors derived from ``tile.index``.
343+
344+ Returns a copy of ``output_shape`` where the single non-broadcast
345+ dimension is replaced with the canonical block-symbol and the associated
346+ block_id to register on the new tensor. If the tensor is not a tile
347+ indexer or it introduces more than one non-broadcast dimension, the
348+ original shape and ``None`` are returned.
349+ """
350+
351+ block_id = self .get_tile_index_tensor_block_id (input_tensor )
352+ if block_id is None :
353+ return list (output_shape ), None
354+
355+ resolved = list (output_shape )
356+ non_broadcast = [i for i , s in enumerate (resolved ) if self .size_hint (s ) != 1 ]
357+ if len (non_broadcast ) <= 1 :
358+ if non_broadcast :
359+ resolved [non_broadcast [0 ]] = self .block_sizes [block_id ].var
360+ return resolved , block_id
361+ return resolved , None
362+
363+ def new_index_result (
364+ self , tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
365+ ) -> torch .Tensor :
366+ """Create a new tensor for indexing/view ops while preserving tile index provenance."""
367+
368+ resolved_shape , block_id = self .resolve_tile_index_shape (tensor , output_shape )
369+ result = tensor .new_empty (resolved_shape )
370+ if block_id is not None :
371+ self .register_tile_index_tensor_block_id (result , block_id )
372+ return result
373+
260374 def to_fake (self , obj : object , origin : Origin ) -> object :
261375 if obj is None :
262376 return None
@@ -339,6 +453,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
339453 self .fake_mode , tensor , shape_env = self .shape_env , source = source
340454 )
341455 self .input_sources [result ] = source
456+ if hasattr (tensor , "_tile_index_block_id" ):
457+ self .register_tile_index_tensor_block_id (
458+ result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
459+ )
342460 if isinstance (source , LocalSource ):
343461 for i , s in enumerate (result .size ()):
344462 if isinstance (s , torch .SymInt ) and isinstance (
@@ -631,6 +749,34 @@ def _has_unbacked(expr: sympy.Expr) -> bool:
631749 return any (n .name .startswith ("u" ) for n in expr .free_symbols ) # pyright: ignore[reportAttributeAccessIssue]
632750
633751
752+ def compute_broadcast_shape_for_tensor_indexers (
753+ shapes : list [list [int | torch .SymInt ]],
754+ env : "CompileEnvironment"
755+ ) -> list [int | torch .SymInt ]:
756+ """Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
757+
758+ For multiple 1D tensors, this should return a shape that represents their Cartesian product.
759+ For example, two tensors of shape [8] and [8] should broadcast to shape [8, 8].
760+ """
761+ if not shapes :
762+ return []
763+
764+ # Special case: multiple 1D tensors form a Cartesian product
765+ all_1d = all (len (shape ) == 1 for shape in shapes )
766+ if all_1d and len (shapes ) > 1 :
767+ # Return the Cartesian product shape
768+ return [shape [0 ] for shape in shapes ]
769+
770+ # General broadcasting case
771+ max_ndim = max (len (s ) for s in shapes )
772+ padded = [([1 ] * (max_ndim - len (s )) + s ) for s in shapes ]
773+
774+ return [
775+ next ((d for d in dims if env .size_hint (d ) != 1 ), 1 )
776+ for dims in zip (* padded , strict = True )
777+ ]
778+
779+
634780def format_shape (shape : tuple [object , ...]) -> str :
635781 def _format_dim (dim : object ) -> str :
636782 if isinstance (dim , torch .SymInt ):
0 commit comments