-
Notifications
You must be signed in to change notification settings - Fork 619
[TorchToLinalg] Direct lowering from Torch to Linalg for torch.aten.convolution_backward #4384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
4f1cb20 to
8e2b616
Compare
|
@zjgarvey hey! May I ask you to take a look when you're available? Thank you in advance for the review. |
zjgarvey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This is an excellent start.
We need to keep the existing decomposition for other backends. I have a few other comments for you to look at, but that's the biggest blocker right now.
| rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, | ||
| cstFalse, cstNone); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep the decomposition, E.g., TOSA and StableHLO still rely on this pattern. The purpose of the backend_legal_ops option in torch-decompose-complex-ops is specifically to prevent selected decomposition patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thank you for the explanation - I didn't know about this mechanism on the backend sides.
Returned this pass and the lit test in decompose-complex-ops.mlir
| SmallVector<int64_t> weightFlipDims; | ||
| weightFlipDims.reserve(numSpatialDims); | ||
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) | ||
| weightFlipDims.push_back(spatialStartDimIdx + i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the weight shape is static at index i, and the dim size is 1 there, don't add to the flip. We definitely see a lot of 1x1 filter convs and the noop flip doesn't get folded easily IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. I also noticed that I forgot to add condition for numSpatialDims == 1 to not insert flip.
So now we flip kernel dims only when numSpatialDims > 1 and this is not 1x1 kernel. + added lit test
Thanks!
| createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy); | ||
| gradOutputSliced = tensor::InsertSliceOp::create( | ||
| rewriter, loc, | ||
| torch_to_linalg::removeSizeInformation(rewriter, loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove the size info?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe "sliced" is a misleading name. Scattered? Or something generic like "Modified" since you are just padding when stride == 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove the size info?
If I leave only gradOutputExpanded here, I see the following error in tests:
lit test:
<stdin>:155:34: error: unexpected error: expected type to be 'tensor<?x?x?x?xf32>' or a rank-reduced version. (size mismatch)
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32>
func test:
error: expected type to be 'tensor<?x?x?x?x?xf32>' or a rank-reduced version. (size mismatch)
note: see current operation: %104 = "tensor.insert_slice"(%15, %103, %73, %73, %73, %90, %98, %76, %78, %80, %82, %84, %74, %74, %74, %92, %100) <{operandSegmentSizes = array<i32: 1, 1, 5, 5, 5>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_strides = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>}> : (tensor<2x4x4x32x32xf32>, tensor<2x4x4x68x68xf32>, index, index, index, index, index, index, index, index, index, index, index, index, index, index, index) -> tensor<2x4x4x68x68xf32>
To solve this problem I decided to take a logic from ConvertAtenConvolutionOp:
torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp
Lines 1647 to 1650 in 244f4b6
| auto paddedTensor = tensor::InsertSliceOp::create( | |
| rewriter, loc, | |
| torch_to_linalg::removeSizeInformation(rewriter, loc, input), | |
| initMaxTensor, insertSliceOffsets, inputSizes, strideIndexValues); |
As far as I understand, insertSliceOp expects dynamic tensor (I see it from error msg), but I pass static tensor - I get this error. To pass dynamic tensor to insertSliceOp, I remove the size info here.
Please let me know if there is another better way how to handle this passing.
Also maybe "sliced" is a misleading name. Scattered? Or something generic like "Modified" since you are just padding when stride == 1.
Renamed to Modified to align two if-else branches. Thanks!
| createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); | ||
| SmallVector<ReassociationIndices> gradWeightCollapseIndices; | ||
| if (isGroupedConvBwd) { | ||
| auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the init just be made on the expanded shape here (instead of expanding the init)? This probably gets folded, but I think it would be better to generate simpler IR when possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably gets folded, but I think it would be better to generate simpler IR when possible.
I will know! I thought that it will be fold by any canonicalization pass further. But let's do it ourself here.
Thanks!
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the input spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the | ||
| // gradient of the output tensor. `dLdx` is the data-gradient tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to mention that dLdy is the stride/padding modified grad output tensor here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And that w is flipped along spatial dims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks
| } | ||
|
|
||
| static linalg::GenericOp | ||
| createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be a util for this already like "createReductionGeneric` or something. In any case, might be good to call this something a little more specific (pun intended).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to have the own impl here for better control and understanding.
Renamed to createConvAsGenericOp.
Thanks!
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, c, o, f, k; | ||
| bindDims(context, n, c, o, f, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, o}; | ||
| indexingMaps = {AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, weiExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, c, oh, ow, f, kh, kw; | ||
| bindDims(context, n, c, oh, ow, f, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kh + oh, d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, oh, ow}; | ||
| indexingMaps = {AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, weiExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr n, c, od, oh, ow, f, kd, kh, kw; | ||
| bindDims(context, n, c, od, oh, ow, f, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kd + od, d1 * kh + oh, | ||
| d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, weiExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, g, cg, o, fg, k; | ||
| bindDims(context, n, g, cg, o, fg, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, o}; | ||
| indexingMaps = {AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, weiExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, g, cg, oh, ow, fg, kh, kw; | ||
| bindDims(context, n, g, cg, oh, ow, fg, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * kh + oh, | ||
| d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow}; | ||
| indexingMaps = {AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, weiExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw; | ||
| bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = { | ||
| n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, weiExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| static void initIndexingMapsAndIteratorTypesForWeightBwd( | ||
| OpBuilder &rewriter, MLIRContext *context, bool isGrouped, | ||
| int numSpatialDims, const SmallVector<int64_t> &strideInts, | ||
| const SmallVector<int64_t> &dilationInts, | ||
| SmallVector<AffineMap> &indexingMaps, SmallVector<IT> &iteratorTypes) { | ||
| // To calculate convolution backward-weight, we use generic operation. | ||
| // The generic operation is a generalization of the convolution operation | ||
| // that can handle any number of spatial dimensions. | ||
| // The generic operation is defined as follows: | ||
| // ``` | ||
| // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] | ||
| // for n in range(batch_size) for o in range(output_spatial_dims)) | ||
| // ``` | ||
| // where `n` is the batch dimension, `g` is the group dimension, | ||
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the output spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation and `s0` is stride. `x` is the input | ||
| // tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the | ||
| // weight-gradient tensor. | ||
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr f, c, k, n, o; | ||
| bindDims(context, f, c, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, o}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, k}; | ||
| indexingMaps = {AffineMap::get(5, 0, inExprs, context), | ||
| AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr f, c, kh, kw, n, oh, ow; | ||
| bindDims(context, f, c, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kh, kw}; | ||
| indexingMaps = {AffineMap::get(7, 0, inExprs, context), | ||
| AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr f, c, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, f, c, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(9, 0, inExprs, context), | ||
| AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr g, fg, cg, k, n, o; | ||
| bindDims(context, g, fg, cg, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, o}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, k}; | ||
| indexingMaps = {AffineMap::get(6, 0, inExprs, context), | ||
| AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr g, fg, cg, kh, kw, n, oh, ow; | ||
| bindDims(context, g, fg, cg, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kh, kw}; | ||
| indexingMaps = {AffineMap::get(8, 0, inExprs, context), | ||
| AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(10, 0, inExprs, context), | ||
| AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There must be a better way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g., you could make the AffineExprs for stride, dilation, spatial dims, etc. SmallVector<AffineExpr>. I don't even think there need to be conditionals on anything other than like:
SmallVector<AffineExpr> lhsExprs = isGrouped ? {n, g, c} : {n, c};
// loop over spatial dims and add expressions...Everything else can be like:
int64_t numIterators = 3; // batch, parallel channel, reduction channel
numIterators += static_cast<int64_t>(isGrouped);
numIterators += numSpatialDims*2 // parallel spatial dims, reduction spatial dims
indexingMaps = {
AffineMap::get(numIterators, lhsExprs, context),
AffineMap::get(numIterators, rhsExprs, context),
AffineMap::get(numIterators, outExprs, context)
};There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea, thanks for that. Implemented.
a-sidorova
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zjgarvey thank you for review! I have applied your comments in the latest commit. Could you please take a look at the changes one more time?
| rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, | ||
| cstFalse, cstNone); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thank you for the explanation - I didn't know about this mechanism on the backend sides.
Returned this pass and the lit test in decompose-complex-ops.mlir
| SmallVector<int64_t> weightFlipDims; | ||
| weightFlipDims.reserve(numSpatialDims); | ||
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) | ||
| weightFlipDims.push_back(spatialStartDimIdx + i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. I also noticed that I forgot to add condition for numSpatialDims == 1 to not insert flip.
So now we flip kernel dims only when numSpatialDims > 1 and this is not 1x1 kernel. + added lit test
Thanks!
| createZeroInitTensor(rewriter, loc, sizes, gradOutputDTy); | ||
| gradOutputSliced = tensor::InsertSliceOp::create( | ||
| rewriter, loc, | ||
| torch_to_linalg::removeSizeInformation(rewriter, loc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove the size info?
If I leave only gradOutputExpanded here, I see the following error in tests:
lit test:
<stdin>:155:34: error: unexpected error: expected type to be 'tensor<?x?x?x?xf32>' or a rank-reduced version. (size mismatch)
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int1, %3 : !torch.vtensor<[2,16,33,33],f32>, !torch.vtensor<[2,128,64,64],f32>, !torch.vtensor<[16,128,2,2],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.vtensor<[2,128,64,64],f32>, !torch.none, !torch.vtensor<[16],f32>
func test:
error: expected type to be 'tensor<?x?x?x?x?xf32>' or a rank-reduced version. (size mismatch)
note: see current operation: %104 = "tensor.insert_slice"(%15, %103, %73, %73, %73, %90, %98, %76, %78, %80, %82, %84, %74, %74, %74, %92, %100) <{operandSegmentSizes = array<i32: 1, 1, 5, 5, 5>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>, static_strides = array<i64: -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808>}> : (tensor<2x4x4x32x32xf32>, tensor<2x4x4x68x68xf32>, index, index, index, index, index, index, index, index, index, index, index, index, index, index, index) -> tensor<2x4x4x68x68xf32>
To solve this problem I decided to take a logic from ConvertAtenConvolutionOp:
torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp
Lines 1647 to 1650 in 244f4b6
| auto paddedTensor = tensor::InsertSliceOp::create( | |
| rewriter, loc, | |
| torch_to_linalg::removeSizeInformation(rewriter, loc, input), | |
| initMaxTensor, insertSliceOffsets, inputSizes, strideIndexValues); |
As far as I understand, insertSliceOp expects dynamic tensor (I see it from error msg), but I pass static tensor - I get this error. To pass dynamic tensor to insertSliceOp, I remove the size info here.
Please let me know if there is another better way how to handle this passing.
Also maybe "sliced" is a misleading name. Scattered? Or something generic like "Modified" since you are just padding when stride == 1.
Renamed to Modified to align two if-else branches. Thanks!
| createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); | ||
| SmallVector<ReassociationIndices> gradWeightCollapseIndices; | ||
| if (isGroupedConvBwd) { | ||
| auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably gets folded, but I think it would be better to generate simpler IR when possible.
I will know! I thought that it will be fold by any canonicalization pass further. But let's do it ourself here.
Thanks!
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, c, o, f, k; | ||
| bindDims(context, n, c, o, f, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, o}; | ||
| indexingMaps = {AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, weiExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, c, oh, ow, f, kh, kw; | ||
| bindDims(context, n, c, oh, ow, f, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kh + oh, d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, oh, ow}; | ||
| indexingMaps = {AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, weiExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr n, c, od, oh, ow, f, kd, kh, kw; | ||
| bindDims(context, n, c, od, oh, ow, f, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kd + od, d1 * kh + oh, | ||
| d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, weiExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, g, cg, o, fg, k; | ||
| bindDims(context, n, g, cg, o, fg, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, o}; | ||
| indexingMaps = {AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, weiExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, g, cg, oh, ow, fg, kh, kw; | ||
| bindDims(context, n, g, cg, oh, ow, fg, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * kh + oh, | ||
| d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow}; | ||
| indexingMaps = {AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, weiExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw; | ||
| bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = { | ||
| n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, weiExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| static void initIndexingMapsAndIteratorTypesForWeightBwd( | ||
| OpBuilder &rewriter, MLIRContext *context, bool isGrouped, | ||
| int numSpatialDims, const SmallVector<int64_t> &strideInts, | ||
| const SmallVector<int64_t> &dilationInts, | ||
| SmallVector<AffineMap> &indexingMaps, SmallVector<IT> &iteratorTypes) { | ||
| // To calculate convolution backward-weight, we use generic operation. | ||
| // The generic operation is a generalization of the convolution operation | ||
| // that can handle any number of spatial dimensions. | ||
| // The generic operation is defined as follows: | ||
| // ``` | ||
| // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] | ||
| // for n in range(batch_size) for o in range(output_spatial_dims)) | ||
| // ``` | ||
| // where `n` is the batch dimension, `g` is the group dimension, | ||
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the output spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation and `s0` is stride. `x` is the input | ||
| // tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the | ||
| // weight-gradient tensor. | ||
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr f, c, k, n, o; | ||
| bindDims(context, f, c, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, o}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, k}; | ||
| indexingMaps = {AffineMap::get(5, 0, inExprs, context), | ||
| AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr f, c, kh, kw, n, oh, ow; | ||
| bindDims(context, f, c, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kh, kw}; | ||
| indexingMaps = {AffineMap::get(7, 0, inExprs, context), | ||
| AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr f, c, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, f, c, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(9, 0, inExprs, context), | ||
| AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr g, fg, cg, k, n, o; | ||
| bindDims(context, g, fg, cg, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, o}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, k}; | ||
| indexingMaps = {AffineMap::get(6, 0, inExprs, context), | ||
| AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr g, fg, cg, kh, kw, n, oh, ow; | ||
| bindDims(context, g, fg, cg, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kh, kw}; | ||
| indexingMaps = {AffineMap::get(8, 0, inExprs, context), | ||
| AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(10, 0, inExprs, context), | ||
| AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea, thanks for that. Implemented.
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the input spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the | ||
| // gradient of the output tensor. `dLdx` is the data-gradient tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks
| } | ||
|
|
||
| static linalg::GenericOp | ||
| createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to have the own impl here for better control and understanding.
Renamed to createConvAsGenericOp.
Thanks!
b20062d to
30dac4b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
30dac4b to
31eac08
Compare
Description:
torch.aten.convolution_backwardfrom Torch to Linalg. Enabled this pass by default. The pass generateslinalg.genericops instead oflinalg.conv_<>for better lowering.DecomposeAtenConvolutionBackwardOpfromTorch/Transforms/DecomposeComplexOps.cpp.convolution_backward.mlir. Also added more test cases for better test coverage.Issue: