diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b19d3f949f03..a04e4da13a83 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1095,16 +1095,22 @@ class DecomposeAtenMvOp : public OpRewritePattern { }; } // namespace -// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations. +// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and +// prims.collapse operations. // -// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where -// leading_dims is of size N, then +// If input is a tensor of shape +// (*leading_dims, C*r*r, H, W), +// +// where leading_dims is of size N, then // X = pixel_shuffle(input, upscale_factor) // // gets replaced with -// A = input.reshape(*leading_dims, C, r, r, H, W) -// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2) -// X = B.reshape(*leading_dims, C, r*H, r*W) +// X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W) +// X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W) +// X = X.permute(0, ..., N, N+3, N+1, N+4, N+2) +// # shape (*leading_dims, C, H, r, W, r) +// X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W) +// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) // // 'r' above is referred to as the 'upscale factor' or just 'factor' below. namespace { @@ -1115,7 +1121,6 @@ class DecomposeAtenPixelShuffleOp LogicalResult matchAndRewrite(AtenPixelShuffleOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); Value inValue = op.getSelf(); auto inType = inValue.getType().cast(); @@ -1127,22 +1132,6 @@ class DecomposeAtenPixelShuffleOp auto inShape = maybeSizes.value(); auto inRank = inShape.size(); - // TODO support dynamic shapes, probably by lowering pixel_shuffle to linalg - // directly. Pixel shuffle does a reshape that is hard to recover - // through pure torch (view) ops, especially in dynamic cases. - // - // See: https://github.com/llvm/torch-mlir/issues/2559 - // - // For now, we just fail the decomposition here so that a sensible error is - // provided: - for (auto dimSize : inShape) { - if (dimSize == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "Currently we only decompose pixel_shuffle if the input tensor " - "is statically shaped"); - } - } - // The input tensor must have at least 3 dimensions: (1) the channel // dimension which gets smaller by 'factor*factor', (2) the H channel which // gets larger by 'factor' and (3) the W channel which get larger by @@ -1152,6 +1141,29 @@ class DecomposeAtenPixelShuffleOp return rewriter.notifyMatchFailure( op, "Expected input tensor to have rank greater than 2."); + const auto inOptionalDType = inType.getOptionalDtype(); + + auto getTypeFromShape = [inOptionalDType](auto &&vals) { + // Get a vector of integers from a vector of Values. + auto getIntShape = [](auto &&vals) { + SmallVector shape; + shape.reserve(vals.size()); + for (auto v : vals) { + int64_t cst_val; + if (matchPattern(v, m_TorchConstantInt(&cst_val))) { + shape.push_back(cst_val); + } else { + shape.push_back(kUnknownSize); + } + } + return shape; + }; + + const auto intShape = getIntShape(vals); + return ValueTensorType::get(vals[0].getContext(), + llvm::ArrayRef(intShape), inOptionalDType); + }; + auto nLeadingDims = inRank - 3; // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead @@ -1169,106 +1181,94 @@ class DecomposeAtenPixelShuffleOp auto factor = op.getUpscaleFactor(); - Value factorSquared = rewriter.createOrFold(loc, factor, factor); + Value outC = rewriter.createOrFold(loc, inC, factorSquared); Value outH = rewriter.createOrFold(loc, inH, factor); Value outW = rewriter.createOrFold(loc, inW, factor); - // Shape of 'A' in the comment at the top - SmallVector prePermuteShape; - prePermuteShape.reserve(nLeadingDims + 5); - - // Shape of 'B' in the comment at the top. - SmallVector postPermuteShape; - postPermuteShape.reserve(nLeadingDims + 5); - - SmallVector outShape; - outShape.reserve(nLeadingDims + 3); - - SmallVector permutation; - permutation.reserve(nLeadingDims + 5); + SmallVector dimensionConstants; + dimensionConstants.reserve(inRank + 2); + for (unsigned i = 0; i < inRank + 2; ++i) { + dimensionConstants.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + } + SmallVector leadingDims; + leadingDims.reserve(nLeadingDims); for (unsigned i = 0; i < nLeadingDims; ++i) { - auto dimensionAttr = rewriter.getI64IntegerAttr(i); - Value dimensionValue = rewriter.create(loc, dimensionAttr); - Value leadingDimSize = - rewriter.createOrFold(loc, inValue, dimensionValue); - prePermuteShape.push_back(leadingDimSize); - postPermuteShape.push_back(leadingDimSize); - outShape.push_back(leadingDimSize); - permutation.push_back(dimensionValue); - + Value leadingDimSize = rewriter.createOrFold( + loc, inValue, dimensionConstants[i]); + leadingDims.push_back(leadingDimSize); } - const auto inOptionalDType = inType.getOptionalDtype(); - - auto getTypeFromShape = [inOptionalDType](auto &&vals) { - // Get a vector of integers from a vector of Values. - auto getIntShape = [](auto &&vals) { - SmallVector shape; - shape.reserve(vals.size()); - for (auto v : vals) { - int64_t cst_val; - if (matchPattern(v, m_TorchConstantInt(&cst_val))) { - shape.push_back(cst_val); - } else { - shape.push_back(kUnknownSize); - } - } - return shape; - }; + SmallVector partiallyExpandedShape = leadingDims; + partiallyExpandedShape.append({outC, factorSquared, inH, inW}); - const auto intShape = getIntShape(vals); - return ValueTensorType::get(vals[0].getContext(), - llvm::ArrayRef(intShape), inOptionalDType); - }; + SmallVector prePermuteShape = leadingDims; + prePermuteShape.append({outC, factor, factor, inH, inW}); - prePermuteShape.insert(prePermuteShape.end(), - {outC, factor, factor, inH, inW}); + SmallVector postPermuteShape = leadingDims; + postPermuteShape.append({outC, inH, factor, inW, factor}); - postPermuteShape.insert(postPermuteShape.end(), - {outC, inH, factor, inW, factor}); + SmallVector partiallyCollapsedShape = leadingDims; + partiallyCollapsedShape.append({outC, inH, factor, outW}); - outShape.insert(outShape.end(), {outC, outH, outW}); + SmallVector outShape = leadingDims; + outShape.append({outC, outH, outW}); + SmallVector permutation{dimensionConstants.begin(), + dimensionConstants.begin() + nLeadingDims}; SmallVector permutationTail{0, 3, 1, 4, 2}; for (uint64_t d : permutationTail) { - permutation.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(nLeadingDims + d))); + permutation.push_back(dimensionConstants[nLeadingDims + d]); } - auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); - - Value shapeA = - rewriter.create(loc, listType, prePermuteShape); - - Value A = rewriter.create( - loc, getTypeFromShape(prePermuteShape), inValue, shapeA); - Value permuteDimsOrder = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), permutation); - Value B = rewriter.create( - loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder); + // Split input channel inC -> (inC, factorSquared) + auto partiallyExpanded = + rewriter + .create( + loc, getTypeFromShape(partiallyExpandedShape), inValue, + dimensionConstants[nLeadingDims], outC) + .getResult(); + + // Split new dimension factorSquared -> (factor, factor) + auto fullyExpanded = rewriter.create( + loc, getTypeFromShape(prePermuteShape), partiallyExpanded, + dimensionConstants[nLeadingDims + 1], factor); + + // Perform the permutation + auto permuted = + rewriter.create(loc, getTypeFromShape(postPermuteShape), + fullyExpanded, permuteDimsOrder); - Value outShapeList = - rewriter.create(loc, listType, outShape); + // Collapse final 2 dimension + auto partiallyCollapsed = rewriter.create( + loc, getTypeFromShape(partiallyCollapsedShape), permuted, + dimensionConstants[nLeadingDims + 3], + dimensionConstants[nLeadingDims + 4]); + + // Collapse back to original rank + rewriter.replaceOpWithNewOp( + op, op.getType(), partiallyCollapsed, + dimensionConstants[nLeadingDims + 1], + dimensionConstants[nLeadingDims + 2]); - rewriter.replaceOpWithNewOp(op, op.getType(), B, - outShapeList); return success(); } }; } // namespace // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) -static Value -getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { +static Value getRelu6Results(PatternRewriter &rewriter, Location loc, + Value input) { BaseTensorType inputType = input.getType().cast(); Value relu = rewriter.create(loc, inputType, input); @@ -1815,7 +1815,7 @@ class DecomposeAtenUnflattenIntOp auto inputTensorType = self.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure(op, - "Expected input type having sizes"); + "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); @@ -1851,7 +1851,7 @@ class DecomposeAtenUnflattenIntOp Value dimSize = rewriter.create(loc, self, /*dim=*/dimValue); if (i == dimInt) { - int64_t inferredSizeInt = inputShape[i]; + int64_t inferredSizeInt = inputShape[i]; int64_t inferredDim; for (unsigned j = 0; j < sizesInts.size(); ++j) { if (sizesInts[j] == -1) { @@ -1865,11 +1865,9 @@ class DecomposeAtenUnflattenIntOp } } if (inferred) { - Value inferredSize = - rewriter.create( + Value inferredSize = rewriter.create( loc, rewriter.getI64IntegerAttr(inferredSizeInt)); - newSizes.insert( - newSizes.begin() + inferredDim + i, inferredSize); + newSizes.insert(newSizes.begin() + inferredDim + i, inferredSize); } } else { newSizes.push_back(dimSize); @@ -4095,7 +4093,7 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern { } // namespace namespace { -class DecomposeAtenCosineSimilarityOp +class DecomposeAtenCosineSimilarityOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCosineSimilarityOp op, @@ -4122,7 +4120,7 @@ class DecomposeAtenCosineSimilarityOp indexBroadcastShapeTorchList); // Compute the mul of A and B - Value dotProduct = + Value dotProduct = rewriter.create(loc, broadcastType, x1, x2); Value cstFalse = rewriter.create(loc, false); Value cstNone = rewriter.create(loc); @@ -4133,17 +4131,17 @@ class DecomposeAtenCosineSimilarityOp loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); - + // Compute the norm of A and B - Value ord = rewriter.create(loc, - rewriter.getF64FloatAttr(2.0)); + Value ord = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); Value normA = rewriter.create( loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); Value normB = rewriter.create( loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); - + // Compute the product of the norms Value normProduct = rewriter.create(loc, op.getType(), normA, normB); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 595afbb74f01..eeaad869089d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -948,8 +948,6 @@ # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "PixelShuffleModuleStaticRank3Int64_basic", - "PixelShuffleModuleStaticRank4Float32_basic", "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", @@ -1371,6 +1369,9 @@ "SplitDimDynamicModule_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 9eb1a8986d4c..a9d9e2e9ce35 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -686,7 +686,56 @@ def forward(self, x): def PixelShuffleModuleStaticRank3Int64_basic(module, tu: TestUtils): module.forward(tu.randint(12, 2, 3, low = 0, high = 100)) +# ============================================================================== + + +class PixelShuffleModuleFullDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1,-1,-1,-1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleFullDynamic()) +def PixelShuffleModuleFullDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(1,8,3,3, low = 0, high = 100)) + +# ============================================================================== + + +class PixelShuffleModuleSpatiallyDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([2,1,8,-1,-1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyDynamic()) +def PixelShuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(2,1,8,2,3, low = 0, high = 100)) + +# ============================================================================== + +class PixelShuffleModuleSpatiallyStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1,-1,-1,3,1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyStatic()) +def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): + module.forward(tu.randint(1,2,12,3,1, low = 0, high = 100)) + + +# ============================================================================== class TensorsConcatModule(torch.nn.Module):