Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoBump] Merge with fixes of c1892de6 (Dec 05) (128) #518

Merged
merged 15 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6684,6 +6684,35 @@ def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [
}];
}

def Torch_AtenConv3dPaddingOp : Torch_Op<"aten.conv3d.padding", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
Torch_StringType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_IntType:$groups
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConv3dPaddingOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenConv3dPaddingOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -6713,6 +6742,35 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
}];
}

def Torch_AtenConv2dPaddingOp : Torch_Op<"aten.conv2d.padding", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv2d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
Torch_StringType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_IntType:$groups
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConv2dPaddingOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenConv2dPaddingOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -6742,6 +6800,35 @@ def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [
}];
}

def Torch_AtenConv1dPaddingOp : Torch_Op<"aten.conv1d.padding", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::conv1d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
Torch_StringType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_IntType:$groups
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConv1dPaddingOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenConv1dPaddingOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
strideInts.clear();
strideInts.append(numSpatialDims, 1);
} else {
if ((int64_t)paddingIntValues.size() + 2 !=
cast<RankedTensorType>(input.getType()).getRank()) {
// pytorch 2.5 generates one element padding = {0} for
// Conv2dWithValidPaddingModule
return rewriter.notifyMatchFailure(op, "unexpected number of padding");
}
// Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
Expand Down
6 changes: 6 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2385,6 +2385,12 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");

if (padding_2d.size() != 2) {
// pytorch 2.5 generates one element padding = {0} for
// Conv2dWithValidPaddingModule
return rewriter.notifyMatchFailure(op, "unexpected number of paddings");
}

// TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}.
// The Torch OFM computation uses 2*pad in each spatial direction, implying
// the same t=b and l=r values for TOSA.
Expand Down
63 changes: 63 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10045,10 +10045,65 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv2d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @__torch__._conv_padding(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.str) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %int-1 = torch.constant.int -1\n"
" %str = torch.constant.str \"same\"\n"
" %none = torch.constant.none\n"
" %str_0 = torch.constant.str \"AssertionError: conv: weight must be at least 3 dimensional.\"\n"
" %int2 = torch.constant.int 2\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.sub.int %0, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list<int>\n"
" %4 = torch.aten.mul.left_t %3, %2 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %5 = torch.aten.eq.str %arg2, %str : !torch.str, !torch.str -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" %6 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %7 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %8 = torch.aten.__range_length %6, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %10 = torch.prim.min.self_int %9 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %10, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.__derive_index %arg3, %6, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %13 = torch.aten.add.int %int2, %12 : !torch.int, !torch.int -> !torch.int\n"
" %14 = torch.aten.__getitem__.t %arg0, %13 : !torch.list<int>, !torch.int -> !torch.int\n"
" %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %16 = torch.aten.mul.int %11, %15 : !torch.int, !torch.int -> !torch.int\n"
" %17 = torch.aten.floordiv.int %16, %int2 : !torch.int, !torch.int -> !torch.int\n"
" %18 = torch.aten._set_item.t %4, %12, %17 : !torch.list<int>, !torch.int, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.If.yield\n"
" }\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv3d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg3 : !torch.list<int> to !torch.optional<list<int>>\n"
" %1 = torch.derefine %arg4 : !torch.list<int> to !torch.optional<list<int>>\n"
Expand Down Expand Up @@ -10118,6 +10173,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv1d.padding\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.str, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %int1 = torch.constant.int 1\n"
" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.str) -> !torch.list<int>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %2 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %false, %1, %int1) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int, %arg7: !torch.list<int>) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n"
" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
Expand Down
82 changes: 82 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5175,6 +5175,82 @@ class DecomposeAtenConv2dOp : public OpRewritePattern<AtenConv2dOp> {
};
} // namespace

// Decompose aten.conv(1/2/3)d.padding to aten.convolution
namespace {
template <typename ConvPaddingOp>
class DecomposeAtenConvPaddingOp : public OpRewritePattern<ConvPaddingOp> {
public:
using OpRewritePattern<ConvPaddingOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConvPaddingOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();

Value weight = op.getWeight();
std::optional<unsigned> maybeRank = getTensorRank(weight);
if (!maybeRank) {
return rewriter.notifyMatchFailure(op, "expected weight to have a rank");
}
unsigned rank = *maybeRank;
// first 2 dimensions of weight are out_channels and in_channels / groups
if (rank < 3)
return rewriter.notifyMatchFailure(
op, "ConvPaddingOp weight must be at least 3 dimensional.");

std::string padding_str;
if (!matchPattern(op.getPadding(), m_TorchConstantStr(padding_str)))
return rewriter.notifyMatchFailure(op,
"padding must be a constant string");

Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));

SmallVector<Value> paddingValues;
if (padding_str == "valid") {
// valid means no padding
for (unsigned iRank = 2; iRank < rank; iRank++) {
paddingValues.push_back(zero);
}
} else {

SmallVector<Value> dilation;
getListConstructElements(op.getDilation(), dilation);

Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value two =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
for (unsigned iRank = 2; iRank < rank; iRank++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(iRank));
Value kernelSize =
rewriter.create<Torch::AtenSizeIntOp>(loc, weight, dim);
Value kernelSizeMinusOne =
rewriter.create<Torch::AtenSubIntOp>(loc, kernelSize, one);
Value padding = rewriter.create<Torch::AtenMulIntOp>(
loc, dilation[iRank - 2], kernelSizeMinusOne);
padding = rewriter.create<AtenFloordivIntOp>(loc, padding, two);
paddingValues.push_back(padding);
}
}

Value emptyList = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
Value padding = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
paddingValues);
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getStride(), padding, op.getDilation(), cstFalse, emptyList,
op.getGroups());

return success();
}
};
} // namespace

// Decompose aten.conv3d to aten.convolution
namespace {
class DecomposeAtenConv3dOp : public OpRewritePattern<AtenConv3dOp> {
Expand Down Expand Up @@ -11473,6 +11549,12 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv3dOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvPaddingOp<AtenConv1dPaddingOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvPaddingOp<AtenConv2dPaddingOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvPaddingOp<AtenConv3dPaddingOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenThresholdOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloatPowerTensorTensorOp>(
patterns);
Expand Down
Loading