diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 23324a007377e..39c16fab21c4e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -21,177 +21,298 @@ #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include #define DEBUG_TYPE "vector-shape-cast-lowering" using namespace mlir; -using namespace mlir::vector; - -/// Increments n-D `indices` by `step` starting from the innermost dimension. -static void incIdx(SmallVectorImpl &indices, VectorType vecType, - int step = 1) { - for (int dim : llvm::reverse(llvm::seq(0, indices.size()))) { - assert(indices[dim] < vecType.getDimSize(dim) && - "Indices are out of bound"); - indices[dim] += step; - if (indices[dim] < vecType.getDimSize(dim)) - break; - indices[dim] = 0; - step = 1; +/// Perform the inplace update +/// rhs <- lhs + rhs +/// +/// where `rhs` is a number expressed in mixed base `base` with most signficant +/// dimensions on the left. For example if `rhs` is {a,b,c} and `base` is +/// {5,3,2} then `rhs` has value a*3*2 + b*2 + c. +/// +/// Some examples where `base` is {5,3,2}: +/// rhs = {0,0,0}, lhs = 1 --> rhs = {0,0,1} +/// rhs = {0,0,1}, lhs = 1 --> rhs = {0,1,0} +/// rhs = {0,0,0}, lhs = 25 --> rhs = {4,0,1} +/// +/// Invalid: +/// rhs = {0,0,2}, lhs = 1 : rhs not in base {5,3,2} +/// +/// Overflows not handled correctly: +/// rhs = {4,2,1}, lhs = 2 --> rhs = {0,0,0} (not {0,0,1}) +static void inplaceAdd(int64_t lhs, ArrayRef base, + MutableArrayRef rhs) { + + // For dimensions in [numIndices - 1, ..., 3, 2, 1, 0]: + for (int dim : llvm::reverse(llvm::seq(0, rhs.size()))) { + int64_t dimBase = base[dim]; + assert(rhs[dim] < dimBase && "rhs not in base"); + + int64_t incremented = rhs[dim] + lhs; + + // If the incremented value excedes the dimension base, we must spill to the + // next most significant dimension and repeat (we might need to spill to + // more significant dimensions multiple times). + lhs = incremented / dimBase; + rhs[dim] = incremented % dimBase; + if (lhs == 0) + break; } } namespace { -/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D -/// vectors progressively. This iterates over the n-1 major dimensions of the -/// n-D vector and performs rewrites into: -/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D -class ShapeCastOpNDDownCastRewritePattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(vector::ShapeCastOp op, - PatternRewriter &rewriter) const override { - auto sourceVectorType = op.getSourceVectorType(); - auto resultVectorType = op.getResultVectorType(); - if (sourceVectorType.isScalable() || resultVectorType.isScalable()) - return failure(); +/// shape_cast is converted to a sequence of extract, extract_strided_slice, +/// insert_strided_slice, and insert operations. The running example will be: +/// +/// %0 = vector.shape_cast %arg0 : +/// vector<2x2x3x4x7x11xi8> to vector<8x6x7x11xi8> +/// +/// In this example the source and result shapes share a common suffix of 7x11. +/// This means we can always decompose the shape_cast into extract, insert, and +/// their strided equivalents, on vectors with shape suffix 7x11. +/// +/// The greatest common divisor (gcd) of the first dimension preceding the +/// common suffix is gcd(4,6) = 2. The algorithm implemented here will operate +/// on vectors with shapes that are `multiples` of (what we define as) the +/// 'atomic shape', 2x7x11. The atomic shape is `gcd` x `common-suffix`. +/// +/// vector<2x2x3x4x7x11xi8> to +/// vector<8x6x7x11xi8> +/// | |||| +/// | ++++------------> common suffix of 7x11 +/// +-----------------> gcd(4,6) is 2 | | +/// | | | +/// v v v +/// atomic shape <----- 2x7x11 +/// +/// +/// +/// The decomposition implemented in this pattern consists of a sequence of +/// repeated steps: +/// +/// (1) Extract vectors from the suffix of the source. +/// In our example this is 2x2x3x4x7x11 -> 4x7x11. +/// +/// (2) Do extract_strided_slice down to the atomic shape. +/// In our example this is 4x7x11 -> 2x7x11. +/// +/// (3) Do insert_strided_slice to the suffix of the result. +/// In our example this is 2x7x11 -> 6x7x11. +/// +/// (4) insert these vectors into the result vector. +/// In our example this is 6x7x11 -> 8x6x7x11. +/// +/// These steps occur with different periods. In this example +/// (1) occurs 12 times, +/// (2) and (3) occur 24 times, and +/// (4) occurs 8 times. +/// +/// Two special cases are handled independently in this pattern +/// (i) A shape_cast that just does leading 1 insertion/removal +/// (ii) A shape_cast where the gcd is 1. +/// +/// These 2 cases can have more compact IR generated by not using the generic +/// algorithm described above. +/// +class ShapeCastOpRewritePattern : public OpRewritePattern { + + // Case (i) of description. + // Assumes source and result shapes are identical up to some leading ones. + static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast, + PatternRewriter &rewriter) { + + const Location loc = shapeCast.getLoc(); + const VectorType sourceType = shapeCast.getSourceVectorType(); + const VectorType resultType = shapeCast.getResultVectorType(); + + const int64_t sourceRank = sourceType.getRank(); + const int64_t resultRank = resultType.getRank(); + const int64_t delta = sourceRank - resultRank; + const int64_t sourceLeading = delta > 0 ? delta : 0; + const int64_t resultLeading = delta > 0 ? 0 : -delta; + + const Value source = shapeCast.getSource(); + const Value poison = rewriter.create(loc, resultType); + const Value extracted = rewriter.create( + loc, source, SmallVector(sourceLeading, 0)); + const Value result = rewriter.create( + loc, extracted, poison, SmallVector(resultLeading, 0)); + + rewriter.replaceOp(shapeCast, result); + return success(); + } - int64_t srcRank = sourceVectorType.getRank(); - int64_t resRank = resultVectorType.getRank(); - if (srcRank < 2 || resRank != 1) - return failure(); + // Case (ii) of description. + // Assumes a shape_cast where the suffix shape of the source starting at + // `sourceDim` and the suffix shape of the result starting at `resultDim` are + // identical. + static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast, + int64_t sourceDim, + int64_t resultDim, + PatternRewriter &rewriter) { - // Compute the number of 1-D vector elements involved in the reshape. - int64_t numElts = 1; - for (int64_t dim = 0; dim < srcRank - 1; ++dim) - numElts *= sourceVectorType.getDimSize(dim); + const Location loc = shapeCast.getLoc(); - auto loc = op.getLoc(); - SmallVector srcIdx(srcRank - 1, 0); - SmallVector resIdx(resRank, 0); - int64_t extractSize = sourceVectorType.getShape().back(); - Value result = rewriter.create(loc, resultVectorType); + const Value source = shapeCast.getSource(); + const ArrayRef sourceShape = + shapeCast.getSourceVectorType().getShape(); - // Compute the indices of each 1-D vector element of the source extraction - // and destination slice insertion and generate such instructions. - for (int64_t i = 0; i < numElts; ++i) { - if (i != 0) { - incIdx(srcIdx, sourceVectorType, /*step=*/1); - incIdx(resIdx, resultVectorType, /*step=*/extractSize); - } + const VectorType resultType = shapeCast.getResultVectorType(); + const ArrayRef resultShape = resultType.getShape(); - Value extract = - rewriter.create(loc, op.getSource(), srcIdx); - result = rewriter.create( - loc, extract, result, - /*offsets=*/resIdx, /*strides=*/1); - } + const int64_t nSlices = + std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1, + std::multiplies()); - rewriter.replaceOp(op, result); - return success(); - } -}; + SmallVector extractIndex(sourceDim, 0); + SmallVector insertIndex(resultDim, 0); + Value result = rewriter.create(loc, resultType); -/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D -/// vectors progressively. This iterates over the n-1 major dimension of the n-D -/// vector and performs rewrites into: -/// vector.extract_strided_slice from 1-D + vector.insert into n-D -/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. -class ShapeCastOpNDUpCastRewritePattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; + for (int i = 0; i < nSlices; ++i) { + Value extracted = + rewriter.create(loc, source, extractIndex); - LogicalResult matchAndRewrite(vector::ShapeCastOp op, - PatternRewriter &rewriter) const override { - auto sourceVectorType = op.getSourceVectorType(); - auto resultVectorType = op.getResultVectorType(); - if (sourceVectorType.isScalable() || resultVectorType.isScalable()) - return failure(); - - int64_t srcRank = sourceVectorType.getRank(); - int64_t resRank = resultVectorType.getRank(); - if (srcRank != 1 || resRank < 2) - return failure(); - - // Compute the number of 1-D vector elements involved in the reshape. - int64_t numElts = 1; - for (int64_t dim = 0; dim < resRank - 1; ++dim) - numElts *= resultVectorType.getDimSize(dim); - - // Compute the indices of each 1-D vector element of the source slice - // extraction and destination insertion and generate such instructions. - auto loc = op.getLoc(); - SmallVector srcIdx(srcRank, 0); - SmallVector resIdx(resRank - 1, 0); - int64_t extractSize = resultVectorType.getShape().back(); - Value result = rewriter.create(loc, resultVectorType); - for (int64_t i = 0; i < numElts; ++i) { - if (i != 0) { - incIdx(srcIdx, sourceVectorType, /*step=*/extractSize); - incIdx(resIdx, resultVectorType, /*step=*/1); - } + result = rewriter.create(loc, extracted, result, + insertIndex); - Value extract = rewriter.create( - loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize, - /*strides=*/1); - result = rewriter.create(loc, extract, result, resIdx); + inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex); + inplaceAdd(1, resultShape.take_front(resultDim), insertIndex); } - rewriter.replaceOp(op, result); + rewriter.replaceOp(shapeCast, result); return success(); } -}; -// We typically should not lower general shape cast operations into data -// movement instructions, since the assumption is that these casts are -// optimized away during progressive lowering. For completeness, however, -// we fall back to a reference implementation that moves all elements -// into the right place if we get here. -class ShapeCastOpRewritePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto sourceVectorType = op.getSourceVectorType(); - auto resultVectorType = op.getResultVectorType(); + VectorType sourceType = op.getSourceVectorType(); + VectorType resultType = op.getResultVectorType(); + + if (sourceType.isScalable() || resultType.isScalable()) + return rewriter.notifyMatchFailure( + op, + "shape_cast where vectors are scalable not handled by this pattern"); + + const ArrayRef sourceShape = sourceType.getShape(); + const ArrayRef resultShape = resultType.getShape(); + const int64_t sourceRank = sourceType.getRank(); + const int64_t resultRank = resultType.getRank(); + const int64_t numElms = sourceType.getNumElements(); + const Value source = op.getSource(); + + // Set the first dimension (starting at the end) in the source and result + // respectively where the dimension sizes differ. Using the running example: + // + // dimensions: [0 1 2 3 4 5 ] [0 1 2 3 ] + // shapes: (2,2,3,4,7,11) -> (8,6,7,11) + // ^ ^ + // | | + // sourceSuffixStartDim is 3 | + // | + // resultSuffixStartDim is 1 + int64_t sourceSuffixStartDim = sourceRank - 1; + int64_t resultSuffixStartDim = resultRank - 1; + while (sourceSuffixStartDim >= 0 && resultSuffixStartDim >= 0 && + (sourceType.getDimSize(sourceSuffixStartDim) == + resultType.getDimSize(resultSuffixStartDim))) { + --sourceSuffixStartDim; + --resultSuffixStartDim; + } - if (sourceVectorType.isScalable() || resultVectorType.isScalable()) - return failure(); - - // Special case for n-D / 1-D lowerings with better implementations. - int64_t srcRank = sourceVectorType.getRank(); - int64_t resRank = resultVectorType.getRank(); - if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1)) - return failure(); - - // Generic ShapeCast lowering path goes all the way down to unrolled scalar - // extract/insert chains. - int64_t numElts = 1; - for (int64_t r = 0; r < srcRank; r++) - numElts *= sourceVectorType.getDimSize(r); - // Replace with data movement operations: - // x[0,0,0] = y[0,0] - // x[0,0,1] = y[0,1] - // x[0,1,0] = y[0,2] - // etc., incrementing the two index vectors "row-major" - // within the source and result shape. - SmallVector srcIdx(srcRank, 0); - SmallVector resIdx(resRank, 0); - Value result = rewriter.create(loc, resultVectorType); - for (int64_t i = 0; i < numElts; i++) { - if (i != 0) { - incIdx(srcIdx, sourceVectorType); - incIdx(resIdx, resultVectorType); + // This is the case (i) where there are just some leading ones to contend + // with in the source or result. It can be handled with a single + // extract/insert pair. + if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) + return leadingOnesLowering(op, rewriter); + + const int64_t sourceSuffixStartDimSize = + sourceType.getDimSize(sourceSuffixStartDim); + const int64_t resultSuffixStartDimSize = + resultType.getDimSize(resultSuffixStartDim); + const int64_t greatestCommonDivisor = + std::gcd(sourceSuffixStartDimSize, resultSuffixStartDimSize); + const int64_t stridedSliceRank = sourceRank - sourceSuffixStartDim; + const size_t extractPeriod = + sourceSuffixStartDimSize / greatestCommonDivisor; + const size_t insertPeriod = + resultSuffixStartDimSize / greatestCommonDivisor; + + SmallVector atomicShape(sourceShape.begin() + sourceSuffixStartDim, + sourceShape.end()); + atomicShape[0] = greatestCommonDivisor; + + const int64_t numAtomicElms = std::accumulate( + atomicShape.begin(), atomicShape.end(), 1, std::multiplies()); + const size_t nAtomicSlices = numElms / numAtomicElms; + + // This is the case (ii) where the strided dimension size is 1. More compact + // IR is generated in this case if we just extract and insert the elements + // directly. In other words, we don't use extract_strided_slice and + // insert_strided_slice. + if (greatestCommonDivisor == 1) + return noStridedSliceLowering(op, sourceSuffixStartDim + 1, + resultSuffixStartDim + 1, rewriter); + + // The insert_strided_slice result's type + const ArrayRef insertStridedShape = + resultShape.drop_front(resultSuffixStartDim); + const VectorType insertStridedType = + VectorType::get(insertStridedShape, resultType.getElementType()); + + SmallVector extractIndex(sourceSuffixStartDim, 0); + SmallVector insertIndex(resultSuffixStartDim, 0); + SmallVector extractOffsets(stridedSliceRank, 0); + SmallVector insertOffsets(stridedSliceRank, 0); + const SmallVector sizes(stridedSliceRank, 1); + + Value extracted = {}; + Value extractedStrided = {}; + Value insertedSlice = {}; + Value result = rewriter.create(loc, resultType); + const Value partResult = + rewriter.create(loc, insertStridedType); + + for (size_t i = 0; i < nAtomicSlices; ++i) { + + const size_t extractStridedPhase = i % extractPeriod; + const size_t insertStridedPhase = i % insertPeriod; + + // vector.extract + if (extractStridedPhase == 0) { + extracted = + rewriter.create(loc, source, extractIndex); + inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim), + extractIndex); } - Value extract = - rewriter.create(loc, op.getSource(), srcIdx); - result = rewriter.create(loc, extract, result, resIdx); + // vector.extract_strided_slice + extractOffsets[0] = extractStridedPhase * greatestCommonDivisor; + extractedStrided = rewriter.create( + loc, extracted, extractOffsets, atomicShape, sizes); + + // vector.insert_strided_slice + if (insertStridedPhase == 0) { + insertedSlice = partResult; + } + insertOffsets[0] = insertStridedPhase * greatestCommonDivisor; + insertedSlice = rewriter.create( + loc, extractedStrided, insertedSlice, insertOffsets, sizes); + + // vector.insert + if (insertStridedPhase + 1 == insertPeriod) { + result = rewriter.create(loc, insertedSlice, result, + insertIndex); + inplaceAdd(1, resultType.getShape().take_front(resultSuffixStartDim), + insertIndex); + } } rewriter.replaceOp(op, result); return success(); @@ -252,7 +373,8 @@ class ScalableShapeCastOpRewritePattern // from >= 2-D scalable vectors or scalable vectors of fixed vectors. if (!isTrailingDimScalable(sourceVectorType) || !isTrailingDimScalable(resultVectorType)) { - return failure(); + return rewriter.notifyMatchFailure( + op, "trailing dims are not scalable, not handled by this pattern"); } // The sizes of the trailing dimension of the source and result vectors, the @@ -329,8 +451,8 @@ class ScalableShapeCastOpRewritePattern // 4. Increment the insert/extract indices, stepping by minExtractionSize // for the trailing dimensions. - incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize); - incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize); + inplaceAdd(minExtractionSize, sourceVectorType.getShape(), srcIdx); + inplaceAdd(minExtractionSize, resultVectorType.getShape(), resIdx); } rewriter.replaceOp(op, result); @@ -347,8 +469,6 @@ class ScalableShapeCastOpRewritePattern void mlir::vector::populateVectorShapeCastLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); + patterns.add( + patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir index ef32f8c6a1cdb..5011d8b2b2ef6 100644 --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -1,145 +1,392 @@ // RUN: mlir-opt %s --transform-interpreter | FileCheck %s // CHECK-LABEL: func @nop_shape_cast -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// CHECK: return %[[A]] : vector<16xf32> +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> func.func @nop_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<16xf32> return %0 : vector<16xf32> } // CHECK-LABEL: func @cancel_shape_cast -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// CHECK: return %[[A]] : vector<16xf32> - +// CHECK-SAME: %[[A:.*]]: vector<16xf32> +// CHECK: return %[[A]] : vector<16xf32> func.func @cancel_shape_cast(%arg0: vector<16xf32>) -> vector<16xf32> { %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> return %1 : vector<16xf32> } -// Shape up and downcasts for 2-D vectors, for supporting conversion to -// llvm.matrix operations -// CHECK-LABEL: func @shape_casts -func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>) { - // CHECK-DAG: %[[ub22:.*]] = ub.poison : vector<2x2xf32> - // CHECK-DAG: %[[ub:.*]] = ub.poison : vector<4xf32> - // CHECK: %[[ex0:.*]] = vector.extract %{{.*}}[0] : vector<2xf32> from vector<2x2xf32> - // - // CHECK: %[[in0:.*]] = vector.insert_strided_slice %[[ex0]], %[[ub]] - // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> - // - // CHECK: %[[ex1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32> - // - // CHECK: %[[in2:.*]] = vector.insert_strided_slice %[[ex1]], %[[in0]] - // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> - // +// Collapse 2-D to 1-D. +// CHECK-LABEL: func @shape_cast_2d1d +// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>) -> vector<4xf32> { +// CHECK: %[[UB:.*]] = ub.poison : vector<4xf32> +// +// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UB]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> +// +// CHECK: %[[EX1:.*]] = vector.extract %{{.*}}[1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[IN2:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]] +// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: return %[[IN2]] : vector<4xf32> +func.func @shape_cast_2d1d(%a: vector<2x2xf32>) -> (vector<4xf32>) { %0 = vector.shape_cast %a : vector<2x2xf32> to vector<4xf32> - // CHECK: %[[add:.*]] = arith.addf %[[in2]], %[[in2]] : vector<4xf32> - %r0 = arith.addf %0, %0: vector<4xf32> - // - // CHECK: %[[ss0:.*]] = vector.extract_strided_slice %[[add]] - // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : - // CHECK-SAME: vector<4xf32> to vector<2xf32> - // - // CHECK: %[[res0:.*]] = vector.insert %[[ss0]], %[[ub22]] [0] : - // CHECK-SAME: vector<2xf32> into vector<2x2xf32> - // - // CHECK: %[[s2:.*]] = vector.extract_strided_slice %[[add]] - // CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : - // CHECK-SAME: vector<4xf32> to vector<2xf32> - // - // CHECK: %[[res1:.*]] = vector.insert %[[s2]], %[[res0]] [1] : - // CHECK-SAME: vector<2xf32> into vector<2x2xf32> - // - %1 = vector.shape_cast %r0 : vector<4xf32> to vector<2x2xf32> - // CHECK: return %[[add]], %[[res1]] : vector<4xf32>, vector<2x2xf32> - return %r0, %1 : vector<4xf32>, vector<2x2xf32> -} - -// CHECK-LABEL: func @shape_cast_2d2d -// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> -// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : f32 into vector<2x3xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32> -// CHECK: return %[[T11]] : vector<2x3xf32> - -func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { - %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> - return %s : vector<2x3xf32> + return %0 : vector<4xf32> } +// Collapse 3-D to 1-D. // CHECK-LABEL: func @shape_cast_3d1d -// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> -// CHECK: %[[UB:.*]] = ub.poison : vector<6xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32> -// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]] -// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32> -// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]] -// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32> -// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]] -// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32> -// CHECK: return %[[T5]] : vector<6xf32> - +// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32> +// CHECK: %[[UB:.*]] = ub.poison : vector<6xf32> +// +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[UB]] +// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32> +// +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]] +// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32> +// +// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]] +// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32> +// CHECK: return %[[T5]] : vector<6xf32> func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> { %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32> return %s : vector<6xf32> } -// CHECK-LABEL: func @shape_cast_1d3d -// CHECK-SAME: %[[A:.*]]: vector<6xf32> -// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32> -// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]] -// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : vector<3xf32> into vector<2x1x3xf32> -// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]] -// CHECK: {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32> -// CHECK: return %[[T3]] : vector<2x1x3xf32> +// Expand 1-D to 2-D. +// CHECK-LABEL: func.func @shape_cast_1d2d( +// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<2x2xf32> { +// CHECK: %[[UB:.*]] = ub.poison : vector<2x2xf32> +// +// CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[A]] +// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : +// CHECK-SAME: vector<4xf32> to vector<2xf32> +// CHECK: %[[res0:.*]] = vector.insert %[[SS0]], %[[UB]] [0] : +// CHECK-SAME: vector<2xf32> into vector<2x2xf32> +// +// CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[A]] +// CHECK-SAME: {offsets = [2], sizes = [2], strides = [1]} : +// CHECK-SAME: vector<4xf32> to vector<2xf32> +// CHECK: %[[res1:.*]] = vector.insert %[[SS2]], %[[res0]] [1] : +// CHECK-SAME: vector<2xf32> into vector<2x2xf32> +// CHECK: return %[[res1]] : vector<2x2xf32> +func.func @shape_cast_1d2d(%a: vector<4xf32>) -> (vector<2x2xf32>) { + %1 = vector.shape_cast %a: vector<4xf32> to vector<2x2xf32> + return %1 : vector<2x2xf32> +} +// Expand 1-D to 3-D. +// CHECK-LABEL: func @shape_cast_1d3d +// CHECK-SAME: %[[A:.*]]: vector<6xf32> +// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32> +// +// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]] +// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : +// CHECK-SAME: vector<6xf32> to vector<3xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : +// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32> +// +// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]] +// CHECK-SAME: {offsets = [3], sizes = [3], strides = [1]} : +// CHECK-SAME: vector<6xf32> to vector<3xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : +// CHECK-SAME: vector<3xf32> into vector<2x1x3xf32> +// CHECK: return %[[T3]] : vector<2x1x3xf32> func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32> return %s : vector<2x1x3xf32> } -// CHECK-LABEL: func.func @shape_cast_0d1d( -// CHECK-SAME: %[[ARG0:.*]]: vector) -> vector<1xf32> { -// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32> -// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector -// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32> -// CHECK: return %[[RES]] : vector<1xf32> -// CHECK: } +// 2-D to 2-D where the inner-most dimensions have no common factors. This +// case requires scalar element by element extraction and insertion. +// CHECK-LABEL: func @shape_cast_2d2d +// CHECK-SAME: %[[A:.*]]: vector<3x2xf32> +// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32> +// +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32> +// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : +// +// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : +// +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32> +// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : +// +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : +// +// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : +// +// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32> +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : +// +// CHECK: return %[[T11]] : vector<2x3xf32> +func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> { + %s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32> + return %s : vector<2x3xf32> +} +// CHECK-LABEL: func.func @shape_cast_0d1d( +// CHECK-SAME: %[[A:.*]]: vector) -> vector<1xf32> { +// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32> +// +// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][] : f32 from vector +// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : +// CHECK: return %[[RES]] : vector<1xf32> func.func @shape_cast_0d1d(%arg0 : vector) -> vector<1xf32> { %s = vector.shape_cast %arg0 : vector to vector<1xf32> return %s : vector<1xf32> } -// CHECK-LABEL: func.func @shape_cast_1d0d( -// CHECK-SAME: %[[ARG0:.*]]: vector<1xf32>) -> vector { -// CHECK: %[[UB:.*]] = ub.poison : vector -// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32> -// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector -// CHECK: return %[[RES]] : vector -// CHECK: } - +// CHECK-LABEL: func.func @shape_cast_1d0d( +// CHECK-SAME: %[[A:.*]]: vector<1xf32>) -> vector { +// CHECK: %[[UB:.*]] = ub.poison : vector +// +// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[A]][0] : f32 from vector<1xf32> +// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : +// CHECK: return %[[RES]] : vector func.func @shape_cast_1d0d(%arg0 : vector<1xf32>) -> vector { %s = vector.shape_cast %arg0 : vector<1xf32> to vector return %s : vector } +// The shapes have 2 inner dimension sizes in common, so the extract result is rank-2. +// CHECK-LABEL: func.func @squeeze_out_prefix_unit_dim( +// CHECK-SAME: %[[A:.*]]: vector<1x2x3xf32>) -> vector<2x3xf32> { +// +// CHECK: %[[EXTRACTED:.*]] = vector.extract %[[A]][0] : +// CHECK-SAME: vector<2x3xf32> from vector<1x2x3xf32> +// CHECK: return %[[EXTRACTED]] : vector<2x3xf32> +func.func @squeeze_out_prefix_unit_dim(%arg0 : vector<1x2x3xf32>) -> vector<2x3xf32> { + %s = vector.shape_cast %arg0 : vector<1x2x3xf32> to vector<2x3xf32> + return %s : vector<2x3xf32> +} + +// The shapes have 1 inner dimension size in common, so the extract results are rank-1. +// CHECK-LABEL: func.func @squeeze_out_middle_unit_dim( +// CHECK-SAME: %[[A:.*]]: vector<2x1x3xf32>) -> vector<2x3xf32> { +// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32> +// +// CHECK: %[[E0:.*]] = vector.extract %[[A]][0, 0] : vector<3xf32> +// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0] : +// +// CHECK: %[[E1:.*]] = vector.extract %[[A]][1, 0] : vector<3xf32> +// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1] : +// CHECK: return %[[I1]] : vector<2x3xf32> +func.func @squeeze_out_middle_unit_dim(%arg0 : vector<2x1x3xf32>) -> vector<2x3xf32> { + %s = vector.shape_cast %arg0 : vector<2x1x3xf32> to vector<2x3xf32> + return %s : vector<2x3xf32> +} + +// CHECK-LABEL: func.func @prepend_unit_dim( +// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<1x2x3xf32> { +// CHECK: %[[UB:.*]] = ub.poison : vector<1x2x3xf32> +// +// CHECK: %[[I0:.*]] = vector.insert %[[A]], %[[UB]] [0] +// CHECK: return %[[I0]] : vector<1x2x3xf32> +func.func @prepend_unit_dim(%arg0 : vector<2x3xf32>) -> vector<1x2x3xf32> { + %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<1x2x3xf32> + return %s : vector<1x2x3xf32> +} + +// CHECK-LABEL: func.func @insert_middle_unit_dim( +// CHECK-SAME: %[[A:.*]]: vector<2x3xf32>) -> vector<2x1x3xf32> { +// CHECK: %[[UB:.*]] = ub.poison : vector<2x1x3xf32> +// +// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<3xf32> +// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0] : vector<3xf32> +// +// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<3xf32> +// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0] : vector<3xf32> +// CHECK: return %[[I1]] : vector<2x1x3xf32> +func.func @insert_middle_unit_dim(%arg0 : vector<2x3xf32>) -> vector<2x1x3xf32> { + %s = vector.shape_cast %arg0 : vector<2x3xf32> to vector<2x1x3xf32> + return %s : vector<2x1x3xf32> +} + +// CHECK-LABEL: func.func @postpend_unit_dims( +// CHECK-SAME: %[[A:.*]]: vector<4xf32>) -> vector<4x1x1xf32> { +// CHECK: %[[UB:.*]] = ub.poison : vector<4x1x1xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : f32 from vector<4xf32> +// CHECK: %[[I0:.*]] = vector.insert %[[E0]], %[[UB]] [0, 0, 0] : f32 into vector<4x1x1xf32> +// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32> +// CHECK: %[[I1:.*]] = vector.insert %[[E1]], %[[I0]] [1, 0, 0] : f32 into vector<4x1x1xf32> +// CHECK: vector.extract %[[A]][2] +// CHECK: vector.insert {{.*}} [2, 0, 0] +// CHECK: vector.extract %[[A]][3] +// CHECK: vector.insert {{.*}} [3, 0, 0] +// CHECK: return +func.func @postpend_unit_dims(%arg0 : vector<4xf32>) -> vector<4x1x1xf32> { + %s = vector.shape_cast %arg0 : vector<4xf32> to vector<4x1x1xf32> + return %s : vector<4x1x1xf32> +} + +// CHECK-LABEL: func.func @expand_inner_dims( +// CHECK-SAME: %[[A:.*]]: vector<2x10xf32>) -> vector<2x2x5xf32> { +// CHECK: %[[UB:.*]] = ub.poison : vector<2x2x5xf32> +// +// CHECK: %[[E0:.*]] = vector.extract %[[A]][0] : vector<10xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[E0]] +// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32> +// CHECK: %[[I0:.*]] = vector.insert %[[S0]], %[[UB]] [0, 0] +// +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[E0]] +// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32> +// CHECK: %[[I1:.*]] = vector.insert %[[S1]], %[[I0]] [0, 1] +// +// CHECK: %[[E1:.*]] = vector.extract %[[A]][1] : vector<10xf32> +// CHECK: %[[S2:.*]] = vector.extract_strided_slice %[[E1]] +// CHECK-SAME: {offsets = [0], sizes = [5], {{.*}} to vector<5xf32> +// CHECK: %[[I2:.*]] = vector.insert %[[S2]], %[[I1]] [1, 0] +// +// CHECK: %[[S3:.*]] = vector.extract_strided_slice %[[E1]] +// CHECK-SAME: {offsets = [5], sizes = [5], {{.*}} to vector<5xf32> +// CHECK: %[[I3:.*]] = vector.insert %[[S3]], %[[I2]] [1, 1] +// CHECK: return %[[I3]] : vector<2x2x5xf32> +func.func @expand_inner_dims(%arg0 : vector<2x10xf32>) -> vector<2x2x5xf32> { + %s = vector.shape_cast %arg0 : vector<2x10xf32> to vector<2x2x5xf32> + return %s : vector<2x2x5xf32> +} + + +// Some pseudocode describing how this function is lowered: +// +// func collapse_inner_dims(A : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> { +// v0 = empty of shape (10) +// v1 = empty of shape (1,2,1,10) +// v0[0:5] = A[0,0,:] +// v0[5:10] = A[0,1,:] +// v1[0,0,0,:] = v0 +// v0[0:5] = A[1,0,:] +// v0[5:10] = A[1,1,:] +// v1[0,1,0,:] = v0 +// return v1; +// } +// CHECK-LABEL: func.func @collapse_inner_dims( +// CHECK-SAME: %[[A:.*]]: vector<2x2x5xi8>) -> vector<1x2x1x10xi8> { +// CHECK-DAG: %[[UBSMALL:.*]] = ub.poison : vector<10xi8> +// CHECK-DAG: %[[UBLARGE:.*]] = ub.poison : vector<1x2x1x10xi8> +// +// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0, 0] +// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[EX0]], %[[UBSMALL]] +// CHECK-SAME: {offsets = [0], {{.*}} +// CHECK: %[[EX1:.*]] = vector.extract %[[A]][0, 1] +// CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[EX1]], %[[IN0]] +// CHECK-SAME: {offsets = [5], {{.*}} +// CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UBLARGE]] [0, 0, 0] +// +// CHECK: %[[EX2:.*]] = vector.extract %[[A]][1, 0] +// CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[EX2]], %[[UBSMALL]] +// CHECK-SAME: {offsets = [0], {{.*}} +// CHECK: %[[EX3:.*]] = vector.extract %[[A]][1, 1] +// CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[EX3]], %[[IN3]] +// CHECK-SAME: {offsets = [5], {{.*}} +// CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [0, 1, 0] +// CHECK: return %[[IN5]] : vector<1x2x1x10xi8> +func.func @collapse_inner_dims(%arg0 : vector<2x2x5xi8>) -> vector<1x2x1x10xi8> { + %s = vector.shape_cast %arg0 : vector<2x2x5xi8> to vector<1x2x1x10xi8> + return %s : vector<1x2x1x10xi8> +} + +// Some alternative pseudocode describing how this function is lowered: +// +// func non_dividing_gcd_decreasing(A : vector<2x15xi8>) -> vector<3x10xi8> { +// v0 = empty of shape (10) +// v1 = empty of shape (3,10) +// e0 = A[0,:] +// v0[0:5] = e0[0:5] +// v0[5:10] = e0[5:10] +// v1[0,:] = v0 +// v0[0,0:5] = e0[10:15] +// e1 = A[1,:] +// v0[0,5:10] = e1[0:5] +// v1[1,:] = v0 +// v0[0,0:5] = e1[5:10] +// v0[0,5:10] = e1[10:15] +// v1[2,:] = v0 +// return v1; +// } +// CHECK-LABEL: func.func @non_dividing_gcd_decreasing( +// CHECK-SAME: %[[A:.*]]: vector<2x15xi8>) -> vector<3x10xi8> { +// CHECK-DAG: %[[UB0:.*]] = ub.poison : vector<10xi8> +// CHECK-DAG: %[[UB1:.*]] = ub.poison : vector<3x10xi8> +// +// First 10 elements: +// CHECK: %[[EX0:.*]] = vector.extract %[[A]][0] : vector<15xi8> from vector<2x15xi8> +// CHECK: %[[SS0:.*]] = vector.extract_strided_slice %[[EX0]] +// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8> +// CHECK: %[[IN0:.*]] = vector.insert_strided_slice %[[SS0]], %[[UB0]] +// CHECK-SAME: {offsets = [0], {{.*}} +// CHECK: %[[SS1:.*]] = vector.extract_strided_slice %[[EX0]] +// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8> +// CHECK: %[[IN1:.*]] = vector.insert_strided_slice %[[SS1]], %[[IN0]] +// CHECK-SAME: {offsets = [5], {{.*}} +// CHECK: %[[IN2:.*]] = vector.insert %[[IN1]], %[[UB1]] [0] : vector<10xi8> into vector<3x10xi8> +// +// Next 10 elements: +// CHECK: %[[SS2:.*]] = vector.extract_strided_slice %[[EX0]] +// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8> +// CHECK: %[[IN3:.*]] = vector.insert_strided_slice %[[SS2]], %[[UB0]] +// CHECK-SAME: {offsets = [0], {{.*}} +// CHECK: %[[EX1:.*]] = vector.extract %[[A]][1] : vector<15xi8> from vector<2x15xi8> +// CHECK: %[[SS3:.*]] = vector.extract_strided_slice %[[EX1]] +// CHECK-SAME: {offsets = [0], {{.*}} to vector<5xi8> +// CHECK: %[[IN4:.*]] = vector.insert_strided_slice %[[SS3]], %[[IN3]] +// CHECK-SAME: {offsets = [5], {{.*}} +// CHECK: %[[IN5:.*]] = vector.insert %[[IN4]], %[[IN2]] [1] : vector<10xi8> into vector<3x10xi8> +// +// Final 10 elements: +// CHECK: %[[SS4:.*]] = vector.extract_strided_slice %[[EX1]] +// CHECK-SAME: {offsets = [5], {{.*}} to vector<5xi8> +// CHECK: %[[IN6:.*]] = vector.insert_strided_slice %[[SS4]], %[[UB0]] +// CHECK-SAME: {offsets = [0], {{.*}} +// CHECK: %[[SS5:.*]] = vector.extract_strided_slice %[[EX1]] +// CHECK-SAME: {offsets = [10], {{.*}} to vector<5xi8> +// CHECK: %[[IN7:.*]] = vector.insert_strided_slice %[[SS5]], %[[IN6]] +// CHECK-SAME: {offsets = [5], {{.*}} +// CHECK: %[[IN8:.*]] = vector.insert %[[IN7]], %[[IN5]] [2] : vector<10xi8> into vector<3x10xi8> +// CHECK: return %[[IN8]] : vector<3x10xi8> +func.func @non_dividing_gcd_decreasing(%arg0 : vector<2x15xi8>) -> vector<3x10xi8> { + %0 = vector.shape_cast %arg0 : vector<2x15xi8> to vector<3x10xi8> + return %0 : vector<3x10xi8> +} + +// CHECK-LABEL: func.func @non_dividing_gcd_increasing( +// CHECK-SAME: %[[A:.*]]: vector<3x10xi8>) -> vector<2x15xi8> { +// +// CHECK-DAG: ub.poison : vector<15xi8> +// CHECK-DAG: ub.poison : vector<2x15xi8> +// +// Collect the first 15 elements, and insert into the first row of the result. +// CHECK: vector.extract %[[A]][0] +// CHECK: extract_strided_slice +// CHECK: insert_strided_slice +// CHECK: extract_strided_slice +// CHECK: insert_strided_slice +// CHECK: vector.extract %[[A]][1] +// CHECK: extract_strided_slice +// CHECK: insert_strided_slice +// CHECK: vector.insert {{.*}} [0] : vector<15xi8> into vector<2x15xi8> +// +// Collect the next 15 elements, and insert into the second row of the result. +// CHECK: extract_strided_slice +// CHECK: insert_strided_slice +// CHECK: vector.extract %[[A]][2] +// CHECK: extract_strided_slice +// CHECK: insert_strided_slice +// CHECK: extract_strided_slice +// CHECK: insert_strided_slice +// CHECK: vector.insert {{.*}} [1] : vector<15xi8> into vector<2x15xi8> +func.func @non_dividing_gcd_increasing(%arg0 : vector<3x10xi8>) -> vector<2x15xi8> { + %0 = vector.shape_cast %arg0 : vector<3x10xi8> to vector<2x15xi8> + return %0 : vector<2x15xi8> +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op