Skip to content

Commit 7cf52ae

Browse files
[Torch Dialect]Add Support for AtenGroupNormOp and AtenNativeGroupNormOp (#2591)
Co-authored-by: LiuYuanqiang <[email protected]>
1 parent 74f7a0c commit 7cf52ae

File tree

9 files changed

+298
-3
lines changed

9 files changed

+298
-3
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5640,6 +5640,34 @@ def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [
56405640
}];
56415641
}
56425642

5643+
def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [
5644+
AllowsTypeRefinement,
5645+
HasValueSemantics,
5646+
ReadOnly
5647+
]> {
5648+
let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`";
5649+
let arguments = (ins
5650+
AnyTorchTensorType:$input,
5651+
Torch_IntType:$num_groups,
5652+
AnyTorchOptionalTensorType:$weight,
5653+
AnyTorchOptionalTensorType:$bias,
5654+
Torch_FloatType:$eps,
5655+
Torch_BoolType:$cudnn_enabled
5656+
);
5657+
let results = (outs
5658+
AnyTorchTensorType:$result
5659+
);
5660+
let hasCustomAssemblyFormat = 1;
5661+
let extraClassDefinition = [{
5662+
ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) {
5663+
return parseDefaultTorchOp(parser, result, 6, 1);
5664+
}
5665+
void AtenGroupNormOp::print(OpAsmPrinter &printer) {
5666+
printDefaultTorchOp(printer, *this, 6, 1);
5667+
}
5668+
}];
5669+
}
5670+
56435671
def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
56445672
AllowsTypeRefinement,
56455673
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8074,6 +8074,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
80748074
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
80758075
" return %0 : !torch.list<int>\n"
80768076
" }\n"
8077+
" func.func @\"__torch_mlir_shape_fn.aten.group_norm\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.list<int> {\n"
8078+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
8079+
" return %0 : !torch.list<int>\n"
8080+
" }\n"
8081+
" func.func @\"__torch_mlir_shape_fn.aten.native_group_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
8082+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
8083+
" %1 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list<int>\n"
8084+
" %2 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list<int>\n"
8085+
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
8086+
" return %3 : !torch.tuple<list<int>, list<int>, list<int>>\n"
8087+
" }\n"
80778088
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
80788089
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
80798090
" return %0 : !torch.list<int>\n"
@@ -8748,6 +8759,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
87488759
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
87498760
" return %0#1 : !torch.int\n"
87508761
" }\n"
8762+
" func.func @\"__torch_mlir_dtype_fn.aten.group_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n"
8763+
" %none = torch.constant.none\n"
8764+
" %str = torch.constant.str \"AssertionError: \"\n"
8765+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
8766+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
8767+
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
8768+
" torch.prim.If %2 -> () {\n"
8769+
" torch.prim.If.yield\n"
8770+
" } else {\n"
8771+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8772+
" torch.prim.If.yield\n"
8773+
" }\n"
8774+
" return %0#1 : !torch.int\n"
8775+
" }\n"
8776+
" func.func @\"__torch_mlir_dtype_fn.aten.native_group_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple<int, int, int> {\n"
8777+
" %none = torch.constant.none\n"
8778+
" %str = torch.constant.str \"AssertionError: \"\n"
8779+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
8780+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
8781+
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
8782+
" torch.prim.If %2 -> () {\n"
8783+
" torch.prim.If.yield\n"
8784+
" } else {\n"
8785+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
8786+
" torch.prim.If.yield\n"
8787+
" }\n"
8788+
" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
8789+
" return %3 : !torch.tuple<int, int, int>\n"
8790+
" }\n"
87518791
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
87528792
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
87538793
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3753,6 +3753,165 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern<OpTy> {
37533753
};
37543754
} // namespace
37553755

3756+
namespace {
3757+
class DecomposeAtenGroupNormOp : public OpRewritePattern<AtenGroupNormOp> {
3758+
using OpRewritePattern<AtenGroupNormOp>::OpRewritePattern;
3759+
LogicalResult matchAndRewrite(AtenGroupNormOp op,
3760+
PatternRewriter &rewriter) const override {
3761+
Location loc = op.getLoc();
3762+
MLIRContext *context = op.getContext();
3763+
3764+
Value input = op.getInput();
3765+
Value weight = op.getWeight();
3766+
Value bias = op.getBias();
3767+
Value numGroups = op.getNumGroups();
3768+
Value eps = op.getEps();
3769+
3770+
Value cstZero =
3771+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
3772+
Value cstOne =
3773+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
3774+
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
3775+
3776+
Value N = rewriter.create<AtenSizeIntOp>(loc, input, cstZero);
3777+
Value C = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
3778+
Value numElements = rewriter.create<AtenNumelOp>(loc, input);
3779+
Value numElementsDivN =
3780+
rewriter.create<AtenFloordivIntOp>(loc, numElements, N);
3781+
Value HxW = rewriter.create<AtenFloordivIntOp>(loc, numElementsDivN, C);
3782+
3783+
AtenNativeGroupNormOp newOp = rewriter.create<AtenNativeGroupNormOp>(
3784+
loc, ArrayRef<Type>{op.getResult().getType(), baseType, baseType},
3785+
input, weight, bias, N, C, HxW, numGroups, eps);
3786+
3787+
rewriter.replaceOp(op, newOp.getResult0());
3788+
return success();
3789+
}
3790+
};
3791+
} // namespace
3792+
3793+
namespace {
3794+
class DecomposeAtenNativeGroupNormOp
3795+
: public OpRewritePattern<AtenNativeGroupNormOp> {
3796+
using OpRewritePattern<AtenNativeGroupNormOp>::OpRewritePattern;
3797+
LogicalResult matchAndRewrite(AtenNativeGroupNormOp op,
3798+
PatternRewriter &rewriter) const override {
3799+
Location loc = op.getLoc();
3800+
MLIRContext *context = op.getContext();
3801+
3802+
Value input = op.getInput();
3803+
Value weight = op.getWeight();
3804+
Value bias = op.getBias();
3805+
Value numGroups = op.getGroup();
3806+
Value eps = op.getEps();
3807+
3808+
// Check the rank of the input/outputs tensor.
3809+
auto inputType = input.getType().cast<BaseTensorType>();
3810+
auto outputType = op.getResult0().getType().cast<BaseTensorType>();
3811+
auto meanType = op.getResult1().getType().cast<BaseTensorType>();
3812+
auto rsqrtVarType = op.getResult2().getType().cast<BaseTensorType>();
3813+
if (!inputType.hasSizes() || !outputType.hasSizes() ||
3814+
!meanType.hasSizes() || !rsqrtVarType.hasSizes()) {
3815+
return rewriter.notifyMatchFailure(
3816+
op, "input/outputs tensor should have known sizes.");
3817+
}
3818+
3819+
Value none = rewriter.create<ConstantNoneOp>(loc);
3820+
Value cstZero =
3821+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
3822+
Value cstOne =
3823+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
3824+
Value cstNegtiveOne =
3825+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
3826+
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
3827+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
3828+
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
3829+
3830+
// GroupNorm requires the channel dimension (C) to be exactly divisible by
3831+
// the number of groups.
3832+
Value channel = rewriter.create<AtenSizeIntOp>(loc, input, cstOne);
3833+
Value remainder =
3834+
rewriter.create<AtenRemainderIntOp>(loc, channel, numGroups);
3835+
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, remainder, cstZero);
3836+
rewriter.create<RuntimeAssertOp>(
3837+
loc, eqOrNot,
3838+
rewriter.getStringAttr("the number of channels must be divisible by "
3839+
"the number of groups"));
3840+
3841+
// Reshape the input tensor to (N, numGroups, -1) to apply normalization.
3842+
SmallVector<Value> newShape;
3843+
newShape.push_back(rewriter.create<AtenSizeIntOp>(loc, input, cstZero));
3844+
newShape.push_back(numGroups);
3845+
newShape.push_back(cstNegtiveOne);
3846+
Value reshapedInput = rewriter.create<AtenViewOp>(
3847+
loc, baseType, input,
3848+
rewriter.create<PrimListConstructOp>(
3849+
loc, Torch::ListType::get(IntType::get(context)), newShape));
3850+
3851+
// Now we proceed with the normalization steps across the 'groupSize'
3852+
// Compute the mean and variance for each group
3853+
Value dimList = rewriter.create<PrimListConstructOp>(
3854+
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
3855+
ArrayRef<Value>{cstNegtiveOne});
3856+
auto mean = rewriter.create<AtenMeanDimOp>(
3857+
loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue,
3858+
/*dtype=*/none);
3859+
auto var = rewriter.create<AtenVarDimOp>(
3860+
loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse,
3861+
/*keepdim=*/cstTrue);
3862+
3863+
// Compute the normalized output: (input - mean) * rsqrt(var + eps)
3864+
auto varPlusEps = rewriter.create<AtenAddScalarOp>(loc, baseType, var, eps,
3865+
/*alpha=*/cstOne);
3866+
auto invStd = rewriter.create<AtenRsqrtOp>(loc, baseType, varPlusEps);
3867+
auto inputSubMean = rewriter.create<AtenSubTensorOp>(
3868+
loc, baseType, reshapedInput, mean, /*alpha=*/cstOne);
3869+
auto normalizedOutput =
3870+
rewriter.create<AtenMulTensorOp>(loc, baseType, inputSubMean, invStd);
3871+
3872+
// Reshape normalized output back to the original input shape
3873+
auto inputShape = rewriter.create<AtenSizeOp>(
3874+
loc, Torch::ListType::get(IntType::get(context)), input);
3875+
auto reshapedOutput = rewriter.create<AtenViewOp>(
3876+
loc, inputType, normalizedOutput, /*shape=*/inputShape);
3877+
3878+
// Apply weight and bias if they are not None
3879+
// Reshape weight and bias to C,1,1,...
3880+
SmallVector<Value> viewShape = {channel};
3881+
for (unsigned i = 2; i < inputType.getSizes().size(); i++) {
3882+
viewShape.push_back(cstOne);
3883+
}
3884+
Value viewShapeSizeList = rewriter.create<PrimListConstructOp>(
3885+
loc, ListType::get(IntType::get(context)), viewShape);
3886+
3887+
Value groupNormOutput = reshapedOutput;
3888+
if (!weight.getType().isa<Torch::NoneType>()) {
3889+
auto weightReshaped = rewriter.create<AtenViewOp>(
3890+
loc, baseType, weight, /*shape=*/viewShapeSizeList);
3891+
groupNormOutput = rewriter.create<AtenMulTensorOp>(
3892+
loc, inputType, groupNormOutput, weightReshaped);
3893+
}
3894+
if (!bias.getType().isa<Torch::NoneType>()) {
3895+
auto biasReshaped = rewriter.create<AtenViewOp>(
3896+
loc, baseType, bias, /*shape=*/viewShapeSizeList);
3897+
groupNormOutput = rewriter.create<AtenAddTensorOp>(
3898+
loc, inputType, groupNormOutput, biasReshaped,
3899+
/*alpha=*/cstOne);
3900+
}
3901+
3902+
Value squeezedMean =
3903+
rewriter.create<AtenSqueezeDimOp>(loc, meanType, mean, cstNegtiveOne);
3904+
Value squeezedRsqrtVar = rewriter.create<AtenSqueezeDimOp>(
3905+
loc, rsqrtVarType, invStd, cstNegtiveOne);
3906+
3907+
rewriter.replaceOp(
3908+
op, ArrayRef<Value>{groupNormOutput, squeezedMean, squeezedRsqrtVar});
3909+
3910+
return success();
3911+
}
3912+
};
3913+
} // namespace
3914+
37563915
namespace {
37573916
class DecomposeAtenNativeBatchNormOp
37583917
: public OpRewritePattern<AtenNativeBatchNormOp> {
@@ -6204,6 +6363,8 @@ class DecomposeComplexOpsPass
62046363
DecomposeAtenAddCLikeOp<AtenAddcdivOp, AtenDivTensorOp>>(patterns);
62056364
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
62066365
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
6366+
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
6367+
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeGroupNormOp>(patterns);
62076368
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
62086369
addPatternIfTargetOpIsIllegal<
62096370
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
407407
target.addIllegalOp<AtenAddcdivOp>();
408408
target.addIllegalOp<AtenLayerNormOp>();
409409
target.addIllegalOp<AtenNativeLayerNormOp>();
410+
target.addIllegalOp<AtenGroupNormOp>();
411+
target.addIllegalOp<AtenNativeGroupNormOp>();
410412
target.addIllegalOp<AtenNativeBatchNormOp>();
411413
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
412414
target.addIllegalOp<AtenConvolutionBackwardOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@
306306

307307
# ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4]))
308308
"ArangeStartOutViewModule_basic",
309+
310+
# ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float'
311+
"GroupNormModule_basic",
312+
"GroupNormNoWeightAndBiasModule_basic",
309313
}
310314

311315
TORCHDYNAMO_CRASHING_SET = {
@@ -586,6 +590,7 @@
586590
"NewFullModuleInt2DStatic_basic",
587591
"NewFullModuleInt2D_basic",
588592
"NewFullModuleInt3D_basic",
593+
"GroupNormModule_basic",
589594
"GatherStaticModule_basic",
590595
"GatherModule_basic",
591596
"Gather2DInputModdule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,12 @@ def aten〇convolution_backward〡shape(grad_output: List[int], input: List[int]
11301130
def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
11311131
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
11321132

1133+
def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> List[int]:
1134+
return upstream_shape_functions.unary(input)
1135+
1136+
def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]:
1137+
return upstream_shape_functions.unary(input), [N, group], [N, group]
1138+
11331139
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
11341140
return upstream_shape_functions.slice(self, dim, start, end, step)
11351141

@@ -1671,6 +1677,18 @@ def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dty
16711677
input_rank, input_dtype = input_rank_dtype
16721678
return input_dtype
16731679

1680+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], error_types={*all_integer_dtypes()}, num_groups=1))
1681+
def aten〇group_norm〡dtype(input_rank_dtype: Tuple[int, int], num_groups: int, weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> int:
1682+
input_rank, input_dtype = input_rank_dtype
1683+
assert not is_integer_dtype(input_dtype)
1684+
return input_dtype
1685+
1686+
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7), (3,), (3,)], error_types={*all_integer_dtypes()}, N=2, C=3, HxW=35, group=1, eps=0.000001))
1687+
def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[int, int, int]:
1688+
input_rank, input_dtype = input_rank_dtype
1689+
assert not is_integer_dtype(input_dtype)
1690+
return input_dtype, input_dtype, input_dtype
1691+
16741692
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
16751693
def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
16761694
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,9 @@ def emit_with_mutating_variants(key, **kwargs):
421421
emit(
422422
"aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)"
423423
)
424+
emit(
425+
'aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)'
426+
)
424427
emit(
425428
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
426429
)

projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch_mlir._version import torch_version_for_comparison, version
1111

1212
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
13-
"NativeGroupNormModule_basic",
1413
"NativeGroupNormBackwardModule_basic",
1514
"QuantizedMLP_basic",
1615
"ReduceMaxAlongDimUnsignedInt_basic",

0 commit comments

Comments
 (0)