From ef6bf1807e7cac01c6dcd9416d31191aa1d54f97 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Mon, 17 Nov 2025 17:00:02 +0400 Subject: [PATCH 1/3] Direct lowering of torch.aten.convolution_backward from torch to linalg --- lib/Conversion/TorchToLinalg/Linear.cpp | 627 ++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 401 ----------- .../Transforms/LowerToBackendContract.cpp | 1 - .../configs/jit_importer_backend.py | 1 + .../test_suite/backprop.py | 69 ++ .../TorchToLinalg/convolution_bwd.mlir | 318 +++++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 34 - 7 files changed, 1015 insertions(+), 436 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/convolution_bwd.mlir diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 5b30e93e6fe4..31b8387be3f2 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1673,6 +1673,631 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( } } +namespace { +class ConvertAtenConvolutionBackwardOp + : public OpConversionPattern { + using IT = utils::IteratorType; + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenConvolutionBackwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value gradOutput = adaptor.getGradOutput(); + Value input = adaptor.getInput(); + Value weight = adaptor.getWeight(); + + auto gradOutputDTy = + cast(gradOutput.getType()).getElementType(); + auto inputDTy = cast(input.getType()).getElementType(); + auto weightDTy = cast(weight.getType()).getElementType(); + if (!isa(gradOutputDTy) || + !isa(inputDTy) || !isa(weightDTy)) + return op.emitError("unimplemented: only fp convolution bwd supported"); + + size_t gradRank = cast(gradOutput.getType()).getRank(); + size_t numSpatialDims = gradRank - 2; + if (numSpatialDims < 1 || numSpatialDims > 3) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1d-3d convolution bwd currently supported"); + + // Transposed convolution backward is not handled here yet. + bool transposed = false; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "only support constant bool for transposed"); + if (transposed) + return rewriter.notifyMatchFailure( + op, "unimplemented: transposed convolution backward"); + + // The `outMask` contains 3 boolean values for the results `grad_input`, + // `grad_weight`, and `grad_bias` respectively. The value being `false` + // means that the corresponding result will be none. + SmallVector outMask; + if (!matchPattern(op.getOutputMask(), + m_TorchListOfConstantBools(outMask)) || + outMask.size() != 3) + return rewriter.notifyMatchFailure( + op, "only constant bool output_mask list of size 3 is supported."); + for (unsigned i = 0; i < outMask.size(); i++) { + if (outMask[i] == false) { + Value result = op->getResults()[i]; + if (!result.getUsers().empty()) + return rewriter.notifyMatchFailure( + op, "unimplemented: false value supported for output_mask only " + "when the result tensor corresponding to that has no users."); + } + } + + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + bool isGroupedConvBwd = numGroups > 1; + int64_t spatialStartDimIdx = isGroupedConvBwd ? 3 : 2; + + // Stride, padding, dilation for the backward conv. We only support constant + // lists here, consistent with forward convolution lowering. + SmallVector paddingIntValues; + SmallVector strideInts, dilationInts, outputPaddingInts; + + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int strides"); + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int dilations"); + if (!matchPattern(op.getOutputPadding(), + m_TorchListOfConstantInts(outputPaddingInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int output paddings"); + if (!llvm::all_of(outputPaddingInts, + [](int64_t outPad) { return outPad == 0; })) + return rewriter.notifyMatchFailure( + op, "unimplemented: only output padding of 0 supported."); + + if (!getListConstructElements(op.getPadding(), paddingIntValues)) + return rewriter.notifyMatchFailure( + op, "only support padding from a list construct"); + paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), + paddingIntValues); + + // The expandGroups lambda function below is used to expand the group + // dimension for weights and input, output tensors. + // For input tensor (dim = 1) : N,C,H,W -> N,G,C/G,H,W + // For grad_output tensor (dim = 1): N,F,H,W -> N,G,F/G,H,W + // For weight tensor (dim = 0) : F,C,H,W -> G,F/G,C,H,W + auto expandGroups = [&](Value tensor, int64_t dim) { + auto inType = cast(tensor.getType()); + auto inShape = makeShapeTorchCompatible(inType.getShape()); + + SmallVector outShape; + for (auto i = 0; i < static_cast(inShape.size()); i++) { + if (i == dim) { + outShape.push_back(numGroups); + outShape.push_back(inShape[i] == kUnknownSize + ? kUnknownSize + : inShape[i] / numGroups); + } else { + outShape.push_back(inShape[i]); + } + } + + SmallVector indices; + for (auto i = 0; i <= static_cast(inShape.size()); i++) { + if (i == dim) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } + + auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); + return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor, + indices); + }; + + SmallVector newResults(op->getNumResults()); + + // Computing Backward-Input Convolution. + if (outMask[0]) { + // If convolution bwd is grouped, `grad_output` should be expanded. + auto gradOutputExpanded = + isGroupedConvBwd ? expandGroups(gradOutput, 1) : gradOutput; + // If convolution bwd is grouped, `weight` should be expanded + auto weightExpanded = isGroupedConvBwd ? expandGroups(weight, 0) : weight; + + // Flip weight along spatial dims only if number of spatial dims > 1. + SmallVector weightFlipDims; + weightFlipDims.reserve(numSpatialDims); + for (int64_t i = 0; i < static_cast(numSpatialDims); ++i) + weightFlipDims.push_back(spatialStartDimIdx + i); + weightExpanded = torch_to_linalg::flipTensor( + rewriter, loc, weightExpanded, weightFlipDims); + + // For backward-input, padding must be adjusted to: + // p'[i] = d[i] * (K[i] - 1) - p[i] + Value c1 = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + SmallVector dilationIntValues = + getAsConstantIntValues(rewriter, loc, dilationInts); + SmallVector weiSizes = + getTensorSizes(rewriter, loc, weightExpanded); + SmallVector paddingValues(numSpatialDims); + for (size_t i = 0; i < numSpatialDims; ++i) { + Value kSize = + castIndexToInt64(rewriter, loc, weiSizes[spatialStartDimIdx + i]); + Value kMinusOne = rewriter.createOrFold(loc, kSize, c1); + Value mul = rewriter.createOrFold(loc, kMinusOne, + dilationIntValues[i]); + paddingValues[i] = + arith::SubIOp::create(rewriter, loc, mul, paddingIntValues[i]); + + if (isValueNegative(paddingValues[i])) + return rewriter.notifyMatchFailure( + op, "unimplemented: negative padding values are not supported."); + } + + // If there are not unit strides, we have to scatter `grad_output` into a + // zero-initialized tensor. + SmallVector gradInputSizes = getTensorSizes(rewriter, loc, input); + Value gradOutputSliced; + if (llvm::any_of(strideInts, [](int64_t stride) { return stride > 1; })) { + // Destination spatial sizes are computed as: + // size[i] = (D[i] - 1) + d[i] * (K[i] - 1) + 1 + // Offsets on spatial dims are paddings + // Strides on spatial dims are the original stride[i]. + Value zero = + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); + Value one = + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); + + // Initialize slice strides, sizes and offsets + SmallVector goSizes = + getTensorSizes(rewriter, loc, gradOutputExpanded); + SmallVector sizes(goSizes.begin(), + goSizes.begin() + spatialStartDimIdx); + SmallVector offsets(spatialStartDimIdx, zero); + SmallVector strides(spatialStartDimIdx, one); + for (size_t i = 0; i < numSpatialDims; ++i) { + // Shapes of `grad_input` has not been expanded yet + // if it's needed for group conv even + Value h = gradInputSizes[2 + i]; + Value k = weiSizes[spatialStartDimIdx + i]; + Value hMinusOne = rewriter.createOrFold(loc, h, one); + Value kMinusOne = rewriter.createOrFold(loc, k, one); + Value mul = rewriter.createOrFold( + loc, castIntToIndex(rewriter, loc, dilationIntValues[i]), + kMinusOne); + Value sum = rewriter.createOrFold(loc, hMinusOne, mul); + sizes.push_back(rewriter.createOrFold(loc, sum, one)); + offsets.push_back(castIntToIndex(rewriter, loc, paddingValues[i])); + + Value strideIntValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(strideInts[i])); + strides.push_back(castIntToIndex(rewriter, loc, strideIntValue)); + } + + Value zeroInit = + createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy); + gradOutputSliced = tensor::InsertSliceOp::create( + rewriter, loc, + torch_to_linalg::removeSizeInformation(rewriter, loc, + gradOutputExpanded), + zeroInit, offsets, goSizes, strides); + } else { + // If there unit strides, pad `grad_output` spatial dims with zeros. + // If conv is grouped, output has shape: + // N x G x F/G x . Otherwise: N x F x . + Value padVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(gradOutputDTy, 0.0)); + gradOutputSliced = torch_to_linalg::getDynamicZeroPaddedTensor( + op, rewriter, gradOutputExpanded, paddingValues, spatialStartDimIdx, + padVal); + } + + // Initialize output buffer. For grouped, compute into an expanded + // [N, G, C/G, D*] tensor and collapse back to the original input shape. + Value gradInputInit = + createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy); + SmallVector gradInputCollapseIndices; + if (isGroupedConvBwd) { + auto gradInputInitExpanded = expandGroups(gradInputInit, 1); + gradInputInit = gradInputInitExpanded.getResult(); + gradInputCollapseIndices = + gradInputInitExpanded.getReassociationIndices(); + } + + // Generate GenericOp. + SmallVector indexingMaps; + SmallVector iteratorTypes; + initIndexingMapsAndIteratorTypesForDataBwd( + rewriter, context, isGroupedConvBwd, numSpatialDims, dilationInts, + indexingMaps, iteratorTypes); + auto genericRes = + createGenericOp(rewriter, loc, gradOutputSliced, weightExpanded, + gradInputInit, indexingMaps, iteratorTypes) + .getResult(0); + + // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the generic op + // if it is grouped. + if (isGroupedConvBwd) { + genericRes = tensor::CollapseShapeOp::create( + rewriter, loc, input.getType(), genericRes, + gradInputCollapseIndices); + } + + // Cast to the final result type expected by the type converter. + newResults[0] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(0).getType()), + genericRes) + .getResult(); + } + + // Computing Backward-Weight Convolution. + if (outMask[1]) { + // If convolution bwd is grouped, `grad_output` should be expanded. + auto gradOutputExpanded = + isGroupedConvBwd ? expandGroups(gradOutput, 1) : gradOutput; + // If convolution bwd is grouped, `input` should be expanded + auto inputExpanded = isGroupedConvBwd ? expandGroups(input, 1) : input; + + // Pad input spatial dims with zeros. If grouped, input has shape: + // N x G x C/G x . Otherwise: N x C x . + // We should only pad the spatial dims, so set unpaddedDims accordingly. + Value padVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(inputDTy, 0.0)); + Value paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( + op, rewriter, inputExpanded, paddingIntValues, spatialStartDimIdx, + padVal); + + // Initialize output buffer. For grouped, compute into an expanded + // [G, F/G, C/G, K*] tensor and collapse back to the original weight + // shape. + SmallVector gradWeightSizes = + getTensorSizes(rewriter, loc, weight); + Value gradWeightInit = + createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); + SmallVector gradWeightCollapseIndices; + if (isGroupedConvBwd) { + auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0); + gradWeightInit = gradWeightInitExpanded.getResult(); + gradWeightCollapseIndices = + gradWeightInitExpanded.getReassociationIndices(); + } + + // Generate GenericOp. + SmallVector indexingMaps; + SmallVector iteratorTypes; + initIndexingMapsAndIteratorTypesForWeightBwd( + rewriter, context, isGroupedConvBwd, numSpatialDims, strideInts, + dilationInts, indexingMaps, iteratorTypes); + auto genericRes = + createGenericOp(rewriter, loc, paddedInput, gradOutputExpanded, + gradWeightInit, indexingMaps, iteratorTypes) + .getResult(0); + + // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the generic op + // if it is grouped. + if (isGroupedConvBwd) { + genericRes = tensor::CollapseShapeOp::create( + rewriter, loc, weight.getType(), genericRes, + gradWeightCollapseIndices); + } + + // Cast to the final result type expected by the type converter. + newResults[1] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(1).getType()), + genericRes) + .getResult(); + } + + // Computing Backward-Bias Convolution. + if (outMask[2]) { + // Sum grad_output along all dims except F using linalg. + DenseSet reduceDims; + reduceDims.insert(0); + for (int64_t i = 2; i < static_cast(gradRank); ++i) + reduceDims.insert(i); + + torch_to_linalg::ReductionOpInfo opInfo{false, gradOutput, reduceDims}; + + // Zero init for the element type (arith.constant expects a scalar attr). + Value initSum = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(gradOutputDTy)); + + auto reductionBody = [&](OpBuilder &b, Location loc, ValueRange args) { + Value x = args[0]; + Value acc = args[1]; + Value sum = arith::AddFOp::create(b, loc, x, acc); + linalg::YieldOp::create(b, loc, sum); + }; + + Value gradBias = torch_to_linalg::createReductionLinalgGeneric( + rewriter, loc, opInfo, initSum, reductionBody); + + newResults[2] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(2).getType()), + gradBias) + .getResult(); + } + + rewriter.replaceOp(op, newResults); + + return success(); + } + +private: + static void initIndexingMapsAndIteratorTypesForDataBwd( + OpBuilder &rewriter, MLIRContext *context, bool isGrouped, + int numSpatialDims, const SmallVector &dilationInts, + SmallVector &indexingMaps, SmallVector &iteratorTypes) { + // To calculate convolution backward-data, we use generic operation. + // The generic operation is a generalization of the convolution operation + // that can handle any number of spatial dimensions. + // The generic operation is defined as follows: + // ``` + // dLdx[n, g, c, o] = sum(dLdy[n, g, f, d0 * k + o] * w[g, f, c, k] + // for n in range(batch_size) for o in range(in_spatial_dims)) + // ``` + // where `n` is the batch dimension, `g` is the group dimension, + // `c` is the input channel dimension, `f` is the output channel + // dimension, `o` is the input spatial dimension, `k` is the kernel + // dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the + // gradient of the output tensor. `dLdx` is the data-gradient tensor. + if (!isGrouped) { + if (numSpatialDims == 1) { + AffineExpr n, c, o, f, k; + bindDims(context, n, c, o, f, k); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + SmallVector goExprs = {n, f, d0 * k + o}; + SmallVector weiExprs = {f, c, k}; + SmallVector outExprs = {n, c, o}; + indexingMaps = {AffineMap::get(5, 0, goExprs, context), + AffineMap::get(5, 0, weiExprs, context), + AffineMap::get(5, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::reduction, IT::reduction}; + } else if (numSpatialDims == 2) { + AffineExpr n, c, oh, ow, f, kh, kw; + bindDims(context, n, c, oh, ow, f, kh, kw); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + SmallVector goExprs = {n, f, d0 * kh + oh, d1 * kw + ow}; + SmallVector weiExprs = {f, c, kh, kw}; + SmallVector outExprs = {n, c, oh, ow}; + indexingMaps = {AffineMap::get(7, 0, goExprs, context), + AffineMap::get(7, 0, weiExprs, context), + AffineMap::get(7, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::reduction, IT::reduction, + IT::reduction}; + } else { + AffineExpr n, c, od, oh, ow, f, kd, kh, kw; + bindDims(context, n, c, od, oh, ow, f, kd, kh, kw); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); + SmallVector goExprs = {n, f, d0 * kd + od, d1 * kh + oh, + d2 * kw + ow}; + SmallVector weiExprs = {f, c, kd, kh, kw}; + SmallVector outExprs = {n, c, od, oh, ow}; + indexingMaps = {AffineMap::get(9, 0, goExprs, context), + AffineMap::get(9, 0, weiExprs, context), + AffineMap::get(9, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::reduction, + IT::reduction, IT::reduction, IT::reduction}; + } + } else { + if (numSpatialDims == 1) { + AffineExpr n, g, cg, o, fg, k; + bindDims(context, n, g, cg, o, fg, k); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + SmallVector goExprs = {n, g, fg, d0 * k + o}; + SmallVector weiExprs = {g, fg, cg, k}; + SmallVector outExprs = {n, g, cg, o}; + indexingMaps = {AffineMap::get(6, 0, goExprs, context), + AffineMap::get(6, 0, weiExprs, context), + AffineMap::get(6, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::reduction, IT::reduction}; + } else if (numSpatialDims == 2) { + AffineExpr n, g, cg, oh, ow, fg, kh, kw; + bindDims(context, n, g, cg, oh, ow, fg, kh, kw); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + SmallVector goExprs = {n, g, fg, d0 * kh + oh, + d1 * kw + ow}; + SmallVector weiExprs = {g, fg, cg, kh, kw}; + SmallVector outExprs = {n, g, cg, oh, ow}; + indexingMaps = {AffineMap::get(8, 0, goExprs, context), + AffineMap::get(8, 0, weiExprs, context), + AffineMap::get(8, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::reduction, + IT::reduction, IT::reduction}; + } else { + AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw; + bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); + SmallVector goExprs = { + n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow}; + SmallVector weiExprs = {g, fg, cg, kd, kh, kw}; + SmallVector outExprs = {n, g, cg, od, oh, ow}; + indexingMaps = {AffineMap::get(10, 0, goExprs, context), + AffineMap::get(10, 0, weiExprs, context), + AffineMap::get(10, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::parallel, + IT::reduction, IT::reduction, IT::reduction, + IT::reduction}; + } + } + } + + static void initIndexingMapsAndIteratorTypesForWeightBwd( + OpBuilder &rewriter, MLIRContext *context, bool isGrouped, + int numSpatialDims, const SmallVector &strideInts, + const SmallVector &dilationInts, + SmallVector &indexingMaps, SmallVector &iteratorTypes) { + // To calculate convolution backward-weight, we use generic operation. + // The generic operation is a generalization of the convolution operation + // that can handle any number of spatial dimensions. + // The generic operation is defined as follows: + // ``` + // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] + // for n in range(batch_size) for o in range(output_spatial_dims)) + // ``` + // where `n` is the batch dimension, `g` is the group dimension, + // `c` is the input channel dimension, `f` is the output channel + // dimension, `o` is the output spatial dimension, `k` is the kernel + // dimension, `d0` is dilation and `s0` is stride. `x` is the input + // tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the + // weight-gradient tensor. + if (!isGrouped) { + if (numSpatialDims == 1) { + AffineExpr f, c, k, n, o; + bindDims(context, f, c, k, n, o); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + SmallVector inExprs = {n, c, d0 * k + s0 * o}; + SmallVector goExprs = {n, f, o}; + SmallVector outExprs = {f, c, k}; + indexingMaps = {AffineMap::get(5, 0, inExprs, context), + AffineMap::get(5, 0, goExprs, context), + AffineMap::get(5, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::reduction, IT::reduction}; + } else if (numSpatialDims == 2) { + AffineExpr f, c, kh, kw, n, oh, ow; + bindDims(context, f, c, kh, kw, n, oh, ow); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + SmallVector inExprs = {n, c, d0 * kh + s0 * oh, + d1 * kw + s1 * ow}; + SmallVector goExprs = {n, f, oh, ow}; + SmallVector outExprs = {f, c, kh, kw}; + indexingMaps = {AffineMap::get(7, 0, inExprs, context), + AffineMap::get(7, 0, goExprs, context), + AffineMap::get(7, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::reduction, IT::reduction, + IT::reduction}; + } else { + AffineExpr f, c, kd, kh, kw, n, od, oh, ow; + bindDims(context, f, c, kd, kh, kw, n, od, oh, ow); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); + AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); + SmallVector inExprs = { + n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; + SmallVector goExprs = {n, f, od, oh, ow}; + SmallVector outExprs = {f, c, kd, kh, kw}; + indexingMaps = {AffineMap::get(9, 0, inExprs, context), + AffineMap::get(9, 0, goExprs, context), + AffineMap::get(9, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::reduction, + IT::reduction, IT::reduction, IT::reduction}; + } + } else { + if (numSpatialDims == 1) { + AffineExpr g, fg, cg, k, n, o; + bindDims(context, g, fg, cg, k, n, o); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + SmallVector inExprs = {n, g, cg, d0 * k + s0 * o}; + SmallVector goExprs = {n, g, fg, o}; + SmallVector outExprs = {g, fg, cg, k}; + indexingMaps = {AffineMap::get(6, 0, inExprs, context), + AffineMap::get(6, 0, goExprs, context), + AffineMap::get(6, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::reduction, IT::reduction}; + } else if (numSpatialDims == 2) { + AffineExpr g, fg, cg, kh, kw, n, oh, ow; + bindDims(context, g, fg, cg, kh, kw, n, oh, ow); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + SmallVector inExprs = {n, g, cg, d0 * kh + s0 * oh, + d1 * kw + s1 * ow}; + SmallVector goExprs = {n, g, fg, oh, ow}; + SmallVector outExprs = {g, fg, cg, kh, kw}; + indexingMaps = {AffineMap::get(8, 0, inExprs, context), + AffineMap::get(8, 0, goExprs, context), + AffineMap::get(8, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::reduction, + IT::reduction, IT::reduction}; + } else { + AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow; + bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow); + AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); + AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); + AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); + AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); + AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); + AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); + SmallVector inExprs = { + n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; + SmallVector goExprs = {n, g, fg, od, oh, ow}; + SmallVector outExprs = {g, fg, cg, kd, kh, kw}; + indexingMaps = {AffineMap::get(10, 0, inExprs, context), + AffineMap::get(10, 0, goExprs, context), + AffineMap::get(10, 0, outExprs, context)}; + iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, + IT::parallel, IT::parallel, IT::parallel, + IT::reduction, IT::reduction, IT::reduction, + IT::reduction}; + } + } + } + + static linalg::GenericOp + createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out, + const SmallVector &indexingMaps, + const SmallVector &iteratorTypes) { + return linalg::GenericOp::create( + b, loc, out.getType(), ValueRange{in0, in1}, out, indexingMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + Value grad = args[1]; + Value output = args[2]; + + // Convert input and grad to accumulator type if needed + Type accType = output.getType(); + if (input.getType() != accType) { + input = arith::ExtFOp::create(b, loc, accType, input); + } + if (grad.getType() != accType) { + grad = arith::ExtFOp::create(b, loc, accType, grad); + } + + Value mul = arith::MulFOp::create(b, loc, input, grad); + Value sum = arith::AddFOp::create(b, loc, mul, output); + linalg::YieldOp::create(b, loc, sum); + }); + } +}; +} // namespace + namespace { /// Creates coefficients based on DFT definition, see @@ -1874,6 +2499,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 08b25c9b6f60..13a839946a6b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5836,406 +5836,6 @@ class DecomposeAtenConvTranspose3dOp }; } // namespace -// The convolution backward op is decomposed as follows: -// inputH, inputW = input.shape[2:] -// output_padding_ = [ -// inputH -// - 1 -// + 2 * padding_[0] -// - dilation_[0] * (weight.shape[2] - 1) -// - (grad_output.shape[2] - 1) * stride_[0], -// inputW -// - 1 -// + 2 * padding_[1] -// - dilation_[1] * (weight.shape[3] - 1) -// - (grad_output.shape[3] - 1) * stride_[1], -// ] -// -// decomp_grad_input = torch.nn.functional.conv_transpose2d( -// grad_output, -// weight, -// None, -// stride_, -// padding_, -// output_padding_, -// groups_, -// dilation_, -// ) -// -// input_transposed = torch.ops.aten.transpose(input, 0, 1) -// grad_output_transposed = grad_output.view( -// grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] -// ) -// decomp_grad_weight = torch.ops.aten.convolution( -// input_transposed, -// grad_output_transposed, -// bias=None, -// stride=dilation_, -// padding=padding_, -// dilation=stride_, -// transposed=False, -// output_padding=[0, 0], -// groups=input.shape[0], -// ) -// decomp_grad_weight = torch.narrow(decomp_grad_weight, 2, 0, weight.shape[2]) -// decomp_grad_weight = torch.narrow(decomp_grad_weight, 3, 0, weight.shape[3]) -// decomp_grad_weight = decomp_grad_weight.view( -// input_transposed.shape[0], -// input_transposed.shape[1], -// grad_output.shape[1], -// *decomp_grad_weight.shape[2:] -// ) -// decomp_grad_weight = decomp_grad_weight.movedim(0, 2) -// decomp_grad_weight = decomp_grad_weight.sum(dim=0) -// -// decomp_grad_bias = torch.sum(grad_output, dim=[0, 2, 3]) -namespace { -class DecomposeAtenConvolutionBackwardOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op, - PatternRewriter &rewriter) const override { - - Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); - Value input = op.getInput(); - Value weight = op.getWeight(); - Value gradOutput = op.getGradOutput(); - std::optional maybeGradRank = getTensorRank(gradOutput); - if (!maybeGradRank) { - return rewriter.notifyMatchFailure(op, - "expected grad output to have a rank"); - } - unsigned gradRank = *maybeGradRank; - if (gradRank != 4) - return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D convolutions supported."); - - Value cstZero = Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(0)); - Value cstOne = Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(1)); - Value cstTwo = Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(2)); - Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); - Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, - rewriter.getBoolAttr(false)); - - SmallVector padding, dilation, stride; - SmallVector paddingInt, dilationInt, strideInt, - outputPaddingInt; - - if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInt))) - return rewriter.notifyMatchFailure( - op, "padding must be a list of constant ints"); - - if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInt))) - return rewriter.notifyMatchFailure( - op, "stride must be a list of constant ints"); - - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInt))) - return rewriter.notifyMatchFailure( - op, "dilation must be a list of constant ints"); - if (!llvm::all_of(dilationInt, - [](int64_t dilationVal) { return dilationVal == 1; })) - return rewriter.notifyMatchFailure( - op, "unimplemented: only dilations of 1 supported."); - - if (!matchPattern(op.getOutputPadding(), - m_TorchListOfConstantInts(outputPaddingInt))) - return rewriter.notifyMatchFailure( - op, "output padding must be a list of constant ints"); - if (!llvm::all_of(outputPaddingInt, - [](int64_t outPad) { return outPad == 0; })) - return rewriter.notifyMatchFailure( - op, "unimplemented: only output padding of 0 supported."); - - // The `outMask` contains 3 boolean values for the results `grad_input`, - // `grad_weight`, and `grad_bias` respectively. The value being `false` - // means that the corresponding result will be none. - SmallVector outMask; - if (!matchPattern(op.getOutputMask(), - m_TorchListOfConstantBools(outMask)) || - outMask.size() != 3) - return rewriter.notifyMatchFailure( - op, "only constant bool output_mask list of size 3 is supported."); - for (unsigned i = 0; i < outMask.size(); i++) { - if (outMask[i] == false) { - Value result = op->getResults()[i]; - if (!result.getUsers().empty()) - return rewriter.notifyMatchFailure( - op, "unimplemented: false value supported for output_mask only " - "when the result tensor corresponding to that has no users."); - } - } - - bool transposed; - if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) - return rewriter.notifyMatchFailure( - op, "transposed arg should be a constant bool."); - if (transposed) - return rewriter.notifyMatchFailure( - op, "unimplemented: transposed convolutions are not supported."); - - Value gradInput = cstNone; - if (outMask[0]) { - // Computing Grad Input. - getListConstructElements(op.getPadding(), padding); - getListConstructElements(op.getStride(), stride); - getListConstructElements(op.getDilation(), dilation); - - // Calculate output padding for first convolution. - // output_padding_ = [ - // inputH - 1 + (2 * padding_[0]) - (dilation_[0] * (weight.size()[2] - // - 1)) - ((grad_out.size()[2] - 1) * stride_[0]), inputW - 1 + (2 * - // padding_[1]) - (dilation_[1] * (weight.size()[3] - 1)) - - // ((grad_out.size()[3] - 1) * stride_[1]), - // ] - SmallVector outputPaddingValues; - for (unsigned i = 2; i < gradRank; i++) { - Value dim = Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(i)); - Value inputVecDim = - Torch::AtenSizeIntOp::create(rewriter, loc, input, dim); - Value gradOutDim = - Torch::AtenSizeIntOp::create(rewriter, loc, gradOutput, dim); - Value weightDim = - Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); - Value inputVecDimMinusOne = - Torch::AtenSubIntOp::create(rewriter, loc, inputVecDim, cstOne); - Value gradOutDimMinusOne = - Torch::AtenSubIntOp::create(rewriter, loc, gradOutDim, cstOne); - Value weightDimMinusOne = - Torch::AtenSubIntOp::create(rewriter, loc, weightDim, cstOne); - Value twoTimesPadding = - Torch::AtenMulIntOp::create(rewriter, loc, padding[i - 2], cstTwo); - Value tmpA = Torch::AtenMulIntOp::create( - rewriter, loc, weightDimMinusOne, dilation[i - 2]); - Value tmpB = Torch::AtenMulIntOp::create( - rewriter, loc, gradOutDimMinusOne, stride[i - 2]); - Value outputPaddingVal = AtenAddIntOp::create( - rewriter, loc, inputVecDimMinusOne, twoTimesPadding); - outputPaddingVal = - AtenSubIntOp::create(rewriter, loc, outputPaddingVal, tmpA); - outputPaddingVal = - AtenSubIntOp::create(rewriter, loc, outputPaddingVal, tmpB); - outputPaddingValues.push_back(outputPaddingVal); - } - Value outputPaddingForGradInput = Torch::PrimListConstructOp::create( - rewriter, loc, ListType::get(IntType::get(context)), - outputPaddingValues); - gradInput = Torch::AtenConvTranspose2dInputOp::create( - rewriter, loc, op.getResultTypes()[0], gradOutput, weight, cstNone, - op.getStride(), op.getPadding(), outputPaddingForGradInput, - op.getGroups(), op.getDilation()); - } - - Value gradWeight = cstNone; - if (outMask[1]) { - // Computing Grad Weight. - Type transposedType; - if (failed(getTransposedType(cast(input.getType()), 0, 1, - transposedType))) - return failure(); - Value inputTransposed = Torch::AtenTransposeIntOp::create( - rewriter, loc, transposedType, input, cstZero, cstOne); - - // For the cases where the stride is non-unit, we compute the `GradWeight` - // through this implementation. - if (!llvm::all_of(strideInt, - [](int64_t stride) { return stride == 1; })) { - SmallVector gradOutputSize; - for (unsigned i = 0; i < gradRank; i++) { - gradOutputSize.push_back(Torch::AtenSizeIntOp::create( - rewriter, loc, gradOutput, - Torch::ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(i)))); - } - - Value gradOutputViewDimZero = Torch::AtenMulIntOp::create( - rewriter, loc, gradOutputSize[0], gradOutputSize[1]); - Value gradOutputViewShapeList = Torch::PrimListConstructOp::create( - rewriter, loc, - Torch::ListType::get(Torch::IntType::get(op.getContext())), - ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], - gradOutputSize[3]}); - - BaseTensorType gradOutputTy = - cast(gradOutput.getType()); - if (!gradOutputTy.hasSizes()) - return failure(); - SmallVector gradOutputSizesInt(gradOutputTy.getSizes()); - SmallVector gradOutputViewSizesInt(gradOutputSizesInt); - if (gradOutputViewSizesInt[0] != kUnknownSize && - gradOutputViewSizesInt[1] != kUnknownSize) - gradOutputViewSizesInt[0] *= gradOutputViewSizesInt[1]; - else - gradOutputViewSizesInt[0] = kUnknownSize; - gradOutputViewSizesInt[1] = 1; - BaseTensorType gradOutputTypeForView = - cast(gradOutputTy.getWithSizesAndDtype( - llvm::ArrayRef(gradOutputViewSizesInt), - gradOutputTy.getOptionalDtype())); - Value gradOutputView = - Torch::AtenViewOp::create(rewriter, loc, gradOutputTypeForView, - gradOutput, gradOutputViewShapeList); - - BaseTensorType inputTransposedTy = - cast(inputTransposed.getType()); - if (!inputTransposedTy.hasSizes()) - return failure(); - SmallVector inputTransposedSizesInt( - inputTransposedTy.getSizes()); - SmallVector gradWeightSizesInt{inputTransposedSizesInt[0], - gradOutputViewSizesInt[0]}; - for (unsigned i = 2; i < gradRank; i++) { - if (inputTransposedSizesInt[i] != kUnknownSize && - gradOutputViewSizesInt[i] != kUnknownSize) { - int64_t kernelSizeInt = - strideInt[i - 2] * (gradOutputViewSizesInt[i] - 1) + 1; - gradWeightSizesInt.push_back( - ((inputTransposedSizesInt[i] + (paddingInt[i - 2] * 2) - - kernelSizeInt) / - dilationInt[i - 2]) + - 1); - } else { - gradWeightSizesInt.push_back(kUnknownSize); - } - } - - BaseTensorType gradWeightTy = - cast(inputTransposedTy.getWithSizesAndDtype( - llvm::ArrayRef(gradWeightSizesInt), - inputTransposedTy.getOptionalDtype())); - - Value numGroup = AtenSizeIntOp::create(rewriter, loc, input, cstZero); - gradWeight = Torch::AtenConvolutionOp::create( - rewriter, loc, gradWeightTy, inputTransposed, gradOutputView, - cstNone, - /*stride=*/op.getDilation(), op.getPadding(), - /*dilation=*/op.getStride(), op.getTransposed(), - op.getOutputPadding(), numGroup); - - BaseTensorType weightTy = cast(weight.getType()); - if (!weightTy.hasSizes()) - return failure(); - SmallVector weightSizes(weightTy.getSizes()); - for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { - gradWeightSizesInt[i + 2] = weightSizes[i + 2]; - BaseTensorType gradWeightNarrowTy = - cast(gradWeightTy.getWithSizesAndDtype( - llvm::ArrayRef(gradWeightSizesInt), - gradWeightTy.getOptionalDtype())); - - Value dim = ConstantIntOp::create(rewriter, loc, - rewriter.getI64IntegerAttr(i + 2)); - Value length = - Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); - gradWeight = Torch::AtenNarrowOp::create( - rewriter, loc, gradWeightNarrowTy, gradWeight, dim, - /*start=*/cstZero, length); - } - - SmallVector gradWeightViewShapeInt{ - inputTransposedSizesInt[0], inputTransposedSizesInt[1]}; - gradWeightViewShapeInt.push_back(gradOutputSizesInt[1]); - gradWeightViewShapeInt.insert( - gradWeightViewShapeInt.end(), - {gradWeightSizesInt[2], gradWeightSizesInt[3]}); - - SmallVector gradWeightViewShapeValue; - for (unsigned i = 0; i < gradWeightViewShapeInt.size(); i++) { - gradWeightViewShapeValue.push_back(Torch::ConstantIntOp::create( - rewriter, loc, - rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i]))); - } - - Value gradWeightViewShapeList = Torch::PrimListConstructOp::create( - rewriter, loc, - Torch::ListType::get(Torch::IntType::get(op.getContext())), - gradWeightViewShapeValue); - - BaseTensorType gradWeightTypeForView = - cast(gradWeightTy.getWithSizesAndDtype( - llvm::ArrayRef(gradWeightViewShapeInt), - gradWeightTy.getOptionalDtype())); - gradWeight = - Torch::AtenViewOp::create(rewriter, loc, gradWeightTypeForView, - gradWeight, gradWeightViewShapeList); - - gradWeightTy = cast(gradWeight.getType()); - SmallVector gradWeightDimsOrder = - computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size()); - SmallVector gradWeightMoveDimShape; - for (unsigned i = 0; i < gradWeightDimsOrder.size(); i++) { - gradWeightMoveDimShape.push_back( - gradWeightViewShapeInt[gradWeightDimsOrder[i]]); - } - BaseTensorType gradWeightTypeForMoveDim = - cast(gradWeightTy.getWithSizesAndDtype( - llvm::ArrayRef(gradWeightMoveDimShape), - gradWeightTy.getOptionalDtype())); - - gradWeight = - AtenMovedimIntOp::create(rewriter, loc, gradWeightTypeForMoveDim, - gradWeight, /*source=*/cstZero, - /*destination=*/cstTwo); - - Value gradIntList = Torch::PrimListConstructOp::create( - rewriter, loc, - Torch::ListType::get(Torch::IntType::get(op.getContext())), - llvm::ArrayRef{cstZero}); - gradWeight = Torch::AtenSumDimIntListOp::create( - rewriter, loc, op.getResultTypes()[1], /*self=*/gradWeight, - /*dim=*/gradIntList, - /*keepdim=*/cstFalse, - /*dtype=*/cstNone); - } else { - if (failed(getTransposedType(cast(gradOutput.getType()), - 0, 1, transposedType))) - return failure(); - Value gradOutputTransposed = Torch::AtenTransposeIntOp::create( - rewriter, loc, transposedType, gradOutput, cstZero, cstOne); - // Convolve input with grad_output. - if (failed( - getTransposedType(cast(op.getResultTypes()[1]), - 0, 1, transposedType))) - return failure(); - gradWeight = Torch::AtenConvolutionOp::create( - rewriter, loc, transposedType, inputTransposed, - gradOutputTransposed, cstNone, op.getStride(), op.getPadding(), - op.getDilation(), op.getTransposed(), op.getOutputPadding(), - op.getGroups()); - gradWeight = Torch::AtenTransposeIntOp::create( - rewriter, loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne); - } - } - - Value gradBias = cstNone; - if (outMask[2]) { - // Computing Grad Bias. - SmallVector dimIntList{cstZero}; - for (unsigned i = 2; i < gradRank; i++) - dimIntList.push_back(Torch::ConstantIntOp::create( - rewriter, loc, rewriter.getI64IntegerAttr(i))); - Value gradIntList = Torch::PrimListConstructOp::create( - rewriter, loc, - Torch::ListType::get(Torch::IntType::get(op.getContext())), - dimIntList); - - // Sum grad_output along dim 1. - gradBias = Torch::AtenSumDimIntListOp::create( - rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, - cstFalse, cstNone); - } - - rewriter.replaceOp(op, {gradInput, gradWeight, gradBias}); - return success(); - } -}; -} // namespace - /** * # one dim input * t = torch.tensor([0, 0, 1, 1, 0, 0] @@ -13162,7 +12762,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>( patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index b149d172496c..6869ce12ab42 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -445,7 +445,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py index 4f547d531294..095b4db6c5a4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py @@ -34,6 +34,7 @@ "aten.flatten.using_ints", "aten.adaptive_avg_pool1d", "aten.adaptive_avg_pool2d", + "aten.convolution_backward", "aten.unflatten.int", ], OutputType.STABLEHLO: [ diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index 5e6e093902c4..e6bbbe3273dc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -228,6 +228,75 @@ def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), tu.rand(2, 2, 3, 3)) +class ConvolutionBackwardModule2DDilated(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 2, 6, 6], torch.float32, True), + ([1, 4, 8, 8], torch.float32, True), + ([2, 4, 3, 3], torch.float32, True), + ] + ) + def forward(self, grad_out, input_vec, weight): + return torch.ops.aten.convolution_backward( + grad_out, + input_vec, + weight, + bias_sizes=[4], + stride=[1, 1], + padding=[1, 1], + dilation=[2, 2], + transposed=False, + output_padding=[0, 0], + groups=1, + output_mask=[True, True, True], + ) + + +@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DDilated()) +def ConvolutionBackwardModule2DDilated_basic(module, tu: TestUtils): + with torch.backends.mkldnn.flags(enabled=False): + module.forward(tu.rand(1, 2, 6, 6), tu.rand(1, 4, 8, 8), tu.rand(2, 4, 3, 3)) + + +class ConvolutionBackwardModule2DStridedPaddedDilatedGrouped(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 16, 32, 32], torch.float32, True), + ([2, 128, 64, 64], torch.float32, True), + ([16, 32, 2, 2], torch.float32, True), + ] + ) + def forward(self, grad_out, input_vec, weight): + return torch.ops.aten.convolution_backward( + grad_out, + input_vec, + weight, + bias_sizes=[4], + stride=[2, 2], + padding=[2, 2], + dilation=[4, 4], + transposed=False, + output_padding=[0, 0], + groups=4, + output_mask=[True, True, True], + ) + + +@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStridedPaddedDilatedGrouped()) +def ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic(module, tu: TestUtils): + with torch.backends.mkldnn.flags(enabled=False): + module.forward(tu.rand(2, 16, 32, 32), tu.rand(2, 128, 64, 64), tu.rand(16, 32, 2, 2)) + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/convolution_bwd.mlir b/test/Conversion/TorchToLinalg/convolution_bwd.mlir new file mode 100644 index 000000000000..0e1f5e67dbb8 --- /dev/null +++ b/test/Conversion/TorchToLinalg/convolution_bwd.mlir @@ -0,0 +1,318 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,63,63],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2,16,63,63],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,2,2],f32> -> tensor<16x128x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,63,63],f32> -> tensor<2x16x63x63xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[T1]] : tensor<16x128x2x2xf32>) outs(%[[W_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST1]], %[[I2]] : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[T1]][%[[I0]], %[[I1]], %[[R2]], %[[R3]]] : tensor<16x128x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] low[0, 0, 1, 1] high[0, 0, 1, 1] + // CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK: tensor.yield %[[CST0]] : f32 + // CHECK: } : tensor<2x16x63x63xf32> to tensor<2x16x65x65xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[W_REV]] : tensor<2x16x65x65xf32>, tensor<16x128x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x63x63xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,63,63],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,2,2],f32> -> tensor<16x128x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[T1]] : tensor<16x128x2x2xf32>) outs(%[[W_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST1]], %[[I2]] : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[T1]][%[[I0]], %[[I1]], %[[R2]], %[[R3]]] : tensor<16x128x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x16x66x66xf32> + // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x16x66x66xf32>) -> tensor<2x16x66x66xf32> + // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0]] into %[[SLICE_FILLED]][0, 0, 0, 0] [2, 16, 33, 33] [1, 1, 2, 2] : tensor<2x16x33x33xf32> into tensor<2x16x66x66xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d5 * 2 + d2, d6 * 2 + d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[SLICE]], %[[W_REV]] : tensor<2x16x66x66xf32>, tensor<16x128x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,32,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,32,2,2],f32> -> tensor<16x32x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xf32> into tensor<2x4x4x33x33xf32> + // CHECK: %[[W_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xf32> into tensor<4x4x32x2x2xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {{.*}} ins(%[[W_EXP]] : tensor<4x4x32x2x2xf32>) outs(%[[W_FILLED]] : tensor<4x4x32x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[I4:.*]] = linalg.index 4 : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[R4:.*]] = arith.subi %[[CST1]], %[[I4]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[W_EXP]][%[[I0]], %[[I1]], %[[I2]], %[[R3]], %[[R4]]] : tensor<4x4x32x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xf32> + // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x4x4x66x66xf32> + // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x4x4x66x66xf32>) -> tensor<2x4x4x66x66xf32> + // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0_EXP]] into %[[SLICE_FILLED]][0, 0, 0, 0, 0] [2, 4, 4, 33, 33] [1, 1, 1, 2, 2] : tensor<2x4x4x33x33xf32> into tensor<2x4x4x66x66xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_EXP:.*]] = tensor.expand_shape %[[OUT_EMPTY]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xf32> into tensor<2x4x32x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EXP]] : tensor<2x4x32x64x64xf32>) -> tensor<2x4x32x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {{.*}} ins(%[[SLICE]], %[[W_REV]] : tensor<2x4x4x66x66xf32>, tensor<4x4x32x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x4x32x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x4x32x64x64xf32> + // CHECK: %[[CONV_COLLAPSED:.*]] = tensor.collapse_shape %[[CONV]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} : tensor<2x4x32x64x64xf32> into tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV_COLLAPSED]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,32,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_1x1s_0x0p_1x1d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,63,63],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_weights_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2,16,63,63],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,63,63],f32> -> tensor<2x16x63x63xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[T1]], %[[T0]] : tensor<2x128x64x64xf32>, tensor<2x16x63x63xf32>) outs(%[[OUT0_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<16x128x2x2xf32> -> !torch.vtensor<[16,128,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x63x63xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,63,63],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> + return %result1, %result2 : !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,32,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[32,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32>) { +func.func @convolution_backward_weights_2x2s_2x2p_2x2d_1g(%arg0: !torch.vtensor<[2,32,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[32,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,32,33,33],f32> -> tensor<2x32x33x33xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T1]] low[0, 0, 2, 2] high[0, 0, 2, 2] + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[CST]] : f32 + // CHECK-NEXT: } : tensor<2x128x64x64xf32> to tensor<2x128x68x68xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<32x128x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EMPTY]] : tensor<32x128x2x2xf32>) -> tensor<32x128x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 * 2 + d5 * 2, d3 * 2 + d6 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0]] : tensor<2x128x68x68xf32>, tensor<2x32x33x33xf32>) outs(%[[OUT0_FILLED]] : tensor<32x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<32x128x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<32x128x2x2xf32> -> !torch.vtensor<[32,128,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<32xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<32xf32>) -> tensor<32xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x32x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<32xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<32xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<32xf32> -> !torch.vtensor<[32],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,32,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[32,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> + return %result1, %result2 : !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,32,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xf32> into tensor<2x4x4x33x33xf32> + // CHECK: %[[T1_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xf32> into tensor<2x4x32x64x64xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T1_EXP]] low[0, 0, 0, 2, 2] high[0, 0, 0, 2, 2] + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[CST]] : f32 + // CHECK-NEXT: } : tensor<2x4x32x64x64xf32> to tensor<2x4x32x68x68xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<16x32x2x2xf32> + // CHECK: %[[OUT0_EXP:.*]] = tensor.expand_shape %[[OUT0_EMPTY]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xf32> into tensor<4x4x32x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EXP]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d2, d3 * 2 + d6 * 2, d4 * 2 + d7 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d1, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0_EXP]] : tensor<2x4x32x68x68xf32>, tensor<2x4x4x33x33xf32>) outs(%[[OUT0_FILLED]] : tensor<4x4x32x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xf32> + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CONV]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} : tensor<4x4x32x2x2xf32> into tensor<16x32x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<16x32x2x2xf32> -> !torch.vtensor<[16,32,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,32,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> + return %result1, %result2 : !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> +} + +// ----- diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 8fe502a7d686..6c1d051b570b 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -273,40 +273,6 @@ func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int { return %arg0 : !torch.int } -// ----- - -// CHECK-LABEL: func.func @convolution_backward_none_result( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1,1,5,5],f32>, -// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,1,3,3],f32>, -// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) { -func.func @convolution_backward_none_result(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,3,3],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) { - // CHECK: %[[VAL_4:.*]] = torch.constant.int 3 - // CHECK: %[[VAL_5:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_6:.*]] = torch.constant.none - // CHECK: %[[VAL_7:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_8:.*]] = torch.constant.bool false - // CHECK: %[[VAL_9:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_9]], %[[VAL_9]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_12:.*]] = torch.aten.transpose.int %[[VAL_1]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,5,5],f32> - // CHECK: %[[VAL_13:.*]] = torch.aten.transpose.int %[[VAL_0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32> - // CHECK: %[[VAL_14:.*]] = torch.aten.convolution %[[VAL_12]], %[[VAL_13]], %[[VAL_6]], %[[VAL_10]], %[[VAL_11]], %[[VAL_10]], %[[VAL_8]], %[[VAL_11]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,3,3],f32> - // CHECK: %[[VAL_15:.*]] = torch.aten.transpose.int %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32> - // CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_5]], %[[VAL_4]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_17:.*]] = torch.aten.sum.dim_IntList %[[VAL_0]], %[[VAL_16]], %[[VAL_8]], %[[VAL_6]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1],f32> - // CHECK: return %[[VAL_15]], %[[VAL_17]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> - %true = torch.constant.bool true - %int0 = torch.constant.int 0 - %false = torch.constant.bool false - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list - %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> - return %result1, %result2 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> -} - // ----- // CHECK-LABEL: func.func @emptyLikeNoneDtype( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> { From f50c2912c6b766539253f42838dae3bfc3546eb8 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Mon, 24 Nov 2025 09:35:13 +0400 Subject: [PATCH 2/3] Applied Zach's comments --- lib/Conversion/TorchToLinalg/Linear.cpp | 500 ++++++++---------- .../Torch/Transforms/DecomposeComplexOps.cpp | 401 ++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + .../test_suite/backprop.py | 9 +- .../TorchToLinalg/convolution_bwd.mlir | 51 +- test/Dialect/Torch/decompose-complex-ops.mlir | 34 ++ 6 files changed, 713 insertions(+), 283 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 31b8387be3f2..60093b54f8e1 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1800,6 +1800,38 @@ class ConvertAtenConvolutionBackwardOp return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor, indices); }; + // The createZeroInitExpandedGroupsTensor lambda function below is used to + // create empty tensor with already expanded group dimension. + auto createZeroInitExpandedGroupsTensor = + [&](OpBuilder &rewriter, Location loc, const SmallVector &sizes, + Type type, int64_t dim, + SmallVector &indices) { + Value groups = + mlir::arith::ConstantIndexOp::create(rewriter, loc, numGroups); + + SmallVector expandedSizes; + for (auto i = 0; i < static_cast(sizes.size()); i++) { + if (i == dim) { + expandedSizes.push_back(groups); + expandedSizes.push_back( + rewriter.createOrFold(loc, sizes[i], + groups)); + } else { + expandedSizes.push_back(sizes[i]); + } + } + + indices.clear(); + for (auto i = 0; i <= static_cast(sizes.size()); i++) { + if (i == dim) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } + + return createZeroInitTensor(rewriter, loc, expandedSizes, type); + }; SmallVector newResults(op->getNumResults()); @@ -1811,13 +1843,22 @@ class ConvertAtenConvolutionBackwardOp // If convolution bwd is grouped, `weight` should be expanded auto weightExpanded = isGroupedConvBwd ? expandGroups(weight, 0) : weight; - // Flip weight along spatial dims only if number of spatial dims > 1. - SmallVector weightFlipDims; - weightFlipDims.reserve(numSpatialDims); - for (int64_t i = 0; i < static_cast(numSpatialDims); ++i) - weightFlipDims.push_back(spatialStartDimIdx + i); - weightExpanded = torch_to_linalg::flipTensor( - rewriter, loc, weightExpanded, weightFlipDims); + // Flip weight along spatial dims only if + // - kernel size is greater than 1, + // - the kernel is not a 1x1 or 1x1x1 kernel. + SmallVector weightDimsInt = makeShapeTorchCompatible( + cast(weightExpanded.getType()).getShape()); + bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(), + weightDimsInt.rbegin() + numSpatialDims, + [](int64_t dim) { return dim == 1; }); + if (numSpatialDims > 1 && !is1x1Kernel) { + SmallVector weightFlipDims; + weightFlipDims.reserve(numSpatialDims); + for (int64_t i = 0; i < static_cast(numSpatialDims); ++i) + weightFlipDims.push_back(spatialStartDimIdx + i); + weightExpanded = torch_to_linalg::flipTensor( + rewriter, loc, weightExpanded, weightFlipDims); + } // For backward-input, padding must be adjusted to: // p'[i] = d[i] * (K[i] - 1) - p[i] @@ -1845,7 +1886,7 @@ class ConvertAtenConvolutionBackwardOp // If there are not unit strides, we have to scatter `grad_output` into a // zero-initialized tensor. SmallVector gradInputSizes = getTensorSizes(rewriter, loc, input); - Value gradOutputSliced; + Value gradOutputModified; if (llvm::any_of(strideInts, [](int64_t stride) { return stride > 1; })) { // Destination spatial sizes are computed as: // size[i] = (D[i] - 1) + d[i] * (K[i] - 1) + 1 @@ -1884,7 +1925,7 @@ class ConvertAtenConvolutionBackwardOp Value zeroInit = createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy); - gradOutputSliced = tensor::InsertSliceOp::create( + gradOutputModified = tensor::InsertSliceOp::create( rewriter, loc, torch_to_linalg::removeSizeInformation(rewriter, loc, gradOutputExpanded), @@ -1895,47 +1936,40 @@ class ConvertAtenConvolutionBackwardOp // N x G x F/G x . Otherwise: N x F x . Value padVal = arith::ConstantOp::create( rewriter, loc, rewriter.getFloatAttr(gradOutputDTy, 0.0)); - gradOutputSliced = torch_to_linalg::getDynamicZeroPaddedTensor( + gradOutputModified = torch_to_linalg::getDynamicZeroPaddedTensor( op, rewriter, gradOutputExpanded, paddingValues, spatialStartDimIdx, padVal); } // Initialize output buffer. For grouped, compute into an expanded // [N, G, C/G, D*] tensor and collapse back to the original input shape. - Value gradInputInit = - createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy); SmallVector gradInputCollapseIndices; - if (isGroupedConvBwd) { - auto gradInputInitExpanded = expandGroups(gradInputInit, 1); - gradInputInit = gradInputInitExpanded.getResult(); - gradInputCollapseIndices = - gradInputInitExpanded.getReassociationIndices(); - } - - // Generate GenericOp. - SmallVector indexingMaps; - SmallVector iteratorTypes; - initIndexingMapsAndIteratorTypesForDataBwd( - rewriter, context, isGroupedConvBwd, numSpatialDims, dilationInts, - indexingMaps, iteratorTypes); - auto genericRes = - createGenericOp(rewriter, loc, gradOutputSliced, weightExpanded, - gradInputInit, indexingMaps, iteratorTypes) - .getResult(0); - - // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the generic op + Value gradInputInit = + isGroupedConvBwd + ? createZeroInitExpandedGroupsTensor(rewriter, loc, + gradInputSizes, inputDTy, 1, + gradInputCollapseIndices) + : createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy); + + // Create convolution for data gradient + auto convRes = createConvInputGradient(rewriter, loc, context, + isGroupedConvBwd, numSpatialDims, + dilationInts, gradOutputModified, + weightExpanded, gradInputInit) + .getResult(0); + + // Collapse [N, G, C/G, D] to [N, C, D] the result of the conv // if it is grouped. if (isGroupedConvBwd) { - genericRes = tensor::CollapseShapeOp::create( - rewriter, loc, input.getType(), genericRes, - gradInputCollapseIndices); + convRes = tensor::CollapseShapeOp::create( + rewriter, loc, input.getType(), convRes, gradInputCollapseIndices); } // Cast to the final result type expected by the type converter. newResults[0] = tensor::CastOp::create(rewriter, loc, getTypeConverter()->convertType( op->getResult(0).getType()), - genericRes) + convRes) .getResult(); } @@ -1961,32 +1995,26 @@ class ConvertAtenConvolutionBackwardOp // shape. SmallVector gradWeightSizes = getTensorSizes(rewriter, loc, weight); - Value gradWeightInit = - createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); SmallVector gradWeightCollapseIndices; - if (isGroupedConvBwd) { - auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0); - gradWeightInit = gradWeightInitExpanded.getResult(); - gradWeightCollapseIndices = - gradWeightInitExpanded.getReassociationIndices(); - } - - // Generate GenericOp. - SmallVector indexingMaps; - SmallVector iteratorTypes; - initIndexingMapsAndIteratorTypesForWeightBwd( - rewriter, context, isGroupedConvBwd, numSpatialDims, strideInts, - dilationInts, indexingMaps, iteratorTypes); - auto genericRes = - createGenericOp(rewriter, loc, paddedInput, gradOutputExpanded, - gradWeightInit, indexingMaps, iteratorTypes) - .getResult(0); + Value gradWeightInit = + isGroupedConvBwd + ? createZeroInitExpandedGroupsTensor(rewriter, loc, + gradWeightSizes, weightDTy, + 0, gradWeightCollapseIndices) + : createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); + + // Create convolution for weight gradient + auto convResult = createConvWeightGradient( + rewriter, loc, context, isGroupedConvBwd, + numSpatialDims, strideInts, dilationInts, + paddedInput, gradOutputExpanded, gradWeightInit) + .getResult(0); - // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the generic op + // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the conv // if it is grouped. if (isGroupedConvBwd) { - genericRes = tensor::CollapseShapeOp::create( - rewriter, loc, weight.getType(), genericRes, + convResult = tensor::CollapseShapeOp::create( + rewriter, loc, weight.getType(), convResult, gradWeightCollapseIndices); } @@ -1994,7 +2022,7 @@ class ConvertAtenConvolutionBackwardOp newResults[1] = tensor::CastOp::create(rewriter, loc, getTypeConverter()->convertType( op->getResult(1).getType()), - genericRes) + convResult) .getResult(); } @@ -2035,121 +2063,90 @@ class ConvertAtenConvolutionBackwardOp } private: - static void initIndexingMapsAndIteratorTypesForDataBwd( - OpBuilder &rewriter, MLIRContext *context, bool isGrouped, - int numSpatialDims, const SmallVector &dilationInts, - SmallVector &indexingMaps, SmallVector &iteratorTypes) { + static linalg::GenericOp createConvInputGradient( + OpBuilder &rewriter, Location loc, MLIRContext *context, bool isGrouped, + size_t numSpatialDims, const SmallVector &dilationInts, + Value gradOutput, Value weight, Value gradInputInit) { // To calculate convolution backward-data, we use generic operation. // The generic operation is a generalization of the convolution operation // that can handle any number of spatial dimensions. // The generic operation is defined as follows: // ``` - // dLdx[n, g, c, o] = sum(dLdy[n, g, f, d0 * k + o] * w[g, f, c, k] + // dLdx[n, g, c, o] = sum(dLdy[n, g, f, d * k + o] * w[g, f, c, k] // for n in range(batch_size) for o in range(in_spatial_dims)) // ``` - // where `n` is the batch dimension, `g` is the group dimension, - // `c` is the input channel dimension, `f` is the output channel - // dimension, `o` is the input spatial dimension, `k` is the kernel - // dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the - // gradient of the output tensor. `dLdx` is the data-gradient tensor. - if (!isGrouped) { - if (numSpatialDims == 1) { - AffineExpr n, c, o, f, k; - bindDims(context, n, c, o, f, k); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - SmallVector goExprs = {n, f, d0 * k + o}; - SmallVector weiExprs = {f, c, k}; - SmallVector outExprs = {n, c, o}; - indexingMaps = {AffineMap::get(5, 0, goExprs, context), - AffineMap::get(5, 0, weiExprs, context), - AffineMap::get(5, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::reduction, IT::reduction}; - } else if (numSpatialDims == 2) { - AffineExpr n, c, oh, ow, f, kh, kw; - bindDims(context, n, c, oh, ow, f, kh, kw); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); - SmallVector goExprs = {n, f, d0 * kh + oh, d1 * kw + ow}; - SmallVector weiExprs = {f, c, kh, kw}; - SmallVector outExprs = {n, c, oh, ow}; - indexingMaps = {AffineMap::get(7, 0, goExprs, context), - AffineMap::get(7, 0, weiExprs, context), - AffineMap::get(7, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::reduction, IT::reduction, - IT::reduction}; - } else { - AffineExpr n, c, od, oh, ow, f, kd, kh, kw; - bindDims(context, n, c, od, oh, ow, f, kd, kh, kw); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); - AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); - SmallVector goExprs = {n, f, d0 * kd + od, d1 * kh + oh, - d2 * kw + ow}; - SmallVector weiExprs = {f, c, kd, kh, kw}; - SmallVector outExprs = {n, c, od, oh, ow}; - indexingMaps = {AffineMap::get(9, 0, goExprs, context), - AffineMap::get(9, 0, weiExprs, context), - AffineMap::get(9, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::parallel, IT::reduction, - IT::reduction, IT::reduction, IT::reduction}; - } - } else { - if (numSpatialDims == 1) { - AffineExpr n, g, cg, o, fg, k; - bindDims(context, n, g, cg, o, fg, k); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - SmallVector goExprs = {n, g, fg, d0 * k + o}; - SmallVector weiExprs = {g, fg, cg, k}; - SmallVector outExprs = {n, g, cg, o}; - indexingMaps = {AffineMap::get(6, 0, goExprs, context), - AffineMap::get(6, 0, weiExprs, context), - AffineMap::get(6, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::reduction, IT::reduction}; - } else if (numSpatialDims == 2) { - AffineExpr n, g, cg, oh, ow, fg, kh, kw; - bindDims(context, n, g, cg, oh, ow, fg, kh, kw); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); - SmallVector goExprs = {n, g, fg, d0 * kh + oh, - d1 * kw + ow}; - SmallVector weiExprs = {g, fg, cg, kh, kw}; - SmallVector outExprs = {n, g, cg, oh, ow}; - indexingMaps = {AffineMap::get(8, 0, goExprs, context), - AffineMap::get(8, 0, weiExprs, context), - AffineMap::get(8, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::parallel, IT::reduction, - IT::reduction, IT::reduction}; - } else { - AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw; - bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); - AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); - SmallVector goExprs = { - n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow}; - SmallVector weiExprs = {g, fg, cg, kd, kh, kw}; - SmallVector outExprs = {n, g, cg, od, oh, ow}; - indexingMaps = {AffineMap::get(10, 0, goExprs, context), - AffineMap::get(10, 0, weiExprs, context), - AffineMap::get(10, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::parallel, IT::parallel, - IT::reduction, IT::reduction, IT::reduction, - IT::reduction}; - } + // where: + // - `dLdx` is the data-gradient tensor. + // - `dLdy` is the output-gradient tensor which is padded if + // there are unit strides, or scattered otherwise. + // - `w` is the weight tensor flipped along spatial dims. + // - `n` is the batch dimension. + // - `g` is the group dimension. + // - `c` is the input channel dimension. + // - `f` is the output channel dimension. + // - `o` is the input spatial dimension. + // - `k` is the kernel dimension. + // - `d` is dilations. + + // Iterators: n, c, f, g, o, k + int64_t numIterators = + 3 + static_cast(isGrouped) + numSpatialDims * 2; + + // Bind dimensions in the following order: n, g, c, o, f, k + SmallVector dims(numIterators); + bindDimsList(context, MutableArrayRef{dims}); + + auto n = [&]() { return dims[0]; }; + auto g = [&]() { + if (!isGrouped) + llvm_unreachable("g() called for non-grouped convolution."); + return dims[1]; + }; + auto c = [&]() { return dims[1 + static_cast(isGrouped)]; }; + auto o = [&](size_t i) { + return dims[1 + static_cast(isGrouped) + 1 + i]; + }; + auto f = [&]() { + return dims[1 + static_cast(isGrouped) + 1 + numSpatialDims]; + }; + auto k = [&](size_t i) { + return dims[1 + static_cast(isGrouped) + 1 + numSpatialDims + 1 + + i]; + }; + + SmallVector lhsExprs = + isGrouped ? SmallVector{n(), g(), f()} + : SmallVector{n(), f()}; + SmallVector rhsExprs = + isGrouped ? SmallVector{g(), f(), c()} + : SmallVector{f(), c()}; + SmallVector outExprs = + isGrouped ? SmallVector{n(), g(), c()} + : SmallVector{n(), c()}; + for (size_t i = 0; i < numSpatialDims; i++) { + AffineExpr d = rewriter.getAffineConstantExpr(dilationInts[i]); + lhsExprs.push_back(d * k(i) + o(i)); + rhsExprs.push_back(k(i)); + outExprs.push_back(o(i)); } + + SmallVector indexingMaps = { + AffineMap::get(numIterators, 0, lhsExprs, context), + AffineMap::get(numIterators, 0, rhsExprs, context), + AffineMap::get(numIterators, 0, outExprs, context)}; + SmallVector iteratorTypes = SmallVector(numIterators, IT::parallel); + std::fill(iteratorTypes.rbegin(), + iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction); + + return createConvAsGenericOp(rewriter, loc, gradOutput, weight, + gradInputInit, indexingMaps, iteratorTypes); } - static void initIndexingMapsAndIteratorTypesForWeightBwd( - OpBuilder &rewriter, MLIRContext *context, bool isGrouped, - int numSpatialDims, const SmallVector &strideInts, - const SmallVector &dilationInts, - SmallVector &indexingMaps, SmallVector &iteratorTypes) { + static linalg::GenericOp createConvWeightGradient( + OpBuilder &rewriter, Location loc, MLIRContext *context, bool isGrouped, + size_t numSpatialDims, const SmallVector &strideInts, + const SmallVector &dilationInts, Value input, Value gradOutput, + Value gradWeightInit) { // To calculate convolution backward-weight, we use generic operation. // The generic operation is a generalization of the convolution operation // that can handle any number of spatial dimensions. @@ -2158,122 +2155,75 @@ class ConvertAtenConvolutionBackwardOp // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] // for n in range(batch_size) for o in range(output_spatial_dims)) // ``` - // where `n` is the batch dimension, `g` is the group dimension, - // `c` is the input channel dimension, `f` is the output channel - // dimension, `o` is the output spatial dimension, `k` is the kernel - // dimension, `d0` is dilation and `s0` is stride. `x` is the input - // tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the - // weight-gradient tensor. - if (!isGrouped) { - if (numSpatialDims == 1) { - AffineExpr f, c, k, n, o; - bindDims(context, f, c, k, n, o); - AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - SmallVector inExprs = {n, c, d0 * k + s0 * o}; - SmallVector goExprs = {n, f, o}; - SmallVector outExprs = {f, c, k}; - indexingMaps = {AffineMap::get(5, 0, inExprs, context), - AffineMap::get(5, 0, goExprs, context), - AffineMap::get(5, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::reduction, IT::reduction}; - } else if (numSpatialDims == 2) { - AffineExpr f, c, kh, kw, n, oh, ow; - bindDims(context, f, c, kh, kw, n, oh, ow); - AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); - AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); - SmallVector inExprs = {n, c, d0 * kh + s0 * oh, - d1 * kw + s1 * ow}; - SmallVector goExprs = {n, f, oh, ow}; - SmallVector outExprs = {f, c, kh, kw}; - indexingMaps = {AffineMap::get(7, 0, inExprs, context), - AffineMap::get(7, 0, goExprs, context), - AffineMap::get(7, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::reduction, IT::reduction, - IT::reduction}; - } else { - AffineExpr f, c, kd, kh, kw, n, od, oh, ow; - bindDims(context, f, c, kd, kh, kw, n, od, oh, ow); - AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); - AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); - AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); - AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); - SmallVector inExprs = { - n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; - SmallVector goExprs = {n, f, od, oh, ow}; - SmallVector outExprs = {f, c, kd, kh, kw}; - indexingMaps = {AffineMap::get(9, 0, inExprs, context), - AffineMap::get(9, 0, goExprs, context), - AffineMap::get(9, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::parallel, IT::reduction, - IT::reduction, IT::reduction, IT::reduction}; - } - } else { - if (numSpatialDims == 1) { - AffineExpr g, fg, cg, k, n, o; - bindDims(context, g, fg, cg, k, n, o); - AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - SmallVector inExprs = {n, g, cg, d0 * k + s0 * o}; - SmallVector goExprs = {n, g, fg, o}; - SmallVector outExprs = {g, fg, cg, k}; - indexingMaps = {AffineMap::get(6, 0, inExprs, context), - AffineMap::get(6, 0, goExprs, context), - AffineMap::get(6, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::reduction, IT::reduction}; - } else if (numSpatialDims == 2) { - AffineExpr g, fg, cg, kh, kw, n, oh, ow; - bindDims(context, g, fg, cg, kh, kw, n, oh, ow); - AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); - AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); - SmallVector inExprs = {n, g, cg, d0 * kh + s0 * oh, - d1 * kw + s1 * ow}; - SmallVector goExprs = {n, g, fg, oh, ow}; - SmallVector outExprs = {g, fg, cg, kh, kw}; - indexingMaps = {AffineMap::get(8, 0, inExprs, context), - AffineMap::get(8, 0, goExprs, context), - AffineMap::get(8, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::parallel, IT::reduction, - IT::reduction, IT::reduction}; - } else { - AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow; - bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow); - AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); - AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); - AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); - AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); - AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); - AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); - SmallVector inExprs = { - n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; - SmallVector goExprs = {n, g, fg, od, oh, ow}; - SmallVector outExprs = {g, fg, cg, kd, kh, kw}; - indexingMaps = {AffineMap::get(10, 0, inExprs, context), - AffineMap::get(10, 0, goExprs, context), - AffineMap::get(10, 0, outExprs, context)}; - iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, - IT::parallel, IT::parallel, IT::parallel, - IT::reduction, IT::reduction, IT::reduction, - IT::reduction}; - } + // - `dLdw` is the weight-gradient tensor. + // - `x` is the padded input tensor. + // - `dLdy` is the output-gradient tensor. + // - `n` is the batch dimension. + // - `g` is the group dimension. + // - `c` is the input channel dimension. + // - `f` is the output channel dimension. + // - `o` is the input spatial dimension. + // - `k` is the kernel dimension. + // - `d` and `s` are dilations and strides accordingly. + + // Iterators: n, c, f, g, o, k + int64_t numIterators = + 3 + static_cast(isGrouped) + numSpatialDims * 2; + + // Bind dimensions in the following order: g, f, c, k, n, o + SmallVector dims(numIterators); + bindDimsList(context, MutableArrayRef{dims}); + + auto g = [&]() { + if (!isGrouped) + llvm_unreachable("g() called for non-grouped convolution."); + return dims[0]; + }; + auto f = [&]() { return dims[static_cast(isGrouped)]; }; + auto c = [&]() { return dims[static_cast(isGrouped) + 1]; }; + auto k = [&](size_t i) { + return dims[static_cast(isGrouped) + 2 + i]; + }; + auto n = [&]() { + return dims[static_cast(isGrouped) + 2 + numSpatialDims]; + }; + auto o = [&](size_t i) { + return dims[static_cast(isGrouped) + 2 + numSpatialDims + 1 + i]; + }; + + SmallVector lhsExprs = + isGrouped ? SmallVector{n(), g(), c()} + : SmallVector{n(), c()}; + SmallVector rhsExprs = + isGrouped ? SmallVector{n(), g(), f()} + : SmallVector{n(), f()}; + SmallVector outExprs = + isGrouped ? SmallVector{g(), f(), c()} + : SmallVector{f(), c()}; + for (size_t i = 0; i < numSpatialDims; i++) { + AffineExpr d = rewriter.getAffineConstantExpr(dilationInts[i]); + AffineExpr s = rewriter.getAffineConstantExpr(strideInts[i]); + lhsExprs.push_back(d * k(i) + s * o(i)); + rhsExprs.push_back(o(i)); + outExprs.push_back(k(i)); } + + SmallVector indexingMaps = { + AffineMap::get(numIterators, 0, lhsExprs, context), + AffineMap::get(numIterators, 0, rhsExprs, context), + AffineMap::get(numIterators, 0, outExprs, context)}; + SmallVector iteratorTypes = SmallVector(numIterators, IT::parallel); + std::fill(iteratorTypes.rbegin(), + iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction); + + return createConvAsGenericOp(rewriter, loc, input, gradOutput, + gradWeightInit, indexingMaps, iteratorTypes); } static linalg::GenericOp - createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out, - const SmallVector &indexingMaps, - const SmallVector &iteratorTypes) { + createConvAsGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, + Value out, const SmallVector &indexingMaps, + const SmallVector &iteratorTypes) { return linalg::GenericOp::create( b, loc, out.getType(), ValueRange{in0, in1}, out, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 13a839946a6b..08b25c9b6f60 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5836,6 +5836,406 @@ class DecomposeAtenConvTranspose3dOp }; } // namespace +// The convolution backward op is decomposed as follows: +// inputH, inputW = input.shape[2:] +// output_padding_ = [ +// inputH +// - 1 +// + 2 * padding_[0] +// - dilation_[0] * (weight.shape[2] - 1) +// - (grad_output.shape[2] - 1) * stride_[0], +// inputW +// - 1 +// + 2 * padding_[1] +// - dilation_[1] * (weight.shape[3] - 1) +// - (grad_output.shape[3] - 1) * stride_[1], +// ] +// +// decomp_grad_input = torch.nn.functional.conv_transpose2d( +// grad_output, +// weight, +// None, +// stride_, +// padding_, +// output_padding_, +// groups_, +// dilation_, +// ) +// +// input_transposed = torch.ops.aten.transpose(input, 0, 1) +// grad_output_transposed = grad_output.view( +// grad_output.shape[0] * grad_output.shape[1], 1, *grad_output.shape[2:] +// ) +// decomp_grad_weight = torch.ops.aten.convolution( +// input_transposed, +// grad_output_transposed, +// bias=None, +// stride=dilation_, +// padding=padding_, +// dilation=stride_, +// transposed=False, +// output_padding=[0, 0], +// groups=input.shape[0], +// ) +// decomp_grad_weight = torch.narrow(decomp_grad_weight, 2, 0, weight.shape[2]) +// decomp_grad_weight = torch.narrow(decomp_grad_weight, 3, 0, weight.shape[3]) +// decomp_grad_weight = decomp_grad_weight.view( +// input_transposed.shape[0], +// input_transposed.shape[1], +// grad_output.shape[1], +// *decomp_grad_weight.shape[2:] +// ) +// decomp_grad_weight = decomp_grad_weight.movedim(0, 2) +// decomp_grad_weight = decomp_grad_weight.sum(dim=0) +// +// decomp_grad_bias = torch.sum(grad_output, dim=[0, 2, 3]) +namespace { +class DecomposeAtenConvolutionBackwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + Value input = op.getInput(); + Value weight = op.getWeight(); + Value gradOutput = op.getGradOutput(); + std::optional maybeGradRank = getTensorRank(gradOutput); + if (!maybeGradRank) { + return rewriter.notifyMatchFailure(op, + "expected grad output to have a rank"); + } + unsigned gradRank = *maybeGradRank; + if (gradRank != 4) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D convolutions supported."); + + Value cstZero = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + Value cstTwo = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(2)); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, + rewriter.getBoolAttr(false)); + + SmallVector padding, dilation, stride; + SmallVector paddingInt, dilationInt, strideInt, + outputPaddingInt; + + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInt))) + return rewriter.notifyMatchFailure( + op, "padding must be a list of constant ints"); + + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInt))) + return rewriter.notifyMatchFailure( + op, "stride must be a list of constant ints"); + + if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInt))) + return rewriter.notifyMatchFailure( + op, "dilation must be a list of constant ints"); + if (!llvm::all_of(dilationInt, + [](int64_t dilationVal) { return dilationVal == 1; })) + return rewriter.notifyMatchFailure( + op, "unimplemented: only dilations of 1 supported."); + + if (!matchPattern(op.getOutputPadding(), + m_TorchListOfConstantInts(outputPaddingInt))) + return rewriter.notifyMatchFailure( + op, "output padding must be a list of constant ints"); + if (!llvm::all_of(outputPaddingInt, + [](int64_t outPad) { return outPad == 0; })) + return rewriter.notifyMatchFailure( + op, "unimplemented: only output padding of 0 supported."); + + // The `outMask` contains 3 boolean values for the results `grad_input`, + // `grad_weight`, and `grad_bias` respectively. The value being `false` + // means that the corresponding result will be none. + SmallVector outMask; + if (!matchPattern(op.getOutputMask(), + m_TorchListOfConstantBools(outMask)) || + outMask.size() != 3) + return rewriter.notifyMatchFailure( + op, "only constant bool output_mask list of size 3 is supported."); + for (unsigned i = 0; i < outMask.size(); i++) { + if (outMask[i] == false) { + Value result = op->getResults()[i]; + if (!result.getUsers().empty()) + return rewriter.notifyMatchFailure( + op, "unimplemented: false value supported for output_mask only " + "when the result tensor corresponding to that has no users."); + } + } + + bool transposed; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "transposed arg should be a constant bool."); + if (transposed) + return rewriter.notifyMatchFailure( + op, "unimplemented: transposed convolutions are not supported."); + + Value gradInput = cstNone; + if (outMask[0]) { + // Computing Grad Input. + getListConstructElements(op.getPadding(), padding); + getListConstructElements(op.getStride(), stride); + getListConstructElements(op.getDilation(), dilation); + + // Calculate output padding for first convolution. + // output_padding_ = [ + // inputH - 1 + (2 * padding_[0]) - (dilation_[0] * (weight.size()[2] + // - 1)) - ((grad_out.size()[2] - 1) * stride_[0]), inputW - 1 + (2 * + // padding_[1]) - (dilation_[1] * (weight.size()[3] - 1)) - + // ((grad_out.size()[3] - 1) * stride_[1]), + // ] + SmallVector outputPaddingValues; + for (unsigned i = 2; i < gradRank; i++) { + Value dim = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)); + Value inputVecDim = + Torch::AtenSizeIntOp::create(rewriter, loc, input, dim); + Value gradOutDim = + Torch::AtenSizeIntOp::create(rewriter, loc, gradOutput, dim); + Value weightDim = + Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); + Value inputVecDimMinusOne = + Torch::AtenSubIntOp::create(rewriter, loc, inputVecDim, cstOne); + Value gradOutDimMinusOne = + Torch::AtenSubIntOp::create(rewriter, loc, gradOutDim, cstOne); + Value weightDimMinusOne = + Torch::AtenSubIntOp::create(rewriter, loc, weightDim, cstOne); + Value twoTimesPadding = + Torch::AtenMulIntOp::create(rewriter, loc, padding[i - 2], cstTwo); + Value tmpA = Torch::AtenMulIntOp::create( + rewriter, loc, weightDimMinusOne, dilation[i - 2]); + Value tmpB = Torch::AtenMulIntOp::create( + rewriter, loc, gradOutDimMinusOne, stride[i - 2]); + Value outputPaddingVal = AtenAddIntOp::create( + rewriter, loc, inputVecDimMinusOne, twoTimesPadding); + outputPaddingVal = + AtenSubIntOp::create(rewriter, loc, outputPaddingVal, tmpA); + outputPaddingVal = + AtenSubIntOp::create(rewriter, loc, outputPaddingVal, tmpB); + outputPaddingValues.push_back(outputPaddingVal); + } + Value outputPaddingForGradInput = Torch::PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), + outputPaddingValues); + gradInput = Torch::AtenConvTranspose2dInputOp::create( + rewriter, loc, op.getResultTypes()[0], gradOutput, weight, cstNone, + op.getStride(), op.getPadding(), outputPaddingForGradInput, + op.getGroups(), op.getDilation()); + } + + Value gradWeight = cstNone; + if (outMask[1]) { + // Computing Grad Weight. + Type transposedType; + if (failed(getTransposedType(cast(input.getType()), 0, 1, + transposedType))) + return failure(); + Value inputTransposed = Torch::AtenTransposeIntOp::create( + rewriter, loc, transposedType, input, cstZero, cstOne); + + // For the cases where the stride is non-unit, we compute the `GradWeight` + // through this implementation. + if (!llvm::all_of(strideInt, + [](int64_t stride) { return stride == 1; })) { + SmallVector gradOutputSize; + for (unsigned i = 0; i < gradRank; i++) { + gradOutputSize.push_back(Torch::AtenSizeIntOp::create( + rewriter, loc, gradOutput, + Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)))); + } + + Value gradOutputViewDimZero = Torch::AtenMulIntOp::create( + rewriter, loc, gradOutputSize[0], gradOutputSize[1]); + Value gradOutputViewShapeList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), + ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], + gradOutputSize[3]}); + + BaseTensorType gradOutputTy = + cast(gradOutput.getType()); + if (!gradOutputTy.hasSizes()) + return failure(); + SmallVector gradOutputSizesInt(gradOutputTy.getSizes()); + SmallVector gradOutputViewSizesInt(gradOutputSizesInt); + if (gradOutputViewSizesInt[0] != kUnknownSize && + gradOutputViewSizesInt[1] != kUnknownSize) + gradOutputViewSizesInt[0] *= gradOutputViewSizesInt[1]; + else + gradOutputViewSizesInt[0] = kUnknownSize; + gradOutputViewSizesInt[1] = 1; + BaseTensorType gradOutputTypeForView = + cast(gradOutputTy.getWithSizesAndDtype( + llvm::ArrayRef(gradOutputViewSizesInt), + gradOutputTy.getOptionalDtype())); + Value gradOutputView = + Torch::AtenViewOp::create(rewriter, loc, gradOutputTypeForView, + gradOutput, gradOutputViewShapeList); + + BaseTensorType inputTransposedTy = + cast(inputTransposed.getType()); + if (!inputTransposedTy.hasSizes()) + return failure(); + SmallVector inputTransposedSizesInt( + inputTransposedTy.getSizes()); + SmallVector gradWeightSizesInt{inputTransposedSizesInt[0], + gradOutputViewSizesInt[0]}; + for (unsigned i = 2; i < gradRank; i++) { + if (inputTransposedSizesInt[i] != kUnknownSize && + gradOutputViewSizesInt[i] != kUnknownSize) { + int64_t kernelSizeInt = + strideInt[i - 2] * (gradOutputViewSizesInt[i] - 1) + 1; + gradWeightSizesInt.push_back( + ((inputTransposedSizesInt[i] + (paddingInt[i - 2] * 2) - + kernelSizeInt) / + dilationInt[i - 2]) + + 1); + } else { + gradWeightSizesInt.push_back(kUnknownSize); + } + } + + BaseTensorType gradWeightTy = + cast(inputTransposedTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightSizesInt), + inputTransposedTy.getOptionalDtype())); + + Value numGroup = AtenSizeIntOp::create(rewriter, loc, input, cstZero); + gradWeight = Torch::AtenConvolutionOp::create( + rewriter, loc, gradWeightTy, inputTransposed, gradOutputView, + cstNone, + /*stride=*/op.getDilation(), op.getPadding(), + /*dilation=*/op.getStride(), op.getTransposed(), + op.getOutputPadding(), numGroup); + + BaseTensorType weightTy = cast(weight.getType()); + if (!weightTy.hasSizes()) + return failure(); + SmallVector weightSizes(weightTy.getSizes()); + for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { + gradWeightSizesInt[i + 2] = weightSizes[i + 2]; + BaseTensorType gradWeightNarrowTy = + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightSizesInt), + gradWeightTy.getOptionalDtype())); + + Value dim = ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i + 2)); + Value length = + Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); + gradWeight = Torch::AtenNarrowOp::create( + rewriter, loc, gradWeightNarrowTy, gradWeight, dim, + /*start=*/cstZero, length); + } + + SmallVector gradWeightViewShapeInt{ + inputTransposedSizesInt[0], inputTransposedSizesInt[1]}; + gradWeightViewShapeInt.push_back(gradOutputSizesInt[1]); + gradWeightViewShapeInt.insert( + gradWeightViewShapeInt.end(), + {gradWeightSizesInt[2], gradWeightSizesInt[3]}); + + SmallVector gradWeightViewShapeValue; + for (unsigned i = 0; i < gradWeightViewShapeInt.size(); i++) { + gradWeightViewShapeValue.push_back(Torch::ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i]))); + } + + Value gradWeightViewShapeList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), + gradWeightViewShapeValue); + + BaseTensorType gradWeightTypeForView = + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightViewShapeInt), + gradWeightTy.getOptionalDtype())); + gradWeight = + Torch::AtenViewOp::create(rewriter, loc, gradWeightTypeForView, + gradWeight, gradWeightViewShapeList); + + gradWeightTy = cast(gradWeight.getType()); + SmallVector gradWeightDimsOrder = + computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size()); + SmallVector gradWeightMoveDimShape; + for (unsigned i = 0; i < gradWeightDimsOrder.size(); i++) { + gradWeightMoveDimShape.push_back( + gradWeightViewShapeInt[gradWeightDimsOrder[i]]); + } + BaseTensorType gradWeightTypeForMoveDim = + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightMoveDimShape), + gradWeightTy.getOptionalDtype())); + + gradWeight = + AtenMovedimIntOp::create(rewriter, loc, gradWeightTypeForMoveDim, + gradWeight, /*source=*/cstZero, + /*destination=*/cstTwo); + + Value gradIntList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), + llvm::ArrayRef{cstZero}); + gradWeight = Torch::AtenSumDimIntListOp::create( + rewriter, loc, op.getResultTypes()[1], /*self=*/gradWeight, + /*dim=*/gradIntList, + /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + } else { + if (failed(getTransposedType(cast(gradOutput.getType()), + 0, 1, transposedType))) + return failure(); + Value gradOutputTransposed = Torch::AtenTransposeIntOp::create( + rewriter, loc, transposedType, gradOutput, cstZero, cstOne); + // Convolve input with grad_output. + if (failed( + getTransposedType(cast(op.getResultTypes()[1]), + 0, 1, transposedType))) + return failure(); + gradWeight = Torch::AtenConvolutionOp::create( + rewriter, loc, transposedType, inputTransposed, + gradOutputTransposed, cstNone, op.getStride(), op.getPadding(), + op.getDilation(), op.getTransposed(), op.getOutputPadding(), + op.getGroups()); + gradWeight = Torch::AtenTransposeIntOp::create( + rewriter, loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne); + } + } + + Value gradBias = cstNone; + if (outMask[2]) { + // Computing Grad Bias. + SmallVector dimIntList{cstZero}; + for (unsigned i = 2; i < gradRank; i++) + dimIntList.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); + Value gradIntList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), + dimIntList); + + // Sum grad_output along dim 1. + gradBias = Torch::AtenSumDimIntListOp::create( + rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, + cstFalse, cstNone); + } + + rewriter.replaceOp(op, {gradInput, gradWeight, gradBias}); + return success(); + } +}; +} // namespace + /** * # one dim input * t = torch.tensor([0, 0, 1, 1, 0, 0] @@ -12762,6 +13162,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>( patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6869ce12ab42..b149d172496c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -445,6 +445,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index e6bbbe3273dc..14346e0a49d5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -292,10 +292,15 @@ def forward(self, grad_out, input_vec, weight): ) -@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DStridedPaddedDilatedGrouped()) +@register_test_case( + module_factory=lambda: ConvolutionBackwardModule2DStridedPaddedDilatedGrouped() +) def ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic(module, tu: TestUtils): with torch.backends.mkldnn.flags(enabled=False): - module.forward(tu.rand(2, 16, 32, 32), tu.rand(2, 128, 64, 64), tu.rand(16, 32, 2, 2)) + module.forward( + tu.rand(2, 16, 32, 32), tu.rand(2, 128, 64, 64), tu.rand(16, 32, 2, 2) + ) + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/convolution_bwd.mlir b/test/Conversion/TorchToLinalg/convolution_bwd.mlir index 0e1f5e67dbb8..7debfea3e6e3 100644 --- a/test/Conversion/TorchToLinalg/convolution_bwd.mlir +++ b/test/Conversion/TorchToLinalg/convolution_bwd.mlir @@ -59,6 +59,47 @@ func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2 // ----- +// CHECK-LABEL: func.func @convolution_backward_input_1x1ker_1x1s_0x0p_1x1d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,64,64],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,1,1],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_1x1ker_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2,16,64,64],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,1,1],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,1,1],f32> -> tensor<16x128x1x1xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,64,64],f32> -> tensor<2x16x64x64xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[T0]], %[[T1]] : tensor<2x16x64x64xf32>, tensor<16x128x1x1xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x64x64xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,64,64],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,1,1],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + // CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, @@ -145,9 +186,8 @@ func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<[2 // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x4x4x66x66xf32> // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x4x4x66x66xf32>) -> tensor<2x4x4x66x66xf32> // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0_EXP]] into %[[SLICE_FILLED]][0, 0, 0, 0, 0] [2, 4, 4, 33, 33] [1, 1, 1, 2, 2] : tensor<2x4x4x33x33xf32> into tensor<2x4x4x66x66xf32> - // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> - // CHECK: %[[OUT_EXP:.*]] = tensor.expand_shape %[[OUT_EMPTY]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xf32> into tensor<2x4x32x64x64xf32> - // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EXP]] : tensor<2x4x32x64x64xf32>) -> tensor<2x4x32x64x64xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x4x32x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x4x32x64x64xf32>) -> tensor<2x4x32x64x64xf32> // CHECK: %[[CONV:.*]] = linalg.generic {{.*}} ins(%[[SLICE]], %[[W_REV]] : tensor<2x4x4x66x66xf32>, tensor<4x4x32x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x4x32x64x64xf32>) { // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 @@ -281,9 +321,8 @@ func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor< // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): // CHECK-NEXT: tensor.yield %[[CST]] : f32 // CHECK-NEXT: } : tensor<2x4x32x64x64xf32> to tensor<2x4x32x68x68xf32> - // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<16x32x2x2xf32> - // CHECK: %[[OUT0_EXP:.*]] = tensor.expand_shape %[[OUT0_EMPTY]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xf32> into tensor<4x4x32x2x2xf32> - // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EXP]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d2, d3 * 2 + d6 * 2, d4 * 2 + d7 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d1, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0_EXP]] : tensor<2x4x32x68x68xf32>, tensor<2x4x4x33x33xf32>) outs(%[[OUT0_FILLED]] : tensor<4x4x32x2x2xf32>) { // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 6c1d051b570b..8fe502a7d686 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -273,6 +273,40 @@ func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int { return %arg0 : !torch.int } +// ----- + +// CHECK-LABEL: func.func @convolution_backward_none_result( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>, %[[VAL_1:.*]]: !torch.vtensor<[1,1,5,5],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,1,3,3],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) { +func.func @convolution_backward_none_result(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,3,3],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>) { + // CHECK: %[[VAL_4:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_6:.*]] = torch.constant.none + // CHECK: %[[VAL_7:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_8:.*]] = torch.constant.bool false + // CHECK: %[[VAL_9:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_9]], %[[VAL_9]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_12:.*]] = torch.aten.transpose.int %[[VAL_1]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,5,5],f32> + // CHECK: %[[VAL_13:.*]] = torch.aten.transpose.int %[[VAL_0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32> + // CHECK: %[[VAL_14:.*]] = torch.aten.convolution %[[VAL_12]], %[[VAL_13]], %[[VAL_6]], %[[VAL_10]], %[[VAL_11]], %[[VAL_10]], %[[VAL_8]], %[[VAL_11]], %[[VAL_9]] : !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,3,3],f32> + // CHECK: %[[VAL_15:.*]] = torch.aten.transpose.int %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[1,1,3,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,1,3,3],f32> + // CHECK: %[[VAL_16:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_5]], %[[VAL_4]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_17:.*]] = torch.aten.sum.dim_IntList %[[VAL_0]], %[[VAL_16]], %[[VAL_8]], %[[VAL_6]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1],f32> + // CHECK: return %[[VAL_15]], %[[VAL_17]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> + return %result1, %result2 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32> +} + // ----- // CHECK-LABEL: func.func @emptyLikeNoneDtype( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> { From 31eac083557384bb1593abd55f11649938b52ed7 Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Tue, 25 Nov 2025 12:05:50 +0400 Subject: [PATCH 3/3] Updated the file xfail_sets.py - added new xfail tests --- projects/pt1/e2e_testing/xfail_sets.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4c8318570c7b..7791a8ef35a7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -434,7 +434,9 @@ "ContainsIntList_False", "ContainsIntList_True", "ConvTbcModule_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", @@ -692,7 +694,9 @@ "Conv2dQInt8PerChannelModule_grouped", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", @@ -2914,8 +2918,10 @@ "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic", @@ -3693,7 +3699,9 @@ "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic", @@ -4337,8 +4345,10 @@ "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic",