Skip to content

[mlir][vector] Avoid setting padding by default to 0 in vector.transfer_read prefer ub.poison #146088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/Arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc,
Value lhs, Value rhs);

arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred);

/// Creates an `arith.constant` operation with a zero value of type `type`. This
/// method asserts if `type` is invalid for representing zero with
/// `arith.constant`.
Value getZeroConstant(OpBuilder &builder, Location loc, Type type);
} // namespace arith
} // namespace mlir

Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Vector/IR/Vector.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ def Vector_Dialect : Dialect {

let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
let dependentDialects = [
"arith::ArithDialect",
"ub::UBDialect"
];
}

// Base class for Vector dialect ops.
Expand Down
21 changes: 10 additions & 11 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1543,30 +1543,29 @@ def Vector_TransferReadOp :
}];

let builders = [
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
/// 1. Builder that sets padding to `padding` or poison if not provided and
/// an empty mask (variant with attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"std::optional<Value>":$padding,
"AffineMapAttr":$permutationMapAttr,
"ArrayAttr":$inBoundsAttr)>,
/// 2. Builder that sets padding to zero and an empty mask (variant without attrs).
/// 2. Builder that sets padding to `padding` or poison if not provided and
/// an empty mask (variant without attrs).
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"std::optional<Value>":$padding,
"AffineMap":$permutationMap,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
/// 3. Builder that sets padding to `padding` or poison if not provided and
/// permutation map to 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
"Value":$padding,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
OpBuilder<(ins "VectorType":$vectorType,
"Value":$source,
"ValueRange":$indices,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>,
"std::optional<Value>":$padding,
CArg<"std::optional<ArrayRef<bool>>", "::std::nullopt">:$inBounds)>
];

let extraClassDeclaration = [{
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
LLVM_DEBUG(permutationMap.print(dbgs()));

auto transfer = state.builder.create<vector::TransferReadOp>(
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices,
/*padding=*/std::nullopt, permutationMap);

// Register replacement for future uses in the scope.
state.registerOpVectorReplacement(loadOp, transfer);
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,16 @@ bool arith::ConstantIndexOp::classof(Operation *op) {
return false;
}

Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc,
Type type) {
// TODO: Incorporate this check to `FloatAttr::get*`.
assert(!isa<Float8E8M0FNUType>(getElementTypeOrSelf(type)) &&
"type doesn't have a zero representation");
TypedAttr zeroAttr = builder.getZeroAttr(type);
assert(zeroAttr && "unsupported type for zero attribute");
return builder.create<arith::ConstantOp>(loc, zeroAttr);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this utility? It's kind of wrapping a single line statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. It only reduces written code size and IMO makes it clear from the get go what the function is doing, and given it's an important constant I added it. But I can remove it.


//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
// Create the new `transfer_read`.
auto newReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), collapsedVT, collapsedMem, indices,
readOp.getPadding(),
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));

// Cast back to the original vector type.
Expand Down
42 changes: 27 additions & 15 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1183,14 +1183,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
auto srcRank = extractOp.getTensor().getType().getRank();
SmallVector<bool> inBounds(dstRank, true);

// Get the value to pad transfer reads with 0.
Value padding =
arith::getZeroConstant(rewriter, loc, resultType.getElementType());

// 2a. Handle scalar broadcast access.
if (memAccessKind == VectorMemoryAccessKind::ScalarBroadcast) {
MLIRContext *ctx = rewriter.getContext();
SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx));
auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx);

auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs,
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
permutationMap, inBounds);

// Mask this broadcasting xfer_read here rather than relying on the generic
Expand Down Expand Up @@ -1227,8 +1231,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
}

auto transferReadOp = rewriter.create<vector::TransferReadOp>(
loc, resultType, extractOp.getTensor(), transferReadIdxs, permutationMap,
inBounds);
loc, resultType, extractOp.getTensor(), transferReadIdxs, padding,
permutationMap, inBounds);

