@@ -257,6 +257,7 @@ def index(
257257 )
258258 else :
259259 dim_tensor_shape_mult_d1 = transpose_tensor_shape [i ]
260+
260261 mult_d1 = convert_binary_elementwise (
261262 ctx ,
262263 target ,
@@ -548,6 +549,9 @@ def index_put_converter(
548549 accumulate : bool = False ,
549550) -> TRTTensor :
550551 # Convert 'input_indices' to TRT tensors (or keep None as is)
552+ input_indices = expand_boolean_indices (
553+ ctx , target , source_ir , name , input_tensor , input_indices
554+ )
551555 indices : List [Optional [Union [TRTTensor , None ]]] = []
552556 for i , idx in enumerate (input_indices ):
553557 if idx is None :
@@ -571,22 +575,40 @@ def index_put_converter(
571575 K = len (I )
572576 # Determine the maximum size 'N' among the index tensors
573577 if K > 0 :
574- index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
578+ index_shapes = (
579+ []
580+ ) # [tensor.shape[0] for tensor in indices if tensor is not None]
581+ for idx_tensor in indices :
582+ if idx_tensor is not None :
583+ if idx_tensor .shape [0 ] != DYNAMIC_DIM :
584+ index_shapes .append (idx_tensor .shape [0 ])
585+ else :
586+ index_shapes .append (
587+ get_shape (
588+ ctx ,
589+ target ,
590+ source_ir ,
591+ name + "idx_shape_dim_0" ,
592+ idx_tensor ,
593+ 0 ,
594+ )
595+ )
575596 N = max (index_shapes ) if index_shapes else 1
576597 else :
577598 N = 1
578599
579600 # Compute shapes and volume for the free dimensions
580601 F_shapes = [input_tensor .shape [i ] for i in F ]
602+ assert - 1 not in F_shapes , "Dynamic shape in free dimensions is not supported"
581603 F_volume = trt .volume (F_shapes ) if F_shapes else 1
582604
583605 # Process indexed dimensions (I)
584606 I_tensors = []
585607 for i in I :
586608 idx = indices [i ]
587609 assert idx is not None
588- idx_reshaped = impl .shuffle . reshape (
589- ctx , target , source_ir , f"{ name } _reshape_idx_I_ { i } " , idx , ( idx . shape [ 0 ], 1 )
610+ idx_reshaped = impl .unsqueeze . unsqueeze (
611+ ctx , target , source_ir , f"{ name } _unsqueeze_idx_I_ { i } " , idx , 1
590612 )
591613 expanded_idx = impl .slice .expand (
592614 ctx ,
@@ -608,46 +630,50 @@ def index_put_converter(
608630 )
609631 arange_tensors .append (arange_tensor )
610632
611- meshgrid_tensors = []
612- for i , arange in enumerate (arange_tensors ):
613- reshape_shape = [1 ] * len (F )
614- reshape_shape [i ] = F_shapes [i ]
615- arange_reshaped = impl .shuffle .reshape (
616- ctx ,
617- target ,
618- source_ir ,
619- f"{ name } _reshape_arange_F_{ F [i ]} " ,
620- arange ,
621- tuple (reshape_shape ),
622- )
623- expanded_arange = impl .slice .expand (
624- ctx ,
625- target ,
626- source_ir ,
627- f"{ name } _expand_arange_F_{ F [i ]} " ,
628- arange_reshaped ,
629- tuple (F_shapes ),
630- )
631- meshgrid_tensors .append (expanded_arange )
632-
633- meshgrid_stacked = impl .cat .cat (
634- ctx ,
635- target ,
636- source_ir ,
637- f"{ name } _stack_meshgrid" ,
638- [
639- impl .shuffle .reshape (
633+ if len (arange_tensors ) == 1 :
634+ # No need to stack
635+ meshgrid_stacked = arange_tensors [0 ]
636+ else :
637+ meshgrid_tensors = []
638+ for i , arange in enumerate (arange_tensors ):
639+ reshape_shape = [1 ] * len (F )
640+ reshape_shape [i ] = F_shapes [i ]
641+ arange_reshaped = impl .shuffle .reshape (
640642 ctx ,
641643 target ,
642644 source_ir ,
643- f"{ name } _reshape_mesh_ { i } " ,
644- t ,
645- ( * F_shapes , 1 ),
645+ f"{ name } _reshape_arange_F_ { F [ i ] } " ,
646+ arange ,
647+ tuple ( reshape_shape ),
646648 )
647- for i , t in enumerate (meshgrid_tensors )
648- ],
649- dim = - 1 ,
650- )
649+ expanded_arange = impl .slice .expand (
650+ ctx ,
651+ target ,
652+ source_ir ,
653+ f"{ name } _expand_arange_F_{ F [i ]} " ,
654+ arange_reshaped ,
655+ tuple (F_shapes ),
656+ )
657+ meshgrid_tensors .append (expanded_arange )
658+
659+ meshgrid_stacked = impl .cat .cat (
660+ ctx ,
661+ target ,
662+ source_ir ,
663+ f"{ name } _stack_meshgrid" ,
664+ [
665+ impl .shuffle .reshape (
666+ ctx ,
667+ target ,
668+ source_ir ,
669+ f"{ name } _reshape_mesh_{ i } " ,
670+ t ,
671+ (* F_shapes , 1 ),
672+ )
673+ for i , t in enumerate (meshgrid_tensors )
674+ ],
675+ dim = - 1 ,
676+ )
651677 meshgrid_reshaped = impl .shuffle .reshape (
652678 ctx ,
653679 target ,
@@ -672,21 +698,15 @@ def index_put_converter(
672698
673699 # Combine all indexed dimensions (I)
674700 if K > 0 :
675- I_combined = impl .cat .cat (
676- ctx ,
677- target ,
678- source_ir ,
679- f"{ name } _cat_I" ,
680- [
681- impl .shuffle .reshape (
682- ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
683- )
684- for i , t in enumerate (I_tensors )
685- ],
686- dim = 2 ,
687- )
701+
702+ I_combined = [
703+ impl .shuffle .reshape (
704+ ctx , target , source_ir , f"{ name } _reshape_I_{ i } " , t , (N , F_volume , 1 )
705+ )
706+ for i , t in enumerate (I_tensors )
707+ ]
688708 else :
689- I_combined = None
709+ I_combined = []
690710
691711 # Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
692712 ii_list = []
@@ -695,24 +715,12 @@ def index_put_converter(
695715 for dim in range (rank ):
696716 unique_suffix = f"{ dim } _{ i_idx if dim in I else f_idx } "
697717 if dim in I :
698- start = [0 , 0 , i_idx ]
699- shape = [N , F_volume , 1 ]
700- stride = [1 , 1 , 1 ]
701- idx_tensor = impl .slice .slice (
702- ctx ,
703- target ,
704- source_ir ,
705- f"{ name } _slice_I_dim_{ unique_suffix } " ,
706- I_combined ,
707- start ,
708- shape ,
709- stride ,
710- )
718+ idx_tensor = I_combined [i_idx ]
711719 ii_list .append (idx_tensor )
712720 i_idx += 1
713721 else :
714722 start = [0 , 0 , f_idx ]
715- shape = [N , F_volume , 1 ]
723+ shape = [- 1 , F_volume , 1 ] if isinstance ( N , TRTTensor ) else [ N , F_volume , 1 ]
716724 stride = [1 , 1 , 1 ]
717725 mesh_tensor = impl .slice .slice (
718726 ctx ,
@@ -731,20 +739,24 @@ def index_put_converter(
731739 indices_cat = impl .cat .cat (
732740 ctx , target , source_ir , f"{ name } _cat_indices" , ii_list , dim = 2
733741 )
742+
743+ # Flatten the indices_cat to (N * F_volume, rank)
734744 indices_cat = impl .shuffle .reshape (
735745 ctx ,
736746 target ,
737747 source_ir ,
738748 f"{ name } _reshape_indices_cat" ,
739749 indices_cat ,
740- (N * F_volume , rank ),
750+ (- 1 , rank ),
741751 )
742752
743753 if not isinstance (values , TRTTensor ):
744754 values = get_trt_tensor (ctx , values , f"{ name } _values" , min_rank = 0 )
745755
746756 # Define the expected shape based on (N,) + F_shapes
747- expected_shape = (N ,) + tuple (F_shapes )
757+ expected_shape = (
758+ (- 1 ,) + tuple (F_shapes ) if isinstance (N , TRTTensor ) else (N ,) + tuple (F_shapes )
759+ )
748760
749761 # Broadcast 'values' to match the expected shape
750762 if len (values .shape ) == 0 or values .shape == (1 ,): # Scalar case
@@ -761,7 +773,12 @@ def index_put_converter(
761773 )
762774 else : # Non-scalar case
763775 values_shape = list (values .shape )
764- if K > 0 and N in values_shape :
776+ if (
777+ K > 0
778+ and N in values_shape
779+ and (len (F ) > 1 and max (F ) - min (F ) + 1 == len (F ))
780+ ):
781+ # Continuous case
765782 n_idx = values_shape .index (N )
766783 permute_order = [n_idx ] + [
767784 i for i in range (len (values_shape )) if i != n_idx
@@ -807,31 +824,27 @@ def index_put_converter(
807824 tuple (broadcast_shape ),
808825 )
809826 else :
827+ # Discontinuous case
810828 values_shape_padded = [1 ] * (
811829 len (expected_shape ) - len (values .shape )
812830 ) + list (values .shape )
813831 broadcast_shape = []
814832 for exp_dim , val_dim in zip (expected_shape , values_shape_padded ):
815- if val_dim == 1 or exp_dim == val_dim :
833+ if val_dim == DYNAMIC_DIM or exp_dim == DYNAMIC_DIM :
834+ broadcast_shape .append (- 1 )
835+ elif val_dim == 1 or exp_dim == val_dim :
816836 broadcast_shape .append (exp_dim )
817837 else :
818838 raise ValueError (
819839 f"Cannot broadcast { values .shape } to { expected_shape } "
820840 )
821- values_reshaped = impl .shuffle .reshape (
822- ctx ,
823- target ,
824- source_ir ,
825- f"{ name } _reshape_values" ,
826- values ,
827- tuple (broadcast_shape ),
828- )
841+
829842 values_expanded = impl .slice .expand (
830843 ctx ,
831844 target ,
832845 source_ir ,
833846 f"{ name } _expand_values" ,
834- values_reshaped ,
847+ values ,
835848 expected_shape ,
836849 )
837850
@@ -842,16 +855,51 @@ def index_put_converter(
842855 source_ir ,
843856 f"{ name } _flatten_values" ,
844857 values_expanded ,
845- (N * F_volume ,),
858+ (- 1 ,),
846859 )
847-
848860 indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
849- # Perform Scatter ND operation
850- scatter_layer = ctx .net .add_scatter (
851- input_tensor ,
852- indices_cat ,
853- flattened_values ,
854- trt .ScatterMode .ND if not accumulate else trt .ScatterMode .ND_ELEMENTWISE_ADD ,
855- )
856- set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
857- return scatter_layer .get_output (0 )
861+ if accumulate :
862+ zero_tensor = impl .full .full (
863+ ctx ,
864+ target ,
865+ source_ir ,
866+ f"{ name } _zero_tensor" ,
867+ [
868+ get_shape (
869+ ctx ,
870+ target ,
871+ source_ir ,
872+ name + f"input_tensor_shape_dim_{ i } " ,
873+ input_tensor ,
874+ i ,
875+ )
876+ for i in range (len (input_tensor .shape ))
877+ ],
878+ 0.0 ,
879+ dtype = input_tensor .dtype ,
880+ )
881+ # Perform Scatter ND operation
882+ scatter_layer = ctx .net .add_scatter (
883+ zero_tensor ,
884+ indices_cat ,
885+ flattened_values ,
886+ trt .ScatterMode .ND ,
887+ )
888+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
889+
890+ scatter_out = scatter_layer .get_output (0 )
891+ result = impl .elementwise .add (
892+ ctx , target , source_ir , f"{ name } _add" , scatter_out , input_tensor
893+ )
894+ return result
895+
896+ else :
897+ scatter_layer = ctx .net .add_scatter (
898+ input_tensor ,
899+ indices_cat ,
900+ flattened_values ,
901+ trt .ScatterMode .ND ,
902+ )
903+ set_layer_name (scatter_layer , target , f"{ name } _scatter" , source_ir )
904+ scatter_out = scatter_layer .get_output (0 )
905+ return scatter_out
0 commit comments