diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index b69cbabe0dde9..9e8fbbc24acaf 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -253,14 +253,9 @@ struct ExtractStridedMetadataOpSubviewFolder /// Compute the expanded sizes of the given \p expandShape for the /// \p groupId-th reassociation group. /// \p origSizes hold the sizes of the source shape as values. -/// This is used to compute the new sizes in cases of dynamic shapes. -/// -/// sizes#i = -/// baseSizes#groupId / product(expandShapeSizes#j, -/// for j in group excluding reassIdx#i) -/// Where reassIdx#i is the reassociation index at index i in \p groupId. -/// -/// \post result.size() == expandShape.getReassociationIndices()[groupId].size() +/// For static dim sizes, we take the values from the result type +/// of \p expandShape. For dynamic dims, we take the values from the +/// output_shape attribute. /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the @@ -275,42 +270,28 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, unsigned groupSize = reassocGroup.size(); SmallVector expandedSizes(groupSize); - - uint64_t productOfAllStaticSizes = 1; - std::optional dynSizeIdx; MemRefType expandShapeType = expandShape.getResultType(); - - // Fill up all the statically known sizes. + DenseMap dynSizes; + Operation::operand_range dynOutShapes = expandShape.getOutputShape(); + for (unsigned i = 0, dynCount = 0, e = expandShapeType.getRank(); i < e; + i++) { + if (expandShapeType.isDynamicDim(i)) + dynSizes[i] = dynOutShapes[dynCount++]; + } for (unsigned i = 0; i < groupSize; ++i) { - uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); + unsigned index = reassocGroup[i]; + uint64_t dimSize = expandShapeType.getDimSize(index); if (ShapedType::isDynamic(dimSize)) { - assert(!dynSizeIdx && "There must be at most one dynamic size per group"); - dynSizeIdx = i; + expandedSizes[i] = dynSizes[index]; continue; } - productOfAllStaticSizes *= dimSize; expandedSizes[i] = builder.getIndexAttr(dimSize); } - - // Compute the dynamic size using the original size and all the other known - // static sizes: - // expandSize = origSize / productOfAllStaticSizes. - if (dynSizeIdx) { - AffineExpr s0 = builder.getAffineSymbolExpr(0); - expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( - builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes), - origSizes[groupId]); - } - return expandedSizes; } /// Compute the expanded strides of the given \p expandShape for the /// \p groupId-th reassociation group. -/// \p origStrides and \p origSizes hold respectively the strides and sizes -/// of the source shape as values. -/// This is used to compute the strides in cases of dynamic shapes and/or -/// dynamic stride for this reassociation group. /// /// strides#i = /// origStrides#reassDim * product(expandShapeSizes#j, for j in @@ -320,11 +301,8 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, /// and expandShapeSizes#j is either: /// - The constant size at dimension j, derived directly from the result type of /// the expand_shape op, or -/// - An affine expression: baseSizes#reassDim / product of all constant sizes -/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic -/// element.) -/// -/// \post result.size() == expandShape.getReassociationIndices()[groupId].size() +/// - The dynamic size at dimension j, derived from the output_shape attribute +/// of the expand shape op. /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, /// this is not possible because this function uses the Affine dialect and the @@ -334,74 +312,56 @@ SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, ArrayRef origSizes, ArrayRef origStrides, unsigned groupId) { - SmallVector reassocGroup = - expandShape.getReassociationIndices()[groupId]; + auto reassocIndices = expandShape.getReassociationIndices(); + unsigned currIdx = 0; + for (unsigned i = 0; i < groupId; i++) + currIdx += reassocIndices[i].size(); + SmallVector reassocGroup = reassocIndices[groupId]; assert(!reassocGroup.empty() && "Reassociation group should have at least one dimension"); unsigned groupSize = reassocGroup.size(); MemRefType expandShapeType = expandShape.getResultType(); - - std::optional dynSizeIdx; - // Fill up the expanded strides, with the information we can deduce from the // resulting shape. - uint64_t currentStride = 1; + Location loc = expandShape.getLoc(); SmallVector expandedStrides(groupSize); - for (int i = groupSize - 1; i >= 0; --i) { - expandedStrides[i] = builder.getIndexAttr(currentStride); - uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); - if (ShapedType::isDynamic(dimSize)) { - assert(!dynSizeIdx && "There must be at most one dynamic size per group"); - dynSizeIdx = i; - continue; - } - - currentStride *= dimSize; + DenseMap dynSizes; + unsigned dynCount = 0; + Operation::operand_range dynOutShapes = expandShape.getOutputShape(); + for (unsigned i = 0, e = expandShapeType.getRank(); i < e; i++) { + if (expandShapeType.isDynamicDim(i)) + dynSizes[i] = dynOutShapes[dynCount++]; } - - // Collect the statically known information about the original stride. - Value source = expandShape.getSrc(); - auto sourceType = cast(source.getType()); - auto [strides, offset] = sourceType.getStridesAndOffset(); - - OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) - ? origStrides[groupId] - : builder.getIndexAttr(strides[groupId]); - - // Apply the original stride to all the strides. - int64_t doneStrideIdx = 0; - // If we saw a dynamic dimension, we need to fix-up all the strides up to - // that dimension with the dynamic size. - if (dynSizeIdx) { - int64_t productOfAllStaticSizes = currentStride; - assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && - "We shouldn't be able to change dynamicity"); - OpFoldResult origSize = origSizes[groupId]; - - AffineExpr s0 = builder.getAffineSymbolExpr(0); - AffineExpr s1 = builder.getAffineSymbolExpr(1); - for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { - int64_t baseExpandedStride = - cast(cast(expandedStrides[doneStrideIdx])) - .getInt(); - expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( - builder, expandShape.getLoc(), - (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, - {origSize, origStride}); - } - } - - // Now apply the origStride to the remaining dimensions. + OpFoldResult origStride = origStrides[groupId]; AffineExpr s0 = builder.getAffineSymbolExpr(0); - for (; doneStrideIdx < groupSize; ++doneStrideIdx) { - int64_t baseExpandedStride = - cast(cast(expandedStrides[doneStrideIdx])) - .getInt(); - expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( - builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); + AffineExpr s1 = builder.getAffineSymbolExpr(1); + int64_t resultOffset; + SmallVector resultStrides; + (void)expandShapeType.getStridesAndOffset(resultStrides, resultOffset); + expandedStrides[groupSize - 1] = + !ShapedType::isDynamic(resultStrides[currIdx + groupSize - 1]) + ? builder.getIndexAttr(resultStrides[currIdx + groupSize - 1]) + : origStride; + OpFoldResult currentStride = builder.getIndexAttr(1); + for (int i = groupSize - 2; i >= 0; i--) { + unsigned index = reassocGroup[i + 1]; + // Multiply `currentStride` with `dimSize`. + currentStride = + expandShapeType.isDynamicDim(index) + ? makeComposedFoldedAffineApply(builder, loc, s0 * s1, + {currentStride, dynSizes[index]}) + : makeComposedFoldedAffineApply( + builder, loc, s0 * expandShapeType.getDimSize(index), + {currentStride}); + // Multiply `origStride` to all the strides in reassociation current group. + expandedStrides[i] = makeComposedFoldedAffineApply( + builder, loc, s0 * s1, {currentStride, origStride}); + } + for (unsigned i = 0; i < groupSize; i++) { + if (!ShapedType::isDynamic(resultStrides[currIdx + i])) + expandedStrides[i] = builder.getIndexAttr(resultStrides[currIdx + i]); } - return expandedStrides; } diff --git a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir index 5517eafb588e8..c20f44a7daa9a 100644 --- a/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/expand-then-convert-to-llvm.mlir @@ -572,36 +572,30 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>, %sz0: index) -> memref< // CHECK-LABEL: func.func @expand_shape_dynamic( // CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32> { +// CHECK: %[[SZ0_I64:.*]] = builtin.unrealized_conversion_cast %[[SZ0]] : index to i64 // CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, // CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[DESC0:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC0]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC1]][1] : !llvm.struct<(ptr, ptr, i64)> // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK: %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64 -// CHECK: %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64 -// CHECK: %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]] : i64 -// CHECK: %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64 -// CHECK: %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]] : i64 -// CHECK: %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]] : i64 -// CHECK: %[[FINAL_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64 -// CHECK: %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_SIZE2]] : i64 to index -// CHECK: %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64 -// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[C0]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC3:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC3]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC4]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C0]], %[[DESC5]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[C1]], %[[DESC2]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC3]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[C2]], %[[DESC4]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[C1]], %[[DESC6]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC7]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 // In this example stride1 and size2 are the same. // Hence with CSE, we get the same SSA value. -// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC5]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC6]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[C1]], %[[DESC7]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC8]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32> +// CHECK: %[[DESC9:.*]] = llvm.insertvalue %[[C2]], %[[DESC8]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC10:.*]] = llvm.insertvalue %[[SZ0_I64]], %[[DESC9]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[SZ0_I64]], %[[DESC10]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[C1]], %[[DESC11]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC12]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32> // CHECK: return %[[RES]] : memref<1x2x?xf32> // CHECK: } @@ -617,39 +611,33 @@ func.func @expand_shape_dynamic_with_non_identity_layout( } // CHECK-LABEL: func.func @expand_shape_dynamic_with_non_identity_layout( // CHECK-SAME: %[[ARG:.*]]: memref<1x?xf32, strided<[?, ?], offset: ?>>, %[[SZ0:.*]]: index) -> memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> { +// CHECK: %[[SZ0_I64:.*]] = builtin.unrealized_conversion_cast %[[SZ0]] : index to i64 // CHECK: %[[MEM:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<1x?xf32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_BUFFER:.*]] = llvm.extractvalue %[[MEM]][0] : !llvm.struct<(ptr, ptr, i64, // CHECK: %[[ALIGNED_BUFFER:.*]] = llvm.extractvalue %[[MEM]][1] : !llvm.struct<(ptr, ptr, i64, +// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[DESC0:.*]] = llvm.insertvalue %2, %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64)> +// CHECK: %[[DESC1:.*]] = llvm.insertvalue %3, %[[DESC0]][1] : !llvm.struct<(ptr, ptr, i64)> // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %[[MEM]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[SIZE1:.*]] = llvm.extractvalue %[[MEM]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEM]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK: %[[CMINUS1:.*]] = llvm.mlir.constant(-1 : index) : i64 -// CHECK: %[[IS_NEGATIVE_SIZE1:.*]] = llvm.icmp "slt" %[[SIZE1]], %[[C0]] : i64 -// CHECK: %[[ABS_SIZE1_MINUS_1:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE1]] : i64 -// CHECK: %[[ADJ_SIZE1:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[ABS_SIZE1_MINUS_1]], %[[SIZE1]] : i1, i64 -// CHECK: %[[SIZE2:.*]] = llvm.sdiv %[[ADJ_SIZE1]], %[[C2]] : i64 -// CHECK: %[[NEGATIVE_SIZE2:.*]] = llvm.sub %[[CMINUS1]], %[[SIZE2]] : i64 -// CHECK: %[[TMP_SIZE2:.*]] = llvm.select %[[IS_NEGATIVE_SIZE1]], %[[NEGATIVE_SIZE2]], %[[SIZE2]] : i1, i64 -// CHECK: %[[SIZE2_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[TMP_SIZE2]] : i64 to index -// CHECK: %[[FINAL_SIZE2:.*]] = builtin.unrealized_conversion_cast %[[SIZE2_TO_IDX]] : index to i64 -// CHECK: %[[FINAL_STRIDE1:.*]] = llvm.mul %[[TMP_SIZE2]], %[[STRIDE1]] -// CHECK: %[[STRIDE1_TO_IDX:.*]] = builtin.unrealized_conversion_cast %[[FINAL_STRIDE1]] : i64 to index -// CHECK: %[[FINAL_STRIDE1:.*]] = builtin.unrealized_conversion_cast %[[STRIDE1_TO_IDX]] : index to i64 -// CHECK: %[[DESC:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC1]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC2]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %11 = llvm.mul %[[SZ0_I64]], %[[STRIDE1]] overflow : i64 +// CHECK: %[[MUL_INDEX:.*]] = builtin.unrealized_conversion_cast %11 : i64 to index +// CHECK: %[[MUL_I64:.*]] = builtin.unrealized_conversion_cast %[[MUL_INDEX]] : index to i64 +// CHECK: %[[DESC2:.*]] = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[BASE_BUFFER]], %[[DESC2]][0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[ALIGNED_BUFFER]], %[[DESC3]][1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESC4]][2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[C1]], %[[DESC3]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC4]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C2]], %[[DESC5]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[FINAL_STRIDE1]], %[[DESC6]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[FINAL_SIZE2]], %[[DESC7]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[DESC9:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESC8]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> -// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC9]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: %[[DESC6:.*]] = llvm.insertvalue %[[C1]], %[[DESC5]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC7:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESC6]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK: %[[DESC8:.*]] = llvm.insertvalue %[[C2]], %[[DESC7]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC9:.*]] = llvm.insertvalue %[[MUL_I64]], %[[DESC8]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC10:.*]] = llvm.insertvalue %[[SZ0_I64]], %[[DESC9]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESC10]][4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[DESC11]] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> to memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> // CHECK: return %[[RES]] : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> // CHECK: } diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index 647731db439c0..946fbac457a74 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -354,67 +354,22 @@ func.func @extract_strided_metadata_of_subview_all_dynamic( // Check that we properly simplify expand_shape into: // reinterpret_cast(extract_strided_metadata) + -// -// Here we have: -// For the group applying to dim0: -// size 0 = baseSizes#0 / (all static sizes in that group) -// = baseSizes#0 / (7 * 8 * 9) -// = baseSizes#0 / 504 -// size 1 = 7 -// size 2 = 8 -// size 3 = 9 -// stride 0 = baseStrides#0 * 7 * 8 * 9 -// = baseStrides#0 * 504 -// stride 1 = baseStrides#0 * 8 * 9 -// = baseStrides#0 * 72 -// stride 2 = baseStrides#0 * 9 -// stride 3 = baseStrides#0 -// -// For the group applying to dim1: -// size 4 = 10 -// size 5 = 2 -// size 6 = baseSizes#1 / (all static sizes in that group) -// = baseSizes#1 / (10 * 2 * 3) -// = baseSizes#1 / 60 -// size 7 = 3 -// stride 4 = baseStrides#1 * size 5 * size 6 * size 7 -// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3 -// = baseStrides#1 * (baseSizes#1 / 60) * 6 -// and since we know that baseSizes#1 is a multiple of 60: -// = baseStrides#1 * (baseSizes#1 / 10) -// stride 5 = baseStrides#1 * size 6 * size 7 -// = baseStrides#1 * (baseSizes#1 / 60) * 3 -// = baseStrides#1 * (baseSizes#1 / 20) -// stride 6 = baseStrides#1 * size 7 -// = baseStrides#1 * 3 -// stride 7 = baseStrides#1 -// -// Base and offset are unchanged. -// -// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)> -// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)> -// -// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)> -// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)> -// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> -// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)> -// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)> -// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 72)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 * 504)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> ((s1 * s0) * 3)> +// CHECK-DAG: #[[$MAP5:.*]] = affine_map<()[s0, s1] -> ((s1 * s0) * 6)> // CHECK-LABEL: func @simplify_expand_shape -// CHECK-SAME: (%[[ARG:.*]]: memref>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index, %[[sz1:.*]]: index // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref> -> memref, index, index, index, index, index -// -// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0] -// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1] -// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0] -// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0] -// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0] -// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] -// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] -// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1] -// -// CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #map()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #map1()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #map2()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE3:.*]] = affine.apply #map3()[%[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #map4()[%[[STRIDES]]#1, %[[sz1]]] +// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #map5()[%[[STRIDES]]#1, %[[sz1]]] +// CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[sz0]], 7, 8, 9, 10, 2, %[[sz1]], 3], strides: [%[[DYN_STRIDE2]], %[[DYN_STRIDE1]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_STRIDE5]], %[[DYN_STRIDE4]], %[[DYN_STRIDE3]], %[[STRIDES]]#1] // // CHECK: return %[[REINTERPRET_CAST]] func.func @simplify_expand_shape( @@ -525,73 +480,34 @@ func.func @extract_strided_metadata_of_expand_shape_all_static( // 2. We properly compute the strides affected by dynamic shapes. (When the // dynamic dimension is not the first one.) // -// Here we have: -// For the group applying to dim0: -// size 0 = baseSizes#0 / (all static sizes in that group) -// = baseSizes#0 / (7 * 8 * 9) -// = baseSizes#0 / 504 -// size 1 = 7 -// size 2 = 8 -// size 3 = 9 -// stride 0 = baseStrides#0 * 7 * 8 * 9 -// = baseStrides#0 * 504 -// stride 1 = baseStrides#0 * 8 * 9 -// = baseStrides#0 * 72 -// stride 2 = baseStrides#0 * 9 -// stride 3 = baseStrides#0 -// -// For the group applying to dim1: -// size 4 = 10 -// size 5 = 2 -// size 6 = baseSizes#1 / (all static sizes in that group) -// = baseSizes#1 / (10 * 2 * 3) -// = baseSizes#1 / 60 -// size 7 = 3 -// stride 4 = baseStrides#1 * size 5 * size 6 * size 7 -// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3 -// = baseStrides#1 * (baseSizes#1 / 60) * 6 -// and since we know that baseSizes#1 is a multiple of 60: -// = baseStrides#1 * (baseSizes#1 / 10) -// stride 5 = baseStrides#1 * size 6 * size 7 -// = baseStrides#1 * (baseSizes#1 / 60) * 3 -// = baseStrides#1 * (baseSizes#1 / 20) -// stride 6 = baseStrides#1 * size 7 -// = baseStrides#1 * 3 -// stride 7 = baseStrides#1 -// // Base and offset are unchanged. // -// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)> -// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)> -// -// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)> -// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)> -// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> -// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)> -// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)> -// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 72)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 * 504)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> ((s1 * s0) * 3)> +// CHECK-DAG: #[[$MAP5:.*]] = affine_map<()[s0, s1] -> ((s1 * s0) * 6)> // CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic -// CHECK-SAME: (%[[ARG:.*]]: memref>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index, %[[sz1:.*]]: index // +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index // CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index // CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref> -> memref, index, index, index, index, index // -// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0] -// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1] -// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0] -// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0] -// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0] -// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] -// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] -// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1] - -// CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref, index, index, index, index, index, index, index, index, index, index, index, index, index +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$MAP1]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$MAP2]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE3:.*]] = affine.apply #[[$MAP3]]()[%[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$MAP4]]()[%[[STRIDES]]#1, %[[sz1]]] +// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$MAP5]]()[%[[STRIDES]]#1, %[[sz1]]] +// +// CHECK: return %[[BASE]], %[[OFFSET]], %[[sz0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[sz1]], %[[C3]], %[[DYN_STRIDE2]], %[[DYN_STRIDE1]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_STRIDE5]], %[[DYN_STRIDE4]], %[[DYN_STRIDE3]], %[[STRIDES]]#1 : memref, index, index, index, index, index, index, index, index, index, index, index, index, index func.func @extract_strided_metadata_of_expand_shape_all_dynamic( %base: memref>, %offset0: index, %offset1: index, %offset2: index, @@ -620,7 +536,6 @@ func.func @extract_strided_metadata_of_expand_shape_all_dynamic( index, index, index, index, index, index, index, index } - // ----- // Check that we properly handle extract_strided_metadata of expand_shape for @@ -1582,3 +1497,27 @@ func.func @extract_strided_metadata_of_memory_space_cast_no_base(%base: memref<2 // CHECK-LABEL: func @extract_strided_metadata_of_memory_space_cast_no_base // CHECK-NOT: memref.memory_space_cast + +// ----- + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1, s2] -> ((s1 * s2) * s0)> +// CHECK-LABEL: expand_shape_dynamic +// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK: %[[ID0:.*]] = arith.index_cast +// CHECK: %[[ID1:.*]] = arith.index_cast +// CHECK: %[[ID2:.*]] = arith.index_cast +// CHECK: %[[MAP_RES0:.*]] = affine.apply #[[$MAP]]()[%[[ID2]], %[[C256]]] +// CHECK: %[[MAP_RES1:.*]] = affine.apply #[[$MAP1]]()[%[[C256]], %[[ID2]], %[[ID1]]] +// CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [%[[ID0]], %[[ID1]], %[[ID2]], 256], strides: [%[[MAP_RES1]], %[[MAP_RES0]], 256, 1] : memref to memref +// CHECK: return %[[REINTERPRET_CAST]] : memref +func.func @expand_shape_dynamic(%arg0: index, %arg1: i64, %arg2: i64, %arg3: i64) -> memref +{ + %alloc_52 = memref.alloc(%arg0) {alignment = 64 : i64} : memref + %120 = arith.index_cast %arg1 : i64 to index + %121 = arith.index_cast %arg2 : i64 to index + %122 = arith.index_cast %arg3 : i64 to index + %expand_shape = memref.expand_shape %alloc_52 [[0, 1, 2], [3]] output_shape [%120, %121, %122, 256] : memref into memref + return %expand_shape : memref +}