@@ -3753,6 +3753,165 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
3753
3753
};
3754
3754
} // namespace
3755
3755
3756
+ namespace {
3757
+ class DecomposeAtenGroupNormOp : public OpRewritePattern <AtenGroupNormOp> {
3758
+ using OpRewritePattern<AtenGroupNormOp>::OpRewritePattern;
3759
+ LogicalResult matchAndRewrite (AtenGroupNormOp op,
3760
+ PatternRewriter &rewriter) const override {
3761
+ Location loc = op.getLoc ();
3762
+ MLIRContext *context = op.getContext ();
3763
+
3764
+ Value input = op.getInput ();
3765
+ Value weight = op.getWeight ();
3766
+ Value bias = op.getBias ();
3767
+ Value numGroups = op.getNumGroups ();
3768
+ Value eps = op.getEps ();
3769
+
3770
+ Value cstZero =
3771
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 ));
3772
+ Value cstOne =
3773
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 ));
3774
+ auto baseType = ValueTensorType::getWithLeastStaticInformation (context);
3775
+
3776
+ Value N = rewriter.create <AtenSizeIntOp>(loc, input, cstZero);
3777
+ Value C = rewriter.create <AtenSizeIntOp>(loc, input, cstOne);
3778
+ Value numElements = rewriter.create <AtenNumelOp>(loc, input);
3779
+ Value numElementsDivN =
3780
+ rewriter.create <AtenFloordivIntOp>(loc, numElements, N);
3781
+ Value HxW = rewriter.create <AtenFloordivIntOp>(loc, numElementsDivN, C);
3782
+
3783
+ AtenNativeGroupNormOp newOp = rewriter.create <AtenNativeGroupNormOp>(
3784
+ loc, ArrayRef<Type>{op.getResult ().getType (), baseType, baseType},
3785
+ input, weight, bias, N, C, HxW, numGroups, eps);
3786
+
3787
+ rewriter.replaceOp (op, newOp.getResult0 ());
3788
+ return success ();
3789
+ }
3790
+ };
3791
+ } // namespace
3792
+
3793
+ namespace {
3794
+ class DecomposeAtenNativeGroupNormOp
3795
+ : public OpRewritePattern<AtenNativeGroupNormOp> {
3796
+ using OpRewritePattern<AtenNativeGroupNormOp>::OpRewritePattern;
3797
+ LogicalResult matchAndRewrite (AtenNativeGroupNormOp op,
3798
+ PatternRewriter &rewriter) const override {
3799
+ Location loc = op.getLoc ();
3800
+ MLIRContext *context = op.getContext ();
3801
+
3802
+ Value input = op.getInput ();
3803
+ Value weight = op.getWeight ();
3804
+ Value bias = op.getBias ();
3805
+ Value numGroups = op.getGroup ();
3806
+ Value eps = op.getEps ();
3807
+
3808
+ // Check the rank of the input/outputs tensor.
3809
+ auto inputType = input.getType ().cast <BaseTensorType>();
3810
+ auto outputType = op.getResult0 ().getType ().cast <BaseTensorType>();
3811
+ auto meanType = op.getResult1 ().getType ().cast <BaseTensorType>();
3812
+ auto rsqrtVarType = op.getResult2 ().getType ().cast <BaseTensorType>();
3813
+ if (!inputType.hasSizes () || !outputType.hasSizes () ||
3814
+ !meanType.hasSizes () || !rsqrtVarType.hasSizes ()) {
3815
+ return rewriter.notifyMatchFailure (
3816
+ op, " input/outputs tensor should have known sizes." );
3817
+ }
3818
+
3819
+ Value none = rewriter.create <ConstantNoneOp>(loc);
3820
+ Value cstZero =
3821
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (0 ));
3822
+ Value cstOne =
3823
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (1 ));
3824
+ Value cstNegtiveOne =
3825
+ rewriter.create <ConstantIntOp>(loc, rewriter.getI64IntegerAttr (-1 ));
3826
+ Value cstTrue = rewriter.create <Torch::ConstantBoolOp>(loc, true );
3827
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
3828
+ auto baseType = ValueTensorType::getWithLeastStaticInformation (context);
3829
+
3830
+ // GroupNorm requires the channel dimension (C) to be exactly divisible by
3831
+ // the number of groups.
3832
+ Value channel = rewriter.create <AtenSizeIntOp>(loc, input, cstOne);
3833
+ Value remainder =
3834
+ rewriter.create <AtenRemainderIntOp>(loc, channel, numGroups);
3835
+ Value eqOrNot = rewriter.create <AtenEqIntOp>(loc, remainder, cstZero);
3836
+ rewriter.create <RuntimeAssertOp>(
3837
+ loc, eqOrNot,
3838
+ rewriter.getStringAttr (" the number of channels must be divisible by "
3839
+ " the number of groups" ));
3840
+
3841
+ // Reshape the input tensor to (N, numGroups, -1) to apply normalization.
3842
+ SmallVector<Value> newShape;
3843
+ newShape.push_back (rewriter.create <AtenSizeIntOp>(loc, input, cstZero));
3844
+ newShape.push_back (numGroups);
3845
+ newShape.push_back (cstNegtiveOne);
3846
+ Value reshapedInput = rewriter.create <AtenViewOp>(
3847
+ loc, baseType, input,
3848
+ rewriter.create <PrimListConstructOp>(
3849
+ loc, Torch::ListType::get (IntType::get (context)), newShape));
3850
+
3851
+ // Now we proceed with the normalization steps across the 'groupSize'
3852
+ // Compute the mean and variance for each group
3853
+ Value dimList = rewriter.create <PrimListConstructOp>(
3854
+ loc, Torch::ListType::get (Torch::IntType::get (op.getContext ())),
3855
+ ArrayRef<Value>{cstNegtiveOne});
3856
+ auto mean = rewriter.create <AtenMeanDimOp>(
3857
+ loc, baseType, reshapedInput, /* dims=*/ dimList, /* keepdim=*/ cstTrue,
3858
+ /* dtype=*/ none);
3859
+ auto var = rewriter.create <AtenVarDimOp>(
3860
+ loc, baseType, reshapedInput, /* dims=*/ dimList, /* unbiased=*/ cstFalse,
3861
+ /* keepdim=*/ cstTrue);
3862
+
3863
+ // Compute the normalized output: (input - mean) * rsqrt(var + eps)
3864
+ auto varPlusEps = rewriter.create <AtenAddScalarOp>(loc, baseType, var, eps,
3865
+ /* alpha=*/ cstOne);
3866
+ auto invStd = rewriter.create <AtenRsqrtOp>(loc, baseType, varPlusEps);
3867
+ auto inputSubMean = rewriter.create <AtenSubTensorOp>(
3868
+ loc, baseType, reshapedInput, mean, /* alpha=*/ cstOne);
3869
+ auto normalizedOutput =
3870
+ rewriter.create <AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
3871
+
3872
+ // Reshape normalized output back to the original input shape
3873
+ auto inputShape = rewriter.create <AtenSizeOp>(
3874
+ loc, Torch::ListType::get (IntType::get (context)), input);
3875
+ auto reshapedOutput = rewriter.create <AtenViewOp>(
3876
+ loc, inputType, normalizedOutput, /* shape=*/ inputShape);
3877
+
3878
+ // Apply weight and bias if they are not None
3879
+ // Reshape weight and bias to C,1,1,...
3880
+ SmallVector<Value> viewShape = {channel};
3881
+ for (unsigned i = 2 ; i < inputType.getSizes ().size (); i++) {
3882
+ viewShape.push_back (cstOne);
3883
+ }
3884
+ Value viewShapeSizeList = rewriter.create <PrimListConstructOp>(
3885
+ loc, ListType::get (IntType::get (context)), viewShape);
3886
+
3887
+ Value groupNormOutput = reshapedOutput;
3888
+ if (!weight.getType ().isa <Torch::NoneType>()) {
3889
+ auto weightReshaped = rewriter.create <AtenViewOp>(
3890
+ loc, baseType, weight, /* shape=*/ viewShapeSizeList);
3891
+ groupNormOutput = rewriter.create <AtenMulTensorOp>(
3892
+ loc, inputType, groupNormOutput, weightReshaped);
3893
+ }
3894
+ if (!bias.getType ().isa <Torch::NoneType>()) {
3895
+ auto biasReshaped = rewriter.create <AtenViewOp>(
3896
+ loc, baseType, bias, /* shape=*/ viewShapeSizeList);
3897
+ groupNormOutput = rewriter.create <AtenAddTensorOp>(
3898
+ loc, inputType, groupNormOutput, biasReshaped,
3899
+ /* alpha=*/ cstOne);
3900
+ }
3901
+
3902
+ Value squeezedMean =
3903
+ rewriter.create <AtenSqueezeDimOp>(loc, meanType, mean, cstNegtiveOne);
3904
+ Value squeezedRsqrtVar = rewriter.create <AtenSqueezeDimOp>(
3905
+ loc, rsqrtVarType, invStd, cstNegtiveOne);
3906
+
3907
+ rewriter.replaceOp (
3908
+ op, ArrayRef<Value>{groupNormOutput, squeezedMean, squeezedRsqrtVar});
3909
+
3910
+ return success ();
3911
+ }
3912
+ };
3913
+ } // namespace
3914
+
3756
3915
namespace {
3757
3916
class DecomposeAtenNativeBatchNormOp
3758
3917
: public OpRewritePattern<AtenNativeBatchNormOp> {
@@ -6204,6 +6363,8 @@ class DecomposeComplexOpsPass
6204
6363
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
6205
6364
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
6206
6365
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
6366
+ addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
6367
+ addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
6207
6368
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
6208
6369
addPatternIfTargetOpIsIllegal<
6209
6370
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
0 commit comments