Skip to content

Commit 6660a26

Browse files
renxidarsuderman
andauthored
lower torch.aten.isinf to linalg (#2638)
Co-authored-by: Rob Suderman <[email protected]>
1 parent 9fc212e commit 6660a26

File tree

4 files changed

+41
-3
lines changed

4 files changed

+41
-3
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
426426
}
427427
if (isa<AtenAbsOp>(op))
428428
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
429+
if (isa<AtenIsinfOp>(op)){
430+
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
431+
Value infinity = b.create<arith::ConstantOp>(
432+
loc, b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
433+
return createEqual(b, loc, abs.getType(), abs, infinity);
434+
}
429435
if (isa<AtenSigmoidOp>(op)) {
430436
auto negate = createCalculationForMathOpWithDtypeConversion<arith::NegFOp>(
431437
b, converter, payloadArgs[0], op);
@@ -1343,7 +1349,7 @@ class ConvertElementwiseOp : public ConversionPattern {
13431349
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
13441350
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
13451351
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
1346-
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
1352+
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp,
13471353
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
13481354
AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op))
13491355
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
@@ -1992,7 +1998,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
19921998
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
19931999
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
19942000
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
1995-
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
2001+
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
2002+
AtenTrilOp,
19962003
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
19972004
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
19982005
patterns.add<ConvertElementwiseOp>(typeConverter, context);

projects/ltc/csrc/base_lazy_backend/shape_inference.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
3939
return {Shape(self.scalar_type(), self.sizes().vec())};
4040
}
4141

42+
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
43+
return {Shape(at::kBool, self.sizes().vec())};
44+
}
45+
4246
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
4347
const at::Tensor& self, at::IntArrayRef kernel_size,
4448
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,7 @@
10331033
"ElementwiseAddScalarIntModule_basic",
10341034
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
10351035
"ElementwiseAtenDivIntScalarModule_basic",
1036+
"ElementwiseAtenIsinfOpModule_basic",
10361037
"ElementwiseAtenWhereSelfModule_basic",
10371038
"ElementwiseBinaryModule_basic",
10381039
"ElementwiseBinaryStaticShapeModule_basic",
@@ -1328,6 +1329,8 @@
13281329
"SliceWholeTensorModule_basic",
13291330
"TensorFloatModule_basic",
13301331
"TensorIntModule_basic",
1332+
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
1333+
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
13311334
}) - {
13321335
### Test failing in make_fx_tosa but not in tosa
13331336

@@ -1489,5 +1492,4 @@
14891492
"ElementwiseBitwiseAndScalarInt64Module_basic",
14901493
"ElementwiseBitwiseAndScalarInt32Module_basic",
14911494
"ElementwiseBitwiseAndScalarInt8Module_basic",
1492-
"ElementwiseIsinfModule_basic",
14931495
}

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3385,6 +3385,31 @@ def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils):
33853385
module.forward(tu.randint(4, 5, high=2).bool())
33863386

33873387

3388+
# ==============================================================================
3389+
3390+
class ElementwiseAtenIsinfOpModule(torch.nn.Module):
3391+
def __init__(self):
3392+
super().__init__()
3393+
3394+
@export
3395+
@annotate_args([
3396+
None,
3397+
([2, 5], torch.float32, True),
3398+
])
3399+
def forward(self, x):
3400+
return torch.ops.aten.isinf(x)
3401+
3402+
@register_test_case(module_factory=lambda: ElementwiseAtenIsinfOpModule())
3403+
def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils):
3404+
test_input = torch.tensor(
3405+
[
3406+
[1, float('inf'), 2, float('-inf'), float('nan')],
3407+
[1, float('inf'), float('-inf'), float('nan'), 3],
3408+
]
3409+
)
3410+
module.forward(test_input)
3411+
3412+
33883413
# ==============================================================================
33893414

33903415

0 commit comments

Comments
 (0)