LDBG("Vectorised as contiguous load: " << extractOp);
return VectorizationHookResult{VectorizationHookStatus::NewOp,
Expand Down Expand Up @@ -1384,7 +1388,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
/// performed to the maximal common vector size implied by the `linalgOp`
/// iteration space. This eager broadcasting is introduced in the
/// permutation_map of the vector.transfer_read operations. The eager
/// broadcasting makes it trivial to detrmine where broadcast, transposes and
/// broadcasting makes it trivial to determine where broadcast, transposes and
/// reductions should occur, without any bookkeeping. The tradeoff is that, in
/// the absence of good canonicalizations, the amount of work increases.
/// This is not deemed a problem as we expect canonicalizations and foldings to
Expand Down Expand Up @@ -1439,7 +1443,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);

Operation *read = rewriter.create<vector::TransferReadOp>(
loc, readType, opOperand->get(), indices, readMap);
loc, readType, opOperand->get(), indices,
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType), readMap);
read = state.maskOperation(rewriter, read, linalgOp, indexingMap);
Value readValue = read->getResult(0);

Expand Down Expand Up @@ -2641,6 +2646,7 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,

Value readValue = rewriter.create<vector::TransferReadOp>(
loc, readType, copyOp.getSource(), indices,
/*padding=*/arith::getZeroConstant(rewriter, loc, srcElementType),
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (cast<VectorType>(readValue.getType()).getRank() == 0) {
readValue =
Expand Down Expand Up @@ -3487,15 +3493,18 @@ struct Conv1DGenerator
SmallVector<Value> resPadding(resShape.size(), zero);

// Read the whole lhs, rhs and res in one shot (with zero padding).
Value lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsShaped,
lhsPadding);
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, lhsPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
// This is needed only for Conv.
Value rhs = nullptr;
if (oper == ConvOperationKind::Conv)
rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
rhsPadding);
Value res = rewriter.create<vector::TransferReadOp>(loc, resType, resShaped,
resPadding);
rhs = rewriter.create<vector::TransferReadOp>(
loc, rhsType, rhsShaped, rhsPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, resPadding,
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));

// The base vectorization case for channeled convolution is input:
// {n,w,c}, weight: {kw,c,f}, output: {n,w,f}. To reuse the base pattern
Expand Down Expand Up @@ -3742,19 +3751,22 @@ struct Conv1DGenerator
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType));
auto maybeMaskedLhs = maybeMaskXferOp(
lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp());

// Read rhs slice of size {kw, c} @ [0, 0].
Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
ValueRange{zero, zero});
Value rhs = rewriter.create<vector::TransferReadOp>(
loc, rhsType, rhsShaped, ValueRange{zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType));
auto maybeMaskedRhs = maybeMaskXferOp(
rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp());

// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
loc, resType, resShaped, ValueRange{zero, zero, zero});
loc, resType, resShaped, ValueRange{zero, zero, zero},
/*padding=*/arith::getZeroConstant(rewriter, loc, resEltType));
auto maybeMaskedRes = maybeMaskXferOp(
resType.getShape(), resType.getScalableDims(), res.getDefiningOp());

Expand Down
39 changes: 18 additions & 21 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4261,33 +4261,39 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, AffineMapAttr permutationMapAttr,
ValueRange indices, std::optional<Value> padding,
AffineMapAttr permutationMapAttr,
/*optional*/ ArrayAttr inBoundsAttr) {

Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>(
result.location, elemType, builder.getZeroAttr(elemType));
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
padding, /*mask=*/Value(), inBoundsAttr);
*padding, /*mask=*/Value(), inBoundsAttr);
}

/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, AffineMap permutationMap,
ValueRange indices, std::optional<Value> padding,
AffineMap permutationMap,
std::optional<ArrayRef<bool>> inBounds) {
auto permutationMapAttr = AffineMapAttr::get(permutationMap);
auto inBoundsAttr = (inBounds && !inBounds.value().empty())
? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
build(builder, result, vectorType, source, indices, permutationMapAttr,
inBoundsAttr);
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, *padding,
permutationMapAttr, inBoundsAttr);
}

