Skip to content

Commit 864c0a1

Browse files
authored
[TorchToLinalg] Enable lowering of AtenMaxPool1dWithIndicesOp to linalg dialect (#4215)
This PR takes care of #4214. - Enable lowering of `AtenMaxPool1dWithIndicesOp` to linalg backend - Add missing shape & dtype inference in abstract library (fixes **Torchscript IR --> Torch IR** conversion) - Update xfail sets - Add support for **ceil_mode** in torch to stablehlo lowering of `AtenMaxPool1dWithIndicesOp` --------- Signed-off-by: Zahid Wakeel <[email protected]>
1 parent 5d374ba commit 864c0a1

File tree

6 files changed

+106
-1
lines changed

6 files changed

+106
-1
lines changed

lib/Conversion/TorchToLinalg/Pooling.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,10 @@ template <> struct DimensionTraits<AtenMaxPool1dOp> {
392392
static_assert(Dim == Dim);
393393
};
394394

395+
template <>
396+
struct DimensionTraits<AtenMaxPool1dWithIndicesOp>
397+
: DimensionTraits<AtenMaxPool1dOp> {};
398+
395399
template <> struct DimensionTraits<AtenMaxPool2dOp> {
396400
static constexpr int64_t Dim = 2;
397401
// unused const variable warning suppression:
@@ -417,7 +421,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
417421
using OpConversionPattern<OpTy>::OpConversionPattern;
418422

419423
static const bool withIndices =
420-
llvm::is_one_of<OpTy, AtenMaxPool2dWithIndicesOp,
424+
llvm::is_one_of<OpTy, AtenMaxPool1dWithIndicesOp,
425+
AtenMaxPool2dWithIndicesOp,
421426
AtenMaxPool3dWithIndicesOp>::value;
422427

423428
private:
@@ -1766,8 +1771,11 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
17661771
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dOp>>(typeConverter, context);
17671772
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dOp>>(typeConverter, context);
17681773

1774+
target.addIllegalOp<AtenMaxPool1dWithIndicesOp>();
17691775
target.addIllegalOp<AtenMaxPool2dWithIndicesOp>();
17701776
target.addIllegalOp<AtenMaxPool3dWithIndicesOp>();
1777+
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool1dWithIndicesOp>>(typeConverter,
1778+
context);
17711779
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool2dWithIndicesOp>>(typeConverter,
17721780
context);
17731781
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,

lib/Conversion/TorchToStablehlo/Pooling.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,29 @@ LogicalResult ConvertAtenOp<AtenMaxPool1dWithIndicesOp>::matchAndRewrite(
132132
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
133133
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
134134

135+
if (ceilMode) {
136+
// Match PyTorch output shape with extra padding. See
137+
// https://github.com/pytorch/pytorch/blob/c5de6ff079e3e5b453d6ff5190c90f02db458928/aten/src/ATen/native/Pool.h#L79
138+
const int64_t inputSize = inputShape[inputRank - 1];
139+
const int64_t numerator =
140+
(inputSize + 2 * padding[0] - dilation[0] * (kernelSize[0] - 1) - 1);
141+
const int64_t floor_output_size = (numerator) / stride[0] + 1;
142+
const int64_t adj = (stride[0] - 1);
143+
int64_t ceil_output_size = std::ceil((numerator + adj) / stride[0]) + 1;
144+
145+
// Ensure last pooling starts inside input
146+
if ((ceil_output_size - 1) * stride[0] >= inputSize + padding[0]) {
147+
ceil_output_size--;
148+
}
149+
150+
// Add extra padding to make output size same as torch
151+
if (ceil_output_size > floor_output_size) {
152+
const int64_t sizeDiff = ceil_output_size - floor_output_size;
153+
const int64_t extraPadding = sizeDiff * stride[0];
154+
stablehloPadding[stablehloPadding.size() - 1] += extraPadding;
155+
}
156+
}
157+
135158
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
136159

137160
auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize);

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8555,6 +8555,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
85558555
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
85568556
" return %0 : !torch.list<int>\n"
85578557
" }\n"
8558+
" func.func @\"__torch_mlir_shape_fn.aten.max_pool1d_with_indices\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
8559+
" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
8560+
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
8561+
" return %1 : !torch.tuple<list<int>, list<int>>\n"
8562+
" }\n"
85588563
" func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
85598564
" %0 = call @__torch__.adaptive_avg_pool1d(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
85608565
" return %0 : !torch.list<int>\n"
@@ -13182,6 +13187,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1318213187
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1318313188
" return %0#1 : !torch.int\n"
1318413189
" }\n"
13190+
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool1d_with_indices\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.tuple<int, int> {\n"
13191+
" %int4 = torch.constant.int 4\n"
13192+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
13193+
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
13194+
" return %1 : !torch.tuple<int, int>\n"
13195+
" }\n"
1318513196
" func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.int {\n"
1318613197
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1318713198
" return %0#1 : !torch.int\n"

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3044,6 +3044,8 @@
30443044
"LogCumsumExpModule_basic",
30453045
"LogCumsumExpStaticNegativeDimModule_basic",
30463046
"LogCumsumExpStaticFloat64DtypeModule_basic",
3047+
"MaxPool1dWithIndicesModule_basic",
3048+
"MaxPool1dWithIndicesCeilModeModule_basic",
30473049
"MaxPool1dCeilModeTrueModule_basic",
30483050
"MaxPool1dModule_basic",
30493051
"MaxPool2dCeilModeTrueModule_basic",
@@ -3806,6 +3808,8 @@
38063808
"LogCumsumExpStaticNegativeDimModule_basic",
38073809
"LogCumsumExpStaticFloat64DtypeModule_basic",
38083810
"MaskedScatterStaticBasic_basic",
3811+
"MaxPool1dWithIndicesModule_basic",
3812+
"MaxPool1dWithIndicesCeilModeModule_basic",
38093813
"MaxPool1dCeilModeTrueModule_basic",
38103814
"MaxPool1dModule_basic",
38113815
"MaxPool2dCeilModeTrueModule_basic",
@@ -4664,6 +4668,8 @@
46644668
"Matmul_4d",
46654669
"Matmul_matvec",
46664670
"Matmul_vecmat",
4671+
"MaxPool1dWithIndicesModule_basic",
4672+
"MaxPool1dWithIndicesCeilModeModule_basic",
46674673
"MaxPool1dCeilModeTrueModule_basic",
46684674
"MaxPool1dModule_basic",
46694675
"MaxPool2dCeilModeTrueModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,6 +1297,10 @@ def aten〇avg_pool1d〡shape(self: List[int], kernel_size: List[int], stride: L
12971297
def aten〇max_pool1d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> List[int]:
12981298
return pool1d(self, kernel_size, stride, padding, ceil_mode)
12991299

1300+
def aten〇max_pool1d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]:
1301+
maxpool1d = indices = pool1d(self, kernel_size, stride, padding, ceil_mode)
1302+
return maxpool1d, indices
1303+
13001304
def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) -> List[int]:
13011305
return adaptive_avg_pool1d(self, output_size)
13021306

@@ -3636,6 +3640,10 @@ def aten〇max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis
36363640
self_rank, self_dtype = self_rank_dtype
36373641
return self_dtype
36383642

3643+
def aten〇max_pool1d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), dilation: List[int] = (1,), ceil_mode: bool = False) -> Tuple[int, int]:
3644+
self_rank, self_dtype = self_rank_dtype
3645+
return self_dtype, torch.int64
3646+
36393647
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
36403648
def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> int:
36413649
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,55 @@ def AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic(module, tu: TestUtils):
180180
# ==============================================================================
181181

