@@ -142,6 +142,30 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
142142 if rdim .reduction and rdim .size == size :
143143 return rdim
144144
145+ # Check if size matches any tile dimension for symbolic equality.
146+ # When building expressions that mix sizes derived from tiles
147+ # (e.g., via slicing) with sizes coming directly from tile block vars, we
148+ # want them to share the same SymInt variable whenever they are equal by
149+ # construction. This preserves equality in the shape environment and avoids
150+ # spurious "size mismatch" issues during fake-tensor broadcasting and
151+ # arithmetic in type propagation.
152+ if isinstance (size , torch .SymInt ):
153+ size_str = str (size )
154+ for block_info in self .block_sizes :
155+ if not block_info .reduction and str (block_info .var ) == size_str :
156+ # Create reduction dimension with the same var to preserve
157+ # symbolic equality and ensure all later users see identical
158+ # symbols (rather than equal-but-distinct SymInts).
159+ rdim_idx = self .allocate_block_size (
160+ size ,
161+ reduction = True ,
162+ source = ReductionLoopBlockSizeSource (
163+ reduction_loop = len ([b for b in self .block_sizes if b .reduction ])
164+ ),
165+ )
166+ self .block_sizes [rdim_idx ].var = block_info .var
167+ return self .block_sizes [rdim_idx ]
168+
145169 # Allocate a new reduction dimension
146170 rdim_idx = self .allocate_block_size (
147171 size ,
@@ -203,6 +227,91 @@ def cached_create_unbacked_symint(
203227 self ._symint_cache [key ] = result
204228 return result
205229
230+
231+ def register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) -> None :
232+ """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
233+ tensor ._tile_index_block_id = block_id # type: ignore[attr-defined]
234+
235+ def get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) -> int | None :
236+ """Return the originating ``tile.index`` block id if present."""
237+ return getattr (tensor , "_tile_index_block_id" , None )
238+
239+ def get_indexer_output_dims (
240+ self ,
241+ indexer_tensor : torch .Tensor ,
242+ base_dim_size : int | torch .SymInt | None ,
243+ ) -> list [int | torch .SymInt ]:
244+ """Map a tensor indexer's shape to the output dimensions for advanced indexing."""
245+
246+ dims = list (indexer_tensor .size ())
247+ non_broadcast_dims = [d for d in dims if self .size_hint (d ) != 1 ]
248+
249+ # Multi-dimensional indexer - return full shape
250+ if len (non_broadcast_dims ) > 1 :
251+ return dims
252+
253+ block_id = self .get_tile_index_tensor_block_id (indexer_tensor )
254+ if block_id is None and base_dim_size is not None :
255+ block_id = self .get_block_id (base_dim_size )
256+ if block_id is None and non_broadcast_dims :
257+ block_id = self .get_block_id (non_broadcast_dims [0 ])
258+
259+ if block_id is not None :
260+ return [self .block_sizes [block_id ].var ]
261+ if non_broadcast_dims :
262+ return [non_broadcast_dims [0 ]]
263+ return [1 ]
264+
265+ def tensor_indexer_broadcast_shape (
266+ self , tensors : typing .Sequence [torch .Tensor ]
267+ ) -> list [int | torch .SymInt ] | None :
268+ """Compute a shared broadcast shape for tensor indexers when needed."""
269+
270+ tensor_list = [t for t in tensors if isinstance (t , torch .Tensor )]
271+ if not tensor_list :
272+ return None
273+
274+ if all (self .get_tile_index_tensor_block_id (t ) is not None for t in tensor_list ):
275+ return None
276+
277+ shapes = [list (t .size ()) for t in tensor_list ]
278+ return compute_broadcast_shape_for_tensor_indexers (shapes , self )
279+
280+ def resolve_tile_index_shape (
281+ self , input_tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
282+ ) -> tuple [list [int | torch .SymInt ], int | None ]:
283+ """Resolve the symbolic shape for tensors derived from ``tile.index``.
284+
285+ Returns a copy of ``output_shape`` where the single non-broadcast
286+ dimension is replaced with the canonical block-symbol and the associated
287+ block_id to register on the new tensor. If the tensor is not a tile
288+ indexer or it introduces more than one non-broadcast dimension, the
289+ original shape and ``None`` are returned.
290+ """
291+
292+ block_id = self .get_tile_index_tensor_block_id (input_tensor )
293+ if block_id is None :
294+ return list (output_shape ), None
295+
296+ resolved = list (output_shape )
297+ non_broadcast = [i for i , s in enumerate (resolved ) if self .size_hint (s ) != 1 ]
298+ if len (non_broadcast ) <= 1 :
299+ if non_broadcast :
300+ resolved [non_broadcast [0 ]] = self .block_sizes [block_id ].var
301+ return resolved , block_id
302+ return resolved , None
303+
304+ def new_index_result (
305+ self , tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
306+ ) -> torch .Tensor :
307+ """Create a new tensor for indexing/view ops while preserving tile index provenance."""
308+
309+ resolved_shape , block_id = self .resolve_tile_index_shape (tensor , output_shape )
310+ result = tensor .new_empty (resolved_shape )
311+ if block_id is not None :
312+ self .register_tile_index_tensor_block_id (result , block_id )
313+ return result
314+
206315 def to_fake (self , obj : object , origin : Origin ) -> object :
207316 if isinstance (obj , torch .Tensor ):
208317 return self ._to_fake_tensor (obj , origin .to_source ())
@@ -283,6 +392,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
283392 self .fake_mode , tensor , shape_env = self .shape_env , source = source
284393 )
285394 self .input_sources [result ] = source
395+ if hasattr (tensor , "_tile_index_block_id" ):
396+ self .register_tile_index_tensor_block_id (
397+ result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
398+ )
286399 if isinstance (source , LocalSource ):
287400 for i , s in enumerate (result .size ()):
288401 if isinstance (s , torch .SymInt ) and isinstance (
@@ -357,9 +470,9 @@ def current() -> CompileEnvironment:
357470 @staticmethod
358471 def has_current () -> bool :
359472 try :
360- CompileEnvironment .current ()
361- return True
362- except NoCurrentEnvironment :
473+ CompileEnvironment .current ()
474+ return True
475+ except NoCurrentEnvironment :
363476 return False
364477
365478 def get_block_id (self , size : int | torch .SymInt | sympy .Expr ) -> int | None :
@@ -535,3 +648,35 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
535648
536649def _has_unbacked (expr : sympy .Expr ) -> bool :
537650 return any (n .name .startswith ("u" ) for n in expr .free_symbols ) # pyright: ignore[reportAttributeAccessIssue]
651+
652+
653+ def compute_broadcast_shape_for_tensor_indexers (
654+ shapes : list [list [int | torch .SymInt ]],
655+ env : "CompileEnvironment"
656+ ) -> list [int | torch .SymInt ]:
657+ """
658+ Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
659+
660+ Args:
661+ shapes: List of shapes from each tensor indexer
662+ env: CompileEnvironment for size_hint and known_equal checks
663+
664+ Returns:
665+ Broadcast shape as list of dimensions
666+ """
667+ if not shapes :
668+ return []
669+
670+ max_ndim = max (len (s ) for s in shapes )
671+ padded = [([1 ] * (max_ndim - len (s )) + s ) for s in shapes ]
672+ broadcast_shape : list [int | torch .SymInt ] = []
673+
674+ for dims_at_pos in zip (* padded , strict = True ):
675+ chosen : int | torch .SymInt | None = None
676+ for d in dims_at_pos :
677+ if env .size_hint (d ) != 1 :
678+ if chosen is None or env .known_equal (chosen , d ):
679+ chosen = d
680+ broadcast_shape .append (chosen if chosen is not None else 1 )
681+
682+ return broadcast_shape
0 commit comments