From 7fc16e9818d6b8dbaf18341693eb2245bdf6844a Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Fri, 27 Jun 2025 14:16:45 +0000 Subject: [PATCH 1/4] [mlir][vector] Avoid setting padding by default in vector transfer read, prefer ub.poisson Signed-off-by: Fabian Mora --- mlir/include/mlir/Dialect/Arith/IR/Arith.h | 3 ++ mlir/include/mlir/Dialect/Vector/IR/Vector.td | 5 ++- .../mlir/Dialect/Vector/IR/VectorOps.td | 18 ++++----- .../Affine/Transforms/SuperVectorize.cpp | 3 +- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 5 +++ .../Transforms/LegalizeVectorStorage.cpp | 1 + .../Linalg/Transforms/Vectorization.cpp | 38 +++++++++++------- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 39 +++++++++---------- .../Vector/Transforms/VectorDistribute.cpp | 2 +- .../Transforms/VectorTransferOpTransforms.cpp | 3 +- .../Affine/SuperVectorize/vectorize_1d.mlir | 12 +++--- .../vectorize_affine_apply.mlir | 6 +-- .../ArmSVE/legalize-transfer-read.mlir | 8 ++-- .../Vector/vector-transfer-flatten.mlir | 24 ++++++------ .../Vector/vector-warp-distribute.mlir | 4 +- 15 files changed, 95 insertions(+), 76 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index 0bee876ac9bfa..84d1a2535e863 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -154,6 +154,9 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs); arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred); + +/// Creates an `arith.constant` operation with a zero value of type `type`. +Value getZeroConstant(OpBuilder &builder, Location loc, Type type); } // namespace arith } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td index 1922cc63ef353..5125ae7c13717 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/Vector.td +++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td @@ -21,7 +21,10 @@ def Vector_Dialect : Dialect { let useDefaultAttributePrinterParser = 1; let hasConstantMaterializer = 1; - let dependentDialects = ["arith::ArithDialect"]; + let dependentDialects = [ + "arith::ArithDialect", + "ub::UBDialect" + ]; } // Base class for Vector dialect ops. diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index e6b85de5a522a..c1fcc5299416e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1543,30 +1543,28 @@ def Vector_TransferReadOp : }]; let builders = [ - /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). + /// 1. Builder that sets padding to `padding` or poisson if not provided and + /// an empty mask (variant with attrs). OpBuilder<(ins "VectorType":$vectorType, "Value":$source, "ValueRange":$indices, + "std::optional":$padding, "AffineMapAttr":$permutationMapAttr, "ArrayAttr":$inBoundsAttr)>, - /// 2. Builder that sets padding to zero and an empty mask (variant without attrs). + /// 2. Builder that sets padding to `padding` or poisson if not provided and + /// an empty mask (variant without attrs). OpBuilder<(ins "VectorType":$vectorType, "Value":$source, "ValueRange":$indices, + "std::optional":$padding, "AffineMap":$permutationMap, CArg<"std::optional>", "::std::nullopt">:$inBounds)>, /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. OpBuilder<(ins "VectorType":$vectorType, "Value":$source, "ValueRange":$indices, - "Value":$padding, - CArg<"std::optional>", "::std::nullopt">:$inBounds)>, - /// 4. Builder that sets padding to zero and permutation map to - /// 'getMinorIdentityMap'. - OpBuilder<(ins "VectorType":$vectorType, - "Value":$source, - "ValueRange":$indices, - CArg<"std::optional>", "::std::nullopt">:$inBounds)>, + "std::optional":$padding, + CArg<"std::optional>", "::std::nullopt">:$inBounds)> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index f6f192a6d964a..6e8f7126df325 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1257,7 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = state.builder.create( - loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap); + loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, std::nullopt, + permutationMap); // Register replacement for future uses in the scope. state.registerOpVectorReplacement(loadOp, transfer); diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 5194f2b58669a..c9fe579a0b8a9 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -292,6 +292,11 @@ bool arith::ConstantIndexOp::classof(Operation *op) { return false; } +Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc, + Type type) { + return builder.create(loc, builder.getZeroAttr(type)); +} + //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d52ff4d4257c7..3dbb93b8a0669 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -426,6 +426,7 @@ struct LegalizeTransferRead : public OpRewritePattern { // Create the new `transfer_read`. auto newReadOp = rewriter.create( readOp.getLoc(), collapsedVT, collapsedMem, indices, + readOp.getPadding(), ArrayRef(origInBounds).drop_back(numCollapseDims - 1)); // Cast back to the original vector type. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 830ae5414c6bd..444396aaeccfc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1191,6 +1191,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, auto transferReadOp = rewriter.create( loc, resultType, extractOp.getTensor(), transferReadIdxs, + arith::getZeroConstant(rewriter, loc, resultType.getElementType()), permutationMap, inBounds); // Mask this broadcasting xfer_read here rather than relying on the generic @@ -1227,8 +1228,9 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, } auto transferReadOp = rewriter.create( - loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap, - inBounds); + loc, resultType, extractOp.getTensor(), transferReadIdxs, + arith::getZeroConstant(rewriter, loc, resultType.getElementType()), + permutationMap, inBounds); LDBG("Vectorised as contiguous load: " << extractOp); return VectorizationHookResult{VectorizationHookStatus::NewOp, @@ -1384,7 +1386,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, /// performed to the maximal common vector size implied by the `linalgOp` /// iteration space. This eager broadcasting is introduced in the /// permutation_map of the vector.transfer_read operations. The eager -/// broadcasting makes it trivial to detrmine where broadcast, transposes and +/// broadcasting makes it trivial to determine where broadcast, transposes and /// reductions should occur, without any bookkeeping. The tradeoff is that, in /// the absence of good canonicalizations, the amount of work increases. /// This is not deemed a problem as we expect canonicalizations and foldings to @@ -1439,7 +1441,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, SmallVector indices(linalgOp.getShape(opOperand).size(), zero); Operation *read = rewriter.create( - loc, readType, opOperand->get(), indices, readMap); + loc, readType, opOperand->get(), indices, + arith::getZeroConstant(rewriter, loc, elemType), readMap); read = state.maskOperation(rewriter, read, linalgOp, indexingMap); Value readValue = read->getResult(0); @@ -2641,6 +2644,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, Value readValue = rewriter.create( loc, readType, copyOp.getSource(), indices, + arith::getZeroConstant(rewriter, loc, srcElementType), rewriter.getMultiDimIdentityMap(srcType.getRank())); if (cast(readValue.getType()).getRank() == 0) { readValue = @@ -3487,15 +3491,18 @@ struct Conv1DGenerator SmallVector resPadding(resShape.size(), zero); // Read the whole lhs, rhs and res in one shot (with zero padding). - Value lhs = rewriter.create(loc, lhsType, lhsShaped, - lhsPadding); + Value lhs = rewriter.create( + loc, lhsType, lhsShaped, lhsPadding, + arith::getZeroConstant(rewriter, loc, lhsEltType)); // This is needed only for Conv. Value rhs = nullptr; if (oper == ConvOperationKind::Conv) - rhs = rewriter.create(loc, rhsType, rhsShaped, - rhsPadding); - Value res = rewriter.create(loc, resType, resShaped, - resPadding); + rhs = rewriter.create( + loc, rhsType, rhsShaped, rhsPadding, + arith::getZeroConstant(rewriter, loc, rhsEltType)); + Value res = rewriter.create( + loc, resType, resShaped, resPadding, + arith::getZeroConstant(rewriter, loc, resEltType)); // The base vectorization case for channeled convolution is input: // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern @@ -3742,19 +3749,22 @@ struct Conv1DGenerator // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, // 0]. Value lhs = rewriter.create( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}); + loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}, + arith::getZeroConstant(rewriter, loc, lhsEltType)); auto maybeMaskedLhs = maybeMaskXferOp( lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); // Read rhs slice of size {kw, c} @ [0, 0]. - Value rhs = rewriter.create(loc, rhsType, rhsShaped, - ValueRange{zero, zero}); + Value rhs = rewriter.create( + loc, rhsType, rhsShaped, ValueRange{zero, zero}, + arith::getZeroConstant(rewriter, loc, rhsEltType)); auto maybeMaskedRhs = maybeMaskXferOp( rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); // Read res slice of size {n, w, c} @ [0, 0, 0]. Value res = rewriter.create( - loc, resType, resShaped, ValueRange{zero, zero, zero}); + loc, resType, resShaped, ValueRange{zero, zero, zero}, + arith::getZeroConstant(rewriter, loc, resEltType)); auto maybeMaskedRes = maybeMaskXferOp( resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a11dbe2589205..fc7ed7e479b49 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4261,33 +4261,39 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns( /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). void TransferReadOp::build(OpBuilder &builder, OperationState &result, VectorType vectorType, Value source, - ValueRange indices, AffineMapAttr permutationMapAttr, + ValueRange indices, std::optional padding, + AffineMapAttr permutationMapAttr, /*optional*/ ArrayAttr inBoundsAttr) { + Type elemType = llvm::cast(source.getType()).getElementType(); - Value padding = builder.create( - result.location, elemType, builder.getZeroAttr(elemType)); + if (!padding) + padding = builder.create(result.location, elemType); build(builder, result, vectorType, source, indices, permutationMapAttr, - padding, /*mask=*/Value(), inBoundsAttr); + *padding, /*mask=*/Value(), inBoundsAttr); } /// 2. Builder that sets padding to zero an empty mask (variant without attrs). void TransferReadOp::build(OpBuilder &builder, OperationState &result, VectorType vectorType, Value source, - ValueRange indices, AffineMap permutationMap, + ValueRange indices, std::optional padding, + AffineMap permutationMap, std::optional> inBounds) { auto permutationMapAttr = AffineMapAttr::get(permutationMap); auto inBoundsAttr = (inBounds && !inBounds.value().empty()) ? builder.getBoolArrayAttr(inBounds.value()) : builder.getBoolArrayAttr( SmallVector(vectorType.getRank(), false)); - build(builder, result, vectorType, source, indices, permutationMapAttr, - inBoundsAttr); + Type elemType = llvm::cast(source.getType()).getElementType(); + if (!padding) + padding = builder.create(result.location, elemType); + build(builder, result, vectorType, source, indices, *padding, + permutationMapAttr, inBoundsAttr); } /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. void TransferReadOp::build(OpBuilder &builder, OperationState &result, VectorType vectorType, Value source, - ValueRange indices, Value padding, + ValueRange indices, std::optional padding, std::optional> inBounds) { AffineMap permutationMap = getTransferMinorIdentityMap( llvm::cast(source.getType()), vectorType); @@ -4296,23 +4302,14 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result, ? builder.getBoolArrayAttr(inBounds.value()) : builder.getBoolArrayAttr( SmallVector(vectorType.getRank(), false)); + Type elemType = llvm::cast(source.getType()).getElementType(); + if (!padding) + padding = builder.create(result.location, elemType); build(builder, result, vectorType, source, indices, permutationMapAttr, - padding, + *padding, /*mask=*/Value(), inBoundsAttr); } -/// 4. Builder that sets padding to zero and permutation map to -/// 'getMinorIdentityMap'. -void TransferReadOp::build(OpBuilder &builder, OperationState &result, - VectorType vectorType, Value source, - ValueRange indices, - std::optional> inBounds) { - Type elemType = llvm::cast(source.getType()).getElementType(); - Value padding = builder.create( - result.location, elemType, builder.getZeroAttr(elemType)); - build(builder, result, vectorType, source, indices, padding, inBounds); -} - template static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError) { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index af90ed8f5deaf..ba9f39c6393ce 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper { } SmallVector inBounds(indices.size(), true); return b.create( - loc, cast(type), buffer, indices, + loc, cast(type), buffer, indices, std::nullopt, ArrayRef(inBounds.begin(), inBounds.end())); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 785a8aaf3f0a9..efdae93e730bd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -660,7 +660,8 @@ class FlattenContiguousRowMajorTransferReadPattern VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, vectorType.getElementType()); vector::TransferReadOp flatRead = rewriter.create( - loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); + loc, flatVectorType, collapsedSource, collapsedIndices, + transferReadOp.getPadding(), collapsedMap); flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); // 4. Replace the old transfer_read with the new one reading from the diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir index 81b04ccceaf27..72ced5b53879b 100644 --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir @@ -21,7 +21,7 @@ func.func @vec1d_1(%A : memref, %B : memref) { // CHECK: for {{.*}} step 128 // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]]) // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]]) -// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32 +// CHECK-NEXT: %{{.*}} = ub.poison : f32 // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref, vector<128xf32> affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector %a0 = affine.load %A[%c0, %c0] : memref @@ -47,7 +47,7 @@ func.func @vec1d_2(%A : memref, %B : memref) { %P = memref.dim %B, %c2 : memref // CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 -// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32 +// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32 // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %[[CST]] : memref, vector<128xf32> affine.for %i3 = 0 to %M { // vectorized %a3 = affine.load %A[%c0, %i3] : memref @@ -76,7 +76,7 @@ func.func @vec1d_3(%A : memref, %B : memref) { // CHECK-NEXT: for [[IV9:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] { // CHECK-NEXT: %[[APP9_0:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]]) // CHECK-NEXT: %[[APP9_1:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]]) -// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32 +// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32 // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%[[APP9_0]], %[[APP9_1]]], %[[CST]] : memref, vector<128xf32> affine.for %i8 = 0 to %M { // vectorized affine.for %i9 = 0 to %N { @@ -280,7 +280,7 @@ func.func @vec_rejected_3(%A : memref, %B : memref) { // CHECK:for [[IV4:%[0-9a-zA-Z_]+]] = 0 to [[ARG_M]] step 128 { // CHECK-NEXT: for [[IV5:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] { -// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32 +// CHECK-NEXT: %{{.*}} = ub.poison : f32 // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{[a-zA-Z0-9_]*}} : memref, vector<128xf32> affine.for %i4 = 0 to %M { // vectorized affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 @@ -424,7 +424,7 @@ func.func @vec_rejected_8(%A : memref, %B : memref) { // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}}) // CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}}) -// CHECK: %{{.*}} = arith.constant 0.0{{.*}}: f32 +// CHECK: %{{.*}} = ub.poison : f32 // CHECK: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref, vector<128xf32> affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}} affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector @@ -458,7 +458,7 @@ func.func @vec_rejected_9(%A : memref, %B : memref) { // CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}}) // CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}}) -// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32 +// CHECK-NEXT: %{{.*}} = ub.poison : f32 // CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref, vector<128xf32> affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}} affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir index 15a7133cf0f65..7d4d111c09799 100644 --- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir +++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_affine_apply.mlir @@ -11,7 +11,7 @@ func.func @vec_affine_apply(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48xf3 // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 { // CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID0]](%[[ARG3]]) // CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID1]](%[[ARG4]]) -// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32 // CHECK-NEXT: %[[S2:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[S0]], %[[S1]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32> // CHECK-NEXT: vector.transfer_write %[[S2]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32> // CHECK-NEXT: } @@ -42,7 +42,7 @@ func.func @vec_affine_apply_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24x48x // CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 12 { // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 48 step 8 { // CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID2]](%[[ARG4]]) -// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32 // CHECK-NEXT: %[[S1:.*]] = vector.transfer_read %[[ARG0]][%[[ARG2]], %[[ARG3]], %[[S0]]], %[[CST]] : memref<8x12x16xf32>, vector<8xf32> // CHECK-NEXT: vector.transfer_write %[[S1]], %[[ARG1]][%[[ARG2]], %[[ARG3]], %[[ARG4]]] : vector<8xf32>, memref<8x24x48xf32> // CHECK-NEXT: } @@ -140,7 +140,7 @@ func.func @affine_map_with_expr_2(%arg0: memref<8x12x16xf32>, %arg1: memref<8x24 // CHECK-NEXT: %[[S0:.*]] = affine.apply #[[$MAP_ID3]](%[[ARG3]], %[[ARG4]], %[[I0]]) // CHECK-NEXT: %[[S1:.*]] = affine.apply #[[$MAP_ID4]](%[[ARG3]], %[[ARG4]], %[[I0]]) // CHECK-NEXT: %[[S2:.*]] = affine.apply #[[$MAP_ID5]](%[[ARG3]], %[[ARG4]], %[[I0]]) -// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32 // CHECK-NEXT: %[[S3:.*]] = vector.transfer_read %[[ARG0]][%[[S0]], %[[S1]], %[[S2]]], %[[CST]] {permutation_map = #[[$MAP_ID6]]} : memref<8x12x16xf32>, vector<8xf32> // CHECK-NEXT: vector.transfer_write %[[S3]], %[[ARG1]][%[[ARG3]], %[[ARG4]], %[[ARG5]]] : vector<8xf32>, memref<8x24x48xf32> // CHECK-NEXT: } diff --git a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir index 5f923cdafb956..49bd2eddbdedd 100644 --- a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir +++ b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir @@ -11,8 +11,8 @@ // CHECK-LABEL: @base_case // CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]: -// CHECK: %[[PAD:.+]] = arith.constant 0 : i8 -// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME: : memref into memref @@ -36,8 +36,8 @@ func.func @base_case(%i : index, %j : index, %M : memref) -> vector< // CHECK-LABEL: @with_3d_vector // CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]: -// CHECK: %[[PAD:.+]] = arith.constant 0 : i8 -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-DAG: %[[PAD:.+]] = arith.constant 123 : i8 +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] // CHECK-SAME: : memref into memref // CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %[[PAD]] {in_bounds = [true]} diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 0f04d3b79b535..d18edd0ac5563 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -85,8 +85,8 @@ func.func @transfer_read_dims_mismatch_contiguous( // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous( // CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> { -// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 -// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]] // CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}> @@ -116,8 +116,8 @@ func.func @transfer_read_dims_mismatch_contiguous_unit_dims( // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims( // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>) // CHECK-SAME: -> vector<1x1x4x3x2xi8> -// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 -// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]] // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>> @@ -149,8 +149,8 @@ func.func @transfer_read_non_contiguous_unit_dims( // CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims( // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> { -// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8 -// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i8 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>> @@ -182,8 +182,8 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index, // CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32> -// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 -// CHECK: %[[C_0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C_0:.+]] = arith.constant 0 : index // CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32> @@ -241,8 +241,8 @@ func.func @transfer_read_leading_dynamic_dims( // CHECK-LABEL: func @transfer_read_leading_dynamic_dims // CHECK-SAME: %[[MEM:.+]]: memref, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index -// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 -// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_I8:.+]] = arith.constant 0 : i8 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME: : memref into memref @@ -304,8 +304,8 @@ func.func @transfer_read_dynamic_dim_to_flatten( // CHECK-SAME: %[[IDX_1:arg0]] // CHECK-SAME: %[[IDX_2:arg1]] // CHECK-SAME: %[[MEM:arg2]] -// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32 -// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_I32:.+]] = arith.constant 0 : i32 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] // CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32> diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index 7cfbcdf101d11..1161dbd4b2166 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1132,8 +1132,8 @@ func.func @warp_execute_nd_distribute(%laneid: index, %v0: vector<1x64x1xf32>, % // CHECK-SCF-IF: gpu.barrier // CHECK-SCF-IF: %[[WID:.*]] = affine.apply #[[$TIMES2]]()[%[[LANEID]]] - // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32> - // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %cst {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32> + // CHECK-SCF-IF-DAG: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[LANEID]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<32x64x1xf32, 3>, vector<1x64x1xf32> + // CHECK-SCF-IF-DAG: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[WID]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : memref<1x64x128xf32, 3>, vector<1x2x128xf32> // CHECK-SCF-IF: return %[[R0]], %[[R1]] : vector<1x64x1xf32>, vector<1x2x128xf32> return %r#0, %r#1 : vector<1x64x1xf32>, vector<1x2x128xf32> } From 4575f849bf76575282ea4560724a44502d09c98e Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Fri, 27 Jun 2025 16:11:36 +0000 Subject: [PATCH 2/4] address reviewer comments --- .../mlir/Dialect/Vector/IR/VectorOps.td | 3 ++- .../Affine/Transforms/SuperVectorize.cpp | 4 +-- .../Linalg/Transforms/Vectorization.cpp | 26 ++++++++++--------- .../Vector/Transforms/VectorDistribute.cpp | 2 +- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index c1fcc5299416e..9c111914fe518 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1559,7 +1559,8 @@ def Vector_TransferReadOp : "std::optional":$padding, "AffineMap":$permutationMap, CArg<"std::optional>", "::std::nullopt">:$inBounds)>, - /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. + /// 3. Builder that sets padding to `padding` or poisson if not provided and + /// permutation map to 'getMinorIdentityMap'. OpBuilder<(ins "VectorType":$vectorType, "Value":$source, "ValueRange":$indices, diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 6e8f7126df325..7fae260767e0a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -1257,8 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = state.builder.create( - loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, std::nullopt, - permutationMap); + loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, + /*padding=*/std::nullopt, permutationMap); // Register replacement for future uses in the scope. state.registerOpVectorReplacement(loadOp, transfer); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 444396aaeccfc..b467114c72f7d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1183,6 +1183,10 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, auto srcRank = extractOp.getTensor().getType().getRank(); SmallVector inBounds(dstRank, true); + // Get the value to pad transfer reads with 0. + Value padding = + arith::getZeroConstant(rewriter, loc, resultType.getElementType()); + // 2a. Handle scalar broadcast access. if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) { MLIRContext *ctx = rewriter.getContext(); @@ -1190,8 +1194,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx); auto transferReadOp = rewriter.create( - loc, resultType, extractOp.getTensor(), transferReadIdxs, - arith::getZeroConstant(rewriter, loc, resultType.getElementType()), + loc, resultType, extractOp.getTensor(), transferReadIdxs, padding, permutationMap, inBounds); // Mask this broadcasting xfer_read here rather than relying on the generic @@ -1228,8 +1231,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, } auto transferReadOp = rewriter.create( - loc, resultType, extractOp.getTensor(), transferReadIdxs, - arith::getZeroConstant(rewriter, loc, resultType.getElementType()), + loc, resultType, extractOp.getTensor(), transferReadIdxs, padding, permutationMap, inBounds); LDBG("Vectorised as contiguous load: " << extractOp); @@ -1442,7 +1444,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, Operation *read = rewriter.create( loc, readType, opOperand->get(), indices, - arith::getZeroConstant(rewriter, loc, elemType), readMap); + /*padding=*/arith::getZeroConstant(rewriter, loc, elemType), readMap); read = state.maskOperation(rewriter, read, linalgOp, indexingMap); Value readValue = read->getResult(0); @@ -2644,7 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, Value readValue = rewriter.create( loc, readType, copyOp.getSource(), indices, - arith::getZeroConstant(rewriter, loc, srcElementType), + /*padding=*/arith::getZeroConstant(rewriter, loc, srcElementType), rewriter.getMultiDimIdentityMap(srcType.getRank())); if (cast(readValue.getType()).getRank() == 0) { readValue = @@ -3493,16 +3495,16 @@ struct Conv1DGenerator // Read the whole lhs, rhs and res in one shot (with zero padding). Value lhs = rewriter.create( loc, lhsType, lhsShaped, lhsPadding, - arith::getZeroConstant(rewriter, loc, lhsEltType)); + /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); // This is needed only for Conv. Value rhs = nullptr; if (oper == ConvOperationKind::Conv) rhs = rewriter.create( loc, rhsType, rhsShaped, rhsPadding, - arith::getZeroConstant(rewriter, loc, rhsEltType)); + /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); Value res = rewriter.create( loc, resType, resShaped, resPadding, - arith::getZeroConstant(rewriter, loc, resEltType)); + /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); // The base vectorization case for channeled convolution is input: // {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern @@ -3750,21 +3752,21 @@ struct Conv1DGenerator // 0]. Value lhs = rewriter.create( loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}, - arith::getZeroConstant(rewriter, loc, lhsEltType)); + /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); auto maybeMaskedLhs = maybeMaskXferOp( lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); // Read rhs slice of size {kw, c} @ [0, 0]. Value rhs = rewriter.create( loc, rhsType, rhsShaped, ValueRange{zero, zero}, - arith::getZeroConstant(rewriter, loc, rhsEltType)); + /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); auto maybeMaskedRhs = maybeMaskXferOp( rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); // Read res slice of size {n, w, c} @ [0, 0, 0]. Value res = rewriter.create( loc, resType, resShaped, ValueRange{zero, zero, zero}, - arith::getZeroConstant(rewriter, loc, resEltType)); + /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); auto maybeMaskedRes = maybeMaskXferOp( resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index ba9f39c6393ce..fb99e22c77ea0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper { } SmallVector inBounds(indices.size(), true); return b.create( - loc, cast(type), buffer, indices, std::nullopt, + loc, cast(type), buffer, indices, /*padding=*/std::nullopt, ArrayRef(inBounds.begin(), inBounds.end())); } From 0ce2576b38f9ba2ee07c8c5e8e68dfa6d9be00dc Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Fri, 27 Jun 2025 16:28:57 +0000 Subject: [PATCH 3/4] address reviewer comments --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 9c111914fe518..dfb2756e57bea 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1543,7 +1543,7 @@ def Vector_TransferReadOp : }]; let builders = [ - /// 1. Builder that sets padding to `padding` or poisson if not provided and + /// 1. Builder that sets padding to `padding` or poison if not provided and /// an empty mask (variant with attrs). OpBuilder<(ins "VectorType":$vectorType, "Value":$source, @@ -1551,7 +1551,7 @@ def Vector_TransferReadOp : "std::optional":$padding, "AffineMapAttr":$permutationMapAttr, "ArrayAttr":$inBoundsAttr)>, - /// 2. Builder that sets padding to `padding` or poisson if not provided and + /// 2. Builder that sets padding to `padding` or poison if not provided and /// an empty mask (variant without attrs). OpBuilder<(ins "VectorType":$vectorType, "Value":$source, @@ -1559,7 +1559,7 @@ def Vector_TransferReadOp : "std::optional":$padding, "AffineMap":$permutationMap, CArg<"std::optional>", "::std::nullopt">:$inBounds)>, - /// 3. Builder that sets padding to `padding` or poisson if not provided and + /// 3. Builder that sets padding to `padding` or poison if not provided and /// permutation map to 'getMinorIdentityMap'. OpBuilder<(ins "VectorType":$vectorType, "Value":$source, From bfdbd714b6ef59bd35861b707804029d76cc08be Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Mon, 30 Jun 2025 18:18:10 +0000 Subject: [PATCH 4/4] add assertions --- mlir/include/mlir/Dialect/Arith/IR/Arith.h | 4 +++- mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index 84d1a2535e863..c2de3c9021b0b 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -155,7 +155,9 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred); -/// Creates an `arith.constant` operation with a zero value of type `type`. +/// Creates an `arith.constant` operation with a zero value of type `type`. This +/// method asserts if `type` is invalid for representing zero with +/// `arith.constant`. Value getZeroConstant(OpBuilder &builder, Location loc, Type type); } // namespace arith } // namespace mlir diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index c9fe579a0b8a9..ec96a35ae82e7 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -294,7 +294,12 @@ bool arith::ConstantIndexOp::classof(Operation *op) { Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc, Type type) { - return builder.create(loc, builder.getZeroAttr(type)); + // TODO: Incorporate this check to `FloatAttr::get*`. + assert(!isa(getElementTypeOrSelf(type)) && + "type doesn't have a zero representation"); + TypedAttr zeroAttr = builder.getZeroAttr(type); + assert(zeroAttr && "unsupported type for zero attribute"); + return builder.create(loc, zeroAttr); } //===----------------------------------------------------------------------===//