@@ -197,17 +197,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
197197 if rdim .reduction and rdim .size == size :
198198 return rdim
199199
200+ # Check if size matches any tile dimension for symbolic equality.
201+ # When building expressions that mix sizes derived from tiles (e.g. via
202+ # slicing) with sizes coming directly from tile block vars, we want them
203+ # to share the same SymInt variable whenever they are equal by
204+ # construction. This preserves equality in the shape environment and
205+ # avoids spurious "size mismatch" issues during fake-tensor broadcasting
206+ # and arithmetic in type propagation.
207+ if isinstance (size , torch .SymInt ):
208+ block_idx = self .get_block_id (size )
209+ if block_idx is not None and not self .block_sizes [block_idx ].reduction :
210+ return self ._clone_block_size_as_reduction (block_idx , size )
211+
212+ sym = size ._sympy_ ()
213+ for block_idx , block_info in enumerate (self .block_sizes ):
214+ if not block_info .reduction and sym == block_info .symbol ():
215+ return self ._clone_block_size_as_reduction (block_idx , size )
216+
200217 # Allocate a new reduction dimension
218+ return self ._allocate_new_reduction (size )
219+
220+ def _clone_block_size_as_reduction (
221+ self , block_idx : int , size : torch .SymInt | int
222+ ) -> BlockSizeInfo :
223+ rdim = self ._allocate_new_reduction (size )
224+ rdim .var = self .block_sizes [block_idx ].var
225+ return rdim
226+
227+ def _allocate_new_reduction (self , size : torch .SymInt | int ) -> BlockSizeInfo :
201228 rdim_idx = self .allocate_block_size (
202229 size ,
203230 reduction = True ,
204231 source = ReductionLoopBlockSizeSource (
205- sum ([ int ( bs . reduction ) for bs in self .block_sizes ] )
232+ self ._next_reduction_loop_index ( )
206233 ),
207234 hint = next_power_of_2 (self .size_hint (size )),
208235 )
209236 return self .block_sizes [rdim_idx ]
210237
238+ def _next_reduction_loop_index (self ) -> int :
239+ return sum (int (info .reduction ) for info in self .block_sizes )
240+
211241 def create_block_var (self , debug_name : str , hint : int = 64 ) -> torch .SymInt :
212242 source = _current_symbol_source ()
213243 with self .shape_env .ignore_fresh_unbacked_symbols ():
@@ -260,6 +290,90 @@ def cached_create_unbacked_symint(
260290 self ._symint_cache [key ] = result
261291 return result
262292
293+
294+ def register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) -> None :
295+ """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
296+ tensor ._tile_index_block_id = block_id # type: ignore[attr-defined]
297+
298+ def get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) -> int | None :
299+ """Return the originating ``tile.index`` block id if present."""
300+ return getattr (tensor , "_tile_index_block_id" , None )
301+
302+ def get_indexer_output_dims (
303+ self ,
304+ indexer_tensor : torch .Tensor ,
305+ base_dim_size : int | torch .SymInt | None ,
306+ ) -> list [int | torch .SymInt ]:
307+ """Map a tensor indexer's shape to the output dimensions for advanced indexing."""
308+
309+ dims = list (indexer_tensor .size ())
310+ non_broadcast_dims = [d for d in dims if self .size_hint (d ) != 1 ]
311+
312+ # Multi-dimensional indexer - return full shape
313+ if len (non_broadcast_dims ) > 1 :
314+ return dims
315+
316+ # Try to find block_id from various sources
317+ block_id = (
318+ self .get_tile_index_tensor_block_id (indexer_tensor )
319+ or (self .get_block_id (base_dim_size ) if base_dim_size is not None else None )
320+ or (self .get_block_id (non_broadcast_dims [0 ]) if non_broadcast_dims else None )
321+ )
322+
323+ if block_id is not None :
324+ return [self .block_sizes [block_id ].var ]
325+ return [non_broadcast_dims [0 ]] if non_broadcast_dims else [1 ]
326+
327+ def tensor_indexer_broadcast_shape (
328+ self , tensors : typing .Sequence [torch .Tensor ]
329+ ) -> list [int | torch .SymInt ] | None :
330+ """Compute a shared broadcast shape for tensor indexers when needed."""
331+
332+ tensor_list = [t for t in tensors if isinstance (t , torch .Tensor )]
333+ if not tensor_list :
334+ return None
335+
336+ if all (self .get_tile_index_tensor_block_id (t ) is not None for t in tensor_list ):
337+ return None
338+
339+ shapes = [list (t .size ()) for t in tensor_list ]
340+ return compute_broadcast_shape_for_tensor_indexers (shapes , self )
341+
342+ def resolve_tile_index_shape (
343+ self , input_tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
344+ ) -> tuple [list [int | torch .SymInt ], int | None ]:
345+ """Resolve the symbolic shape for tensors derived from ``tile.index``.
346+
347+ Returns a copy of ``output_shape`` where the single non-broadcast
348+ dimension is replaced with the canonical block-symbol and the associated
349+ block_id to register on the new tensor. If the tensor is not a tile
350+ indexer or it introduces more than one non-broadcast dimension, the
351+ original shape and ``None`` are returned.
352+ """
353+
354+ block_id = self .get_tile_index_tensor_block_id (input_tensor )
355+ if block_id is None :
356+ return list (output_shape ), None
357+
358+ resolved = list (output_shape )
359+ non_broadcast = [i for i , s in enumerate (resolved ) if self .size_hint (s ) != 1 ]
360+ if len (non_broadcast ) <= 1 :
361+ if non_broadcast :
362+ resolved [non_broadcast [0 ]] = self .block_sizes [block_id ].var
363+ return resolved , block_id
364+ return resolved , None
365+
366+ def new_index_result (
367+ self , tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
368+ ) -> torch .Tensor :
369+ """Create a new tensor for indexing/view ops while preserving tile index provenance."""
370+
371+ resolved_shape , block_id = self .resolve_tile_index_shape (tensor , output_shape )
372+ result = tensor .new_empty (resolved_shape )
373+ if block_id is not None :
374+ self .register_tile_index_tensor_block_id (result , block_id )
375+ return result
376+
263377 def to_fake (self , obj : object , origin : Origin ) -> object :
264378 if obj is None :
265379 return None
@@ -342,6 +456,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
342456 self .fake_mode , tensor , shape_env = self .shape_env , source = source
343457 )
344458 self .input_sources [result ] = source
459+ if hasattr (tensor , "_tile_index_block_id" ):
460+ self .register_tile_index_tensor_block_id (
461+ result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
462+ )
345463 if isinstance (source , LocalSource ):
346464 for i , s in enumerate (result .size ()):
347465 if isinstance (s , torch .SymInt ) and isinstance (
@@ -634,6 +752,34 @@ def _has_unbacked(expr: sympy.Expr) -> bool:
634752 return any (n .name .startswith ("u" ) for n in expr .free_symbols ) # pyright: ignore[reportAttributeAccessIssue]
635753
636754
755+ def compute_broadcast_shape_for_tensor_indexers (
756+ shapes : list [list [int | torch .SymInt ]],
757+ env : "CompileEnvironment"
758+ ) -> list [int | torch .SymInt ]:
759+ """Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
760+
761+ For multiple 1D tensors, this should return a shape that represents their Cartesian product.
762+ For example, two tensors of shape [8] and [8] should broadcast to shape [8, 8].
763+ """
764+ if not shapes :
765+ return []
766+
767+ # Special case: multiple 1D tensors form a Cartesian product
768+ all_1d = all (len (shape ) == 1 for shape in shapes )
769+ if all_1d and len (shapes ) > 1 :
770+ # Return the Cartesian product shape
771+ return [shape [0 ] for shape in shapes ]
772+
773+ # General broadcasting case
774+ max_ndim = max (len (s ) for s in shapes )
775+ padded = [([1 ] * (max_ndim - len (s )) + s ) for s in shapes ]
776+
777+ return [
778+ next ((d for d in dims if env .size_hint (d ) != 1 ), 1 )
779+ for dims in zip (* padded , strict = True )
780+ ]
781+
782+
637783def format_shape (shape : tuple [object , ...]) -> str :
638784 def _format_dim (dim : object ) -> str :
639785 if isinstance (dim , torch .SymInt ):
0 commit comments