diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index cafc3d91fd1e9..997d0ccb28d76 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -332,12 +332,37 @@ def Tensor_ExtractOp : Tensor_Op<"extract", [ ```mlir %4 = tensor.extract %t[%1, %2] : tensor<4x4xi32> %5 = tensor.extract %rt[%1, %2] : tensor + %6 = tensor.extract %rt[3, 4] : tensor + %7 = tensor.extract %rt[%1, 4] : tensor ``` }]; - let arguments = (ins AnyRankedTensor:$tensor, Variadic:$indices); + let arguments = (ins + AnyRankedTensor:$tensor, + Variadic:$indices, + DenseI64ArrayAttr:$static_indices + ); let results = (outs AnyType:$result); - let assemblyFormat = "$tensor `[` $indices `]` attr-dict `:` type($tensor)"; + let assemblyFormat = [{ + $tensor `` + custom($indices, $static_indices) + attr-dict `:` type($tensor) + }]; + + let builders = [ + // Build an ExtractOp with mixed static and dynamic indexes. + OpBuilder<(ins "Value":$tensor, "ArrayRef":$indexes, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an ExtractOp with mixed static, dynamic indexes and inferred result type. + OpBuilder<(ins "Type":$resultType, "Value":$tensor, "ArrayRef":$indexes, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an ExtractOp with dynamic indexes. + OpBuilder<(ins "Value":$source, CArg<"ValueRange", "{}">:$indexes, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an ExtractOp with dynamic indexes and inferred result type. + OpBuilder<(ins "Type":$resultType, "Value":$source, CArg<"ValueRange", "{}">:$indexes, + CArg<"ArrayRef", "{}">:$attrs)>, + ]; let hasCanonicalizer = 1; let hasFolder = 1; @@ -808,16 +833,35 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ let arguments = (ins AnyType:$scalar, AnyRankedTensor:$dest, - Variadic:$indices); + Variadic:$indices, + DenseI64ArrayAttr:$static_indices + ); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ - $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest) + $scalar `into` + $dest `` custom($indices, $static_indices) + attr-dict `:` type($dest) }]; let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getDestMutable(); } }]; + let builders = [ + // Build an InsertOp with mixed static and dynamic indexes. + OpBuilder<(ins "Value":$scalar, "Value":$dest, "ArrayRef":$indexes, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an InsertOp with mixed static, dynamic indexes and inferred result type. + OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, "ArrayRef":$indexes, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an InsertOp with dynamic indexes. + OpBuilder<(ins "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an InsertOp with dynamic indexes and inferred result type. + OpBuilder<(ins "Type":$resultType, "Value":$scalar, "Value":$dest, CArg<"ValueRange", "{}">:$indexes, + CArg<"ArrayRef", "{}">:$attrs)>, + ]; + let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 8eb8e579954fa..89184f2162c2c 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1736,6 +1736,32 @@ struct ShapeOfFromReshape : public OpRewritePattern { } }; +struct ExtractFromShapeOfExtentTensor + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override { + auto tensorShapeOfOp = op.getTensor().getDefiningOp(); + if (!tensorShapeOfOp) + return rewriter.notifyMatchFailure(op, "producer is not shape.shape_of"); + + int64_t staticIndice = op.getStaticIndices()[0]; + Type indexType = rewriter.getIndexType(); + Value indice = + staticIndice != ShapedType::kDynamic + ? tensorShapeOfOp->getDialect() + ->materializeConstant( + rewriter, IntegerAttr::get(indexType, staticIndice), + indexType, op.getLoc()) + ->getResult(0) + : op.getIndices()[0]; + rewriter.replaceOpWithNewOp(op, tensorShapeOfOp.getArg(), + indice); + return success(); + } +}; + // Canonicalize // ``` // %0 = shape.shape_of %arg : tensor -> tensor<3xindex> diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td index cb294ae2978fc..e135105d6980b 100644 --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -44,9 +44,3 @@ def SizeToIndexToSizeCanonicalization : Pat< def TensorCastConstShape : Pat < (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg), [(HasStaticShape $res)]>; - -// tensor.extract from shape_of -> tensor.dim. We can take the first index -// because shape_of always returns a 1D tensor. -def ExtractFromShapeOfExtentTensor : Pat< - (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices), - (Tensor_DimOp $arg, (TakeFront $indices))>; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index e11c6aaccf74d..73dc98ee93ed4 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -39,6 +39,19 @@ using llvm::divideCeilSigned; using llvm::divideFloorSigned; using llvm::mod; +static LogicalResult +checkTensorRankMatchIndices(Value tensor, ValueRange dynamicIndices, + ArrayRef staticIndices) { + auto tensorType = llvm::cast(tensor.getType()); + int64_t dynamicDimCount = llvm::count_if(staticIndices, [](int64_t element) { + return element == ShapedType::kDynamic; + }); + if (tensorType.getRank() != staticIndices.size() || + dynamicDimCount != static_cast(dynamicIndices.size())) + return LogicalResult::failure(); + return LogicalResult::success(); +} + /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *TensorDialect::materializeConstant(OpBuilder &builder, @@ -1120,10 +1133,49 @@ void ExtractOp::getAsmResultNames( setNameFn(getResult(), "extracted"); } +// Build an ExtractOp with mixed static and dynamic indexes. +void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor, + ArrayRef indices, + ArrayRef attrs) { + Type resultType = llvm::cast(tensor.getType()).getElementType(); + build(b, result, resultType, tensor, indices, attrs); +} + +// Build an ExtractOp with mixed static, dynamic indexes and inferred result +// Type. +void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value tensor, ArrayRef indices, + ArrayRef attrs) { + SmallVector staticIndices; + SmallVector dynamicIndices; + dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices); + result.addAttributes(attrs); + build(b, result, resultType, tensor, dynamicIndices, + b.getDenseI64ArrayAttr(staticIndices)); +} + +// Build an ExtractOp with dynamic indexes and inferred result type. +void ExtractOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value tensor, ValueRange indices, + ArrayRef attrs) { + SmallVector indicesValues = llvm::to_vector<4>( + llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })); + build(b, result, resultType, tensor, indicesValues, attrs); +} + +// Build an ExtractOp with dynamic indexes. +void ExtractOp::build(OpBuilder &b, OperationState &result, Value tensor, + ValueRange indices, ArrayRef attrs) { + Type resultType = llvm::cast(tensor.getType()).getElementType(); + SmallVector indicesValues = llvm::to_vector<4>( + llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })); + build(b, result, resultType, tensor, indicesValues, attrs); +} + LogicalResult ExtractOp::verify() { // Verify the # indices match if we have a ranked type. - auto tensorType = llvm::cast(getTensor().getType()); - if (tensorType.getRank() != static_cast(getIndices().size())) + if (failed(checkTensorRankMatchIndices(getTensor(), getIndices(), + getStaticIndices()))) return emitOpError("incorrect number of indices for extract_element"); return success(); } @@ -1137,12 +1189,18 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { // Collect the constant indices into the tensor. SmallVector indices; - for (Attribute indice : adaptor.getIndices()) { - if (!indice || !llvm::isa(indice)) - return {}; - indices.push_back(llvm::cast(indice).getInt()); + auto dynamicIndicesIt = adaptor.getIndices().begin(); + for (int64_t i : getStaticIndices()) { + if (i != ShapedType::kDynamic) { + indices.push_back(i); + } else { + Attribute indice = *dynamicIndicesIt; + if (!indice || !llvm::isa(indice)) + return {}; + indices.push_back(llvm::cast(indice).getInt()); + dynamicIndicesIt++; + } } - // Fold extract(from_elements(...)). if (auto fromElementsOp = getTensor().getDefiningOp()) { auto tensorType = llvm::cast(fromElementsOp.getType()); @@ -1354,10 +1412,48 @@ void InsertOp::getAsmResultNames( setNameFn(getResult(), "inserted"); } +// Build an ExtractOp with mixed static and dynamic indexes. +void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar, + Value dest, ArrayRef indices, + ArrayRef attrs) { + build(b, result, dest.getType(), scalar, dest, indices, attrs); +} + +// Build an InsertOp with mixed static, dynamic indexes and inferred result +// Type. +void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value scalar, Value dest, ArrayRef indices, + ArrayRef attrs) { + SmallVector staticIndices; + SmallVector dynamicIndices; + dispatchIndexOpFoldResults(indices, dynamicIndices, staticIndices); + result.addAttributes(attrs); + build(b, result, resultType, scalar, dest, dynamicIndices, + b.getDenseI64ArrayAttr(staticIndices)); +} + +// Build an ExtractOp with dynamic indexes and inferred result type. +void InsertOp::build(OpBuilder &b, OperationState &result, Type resultType, + Value scalar, Value dest, ValueRange indices, + ArrayRef attrs) { + SmallVector indicesValues = llvm::to_vector<4>( + llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })); + build(b, result, resultType, scalar, dest, indicesValues, attrs); +} + +// Build an InsertOp with dynamic indexes. +void InsertOp::build(OpBuilder &b, OperationState &result, Value scalar, + Value dest, ValueRange indices, + ArrayRef attrs) { + SmallVector indicesValues = llvm::to_vector<4>( + llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })); + build(b, result, dest.getType(), scalar, dest, indicesValues, attrs); +} + LogicalResult InsertOp::verify() { // Verify the # indices match if we have a ranked type. - auto destType = llvm::cast(getDest().getType()); - if (destType.getRank() != static_cast(getIndices().size())) + if (failed(checkTensorRankMatchIndices(getDest(), getIndices(), + getStaticIndices()))) return emitOpError("incorrect number of indices"); return success(); } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 5b98a7790debf..8c04e574dbc51 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1519,6 +1519,19 @@ func.func @extract_shapeof(%arg0 : tensor) -> index { return %result : index } +// ----- + +// CHECK-LABEL: func @extract_shapeof_static_indice +// CHECK-SAME: %[[ARG0:.*]]: tensor +func.func @extract_shapeof_static_indice(%arg0 : tensor) -> index { +// CHECK: %[[C1:.*]] = arith.constant 1 + %shape = shape.shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]] + %result = tensor.extract %shape[1] : tensor<2xindex> +// CHECK: return %[[DIM]] + return %result : index +} + // ----- diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 4b8efde78cc23..8f7c7478669b4 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -137,11 +137,12 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1 // ----- // CHECK-LABEL: func @fold_extract -func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex) { +func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, i32, complex) { %const_0 = arith.constant 0 : index %const_1 = arith.constant 1 : index %const_3 = arith.constant 3 : index // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32 + // CHECK-DAG: [[CNEG1:%.+]] = arith.constant -1 : i32 // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16 // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16 @@ -162,13 +163,16 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex) { %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32> %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32> + // Fold an extract into a dense tensor with mixed dynamic and static indexes. + %ext_5 = tensor.extract %3[%const_1, 0, 2] : tensor<2x1x4xi32> + // Fold an extract into a complex constant. // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex %4 = arith.constant dense<(1.2, 2.3)> : tensor> - %ext_5 = tensor.extract %4[] : tensor> + %ext_6 = tensor.extract %4[] : tensor> - // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]] - return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex + // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[CNEG1]], [[C5]] + return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6: f32, f16, f16, i32, i32, complex } // ----- diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 41b6529f64afa..8c594ddacb8d3 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -64,7 +64,7 @@ func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) { // ----- -func.func @extract_too_many_indices(%arg0: tensor) { +func.func @extract_too_few_indices(%arg0: tensor) { // expected-error@+1 {{incorrect number of indices for extract_element}} %0 = tensor.extract %arg0[] : tensor return @@ -72,7 +72,24 @@ func.func @extract_too_many_indices(%arg0: tensor) { // ----- -func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor) { +func.func @extract_too_many_static_indices(%arg0: tensor) { + // expected-error@+1 {{incorrect number of indices for extract_element}} + %0 = tensor.extract %arg0[2, 3] : tensor + return +} + +// ----- + +func.func @extract_too_many_mixed_indices(%arg0: tensor) { + %c1 = arith.constant 1 : index + // expected-error@+1 {{incorrect number of indices for extract_element}} + %0 = tensor.extract %arg0[%c1, 2, 3] : tensor + return +} + +// ----- + +func.func @insert_too_few_indices(%arg0: f32, %arg1: tensor) { // expected-error@+1 {{incorrect number of indices}} %0 = tensor.insert %arg0 into %arg1[] : tensor return @@ -80,6 +97,23 @@ func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor) { // ----- +func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor) { + // expected-error@+1 {{incorrect number of indices}} + %0 = tensor.insert %arg0 into %arg1[2, 3] : tensor + return +} + +// ----- + +func.func @insert_too_many_mixed_indices(%arg0: f32, %arg1: tensor) { + %c1 = arith.constant 1 : index + // expected-error@+1 {{incorrect number of indices}} + %0 = tensor.insert %arg0 into %arg1[%c1, 2, 3] : tensor + return +} + +// ----- + func.func @tensor.from_elements_wrong_result_type() { // expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}} %c0 = arith.constant 0 : i32 diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 378137a14b59f..0a4cd08239c5b 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -58,6 +58,9 @@ func.func @empty_with_encoding(%sz: index) -> tensor<5x?x6xf32, "foo"> { func.func @extract(%arg0: tensor, %arg1: index) { // CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor %0 = tensor.extract %arg0[%arg1, %arg1, %arg1] : tensor + + // CHECK: tensor.extract %[[TENSOR]][%[[INDEX]], 2, 3] : tensor + %1 = tensor.extract %arg0[%arg1, 2, 3] : tensor return } @@ -70,6 +73,9 @@ func.func @extract(%arg0: tensor, %arg1: index) { func.func @insert(%arg0: f32, %arg1: index, %arg2: tensor) { // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor + + // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], 2, 3] : tensor + %1 = tensor.insert %arg0 into %arg2[%arg1, 2, 3] : tensor return }