182182

183+
class MaxPool1dWithIndicesModule(torch.nn.Module):
184+
185+
def __init__(self):
186+
super().__init__()
187+
188+
@export
189+
@annotate_args(
190+
[
191+
None,
192+
([-1, -1, -1], torch.float32, True),
193+
]
194+
)
195+
def forward(self, x):
196+
return torch.ops.aten.max_pool1d_with_indices(
197+
x, kernel_size=[6], stride=[2], padding=[3], dilation=2, ceil_mode=False
198+
)
199+
200+
201+
@register_test_case(module_factory=lambda: MaxPool1dWithIndicesModule())
202+
def MaxPool1dWithIndicesModule_basic(module, tu: TestUtils):
203+
module.forward(tu.rand(1, 64, 112, low=-1))
204+
205+
206+
class MaxPool1dWithIndicesCeilModeModule(torch.nn.Module):
207+
208+
def __init__(self):
209+
super().__init__()
210+
211+
@export
212+
@annotate_args(
213+
[
214+
None,
215+
([-1, -1, -1], torch.float32, True),
216+
]
217+
)
218+
def forward(self, x):
219+
return torch.ops.aten.max_pool1d_with_indices(
220+
x, kernel_size=[4], stride=[2], padding=[2], dilation=2, ceil_mode=True
221+
)
222+
223+
224+
@register_test_case(module_factory=lambda: MaxPool1dWithIndicesCeilModeModule())
225+
def MaxPool1dWithIndicesCeilModeModule_basic(module, tu: TestUtils):
226+
module.forward(tu.rand(3, 25, 37, low=-1))
227+
228+
229+
# ==============================================================================
230+
231+
183232
class MaxPool1dModule(torch.nn.Module):
184233

185234
def __init__(self):

0 commit comments

Comments
 (0)