Skip to content

Commit a8538e1

Browse files
authored
Decompose AtenNormalFunctionalOp into AtenRandn* and other arithmetic. (#2737)
1 parent f85e5c9 commit a8538e1

File tree

6 files changed

+85
-2
lines changed

6 files changed

+85
-2
lines changed

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7655,6 +7655,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
76557655
" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
76567656
" return %arg0 : !torch.list<int>\n"
76577657
" }\n"
7658+
" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
7659+
" return %arg0 : !torch.list<int>\n"
7660+
" }\n"
76587661
" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
76597662
" %0 = torch.derefine %arg0 : !torch.float to !torch.union<float, int>\n"
76607663
" %1 = torch.derefine %arg1 : !torch.float to !torch.union<float, int>\n"
@@ -11557,6 +11560,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1155711560
" }\n"
1155811561
" return %1 : !torch.int\n"
1155911562
" }\n"
11563+
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
11564+
" %none = torch.constant.none\n"
11565+
" %str = torch.constant.str \"AssertionError: \"\n"
11566+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11567+
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
11568+
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
11569+
" torch.prim.If %2 -> () {\n"
11570+
" torch.prim.If.yield\n"
11571+
" } else {\n"
11572+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11573+
" torch.prim.If.yield\n"
11574+
" }\n"
11575+
" return %0#1 : !torch.int\n"
11576+
" }\n"
1156011577
" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
1156111578
" %str = torch.constant.str \"AssertionError: \"\n"
1156211579
" %int6 = torch.constant.int 6\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3669,9 +3669,38 @@ class DecomposeAtenExponentialOp : public OpRewritePattern<AtenExponentialOp> {
36693669
return success();
36703670
}
36713671
};
3672-
} // namespace
36733672

3674-
namespace {
3673+
// aten.normal_functional(mean, sigma) = randn() * sigma + mean.
3674+
class DecomposeAtenNormalFunctionalOp
3675+
: public OpRewritePattern<AtenNormalFunctionalOp> {
3676+
public:
3677+
using OpRewritePattern::OpRewritePattern;
3678+
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
3679+
PatternRewriter &rewriter) const override {
3680+
if (!op.getGenerator().getType().isa<Torch::NoneType>())
3681+
return rewriter.notifyMatchFailure(
3682+
op, "The generator has to be None because only global default "
3683+
"generator is supported");
3684+
3685+
Location loc = op.getLoc();
3686+
Type resultType = op.getType();
3687+
Value std = op.getStd();
3688+
Value mean = op.getMean();
3689+
3690+
Value none = rewriter.create<ConstantNoneOp>(loc);
3691+
Value one =
3692+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
3693+
Value randN = rewriter.create<AtenRandnLikeOp>(
3694+
loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none,
3695+
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
3696+
Value stdRandN =
3697+
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
3698+
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN,
3699+
mean, /*alpha=*/one);
3700+
return success();
3701+
}
3702+
};
3703+
36753704
template <typename OpTy, typename T1T2Op>
36763705
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
36773706
using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -6591,6 +6620,7 @@ class DecomposeComplexOpsPass
65916620
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
65926621
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
65936622
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
6623+
addPatternIfTargetOpIsIllegal<DecomposeAtenNormalFunctionalOp>(patterns);
65946624
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
65956625
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
65966626
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
494494
target.addIllegalOp<AtenRandnOp>();
495495
target.addIllegalOp<AtenRandnGeneratorOp>();
496496
target.addIllegalOp<AtenRandnLikeOp>();
497+
target.addIllegalOp<AtenNormalFunctionalOp>();
497498
target.addIllegalOp<AtenVarMeanOp>();
498499
target.addIllegalOp<AtenCosineSimilarityOp>();
499500
target.addIllegalOp<AtenNewEmptyStridedOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,6 +1484,7 @@
14841484
"VarMeanUnbiasedModule_basic",
14851485
"RandnLikeModule_basic",
14861486
"RandnLikeDtypeModule_basic",
1487+
"NormalFunctionalModule_basic",
14871488
"BernoulliFloatModule_basic",
14881489
"BernoulliModule_basic",
14891490
"BernoulliPModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,9 @@ def aten〇randn〡shape(size: List[int], dtype: Optional[int] = None, layout: O
902902
def aten〇randn〇generator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
903903
return size
904904

905+
def aten〇normal_functional〡shape(self: List[int], mean: float = 0., std: float = 1., generator: Any = None) -> List[int]:
906+
return self
907+
905908
def aten〇arange〇start_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
906909
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
907910

@@ -3822,6 +3825,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
38223825
assert not is_integer_dtype(dtype)
38233826
return dtype
38243827

3828+
@check_dtype_function(_check_tensors_with_the_same_dtype(
3829+
num_of_tensors=1,
3830+
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
3831+
def aten〇normal_functional〡dtype(self_rank_dtype: Tuple[int, int], mean: float = 0., std: float = 1., generator: Any = None) -> int:
3832+
self_rank, self_dtype = self_rank_dtype
3833+
if self_dtype is None:
3834+
return torch.float32
3835+
assert not is_integer_dtype(self_dtype)
3836+
return self_dtype
3837+
38253838
@check_dtype_function([Invocation(size=[1], generator=None),
38263839
Invocation(size=[1], generator=None, dtype=torch.float32),
38273840
ErrorInvocation(size=[1], generator=None, dtype=torch.int32),

projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,24 @@ def forward(self, x):
605605
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
606606
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
607607
module.forward(tu.rand(256, 1024).double())
608+
# ==============================================================================
609+
610+
class NormalFunctionalModule(torch.nn.Module):
611+
def __init__(self):
612+
super().__init__()
613+
614+
@export
615+
@annotate_args([
616+
None,
617+
([-1, -1], torch.float64, True),
618+
])
619+
def forward(self, x):
620+
a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0)
621+
mean = torch.mean(a)
622+
std = torch.std(a)
623+
return mean, std
624+
625+
626+
@register_test_case(module_factory=lambda: NormalFunctionalModule())
627+
def NormalFunctionalModule_basic(module, tu: TestUtils):
628+
module.forward(tu.rand(2048, 4096).double())

0 commit comments

Comments
 (0)