diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 5b30e93e6fe4..60093b54f8e1 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1673,6 +1673,581 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( } } +namespace { +class ConvertAtenConvolutionBackwardOp + : public OpConversionPattern { + using IT = utils::IteratorType; + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenConvolutionBackwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value gradOutput = adaptor.getGradOutput(); + Value input = adaptor.getInput(); + Value weight = adaptor.getWeight(); + + auto gradOutputDTy = + cast(gradOutput.getType()).getElementType(); + auto inputDTy = cast(input.getType()).getElementType(); + auto weightDTy = cast(weight.getType()).getElementType(); + if (!isa(gradOutputDTy) || + !isa(inputDTy) || !isa(weightDTy)) + return op.emitError("unimplemented: only fp convolution bwd supported"); + + size_t gradRank = cast(gradOutput.getType()).getRank(); + size_t numSpatialDims = gradRank - 2; + if (numSpatialDims < 1 || numSpatialDims > 3) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1d-3d convolution bwd currently supported"); + + // Transposed convolution backward is not handled here yet. + bool transposed = false; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "only support constant bool for transposed"); + if (transposed) + return rewriter.notifyMatchFailure( + op, "unimplemented: transposed convolution backward"); + + // The `outMask` contains 3 boolean values for the results `grad_input`, + // `grad_weight`, and `grad_bias` respectively. The value being `false` + // means that the corresponding result will be none. + SmallVector outMask; + if (!matchPattern(op.getOutputMask(), + m_TorchListOfConstantBools(outMask)) || + outMask.size() != 3) + return rewriter.notifyMatchFailure( + op, "only constant bool output_mask list of size 3 is supported."); + for (unsigned i = 0; i < outMask.size(); i++) { + if (outMask[i] == false) { + Value result = op->getResults()[i]; + if (!result.getUsers().empty()) + return rewriter.notifyMatchFailure( + op, "unimplemented: false value supported for output_mask only " + "when the result tensor corresponding to that has no users."); + } + } + + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + bool isGroupedConvBwd = numGroups > 1; + int64_t spatialStartDimIdx = isGroupedConvBwd ? 3 : 2; + + // Stride, padding, dilation for the backward conv. We only support constant + // lists here, consistent with forward convolution lowering. + SmallVector paddingIntValues; + SmallVector strideInts, dilationInts, outputPaddingInts; + + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int strides"); + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) + return rewriter.notifyMatchFailure(op, + "only support constant int dilations"); + if (!matchPattern(op.getOutputPadding(), + m_TorchListOfConstantInts(outputPaddingInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int output paddings"); + if (!llvm::all_of(outputPaddingInts, + [](int64_t outPad) { return outPad == 0; })) + return rewriter.notifyMatchFailure( + op, "unimplemented: only output padding of 0 supported."); + + if (!getListConstructElements(op.getPadding(), paddingIntValues)) + return rewriter.notifyMatchFailure( + op, "only support padding from a list construct"); + paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), + paddingIntValues); + + // The expandGroups lambda function below is used to expand the group + // dimension for weights and input, output tensors. + // For input tensor (dim = 1) : N,C,H,W -> N,G,C/G,H,W + // For grad_output tensor (dim = 1): N,F,H,W -> N,G,F/G,H,W + // For weight tensor (dim = 0) : F,C,H,W -> G,F/G,C,H,W + auto expandGroups = [&](Value tensor, int64_t dim) { + auto inType = cast(tensor.getType()); + auto inShape = makeShapeTorchCompatible(inType.getShape()); + + SmallVector outShape; + for (auto i = 0; i < static_cast(inShape.size()); i++) { + if (i == dim) { + outShape.push_back(numGroups); + outShape.push_back(inShape[i] == kUnknownSize + ? kUnknownSize + : inShape[i] / numGroups); + } else { + outShape.push_back(inShape[i]); + } + } + + SmallVector indices; + for (auto i = 0; i <= static_cast(inShape.size()); i++) { + if (i == dim) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } + + auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); + return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor, + indices); + }; + // The createZeroInitExpandedGroupsTensor lambda function below is used to + // create empty tensor with already expanded group dimension. + auto createZeroInitExpandedGroupsTensor = + [&](OpBuilder &rewriter, Location loc, const SmallVector &sizes, + Type type, int64_t dim, + SmallVector &indices) { + Value groups = + mlir::arith::ConstantIndexOp::create(rewriter, loc, numGroups); + + SmallVector expandedSizes; + for (auto i = 0; i < static_cast(sizes.size()); i++) { + if (i == dim) { + expandedSizes.push_back(groups); + expandedSizes.push_back( + rewriter.createOrFold(loc, sizes[i], + groups)); + } else { + expandedSizes.push_back(sizes[i]); + } + } + + indices.clear(); + for (auto i = 0; i <= static_cast(sizes.size()); i++) { + if (i == dim) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } + + return createZeroInitTensor(rewriter, loc, expandedSizes, type); + }; + + SmallVector newResults(op->getNumResults()); + + // Computing Backward-Input Convolution. + if (outMask[0]) { + // If convolution bwd is grouped, `grad_output` should be expanded. + auto gradOutputExpanded = + isGroupedConvBwd ? expandGroups(gradOutput, 1) : gradOutput; + // If convolution bwd is grouped, `weight` should be expanded + auto weightExpanded = isGroupedConvBwd ? expandGroups(weight, 0) : weight; + + // Flip weight along spatial dims only if + // - kernel size is greater than 1, + // - the kernel is not a 1x1 or 1x1x1 kernel. + SmallVector weightDimsInt = makeShapeTorchCompatible( + cast(weightExpanded.getType()).getShape()); + bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(), + weightDimsInt.rbegin() + numSpatialDims, + [](int64_t dim) { return dim == 1; }); + if (numSpatialDims > 1 && !is1x1Kernel) { + SmallVector weightFlipDims; + weightFlipDims.reserve(numSpatialDims); + for (int64_t i = 0; i < static_cast(numSpatialDims); ++i) + weightFlipDims.push_back(spatialStartDimIdx + i); + weightExpanded = torch_to_linalg::flipTensor( + rewriter, loc, weightExpanded, weightFlipDims); + } + + // For backward-input, padding must be adjusted to: + // p'[i] = d[i] * (K[i] - 1) - p[i] + Value c1 = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + SmallVector dilationIntValues = + getAsConstantIntValues(rewriter, loc, dilationInts); + SmallVector weiSizes = + getTensorSizes(rewriter, loc, weightExpanded); + SmallVector paddingValues(numSpatialDims); + for (size_t i = 0; i < numSpatialDims; ++i) { + Value kSize = + castIndexToInt64(rewriter, loc, weiSizes[spatialStartDimIdx + i]); + Value kMinusOne = rewriter.createOrFold(loc, kSize, c1); + Value mul = rewriter.createOrFold(loc, kMinusOne, + dilationIntValues[i]); + paddingValues[i] = + arith::SubIOp::create(rewriter, loc, mul, paddingIntValues[i]); + + if (isValueNegative(paddingValues[i])) + return rewriter.notifyMatchFailure( + op, "unimplemented: negative padding values are not supported."); + } + + // If there are not unit strides, we have to scatter `grad_output` into a + // zero-initialized tensor. + SmallVector gradInputSizes = getTensorSizes(rewriter, loc, input); + Value gradOutputModified; + if (llvm::any_of(strideInts, [](int64_t stride) { return stride > 1; })) { + // Destination spatial sizes are computed as: + // size[i] = (D[i] - 1) + d[i] * (K[i] - 1) + 1 + // Offsets on spatial dims are paddings + // Strides on spatial dims are the original stride[i]. + Value zero = + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); + Value one = + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); + + // Initialize slice strides, sizes and offsets + SmallVector goSizes = + getTensorSizes(rewriter, loc, gradOutputExpanded); + SmallVector sizes(goSizes.begin(), + goSizes.begin() + spatialStartDimIdx); + SmallVector offsets(spatialStartDimIdx, zero); + SmallVector strides(spatialStartDimIdx, one); + for (size_t i = 0; i < numSpatialDims; ++i) { + // Shapes of `grad_input` has not been expanded yet + // if it's needed for group conv even + Value h = gradInputSizes[2 + i]; + Value k = weiSizes[spatialStartDimIdx + i]; + Value hMinusOne = rewriter.createOrFold(loc, h, one); + Value kMinusOne = rewriter.createOrFold(loc, k, one); + Value mul = rewriter.createOrFold( + loc, castIntToIndex(rewriter, loc, dilationIntValues[i]), + kMinusOne); + Value sum = rewriter.createOrFold(loc, hMinusOne, mul); + sizes.push_back(rewriter.createOrFold(loc, sum, one)); + offsets.push_back(castIntToIndex(rewriter, loc, paddingValues[i])); + + Value strideIntValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(strideInts[i])); + strides.push_back(castIntToIndex(rewriter, loc, strideIntValue)); + } + + Value zeroInit = + createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy); + gradOutputModified = tensor::InsertSliceOp::create( + rewriter, loc, + torch_to_linalg::removeSizeInformation(rewriter, loc, + gradOutputExpanded), + zeroInit, offsets, goSizes, strides); + } else { + // If there unit strides, pad `grad_output` spatial dims with zeros. + // If conv is grouped, output has shape: + // N x G x F/G x . Otherwise: N x F x . + Value padVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(gradOutputDTy, 0.0)); + gradOutputModified = torch_to_linalg::getDynamicZeroPaddedTensor( + op, rewriter, gradOutputExpanded, paddingValues, spatialStartDimIdx, + padVal); + } + + // Initialize output buffer. For grouped, compute into an expanded + // [N, G, C/G, D*] tensor and collapse back to the original input shape. + SmallVector gradInputCollapseIndices; + Value gradInputInit = + isGroupedConvBwd + ? createZeroInitExpandedGroupsTensor(rewriter, loc, + gradInputSizes, inputDTy, 1, + gradInputCollapseIndices) + : createZeroInitTensor(rewriter, loc, gradInputSizes, inputDTy); + + // Create convolution for data gradient + auto convRes = createConvInputGradient(rewriter, loc, context, + isGroupedConvBwd, numSpatialDims, + dilationInts, gradOutputModified, + weightExpanded, gradInputInit) + .getResult(0); + + // Collapse [N, G, C/G, D] to [N, C, D] the result of the conv + // if it is grouped. + if (isGroupedConvBwd) { + convRes = tensor::CollapseShapeOp::create( + rewriter, loc, input.getType(), convRes, gradInputCollapseIndices); + } + + // Cast to the final result type expected by the type converter. + newResults[0] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(0).getType()), + convRes) + .getResult(); + } + + // Computing Backward-Weight Convolution. + if (outMask[1]) { + // If convolution bwd is grouped, `grad_output` should be expanded. + auto gradOutputExpanded = + isGroupedConvBwd ? expandGroups(gradOutput, 1) : gradOutput; + // If convolution bwd is grouped, `input` should be expanded + auto inputExpanded = isGroupedConvBwd ? expandGroups(input, 1) : input; + + // Pad input spatial dims with zeros. If grouped, input has shape: + // N x G x C/G x . Otherwise: N x C x . + // We should only pad the spatial dims, so set unpaddedDims accordingly. + Value padVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(inputDTy, 0.0)); + Value paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( + op, rewriter, inputExpanded, paddingIntValues, spatialStartDimIdx, + padVal); + + // Initialize output buffer. For grouped, compute into an expanded + // [G, F/G, C/G, K*] tensor and collapse back to the original weight + // shape. + SmallVector gradWeightSizes = + getTensorSizes(rewriter, loc, weight); + SmallVector gradWeightCollapseIndices; + Value gradWeightInit = + isGroupedConvBwd + ? createZeroInitExpandedGroupsTensor(rewriter, loc, + gradWeightSizes, weightDTy, + 0, gradWeightCollapseIndices) + : createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); + + // Create convolution for weight gradient + auto convResult = createConvWeightGradient( + rewriter, loc, context, isGroupedConvBwd, + numSpatialDims, strideInts, dilationInts, + paddedInput, gradOutputExpanded, gradWeightInit) + .getResult(0); + + // Collapse [G, F/G, C/G, D] to [F, C/G, D] the result of the conv + // if it is grouped. + if (isGroupedConvBwd) { + convResult = tensor::CollapseShapeOp::create( + rewriter, loc, weight.getType(), convResult, + gradWeightCollapseIndices); + } + + // Cast to the final result type expected by the type converter. + newResults[1] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(1).getType()), + convResult) + .getResult(); + } + + // Computing Backward-Bias Convolution. + if (outMask[2]) { + // Sum grad_output along all dims except F using linalg. + DenseSet reduceDims; + reduceDims.insert(0); + for (int64_t i = 2; i < static_cast(gradRank); ++i) + reduceDims.insert(i); + + torch_to_linalg::ReductionOpInfo opInfo{false, gradOutput, reduceDims}; + + // Zero init for the element type (arith.constant expects a scalar attr). + Value initSum = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(gradOutputDTy)); + + auto reductionBody = [&](OpBuilder &b, Location loc, ValueRange args) { + Value x = args[0]; + Value acc = args[1]; + Value sum = arith::AddFOp::create(b, loc, x, acc); + linalg::YieldOp::create(b, loc, sum); + }; + + Value gradBias = torch_to_linalg::createReductionLinalgGeneric( + rewriter, loc, opInfo, initSum, reductionBody); + + newResults[2] = tensor::CastOp::create(rewriter, loc, + getTypeConverter()->convertType( + op->getResult(2).getType()), + gradBias) + .getResult(); + } + + rewriter.replaceOp(op, newResults); + + return success(); + } + +private: + static linalg::GenericOp createConvInputGradient( + OpBuilder &rewriter, Location loc, MLIRContext *context, bool isGrouped, + size_t numSpatialDims, const SmallVector &dilationInts, + Value gradOutput, Value weight, Value gradInputInit) { + // To calculate convolution backward-data, we use generic operation. + // The generic operation is a generalization of the convolution operation + // that can handle any number of spatial dimensions. + // The generic operation is defined as follows: + // ``` + // dLdx[n, g, c, o] = sum(dLdy[n, g, f, d * k + o] * w[g, f, c, k] + // for n in range(batch_size) for o in range(in_spatial_dims)) + // ``` + // where: + // - `dLdx` is the data-gradient tensor. + // - `dLdy` is the output-gradient tensor which is padded if + // there are unit strides, or scattered otherwise. + // - `w` is the weight tensor flipped along spatial dims. + // - `n` is the batch dimension. + // - `g` is the group dimension. + // - `c` is the input channel dimension. + // - `f` is the output channel dimension. + // - `o` is the input spatial dimension. + // - `k` is the kernel dimension. + // - `d` is dilations. + + // Iterators: n, c, f, g, o, k + int64_t numIterators = + 3 + static_cast(isGrouped) + numSpatialDims * 2; + + // Bind dimensions in the following order: n, g, c, o, f, k + SmallVector dims(numIterators); + bindDimsList(context, MutableArrayRef{dims}); + + auto n = [&]() { return dims[0]; }; + auto g = [&]() { + if (!isGrouped) + llvm_unreachable("g() called for non-grouped convolution."); + return dims[1]; + }; + auto c = [&]() { return dims[1 + static_cast(isGrouped)]; }; + auto o = [&](size_t i) { + return dims[1 + static_cast(isGrouped) + 1 + i]; + }; + auto f = [&]() { + return dims[1 + static_cast(isGrouped) + 1 + numSpatialDims]; + }; + auto k = [&](size_t i) { + return dims[1 + static_cast(isGrouped) + 1 + numSpatialDims + 1 + + i]; + }; + + SmallVector lhsExprs = + isGrouped ? SmallVector{n(), g(), f()} + : SmallVector{n(), f()}; + SmallVector rhsExprs = + isGrouped ? SmallVector{g(), f(), c()} + : SmallVector{f(), c()}; + SmallVector outExprs = + isGrouped ? SmallVector{n(), g(), c()} + : SmallVector{n(), c()}; + for (size_t i = 0; i < numSpatialDims; i++) { + AffineExpr d = rewriter.getAffineConstantExpr(dilationInts[i]); + lhsExprs.push_back(d * k(i) + o(i)); + rhsExprs.push_back(k(i)); + outExprs.push_back(o(i)); + } + + SmallVector indexingMaps = { + AffineMap::get(numIterators, 0, lhsExprs, context), + AffineMap::get(numIterators, 0, rhsExprs, context), + AffineMap::get(numIterators, 0, outExprs, context)}; + SmallVector iteratorTypes = SmallVector(numIterators, IT::parallel); + std::fill(iteratorTypes.rbegin(), + iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction); + + return createConvAsGenericOp(rewriter, loc, gradOutput, weight, + gradInputInit, indexingMaps, iteratorTypes); + } + + static linalg::GenericOp createConvWeightGradient( + OpBuilder &rewriter, Location loc, MLIRContext *context, bool isGrouped, + size_t numSpatialDims, const SmallVector &strideInts, + const SmallVector &dilationInts, Value input, Value gradOutput, + Value gradWeightInit) { + // To calculate convolution backward-weight, we use generic operation. + // The generic operation is a generalization of the convolution operation + // that can handle any number of spatial dimensions. + // The generic operation is defined as follows: + // ``` + // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] + // for n in range(batch_size) for o in range(output_spatial_dims)) + // ``` + // - `dLdw` is the weight-gradient tensor. + // - `x` is the padded input tensor. + // - `dLdy` is the output-gradient tensor. + // - `n` is the batch dimension. + // - `g` is the group dimension. + // - `c` is the input channel dimension. + // - `f` is the output channel dimension. + // - `o` is the input spatial dimension. + // - `k` is the kernel dimension. + // - `d` and `s` are dilations and strides accordingly. + + // Iterators: n, c, f, g, o, k + int64_t numIterators = + 3 + static_cast(isGrouped) + numSpatialDims * 2; + + // Bind dimensions in the following order: g, f, c, k, n, o + SmallVector dims(numIterators); + bindDimsList(context, MutableArrayRef{dims}); + + auto g = [&]() { + if (!isGrouped) + llvm_unreachable("g() called for non-grouped convolution."); + return dims[0]; + }; + auto f = [&]() { return dims[static_cast(isGrouped)]; }; + auto c = [&]() { return dims[static_cast(isGrouped) + 1]; }; + auto k = [&](size_t i) { + return dims[static_cast(isGrouped) + 2 + i]; + }; + auto n = [&]() { + return dims[static_cast(isGrouped) + 2 + numSpatialDims]; + }; + auto o = [&](size_t i) { + return dims[static_cast(isGrouped) + 2 + numSpatialDims + 1 + i]; + }; + + SmallVector lhsExprs = + isGrouped ? SmallVector{n(), g(), c()} + : SmallVector{n(), c()}; + SmallVector rhsExprs = + isGrouped ? SmallVector{n(), g(), f()} + : SmallVector{n(), f()}; + SmallVector outExprs = + isGrouped ? SmallVector{g(), f(), c()} + : SmallVector{f(), c()}; + for (size_t i = 0; i < numSpatialDims; i++) { + AffineExpr d = rewriter.getAffineConstantExpr(dilationInts[i]); + AffineExpr s = rewriter.getAffineConstantExpr(strideInts[i]); + lhsExprs.push_back(d * k(i) + s * o(i)); + rhsExprs.push_back(o(i)); + outExprs.push_back(k(i)); + } + + SmallVector indexingMaps = { + AffineMap::get(numIterators, 0, lhsExprs, context), + AffineMap::get(numIterators, 0, rhsExprs, context), + AffineMap::get(numIterators, 0, outExprs, context)}; + SmallVector iteratorTypes = SmallVector(numIterators, IT::parallel); + std::fill(iteratorTypes.rbegin(), + iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction); + + return createConvAsGenericOp(rewriter, loc, input, gradOutput, + gradWeightInit, indexingMaps, iteratorTypes); + } + + static linalg::GenericOp + createConvAsGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, + Value out, const SmallVector &indexingMaps, + const SmallVector &iteratorTypes) { + return linalg::GenericOp::create( + b, loc, out.getType(), ValueRange{in0, in1}, out, indexingMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + Value grad = args[1]; + Value output = args[2]; + + // Convert input and grad to accumulator type if needed + Type accType = output.getType(); + if (input.getType() != accType) { + input = arith::ExtFOp::create(b, loc, accType, input); + } + if (grad.getType() != accType) { + grad = arith::ExtFOp::create(b, loc, accType, grad); + } + + Value mul = arith::MulFOp::create(b, loc, input, grad); + Value sum = arith::AddFOp::create(b, loc, mul, output); + linalg::YieldOp::create(b, loc, sum); + }); + } +}; +} // namespace + namespace { /// Creates coefficients based on DFT definition, see @@ -1874,6 +2449,8 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4c8318570c7b..7791a8ef35a7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -434,7 +434,9 @@ "ContainsIntList_False", "ContainsIntList_True", "ConvTbcModule_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", @@ -692,7 +694,9 @@ "Conv2dQInt8PerChannelModule_grouped", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", @@ -2914,8 +2918,10 @@ "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic", @@ -3693,7 +3699,9 @@ "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic", @@ -4337,8 +4345,10 @@ "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", + "ConvolutionBackwardModule2DDilated_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "ConvolutionModule2DGroups_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py index 4f547d531294..095b4db6c5a4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py @@ -34,6 +34,7 @@ "aten.flatten.using_ints", "aten.adaptive_avg_pool1d", "aten.adaptive_avg_pool2d", + "aten.convolution_backward", "aten.unflatten.int", ], OutputType.STABLEHLO: [ diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index 5e6e093902c4..14346e0a49d5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -228,6 +228,80 @@ def ConvolutionBackwardModule2DStrided_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 4, 4), tu.rand(1, 2, 8, 8), tu.rand(2, 2, 3, 3)) +class ConvolutionBackwardModule2DDilated(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 2, 6, 6], torch.float32, True), + ([1, 4, 8, 8], torch.float32, True), + ([2, 4, 3, 3], torch.float32, True), + ] + ) + def forward(self, grad_out, input_vec, weight): + return torch.ops.aten.convolution_backward( + grad_out, + input_vec, + weight, + bias_sizes=[4], + stride=[1, 1], + padding=[1, 1], + dilation=[2, 2], + transposed=False, + output_padding=[0, 0], + groups=1, + output_mask=[True, True, True], + ) + + +@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DDilated()) +def ConvolutionBackwardModule2DDilated_basic(module, tu: TestUtils): + with torch.backends.mkldnn.flags(enabled=False): + module.forward(tu.rand(1, 2, 6, 6), tu.rand(1, 4, 8, 8), tu.rand(2, 4, 3, 3)) + + +class ConvolutionBackwardModule2DStridedPaddedDilatedGrouped(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 16, 32, 32], torch.float32, True), + ([2, 128, 64, 64], torch.float32, True), + ([16, 32, 2, 2], torch.float32, True), + ] + ) + def forward(self, grad_out, input_vec, weight): + return torch.ops.aten.convolution_backward( + grad_out, + input_vec, + weight, + bias_sizes=[4], + stride=[2, 2], + padding=[2, 2], + dilation=[4, 4], + transposed=False, + output_padding=[0, 0], + groups=4, + output_mask=[True, True, True], + ) + + +@register_test_case( + module_factory=lambda: ConvolutionBackwardModule2DStridedPaddedDilatedGrouped() +) +def ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic(module, tu: TestUtils): + with torch.backends.mkldnn.flags(enabled=False): + module.forward( + tu.rand(2, 16, 32, 32), tu.rand(2, 128, 64, 64), tu.rand(16, 32, 2, 2) + ) + + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/convolution_bwd.mlir b/test/Conversion/TorchToLinalg/convolution_bwd.mlir new file mode 100644 index 000000000000..7debfea3e6e3 --- /dev/null +++ b/test/Conversion/TorchToLinalg/convolution_bwd.mlir @@ -0,0 +1,357 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -mlir-print-local-scope -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,63,63],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2,16,63,63],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,2,2],f32> -> tensor<16x128x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,63,63],f32> -> tensor<2x16x63x63xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[T1]] : tensor<16x128x2x2xf32>) outs(%[[W_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST1]], %[[I2]] : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[T1]][%[[I0]], %[[I1]], %[[R2]], %[[R3]]] : tensor<16x128x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] low[0, 0, 1, 1] high[0, 0, 1, 1] + // CHECK: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK: tensor.yield %[[CST0]] : f32 + // CHECK: } : tensor<2x16x63x63xf32> to tensor<2x16x65x65xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[W_REV]] : tensor<2x16x65x65xf32>, tensor<16x128x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x63x63xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,63,63],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_input_1x1ker_1x1s_0x0p_1x1d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,64,64],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,1,1],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_1x1ker_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2,16,64,64],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,1,1],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,1,1],f32> -> tensor<16x128x1x1xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,64,64],f32> -> tensor<2x16x64x64xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[T0]], %[[T1]] : tensor<2x16x64x64xf32>, tensor<16x128x1x1xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x64x64xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,64,64],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,1,1],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_2x2s_2x2p_2x2d_1g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,128,2,2],f32> -> tensor<16x128x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[T1]] : tensor<16x128x2x2xf32>) outs(%[[W_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[R2:.*]] = arith.subi %[[CST1]], %[[I2]] : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[T1]][%[[I0]], %[[I1]], %[[R2]], %[[R3]]] : tensor<16x128x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x16x66x66xf32> + // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x16x66x66xf32>) -> tensor<2x16x66x66xf32> + // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0]] into %[[SLICE_FILLED]][0, 0, 0, 0] [2, 16, 33, 33] [1, 1, 2, 2] : tensor<2x16x33x33xf32> into tensor<2x16x66x66xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x128x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x128x64x64xf32>) -> tensor<2x128x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d5 * 2 + d2, d6 * 2 + d3)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[SLICE]], %[[W_REV]] : tensor<2x16x66x66xf32>, tensor<16x128x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x128x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_input_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,32,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST1:.*]] = arith.constant 1 : index + // CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[16,32,2,2],f32> -> tensor<16x32x2x2xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xf32> into tensor<2x4x4x33x33xf32> + // CHECK: %[[W_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} output_shape [4, 4, 32, 2, 2] : tensor<16x32x2x2xf32> into tensor<4x4x32x2x2xf32> + // CHECK: %[[W_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32> + // CHECK: %[[W_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[W_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> + // CHECK: %[[W_REV:.*]] = linalg.generic {{.*}} ins(%[[W_EXP]] : tensor<4x4x32x2x2xf32>) outs(%[[W_FILLED]] : tensor<4x4x32x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_W:.*]]: f32, %[[OUT_W:.*]]: f32): + // CHECK-NEXT: %[[I0:.*]] = linalg.index 0 : index + // CHECK-NEXT: %[[I1:.*]] = linalg.index 1 : index + // CHECK-NEXT: %[[I2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[I3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[I4:.*]] = linalg.index 4 : index + // CHECK-NEXT: %[[R3:.*]] = arith.subi %[[CST1]], %[[I3]] : index + // CHECK-NEXT: %[[R4:.*]] = arith.subi %[[CST1]], %[[I4]] : index + // CHECK-NEXT: %[[EX:.*]] = tensor.extract %[[W_EXP]][%[[I0]], %[[I1]], %[[I2]], %[[R3]], %[[R4]]] : tensor<4x4x32x2x2xf32> + // CHECK-NEXT: linalg.yield %[[EX]] : f32 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xf32> + // CHECK: %[[SLICE_EMPTY:.*]] = tensor.empty() : tensor<2x4x4x66x66xf32> + // CHECK: %[[SLICE_FILLED:.*]] = linalg.fill ins(%cst : f32) outs(%[[SLICE_EMPTY]] : tensor<2x4x4x66x66xf32>) -> tensor<2x4x4x66x66xf32> + // CHECK: %[[SLICE:.*]] = tensor.insert_slice %[[T0_EXP]] into %[[SLICE_FILLED]][0, 0, 0, 0, 0] [2, 4, 4, 33, 33] [1, 1, 1, 2, 2] : tensor<2x4x4x33x33xf32> into tensor<2x4x4x66x66xf32> + // CHECK: %[[OUT_EMPTY:.*]] = tensor.empty() : tensor<2x4x32x64x64xf32> + // CHECK: %[[OUT_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[OUT_EMPTY]] : tensor<2x4x32x64x64xf32>) -> tensor<2x4x32x64x64xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {{.*}} ins(%[[SLICE]], %[[W_REV]] : tensor<2x4x4x66x66xf32>, tensor<4x4x32x2x2xf32>) outs(%[[OUT_FILLED]] : tensor<2x4x32x64x64xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[ACC:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[ACC]] : f32 + // CHECK-NEXT: } -> tensor<2x4x32x64x64xf32> + // CHECK: %[[CONV_COLLAPSED:.*]] = tensor.collapse_shape %[[CONV]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} : tensor<2x4x32x64x64xf32> into tensor<2x128x64x64xf32> + // CHECK: %[[IGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV_COLLAPSED]] : tensor<2x128x64x64xf32> -> !torch.vtensor<[2,128,64,64],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST0]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[IGRAD]], %[[BIAS]] : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %true, %false, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,32,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32> + return %result0, %result2 : !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_1x1s_0x0p_1x1d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,63,63],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_weights_1x1s_0x0p_1x1d_1g(%arg0: !torch.vtensor<[2,16,63,63],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,63,63],f32> -> tensor<2x16x63x63xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<16x128x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EMPTY]] : tensor<16x128x2x2xf32>) -> tensor<16x128x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 + d5, d3 + d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[T1]], %[[T0]] : tensor<2x128x64x64xf32>, tensor<2x16x63x63xf32>) outs(%[[OUT0_FILLED]] : tensor<16x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<16x128x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<16x128x2x2xf32> -> !torch.vtensor<[16,128,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x63x63xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,63,63],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> + return %result1, %result2 : !torch.vtensor<[16,128,2,2],f32>, !torch.vtensor<[16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_1g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,32,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[32,128,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32>) { +func.func @convolution_backward_weights_2x2s_2x2p_2x2d_1g(%arg0: !torch.vtensor<[2,32,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[32,128,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,32,33,33],f32> -> tensor<2x32x33x33xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T1]] low[0, 0, 2, 2] high[0, 0, 2, 2] + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[CST]] : f32 + // CHECK-NEXT: } : tensor<2x128x64x64xf32> to tensor<2x128x68x68xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<32x128x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EMPTY]] : tensor<32x128x2x2xf32>) -> tensor<32x128x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1, d2 * 2 + d5 * 2, d3 * 2 + d6 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d0, d5, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0]] : tensor<2x128x68x68xf32>, tensor<2x32x33x33xf32>) outs(%[[OUT0_FILLED]] : tensor<32x128x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<32x128x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[CONV]] : tensor<32x128x2x2xf32> -> !torch.vtensor<[32,128,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<32xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<32xf32>) -> tensor<32xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x32x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<32xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<32xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<32xf32> -> !torch.vtensor<[32],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,32,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[32,128,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> + return %result1, %result2 : !torch.vtensor<[32,128,2,2],f32>, !torch.vtensor<[32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],f32>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],f32>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32>) { +func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g(%arg0: !torch.vtensor<[2,16,33,33],f32>, %arg1: !torch.vtensor<[2,128,64,64],f32>, %arg2: !torch.vtensor<[16,32,2,2],f32>, %arg3: !torch.vtensor<[],f32>) -> (!torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32>) { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,128,64,64],f32> -> tensor<2x128x64x64xf32> + // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,16,33,33],f32> -> tensor<2x16x33x33xf32> + // CHECK: %[[T0_EXP:.*]] = tensor.expand_shape %[[T0]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 4, 33, 33] : tensor<2x16x33x33xf32> into tensor<2x4x4x33x33xf32> + // CHECK: %[[T1_EXP:.*]] = tensor.expand_shape %[[T1]] {{\[\[0\], \[1, 2\], \[3\], \[4\]\]}} output_shape [2, 4, 32, 64, 64] : tensor<2x128x64x64xf32> into tensor<2x4x32x64x64xf32> + // CHECK: %[[PAD:.*]] = tensor.pad %[[T1_EXP]] low[0, 0, 0, 2, 2] high[0, 0, 0, 2, 2] + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[CST]] : f32 + // CHECK-NEXT: } : tensor<2x4x32x64x64xf32> to tensor<2x4x32x68x68xf32> + // CHECK: %[[OUT0_EMPTY:.*]] = tensor.empty() : tensor<4x4x32x2x2xf32> + // CHECK: %[[OUT0_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[OUT0_EMPTY]] : tensor<4x4x32x2x2xf32>) -> tensor<4x4x32x2x2xf32> + // CHECK: %[[CONV:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d2, d3 * 2 + d6 * 2, d4 * 2 + d7 * 2)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d0, d1, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PAD]], %[[T0_EXP]] : tensor<2x4x32x68x68xf32>, tensor<2x4x4x33x33xf32>) outs(%[[OUT0_FILLED]] : tensor<4x4x32x2x2xf32>) { + // CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[IN]], %[[IN1]] : f32 + // CHECK-NEXT: %[[CONV_RES:.*]] = arith.addf %[[MUL]], %[[OUT]] : f32 + // CHECK-NEXT: linalg.yield %[[CONV_RES]] : f32 + // CHECK-NEXT: } -> tensor<4x4x32x2x2xf32> + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CONV]] {{\[\[0, 1\], \[2\], \[3\], \[4\]\]}} : tensor<4x4x32x2x2xf32> into tensor<16x32x2x2xf32> + // CHECK: %[[WGRAD:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor<16x32x2x2xf32> -> !torch.vtensor<[16,32,2,2],f32> + // CHECK: %[[SUM_EMPTY:.*]] = tensor.empty() : tensor<16xf32> + // CHECK: %[[SUM_FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[SUM_EMPTY]] : tensor<16xf32>) -> tensor<16xf32> + // CHECK: %[[SUM_GEN:.*]] = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1)>], iterator_types = ["reduction", "parallel", "reduction", "reduction"]} ins(%[[T0]] : tensor<2x16x33x33xf32>) outs(%[[SUM_FILLED]] : tensor<16xf32>) { + // CHECK-NEXT: ^bb0(%[[IN_B:.*]]: f32, %[[ACC_B:.*]]: f32): + // CHECK-NEXT: %[[B_RES:.*]] = arith.addf %[[IN_B]], %[[ACC_B]] : f32 + // CHECK-NEXT: linalg.yield %[[B_RES]] : f32 + // CHECK-NEXT: } -> tensor<16xf32> + // CHECK: %[[BIAS:.*]] = torch_c.from_builtin_tensor %[[SUM_GEN]] : tensor<16xf32> -> !torch.vtensor<[16],f32> + // CHECK: return %[[WGRAD]], %[[BIAS]] : !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,32,2,2],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int, !torch.list -> !torch.none, !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> + return %result1, %result2 : !torch.vtensor<[16,32,2,2],f32>, !torch.vtensor<[16],f32> +} + +// -----