/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices, Value padding,
ValueRange indices, std::optional<Value> padding,
std::optional<ArrayRef<bool>> inBounds) {
AffineMap permutationMap = getTransferMinorIdentityMap(
llvm::cast<ShapedType>(source.getType()), vectorType);
Expand All @@ -4296,23 +4302,14 @@ void TransferReadOp::build(OpBuilder &builder, OperationState &result,
? builder.getBoolArrayAttr(inBounds.value())
: builder.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), false));
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
if (!padding)
padding = builder.create<ub::PoisonOp>(result.location, elemType);
build(builder, result, vectorType, source, indices, permutationMapAttr,
padding,
*padding,
/*mask=*/Value(), inBoundsAttr);
}

/// 4. Builder that sets padding to zero and permutation map to
/// 'getMinorIdentityMap'.
void TransferReadOp::build(OpBuilder &builder, OperationState &result,
VectorType vectorType, Value source,
ValueRange indices,
std::optional<ArrayRef<bool>> inBounds) {
Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
Value padding = builder.create<arith::ConstantOp>(
result.location, elemType, builder.getZeroAttr(elemType));
build(builder, result, vectorType, source, indices, padding, inBounds);
}

template <typename EmitFun>
static LogicalResult verifyPermutationMap(AffineMap permutationMap,
EmitFun emitOpError) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ struct DistributedLoadStoreHelper {
}
SmallVector<bool> inBounds(indices.size(), true);
return b.create<vector::TransferReadOp>(
loc, cast<VectorType>(type), buffer, indices,
loc, cast<VectorType>(type), buffer, indices, /*padding=*/std::nullopt,
ArrayRef<bool>(inBounds.begin(), inBounds.end()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,8 @@ class FlattenContiguousRowMajorTransferReadPattern
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
loc, flatVectorType, collapsedSource, collapsedIndices,
transferReadOp.getPadding(), collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));

// 4. Replace the old transfer_read with the new one reading from the
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/Affine/SuperVectorize/vectorize_1d.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func.func @vec1d_1(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for {{.*}} step 128
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%[[C0]])
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i0 = 0 to %M { // vectorized due to scalar -> vector
%a0 = affine.load %A[%c0, %c0] : memref<?x?xf32>
Expand All @@ -47,7 +47,7 @@ func.func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
%P = memref.dim %B, %c2 : memref<?x?x?xf32>

// CHECK:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i3 = 0 to %M { // vectorized
%a3 = affine.load %A[%c0, %i3] : memref<?x?xf32>
Expand Down Expand Up @@ -76,7 +76,7 @@ func.func @vec1d_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK-NEXT: for [[IV9:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %[[APP9_0:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[APP9_1:[0-9a-zA-Z_]+]] = affine.apply {{.*}}([[IV9]], [[IV8]])
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %[[CST:.*]] = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%[[APP9_0]], %[[APP9_1]]], %[[CST]] : memref<?x?xf32>, vector<128xf32>
affine.for %i8 = 0 to %M { // vectorized
affine.for %i9 = 0 to %N {
Expand Down Expand Up @@ -280,7 +280,7 @@ func.func @vec_rejected_3(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {

// CHECK:for [[IV4:%[0-9a-zA-Z_]+]] = 0 to [[ARG_M]] step 128 {
// CHECK-NEXT: for [[IV5:%[0-9a-zA-Z_]*]] = 0 to [[ARG_N]] {
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{[a-zA-Z0-9_]*}} : memref<?x?xf32>, vector<128xf32>
affine.for %i4 = 0 to %M { // vectorized
affine.for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1
Expand Down Expand Up @@ -424,7 +424,7 @@ func.func @vec_rejected_8(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK: %{{.*}} = ub.poison : f32
// CHECK: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %{{.*}} in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
Expand Down Expand Up @@ -458,7 +458,7 @@ func.func @vec_rejected_9(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
// CHECK: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
// CHECK: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = affine.apply #[[$map_id1]](%{{.*}})
// CHECK-NEXT: %{{.*}} = arith.constant 0.0{{.*}}: f32
// CHECK-NEXT: %{{.*}} = ub.poison : f32
// CHECK-NEXT: {{.*}} = vector.transfer_read %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} {permutation_map = #[[$map_proj_d0d1_0]]} : memref<?x?xf32>, vector<128xf32>
affine.for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %{{.*}}
affine.for %i18 = 0 to %M { // vectorized due to scalar -> vector
Expand Down
Loading