diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index fbbf817ecff98..49ccc0f41a1a5 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -718,6 +718,8 @@ def Vector_ExtractOp : let results = (outs AnyType:$result); let builders = [ + // Builder to extract a scalar from a rank-0 vector. + OpBuilder<(ins "Value":$source)>, OpBuilder<(ins "Value":$source, "int64_t":$position)>, OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>, OpBuilder<(ins "Value":$source, "ArrayRef":$position)>, @@ -913,6 +915,8 @@ def Vector_InsertOp : let results = (outs AnyVectorOfAnyRank:$result); let builders = [ + // Builder to insert a scalar/rank-0 vector into a rank-0 vector. + OpBuilder<(ins "Value":$source, "Value":$dest)>, OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>, OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>, OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef":$position)>, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d4c1da30d498d..51b9219944288 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -560,11 +560,9 @@ struct ElideUnitDimsInMultiDimReduction } else { // This means we are reducing all the dimensions, and all reduction // dimensions are of size 1. So a simple extraction would do. - SmallVector zeroIdx(shape.size(), 0); if (mask) - mask = rewriter.create(loc, mask, zeroIdx); - cast = rewriter.create(loc, reductionOp.getSource(), - zeroIdx); + mask = rewriter.create(loc, mask); + cast = rewriter.create(loc, reductionOp.getSource()); } Value result = @@ -698,16 +696,9 @@ struct ElideSingleElementReduction : public OpRewritePattern { return failure(); Location loc = reductionOp.getLoc(); - Value result; - if (vectorType.getRank() == 0) { - if (mask) - mask = rewriter.create(loc, mask); - result = rewriter.create(loc, reductionOp.getVector()); - } else { - if (mask) - mask = rewriter.create(loc, mask, 0); - result = rewriter.create(loc, reductionOp.getVector(), 0); - } + if (mask) + mask = rewriter.create(loc, mask); + Value result = rewriter.create(loc, reductionOp.getVector()); if (Value acc = reductionOp.getAcc()) result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), @@ -1294,6 +1285,12 @@ void ExtractOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } +void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, + Value source) { + auto vectorTy = cast(source.getType()); + build(builder, result, source, SmallVector(vectorTy.getRank(), 0)); +} + void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t position) { build(builder, result, source, ArrayRef{position}); @@ -2916,6 +2913,13 @@ void vector::InsertOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); } +void vector::InsertOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest) { + auto vectorTy = cast(dest.getType()); + build(builder, result, source, dest, + SmallVector(vectorTy.getRank(), 0)); +} + void vector::InsertOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, int64_t position) { build(builder, result, source, dest, ArrayRef{position}); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index fec3c6c52e5e4..11dcfe421e0c4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -52,11 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern { // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. if (srcRank <= 1 && dstRank == 1) { - Value ext; - if (srcRank == 0) - ext = rewriter.create(loc, op.getSource()); - else - ext = rewriter.create(loc, op.getSource(), 0); + Value ext = rewriter.create(loc, op.getSource()); rewriter.replaceOpWithNewOp(op, dstType, ext); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 9c1e5fcee91de..23324a007377e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -189,25 +189,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern { incIdx(resIdx, resultVectorType); } - Value extract; - if (srcRank == 0) { - // 0-D vector special case - assert(srcIdx.empty() && "Unexpected indices for 0-D vector"); - extract = rewriter.create( - loc, op.getSourceVectorType().getElementType(), op.getSource()); - } else { - extract = - rewriter.create(loc, op.getSource(), srcIdx); - } - - if (resRank == 0) { - // 0-D vector special case - assert(resIdx.empty() && "Unexpected indices for 0-D vector"); - result = rewriter.create(loc, extract, result); - } else { - result = - rewriter.create(loc, extract, result, resIdx); - } + Value extract = + rewriter.create(loc, op.getSource(), srcIdx); + result = rewriter.create(loc, extract, result, resIdx); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 8c9e2d889808a..62dfd439b0ad1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -929,17 +929,8 @@ class RewriteScalarWrite : public OpRewritePattern { if (!xferOp.getPermutationMap().isMinorIdentity()) return failure(); // Only float and integer element types are supported. - Value scalar; - if (vecType.getRank() == 0) { - // vector.extract does not support vector etc., so use - // vector.extractelement instead. - scalar = rewriter.create(xferOp.getLoc(), - xferOp.getVector()); - } else { - SmallVector pos(vecType.getRank(), 0); - scalar = rewriter.create(xferOp.getLoc(), - xferOp.getVector(), pos); - } + Value scalar = + rewriter.create(xferOp.getLoc(), xferOp.getVector()); // Construct a scalar store. if (isa(xferOp.getSource().getType())) { rewriter.replaceOpWithNewOp( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 5404fdda033ee..992bc93aea959 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -187,7 +187,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector) -> vector<3x2xf32> { // CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> // CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32> // CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32> // CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]] // CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]] diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index bf755b466c7eb..8bb6593d99058 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2658,7 +2658,7 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 { // CHECK-LABEL: func.func @fold_0d_vector_reduction func.func @fold_0d_vector_reduction(%arg0: vector) -> f32 { - // CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector + // CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector // CHECK-NEXT: return %[[RES]] : f32 %0 = vector.reduction , %arg0 : vector into f32 return %0 : f32 diff --git a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir index b4ebb14b8829e..52b0fdee184f6 100644 --- a/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir +++ b/mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir @@ -45,9 +45,7 @@ func.func @tensor_transfer_read_0d(%t: tensor, %idx: index) -> f32 { // CHECK-LABEL: func @transfer_write_0d( // CHECK-SAME: %[[m:.*]]: memref, %[[idx:.*]]: index, %[[f:.*]]: f32 -// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector -// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector -// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] +// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]] func.func @transfer_write_0d(%m: memref, %idx: index, %f: f32) { %0 = vector.broadcast %f : f32 to vector vector.transfer_write %0, %m[%idx, %idx, %idx] : vector, memref @@ -69,9 +67,7 @@ func.func @transfer_write_1d(%m: memref, %idx: index, %f: f32) { // CHECK-LABEL: func @tensor_transfer_write_0d( // CHECK-SAME: %[[t:.*]]: tensor, %[[idx:.*]]: index, %[[f:.*]]: f32 -// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector -// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector -// CHECK: %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]] +// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]] // CHECK: return %[[r]] func.func @tensor_transfer_write_0d(%t: tensor, %idx: index, %f: f32) -> tensor { %0 = vector.broadcast %f : f32 to vector diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir index ab30acf68b30b..ef32f8c6a1cdb 100644 --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -117,7 +117,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> { // CHECK-LABEL: func.func @shape_cast_0d1d( // CHECK-SAME: %[[ARG0:.*]]: vector) -> vector<1xf32> { // CHECK: %[[UB:.*]] = ub.poison : vector<1xf32> -// CHECK: %[[EXTRACT0:.*]] = vector.extractelement %[[ARG0]][] : vector +// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector // CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32> // CHECK: return %[[RES]] : vector<1xf32> // CHECK: } @@ -131,7 +131,7 @@ func.func @shape_cast_0d1d(%arg0 : vector) -> vector<1xf32> { // CHECK-SAME: %[[ARG0:.*]]: vector<1xf32>) -> vector { // CHECK: %[[UB:.*]] = ub.poison : vector // CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32> -// CHECK: %[[RES:.*]] = vector.insertelement %[[EXTRACT0]], %[[UB]][] : vector +// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector // CHECK: return %[[RES]] : vector // CHECK: }