@@ -545,8 +545,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
545
545
/* dtype=*/ noneVal);
546
546
return success ();
547
547
});
548
+ // onnx.ReduceMean with axes provided as argument introduced in opset 18
548
549
patterns.onOp (
549
- " ReduceMean" , 13 ,
550
+ " ReduceMean" , 18 ,
550
551
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
551
552
Torch::ValueTensorType resultType;
552
553
Value data;
@@ -632,6 +633,82 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
632
633
/* dtype=*/ noneVal);
633
634
return success ();
634
635
});
636
+
637
+ // onnx.ReduceMean with axes provided as attribute
638
+ patterns.onOp (
639
+ " ReduceMean" , 1 ,
640
+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
641
+ Torch::ValueTensorType resultType;
642
+ Value data;
643
+ llvm::SmallVector<int64_t > axes;
644
+ int64_t keepDims;
645
+ int64_t noop_with_empty_axes;
646
+ if (binder.tensorOperand (data) ||
647
+ binder.tensorResultType (resultType) ||
648
+ binder.s64IntegerArrayAttr (axes, " axes" , 0 ) ||
649
+ binder.s64IntegerAttr (keepDims, " keepdims" , 1 ) ||
650
+ binder.s64IntegerAttr (noop_with_empty_axes, " noop_with_empty_axes" ,
651
+ 0 ))
652
+ return failure ();
653
+ SmallVector<Value> dimList;
654
+ SmallVector<int64_t > selectSizes;
655
+ selectSizes.push_back (1 );
656
+ Value noneVal = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
657
+ // deal with case when axes is empty
658
+ if (axes.size () == 0 ) {
659
+ if (noop_with_empty_axes == 0 ) {
660
+ Value keepDimsConstInt = rewriter.create <Torch::ConstantIntOp>(
661
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
662
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), keepDims));
663
+ Value keepDimsBool = rewriter.create <Torch::AtenBoolIntOp>(
664
+ binder.getLoc (), keepDimsConstInt);
665
+ rewriter.replaceOpWithNewOp <Torch::AtenMeanDimOp>(
666
+ binder.op , resultType, data, /* dim=*/ noneVal, keepDimsBool,
667
+ /* dtype=*/ noneVal);
668
+ } else {
669
+ rewriter.replaceOp (binder.op , data);
670
+ }
671
+ return success ();
672
+ }
673
+ Value zero = rewriter.create <Torch::ConstantIntOp>(
674
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
675
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
676
+ int64_t adjustmentInt =
677
+ cast<Torch::ValueTensorType>(data.getType ()).getSizes ().size ();
678
+ Value adjustment = rewriter.create <Torch::ConstantIntOp>(
679
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
680
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
681
+ adjustmentInt));
682
+ // convert axes (tensor) into torch int list while dealing with neg axis
683
+ for (int i = 0 ; i < axes.size (); i++) {
684
+ // Go through the axes list and get each dim in the list
685
+ int64_t dim = axes[i];
686
+ if (dim < 0 ) {
687
+ dim += adjustmentInt;
688
+ }
689
+ // deal with neg axis: if (axis < 0) axis += rank
690
+ Value finalDim = rewriter.create <Torch::ConstantIntOp>(
691
+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
692
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), dim));
693
+ dimList.push_back (finalDim);
694
+ }
695
+ Value dimValueList = rewriter.create <Torch::PrimListConstructOp>(
696
+ binder.getLoc (),
697
+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
698
+ dimList);
699
+ Value keepDimBool;
700
+ if (keepDims == 1 ) {
701
+ keepDimBool =
702
+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), true );
703
+ } else {
704
+ keepDimBool =
705
+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), false );
706
+ }
707
+ rewriter.replaceOpWithNewOp <Torch::AtenMeanDimOp>(
708
+ binder.op , resultType, data, dimValueList, keepDimBool,
709
+ /* dtype=*/ noneVal);
710
+ return success ();
711
+ });
635
712
patterns.onOp (
636
713
" ReduceMin" , 13 ,
637
714
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
0 commit comments