diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index c5b08d6aa022b..dad08305b2a64 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -695,14 +695,14 @@ def Vector_ExtractOp : %1 = vector.extract %0[3]: vector<8x16xf32> from vector<4x8x16xf32> %2 = vector.extract %0[2, 1, 3]: f32 from vector<4x8x16xf32> %3 = vector.extract %1[]: vector from vector - %4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32> - %5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32> + %4 = vector.extract %0[%a, %b, %c : index] : f32 from vector<4x8x16xf32> + %5 = vector.extract %0[2, %b : index] : vector<16xf32> from vector<4x8x16xf32> ``` }]; let arguments = (ins AnyVectorOfAnyRank:$vector, - Variadic:$dynamic_position, + Variadic:$dynamic_position, DenseI64ArrayAttr:$static_position ); let results = (outs AnyType:$result); @@ -737,7 +737,8 @@ def Vector_ExtractOp : let assemblyFormat = [{ $vector `` - custom($dynamic_position, $static_position) + custom($dynamic_position, $static_position, + type($dynamic_position)) attr-dict `:` type($result) `from` type($vector) }]; @@ -883,15 +884,15 @@ def Vector_InsertOp : %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32> %5 = vector.insert %3, %4[2, 1, 3] : f32 into vector<4x8x16xf32> %8 = vector.insert %6, %7[] : f32 into vector - %11 = vector.insert %9, %10[%a, %b, %c] : vector into vector<4x8x16xf32> - %12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32> + %11 = vector.insert %9, %10[%a, %b, %c : index] : vector into vector<4x8x16xf32> + %12 = vector.insert %4, %10[2, %b : index] : vector<16xf32> into vector<4x8x16xf32> ``` }]; let arguments = (ins AnyType:$source, AnyVectorOfAnyRank:$dest, - Variadic:$dynamic_position, + Variadic:$dynamic_position, DenseI64ArrayAttr:$static_position ); let results = (outs AnyVectorOfAnyRank:$result); @@ -926,7 +927,9 @@ def Vector_InsertOp : }]; let assemblyFormat = [{ - $source `,` $dest custom($dynamic_position, $static_position) + $source `,` $dest + custom($dynamic_position, $static_position, + type($dynamic_position)) attr-dict `:` type($source) `into` type($dest) }]; @@ -1344,7 +1347,7 @@ def Vector_TransferReadOp : %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref // Update the temporary gathered slice with the individual element %slice = memref.load %tmp : memref> -> vector<3x4x5xf32> - %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32> + %updated = vector.insert %a, %slice[%i, %j, %k : index] : f32 into vector<3x4x5xf32> memref.store %updated, %tmp : memref> }}} // At this point we gathered the elements from the original @@ -1367,7 +1370,7 @@ def Vector_TransferReadOp : %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref %slice = memref.load %tmp : memref> -> vector<3x4x5xf32> // Here we only store to the first element in dimension one - %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32> + %updated = vector.insert %a, %slice[%i, 0, %k : index] : f32 into vector<3x4x5xf32> memref.store %updated, %tmp : memref> }} // At this point we gathered the elements from the original diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index a7222794f320b..699dd1da863b6 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -794,16 +794,26 @@ class AsmParser { }; /// Parse a list of comma-separated items with an optional delimiter. If a - /// delimiter is provided, then an empty list is allowed. If not, then at + /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. /// + /// `parseSuffixFn` is an optional function to parse any suffix that can be + /// appended to the comma separated list within the delimiter. + /// /// contextMessage is an optional message appended to "expected '('" sorts of /// diagnostics when parsing the delimeters. - virtual ParseResult + virtual ParseResult parseCommaSeparatedList( + Delimiter delimiter, function_ref parseElementFn, + std::optional> parseSuffixFn = std::nullopt, + StringRef contextMessage = StringRef()) = 0; + ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref parseElementFn, - StringRef contextMessage = StringRef()) = 0; - + StringRef contextMessage) { + return parseCommaSeparatedList(delimiter, parseElementFn, + /*parseSuffixFn=*/std::nullopt, + contextMessage); + } /// Parse a comma separated list of elements that must have at least one entry /// in it. ParseResult @@ -1319,6 +1329,9 @@ class AsmParser { virtual ParseResult parseOptionalColonTypeList(SmallVectorImpl &result) = 0; + /// Parse an optional colon followed by a type. + virtual ParseResult parseOptionalColonType(Type &result) = 0; + /// Parse a keyword followed by a type. ParseResult parseKeywordType(const char *keyword, Type &result) { return failure(parseKeyword(keyword) || parseType(result)); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index 3dcbd2f1af193..1971c25a8f20b 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -96,8 +96,10 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final /// in `integers` is `kDynamic` or (2) the next value otherwise. If `valueTypes` /// is non-empty, it is expected to contain as many elements as `values` /// indicating their types. This allows idiomatic printing of mixed value and -/// integer attributes in a list. E.g. -/// `[%arg0 : index, 7, 42, %arg42 : i32]`. +/// integer attributes in a list. E.g., `[%arg0 : index, 7, 42, %arg42 : i32]`. +/// If `hasSameTypeDynamicValues` is `true`, `valueTypes` are expected to be the +/// same and only one type is printed at the end of the list. E.g., +/// `[0, %arg2, 3, %arg42, 2 : i8]`. /// /// Indices can be scalable. For example, "4" in "[2, [4], 8]" is scalable. /// This notation is similar to how scalable dims are marked when defining @@ -108,7 +110,8 @@ void printDynamicIndexList( OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, ArrayRef scalables, TypeRange valueTypes = TypeRange(), - AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square, + bool hasSameTypeDynamicValues = false); inline void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, @@ -123,6 +126,13 @@ inline void printDynamicIndexList( return printDynamicIndexList(printer, op, values, integers, {}, valueTypes, delimiter); } +inline void printSameTypeDynamicIndexList( + OpAsmPrinter &printer, Operation *op, OperandRange values, + ArrayRef integers, TypeRange valueTypes = TypeRange(), + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + return printDynamicIndexList(printer, op, values, integers, {}, valueTypes, + delimiter, /*hasSameTypeDynamicValues=*/true); +} /// Parser hook for custom directive in assemblyFormat. /// @@ -150,7 +160,8 @@ ParseResult parseDynamicIndexList( SmallVectorImpl &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableVals, SmallVectorImpl *valueTypes = nullptr, - AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square); + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square, + bool hasSameTypeDynamicValues = false); inline ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl &values, @@ -188,6 +199,16 @@ inline ParseResult parseDynamicIndexList( return parseDynamicIndexList(parser, values, integers, scalableVals, &valueTypes, delimiter); } +inline ParseResult parseSameTypeDynamicIndexList( + OpAsmParser &parser, + SmallVectorImpl &values, + DenseI64ArrayAttr &integers, SmallVectorImpl &valueTypes, + AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) { + DenseBoolArrayAttr scalableVals = {}; + return parseDynamicIndexList(parser, values, integers, scalableVals, + &valueTypes, delimiter, + /*hasSameTypeDynamicValues=*/true); +} /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index 04250f63dcd25..4d5b93ec09d17 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -340,12 +340,16 @@ class AsmParserImpl : public BaseT { /// Parse a list of comma-separated items with an optional delimiter. If a /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. - ParseResult parseCommaSeparatedList(Delimiter delimiter, - function_ref parseElt, - StringRef contextMessage) override { - return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage); + ParseResult parseCommaSeparatedList( + Delimiter delimiter, function_ref parseElt, + std::optional> parseSuffix, + StringRef contextMessage) override { + return parser.parseCommaSeparatedList(delimiter, parseElt, parseSuffix, + contextMessage); } + using BaseT::parseCommaSeparatedList; + //===--------------------------------------------------------------------===// // Keyword Parsing //===--------------------------------------------------------------------===// @@ -590,6 +594,17 @@ class AsmParserImpl : public BaseT { return parser.parseTypeListNoParens(result); } + /// Parse an optional colon followed by a type. + ParseResult parseOptionalColonType(Type &result) override { + SmallVector types; + ParseResult parseResult = parseOptionalColonTypeList(types); + if (llvm::succeeded(parseResult) && types.size() > 1) + return emitError(getCurrentLocation(), "expected single type"); + if (!types.empty()) + result = types[0]; + return parseResult; + } + ParseResult parseDimensionList(SmallVectorImpl &dimensions, bool allowDynamic, bool withTrailingX) override { diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index 8f19487d80fa3..6476910f71eb7 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -80,10 +80,10 @@ AsmParserCodeCompleteContext::~AsmParserCodeCompleteContext() = default; /// Parse a list of comma-separated items with an optional delimiter. If a /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. -ParseResult -Parser::parseCommaSeparatedList(Delimiter delimiter, - function_ref parseElementFn, - StringRef contextMessage) { +ParseResult Parser::parseCommaSeparatedList( + Delimiter delimiter, function_ref parseElementFn, + std::optional> parseSuffixFn, + StringRef contextMessage) { switch (delimiter) { case Delimiter::None: break; @@ -144,6 +144,9 @@ Parser::parseCommaSeparatedList(Delimiter delimiter, return failure(); } + if (parseSuffixFn && (*parseSuffixFn)()) + return failure(); + switch (delimiter) { case Delimiter::None: return success(); diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index bf91831798056..1ebca05bbcb2e 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -46,10 +46,17 @@ class Parser { /// Parse a list of comma-separated items with an optional delimiter. If a /// delimiter is provided, then an empty list is allowed. If not, then at /// least one element will be parsed. + ParseResult parseCommaSeparatedList( + Delimiter delimiter, function_ref parseElementFn, + std::optional> parseSuffixFn = std::nullopt, + StringRef contextMessage = StringRef()); ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref parseElementFn, - StringRef contextMessage = StringRef()); + StringRef contextMessage) { + return parseCommaSeparatedList(delimiter, parseElementFn, std::nullopt, + contextMessage); + } /// Parse a comma separated list of elements that must have at least one entry /// in it. diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index 55965d9c2a531..c5c3353bf0477 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -501,13 +501,14 @@ struct VectorOuterProductToArmSMELowering /// /// Example: /// ``` -/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> +/// %el = vector.extract %tile[%row, %col : index] : i32 from +/// vector<[4]x[4]xi32> /// ``` /// Becomes: /// ``` /// %slice = arm_sme.extract_tile_slice %tile[%row] /// : vector<[4]xi32> from vector<[4]x[4]xi32> -/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> +/// %el = vector.extract %slice[%col : index] : i32 from vector<[4]xi32> /// ``` struct VectorExtractToArmSMELowering : public OpRewritePattern { @@ -561,8 +562,9 @@ struct VectorExtractToArmSMELowering /// ``` /// %slice = arm_sme.extract_tile_slice %tile[%row] /// : vector<[4]xi32> from vector<[4]x[4]xi32> -/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> -/// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row] +/// %new_slice = vector.insert %el, %slice[%col : index] : i32 into +/// vector<[4]xi32> %new_tile = arm_sme.insert_tile_slice %new_slice, +/// %tile[%row] /// : vector<[4]xi32> into vector<[4]x[4]xi32> /// ``` struct VectorInsertToArmSMELowering diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 3a4dc806efe97..b623a86c53ee7 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1050,10 +1050,10 @@ getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) { /// %vscale = vector.vscale /// %c4_vscale = arith.muli %vscale, %c4 : index /// scf.for %idx = %c0 to %c4_vscale step %c1 { -/// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32> -/// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32> -/// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32> -/// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32> +/// %4 = vector.extract %0[%idx : index] : f32 from vector<[4]xf32> +/// %5 = vector.extract %1[%idx : index] : f32 from vector<[4]xf32> +/// %6 = vector.extract %2[%idx : index] : f32 from vector<[4]xf32> +/// %7 = vector.extract %3[%idx : index] : f32 from vector<[4]xf32> /// %slice_i = affine.apply #map(%idx)[%i] /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32> /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]} diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp index ca33636336bf0..8e44ff60eec87 100644 --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -114,7 +114,8 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef integers, ArrayRef scalables, TypeRange valueTypes, - AsmParser::Delimiter delimiter) { + AsmParser::Delimiter delimiter, + bool hasSameTypeDynamicValues) { char leftDelimiter = getLeftDelimiter(delimiter); char rightDelimiter = getRightDelimiter(delimiter); printer << leftDelimiter; @@ -130,7 +131,7 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, printer << "["; if (ShapedType::isDynamic(integer)) { printer << values[dynamicValIdx]; - if (!valueTypes.empty()) + if (!hasSameTypeDynamicValues && !valueTypes.empty()) printer << " : " << valueTypes[dynamicValIdx]; ++dynamicValIdx; } else { @@ -142,6 +143,13 @@ void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, scalableIndexIdx++; }); + if (hasSameTypeDynamicValues && !valueTypes.empty()) { + assert(std::all_of(valueTypes.begin(), valueTypes.end(), + [&](Type type) { return type == valueTypes[0]; }) && + "Expected the same value types"); + printer << " : " << valueTypes[0]; + } + printer << rightDelimiter; } @@ -149,7 +157,8 @@ ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalables, - SmallVectorImpl *valueTypes, AsmParser::Delimiter delimiter) { + SmallVectorImpl *valueTypes, AsmParser::Delimiter delimiter, + bool hasSameTypeDynamicValues) { SmallVector integerVals; SmallVector scalableVals; @@ -163,7 +172,8 @@ ParseResult mlir::parseDynamicIndexList( if (res.has_value() && succeeded(res.value())) { values.push_back(operand); integerVals.push_back(ShapedType::kDynamic); - if (valueTypes && parser.parseColonType(valueTypes->emplace_back())) + if (!hasSameTypeDynamicValues && valueTypes && + parser.parseColonType(valueTypes->emplace_back())) return failure(); } else { int64_t integer; @@ -178,10 +188,34 @@ ParseResult mlir::parseDynamicIndexList( return failure(); return success(); }; + auto parseColonType = [&]() -> ParseResult { + if (hasSameTypeDynamicValues) { + assert(valueTypes && "Expected non-null value types"); + assert(valueTypes->empty() && "Expected no parsed value types"); + + Type dynValType; + if (parser.parseOptionalColonType(dynValType)) + return failure(); + + if (!dynValType && !values.empty()) + return parser.emitError(parser.getNameLoc()) + << "expected a type for dynamic indices"; + if (dynValType) { + if (values.empty()) + return parser.emitError(parser.getNameLoc()) + << "expected no type for constant indices"; + + // Broadcast the single type to all the dynamic values. + valueTypes->append(values.size(), dynValType); + } + } + return success(); + }; if (parser.parseCommaSeparatedList(delimiter, parseIntegerOrValue, - " in dynamic index list")) + parseColonType, " in dynamic index list")) return parser.emitError(parser.getNameLoc()) - << "expected SSA value or integer"; + << "expected a valid list of SSA values or integers"; + integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); scalables = parser.getBuilder().getDenseBoolArrayAttr(scalableVals); return success(); diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir index ff7b4bcb5f65a..c93dbf8836f6c 100644 --- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir +++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir @@ -151,7 +151,7 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest // CHECK-NOT: arm_sme.store_tile_slice func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref, %slice_index: index) { %c0 = arith.constant 0 : index - %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32> vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref return } @@ -202,7 +202,7 @@ func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index { // CHECK-NOT: arm_sve.psel %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1> - %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1> + %slice = vector.extract %mask[%index : index] : vector<[32]xi1> from vector<[4]x[32]xi1> return %slice : vector<[32]xi1> } @@ -215,7 +215,7 @@ func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index { // CHECK-NOT: arm_sve.psel %mask = vector.create_mask %a, %b : vector<4x[8]xi1> - %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1> + %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<4x[8]xi1> return %slice : vector<[8]xi1> } @@ -227,7 +227,7 @@ func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1> { // CHECK-NOT: arm_sve.psel - %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> + %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<[4]x[8]xi1> return %slice : vector<[8]xi1> } @@ -240,6 +240,6 @@ func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index { // CHECK-NOT: arm_sve.psel %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> - %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1> + %el = vector.extract %mask[2, %index : index] : i1 from vector<[4]x[8]xi1> return %el : i1 } diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir index 0f973af799634..6ca19c5746ea1 100644 --- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir +++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir @@ -345,7 +345,7 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb // CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref, vector<[4]xi1>, vector<[4]x[4]xf32> func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref, %slice_index: index) { %c0 = arith.constant 0 : index - %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32> vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref return } @@ -361,7 +361,7 @@ func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref, vector<[4]xi1>, vector<[4]x[4]xf32> func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref, %mask: vector<[4]xi1>, %slice_index: index) { %c0 = arith.constant 0 : index - %slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32> + %slice = vector.extract %vector[%slice_index : index] : vector<[4]xf32> from vector<[4]x[4]xf32> vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref return } @@ -927,7 +927,7 @@ func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vect // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK-NEXT: arm_sme.insert_tile_slice %[[SLICE]], %[[TILE]][%[[INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32> %tile = arm_sme.get_tile : vector<[4]x[4]xi32> - %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[4]xi32> into vector<[4]x[4]xi32> return %new_tile : vector<[4]x[4]xi32> } @@ -937,7 +937,7 @@ func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vect func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8> %tile = arm_sme.get_tile : vector<[16]x[16]xi8> - %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[16]xi8> into vector<[16]x[16]xi8> return %new_tile : vector<[16]x[16]xi8> } @@ -947,7 +947,7 @@ func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vecto func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xi16> into vector<[8]x[8]xi16> %tile = arm_sme.get_tile : vector<[8]x[8]xi16> - %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xi16> into vector<[8]x[8]xi16> return %new_tile : vector<[8]x[8]xi16> } @@ -957,7 +957,7 @@ func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vect func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[2]xi64> into vector<[2]x[2]xi64> %tile = arm_sme.get_tile : vector<[2]x[2]xi64> - %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[2]xi64> into vector<[2]x[2]xi64> return %new_tile : vector<[2]x[2]xi64> } @@ -967,7 +967,7 @@ func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vect func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[1]xi128> into vector<[1]x[1]xi128> %tile = arm_sme.get_tile : vector<[1]x[1]xi128> - %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[1]xi128> into vector<[1]x[1]xi128> return %new_tile : vector<[1]x[1]xi128> } @@ -977,7 +977,7 @@ func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> ve func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16> %tile = arm_sme.get_tile : vector<[8]x[8]xf16> - %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xf16> into vector<[8]x[8]xf16> return %new_tile : vector<[8]x[8]xf16> } @@ -987,7 +987,7 @@ func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vect func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[8]xbf16> into vector<[8]x[8]xbf16> %tile = arm_sme.get_tile : vector<[8]x[8]xbf16> - %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[8]xbf16> into vector<[8]x[8]xbf16> return %new_tile : vector<[8]x[8]xbf16> } @@ -997,7 +997,7 @@ func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> ve func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[4]xf32> into vector<[4]x[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> - %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[4]xf32> into vector<[4]x[4]xf32> return %new_tile : vector<[4]x[4]xf32> } @@ -1007,7 +1007,7 @@ func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vect func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> { // CHECK: arm_sme.insert_tile_slice %{{.*}} : vector<[2]xf64> into vector<[2]x[2]xf64> %tile = arm_sme.get_tile : vector<[2]x[2]xf64> - %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64> + %new_tile = vector.insert %slice, %tile[%row : index] : vector<[2]xf64> into vector<[2]x[2]xf64> return %new_tile : vector<[2]x[2]xf64> } @@ -1020,10 +1020,10 @@ func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vect func.func @vector_insert_element_i32(%el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> { // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK-NEXT: %[[SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[ROW]]] : vector<[4]xi32> from vector<[4]x[4]xi32> - // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32> + // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]] : index] : i32 into vector<[4]xi32> // CHECK-NEXT: arm_sme.insert_tile_slice %[[NEW_SLICE]], %[[TILE]][%[[ROW]]] : vector<[4]xi32> into vector<[4]x[4]xi32> %tile = arm_sme.get_tile : vector<[4]x[4]xi32> - %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i32 into vector<[4]x[4]xi32> return %new_tile : vector<[4]x[4]xi32> } @@ -1035,7 +1035,7 @@ func.func @vector_insert_element_i8(%el: i8, %row: index, %col: index) -> vector // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[16]xi8> into vector<[16]x[16]xi8> %tile = arm_sme.get_tile : vector<[16]x[16]xi8> - %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i8 into vector<[16]x[16]xi8> return %new_tile : vector<[16]x[16]xi8> } @@ -1047,7 +1047,7 @@ func.func @vector_insert_element_i16(%el: i16, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xi16> into vector<[8]x[8]xi16> %tile = arm_sme.get_tile : vector<[8]x[8]xi16> - %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i16 into vector<[8]x[8]xi16> return %new_tile : vector<[8]x[8]xi16> } @@ -1059,7 +1059,7 @@ func.func @vector_insert_element_i64(%el: i64, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[2]xi64> into vector<[2]x[2]xi64> %tile = arm_sme.get_tile : vector<[2]x[2]xi64> - %new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i64 into vector<[2]x[2]xi64> return %new_tile : vector<[2]x[2]xi64> } @@ -1071,7 +1071,7 @@ func.func @vector_insert_element_i128(%el: i128, %row: index, %col: index) -> ve // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[1]xi128> into vector<[1]x[1]xi128> %tile = arm_sme.get_tile : vector<[1]x[1]xi128> - %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128> + %new_tile = vector.insert %el, %tile[%row, %col : index] : i128 into vector<[1]x[1]xi128> return %new_tile : vector<[1]x[1]xi128> } @@ -1083,7 +1083,7 @@ func.func @vector_insert_element_f16(%el: f16, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xf16> into vector<[8]x[8]xf16> %tile = arm_sme.get_tile : vector<[8]x[8]xf16> - %new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16> + %new_tile = vector.insert %el, %tile[%row, %col : index] : f16 into vector<[8]x[8]xf16> return %new_tile : vector<[8]x[8]xf16> } @@ -1095,7 +1095,7 @@ func.func @vector_insert_element_bf16(%el: bf16, %row: index, %col: index) -> ve // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[8]xbf16> into vector<[8]x[8]xbf16> %tile = arm_sme.get_tile : vector<[8]x[8]xbf16> - %new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16> + %new_tile = vector.insert %el, %tile[%row, %col : index] : bf16 into vector<[8]x[8]xbf16> return %new_tile : vector<[8]x[8]xbf16> } @@ -1107,7 +1107,7 @@ func.func @vector_insert_element_f32(%el: f32, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[4]xf32> into vector<[4]x[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> - %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32> + %new_tile = vector.insert %el, %tile[%row, %col : index] : f32 into vector<[4]x[4]xf32> return %new_tile : vector<[4]x[4]xf32> } @@ -1119,7 +1119,7 @@ func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vect // CHECK: arm_sme.extract_tile_slice %[[TILE]]{{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64> // CHECK: arm_sme.insert_tile_slice %{{.*}}, %[[TILE]][%{{.*}}] : vector<[2]xf64> into vector<[2]x[2]xf64> %tile = arm_sme.get_tile : vector<[2]x[2]xf64> - %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64> + %new_tile = vector.insert %el, %tile[%row, %col : index] : f64 into vector<[2]x[2]xf64> return %new_tile : vector<[2]x[2]xf64> } @@ -1135,7 +1135,7 @@ func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> { // CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK: arm_sme.extract_tile_slice %[[TILE]][%[[INDEX]]] : vector<[4]xi32> from vector<[4]x[4]xi32> %tile = arm_sme.get_tile : vector<[4]x[4]xi32> - %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32> + %slice = vector.extract %tile[%row : index] : vector<[4]xi32> from vector<[4]x[4]xi32> return %slice : vector<[4]xi32> } @@ -1145,7 +1145,7 @@ func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> { func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8> %tile = arm_sme.get_tile : vector<[16]x[16]xi8> - %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8> + %slice = vector.extract %tile[%row : index] : vector<[16]xi8> from vector<[16]x[16]xi8> return %slice : vector<[16]xi8> } @@ -1155,7 +1155,7 @@ func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> { func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16> %tile = arm_sme.get_tile : vector<[8]x[8]xi16> - %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16> + %slice = vector.extract %tile[%row : index] : vector<[8]xi16> from vector<[8]x[8]xi16> return %slice : vector<[8]xi16> } @@ -1165,7 +1165,7 @@ func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> { func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64> %tile = arm_sme.get_tile : vector<[2]x[2]xi64> - %slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64> + %slice = vector.extract %tile[%row : index] : vector<[2]xi64> from vector<[2]x[2]xi64> return %slice : vector<[2]xi64> } @@ -1175,7 +1175,7 @@ func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> { func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128> %tile = arm_sme.get_tile : vector<[1]x[1]xi128> - %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128> + %slice = vector.extract %tile[%row : index] : vector<[1]xi128> from vector<[1]x[1]xi128> return %slice : vector<[1]xi128> } @@ -1185,7 +1185,7 @@ func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> { func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16> %tile = arm_sme.get_tile : vector<[8]x[8]xf16> - %slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16> + %slice = vector.extract %tile[%row : index] : vector<[8]xf16> from vector<[8]x[8]xf16> return %slice : vector<[8]xf16> } @@ -1195,7 +1195,7 @@ func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> { func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16> %tile = arm_sme.get_tile : vector<[8]x[8]xbf16> - %slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16> + %slice = vector.extract %tile[%row : index] : vector<[8]xbf16> from vector<[8]x[8]xbf16> return %slice : vector<[8]xbf16> } @@ -1205,7 +1205,7 @@ func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> { func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> - %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32> + %slice = vector.extract %tile[%row : index] : vector<[4]xf32> from vector<[4]x[4]xf32> return %slice : vector<[4]xf32> } @@ -1215,7 +1215,7 @@ func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> { func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> { // CHECK: arm_sme.extract_tile_slice {{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64> %tile = arm_sme.get_tile : vector<[2]x[2]xf64> - %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64> + %slice = vector.extract %tile[%row : index] : vector<[2]xf64> from vector<[2]x[2]xf64> return %slice : vector<[2]xf64> } @@ -1227,9 +1227,9 @@ func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> { func.func @vector_extract_element(%row: index, %col: index) -> i32 { // CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> // CHECK-NEXT: %[[SLICE:.*]] = arm_sme.extract_tile_slice %[[TILE]][%[[ROW]]] : vector<[4]xi32> from vector<[4]x[4]xi32> - // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32> + // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]] : index] : i32 from vector<[4]xi32> %tile = arm_sme.get_tile : vector<[4]x[4]xi32> - %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32> + %el = vector.extract %tile[%row, %col : index] : i32 from vector<[4]x[4]xi32> return %el : i32 } @@ -1238,9 +1238,9 @@ func.func @vector_extract_element(%row: index, %col: index) -> i32 { // CHECK-LABEL: @vector_extract_element_i8 func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[16]xi8> from vector<[16]x[16]xi8> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i8 from vector<[16]xi8> %tile = arm_sme.get_tile : vector<[16]x[16]xi8> - %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8> + %el = vector.extract %tile[%row, %col : index] : i8 from vector<[16]x[16]xi8> return %el : i8 } @@ -1249,9 +1249,9 @@ func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 { // CHECK-LABEL: @vector_extract_element_i16 func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xi16> from vector<[8]x[8]xi16> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i16 from vector<[8]xi16> %tile = arm_sme.get_tile : vector<[8]x[8]xi16> - %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16> + %el = vector.extract %tile[%row, %col : index] : i16 from vector<[8]x[8]xi16> return %el : i16 } @@ -1260,9 +1260,9 @@ func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 { // CHECK-LABEL: @vector_extract_element_i64 func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[2]xi64> from vector<[2]x[2]xi64> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i64 from vector<[2]xi64> %tile = arm_sme.get_tile : vector<[2]x[2]xi64> - %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64> + %el = vector.extract %tile[%row, %col : index] : i64 from vector<[2]x[2]xi64> return %el : i64 } @@ -1271,9 +1271,9 @@ func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 { // CHECK-LABEL: @vector_extract_element_i128 func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[1]xi128> from vector<[1]x[1]xi128> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : i128 from vector<[1]xi128> %tile = arm_sme.get_tile : vector<[1]x[1]xi128> - %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128> + %el = vector.extract %tile[%row, %col : index] : i128 from vector<[1]x[1]xi128> return %el : i128 } @@ -1282,9 +1282,9 @@ func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 { // CHECK-LABEL: @vector_extract_element_f16 func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xf16> from vector<[8]x[8]xf16> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f16 from vector<[8]xf16> %tile = arm_sme.get_tile : vector<[8]x[8]xf16> - %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16> + %el = vector.extract %tile[%row, %col : index] : f16 from vector<[8]x[8]xf16> return %el : f16 } @@ -1293,9 +1293,9 @@ func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 { // CHECK-LABEL: @vector_extract_element_bf16 func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[8]xbf16> from vector<[8]x[8]xbf16> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : bf16 from vector<[8]xbf16> %tile = arm_sme.get_tile : vector<[8]x[8]xbf16> - %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16> + %el = vector.extract %tile[%row, %col : index] : bf16 from vector<[8]x[8]xbf16> return %el : bf16 } @@ -1304,9 +1304,9 @@ func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 { // CHECK-LABEL: @vector_extract_element_f32 func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[4]xf32> from vector<[4]x[4]xf32> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f32 from vector<[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> - %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32> + %el = vector.extract %tile[%row, %col : index] : f32 from vector<[4]x[4]xf32> return %el : f32 } @@ -1315,9 +1315,9 @@ func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 { // CHECK-LABEL: @vector_extract_element_f64 func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 { // CHECK: %[[SLICE:.*]] = arm_sme.extract_tile_slice %{{.*}} : vector<[2]xf64> from vector<[2]x[2]xf64> - // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}} : index] : f64 from vector<[2]xf64> %tile = arm_sme.get_tile : vector<[2]x[2]xf64> - %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64> + %el = vector.extract %tile[%row, %col : index] : f64 from vector<[2]x[2]xf64> return %el : f64 } @@ -1335,7 +1335,7 @@ func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: ind // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1> // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1> %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> - %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> + %slice = vector.extract %mask[%index : index] : vector<[8]xi1> from vector<[4]x[8]xi1> return %slice : vector<[8]xi1> } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 03bcb341efea2..acbf0f71b38d2 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1119,6 +1119,38 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 { // CHECK: llvm.extractelement {{.*}}[{{.*}} : i64] : vector<16xf32> // CHECK: return {{.*}} : f32 +// ----- + +func.func @extract_i32_index(%arg0: vector<16xf32>, %arg1: i32) -> f32 { + %0 = vector.extract %arg0[%arg1 : i32]: f32 from vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_i32_index +// CHECK: llvm.extractelement {{.*}}[{{.*}} : i32] : vector<16xf32> +// CHECK: return {{.*}} : f32 + +// ----- + +func.func @extract_i8_index(%arg0: vector<16xf32>, %arg1: i8) -> f32 { + %0 = vector.extract %arg0[%arg1 : i8]: f32 from vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_i8_index +// CHECK: llvm.extractelement {{.*}}[{{.*}} : i8] : vector<16xf32> +// CHECK: return {{.*}} : f32 + +// ----- + +func.func @extract_i1_index(%arg0: vector<16xf32>, %arg1: i1) -> f32 { + %0 = vector.extract %arg0[%arg1 : i1]: f32 from vector<16xf32> + return %0 : f32 +} +// CHECK-LABEL: @extract_i1_index +// CHECK: llvm.extractelement {{.*}}[{{.*}} : i1] : vector<16xf32> +// CHECK: return {{.*}} : f32 + +// ----- + func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 { %0 = vector.extract %arg0[15]: f32 from vector<[16]xf32> return %0 : f32 @@ -1239,7 +1271,7 @@ func.func @extract_scalar_from_vec_3d_f32_scalable(%arg0: vector<4x3x[16]xf32>) // ----- func.func @extract_scalar_from_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg1: index) -> f32 { - %0 = vector.extract %arg0[%arg1]: f32 from vector<16xf32> + %0 = vector.extract %arg0[%arg1 : index] : f32 from vector<16xf32> return %0 : f32 } // CHECK-LABEL: @extract_scalar_from_vec_1d_f32_dynamic_idx @@ -1247,8 +1279,10 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %ar // CHECK: %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64 // CHECK: llvm.extractelement %[[VEC]][%[[UC]] : i64] : vector<16xf32> +// ----- + func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]xf32>, %arg1: index) -> f32 { - %0 = vector.extract %arg0[%arg1]: f32 from vector<[16]xf32> + %0 = vector.extract %arg0[%arg1 : index] : f32 from vector<[16]xf32> return %0 : f32 } // CHECK-LABEL: @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable @@ -1259,7 +1293,7 @@ func.func @extract_scalar_from_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16 // ----- func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: index) -> f32 { - %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32> + %0 = vector.extract %arg0[0, %arg1 : index] : f32 from vector<1x16xf32> return %0 : f32 } @@ -1268,8 +1302,10 @@ func.func @extract_scalar_from_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, % // CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idx( // CHECK: vector.extract +// ----- + func.func @extract_scalar_from_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: index) -> f32 { - %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x[16]xf32> + %0 = vector.extract %arg0[0, %arg1 : index] : f32 from vector<1x[16]xf32> return %0 : f32 } @@ -1356,6 +1392,38 @@ func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> ve // CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32> // CHECK: return {{.*}} : vector<4xf32> +// ----- + +func.func @insert_i32_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i32) -> vector<4xf32> { + %0 = vector.insert %arg0, %arg1[%arg2 : i32] : f32 into vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: @insert_i32_index +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i32] : vector<4xf32> +// CHECK: return {{.*}} : vector<4xf32> + +// ----- + +func.func @insert_i8_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i8) -> vector<4xf32> { + %0 = vector.insert %arg0, %arg1[%arg2 : i8] : f32 into vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: @insert_i8_index +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i8] : vector<4xf32> +// CHECK: return {{.*}} : vector<4xf32> + +// ----- + +func.func @insert_i1_index(%arg0: f32, %arg1: vector<4xf32>, %arg2: i1) -> vector<4xf32> { + %0 = vector.insert %arg0, %arg1[%arg2 : i1] : f32 into vector<4xf32> + return %0 : vector<4xf32> +} +// CHECK-LABEL: @insert_i1_index +// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i1] : vector<4xf32> +// CHECK: return {{.*}} : vector<4xf32> + +// ----- + func.func @insert_scalar_into_vec_1d_f32_scalable(%arg0: f32, %arg1: vector<[4]xf32>) -> vector<[4]xf32> { %0 = vector.insert %arg0, %arg1[3] : f32 into vector<[4]xf32> return %0 : vector<[4]xf32> @@ -1460,7 +1528,7 @@ func.func @insert_scalar_into_vec_3d_f32_scalable(%arg0: f32, %arg1: vector<4x8x func.func @insert_scalar_into_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg1: f32, %arg2: index) -> vector<16xf32> { - %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<16xf32> + %0 = vector.insert %arg1, %arg0[%arg2 : index] : f32 into vector<16xf32> return %0 : vector<16xf32> } @@ -1471,7 +1539,7 @@ func.func @insert_scalar_into_vec_1d_f32_dynamic_idx(%arg0: vector<16xf32>, %arg func.func @insert_scalar_into_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16]xf32>, %arg1: f32, %arg2: index) -> vector<[16]xf32> { - %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<[16]xf32> + %0 = vector.insert %arg1, %arg0[%arg2 : index] : f32 into vector<[16]xf32> return %0 : vector<[16]xf32> } @@ -1484,7 +1552,7 @@ func.func @insert_scalar_into_vec_1d_f32_dynamic_idx_scalable(%arg0: vector<[16] func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %arg1: f32, %idx: index) -> vector<1x16xf32> { - %0 = vector.insert %arg1, %arg0[0, %idx]: f32 into vector<1x16xf32> + %0 = vector.insert %arg1, %arg0[0, %idx : index] : f32 into vector<1x16xf32> return %0 : vector<1x16xf32> } @@ -1495,7 +1563,7 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[16]xf32>, %arg1: f32, %idx: index) -> vector<1x[16]xf32> { - %0 = vector.insert %arg1, %arg0[0, %idx]: f32 into vector<1x[16]xf32> + %0 = vector.insert %arg1, %arg0[0, %idx : index] : f32 into vector<1x[16]xf32> return %0 : vector<1x[16]xf32> } diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir index 5a6da3a06387a..7d25d2b1c1e99 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -828,10 +828,10 @@ func.func @scalable_transpose_store_unmasked(%vec: vector<4x[4]xf32>, %dest: mem // FULL-UNROLL: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index // FULL-UNROLL: scf.for %[[VAL_13:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { // FULL-UNROLL: %[[SLICE_I:.*]] = affine.apply #[[$SLICE_MAP]](%[[VAL_13]]){{\[}}%[[I]]] -// FULL-UNROLL: %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32> -// FULL-UNROLL: %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32> -// FULL-UNROLL: %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32> -// FULL-UNROLL: %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]]] : f32 from vector<[4]xf32> +// FULL-UNROLL: %[[ELEM_0:.*]] = vector.extract %[[SLICE_0]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32> +// FULL-UNROLL: %[[ELEM_1:.*]] = vector.extract %[[SLICE_1]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32> +// FULL-UNROLL: %[[ELEM_2:.*]] = vector.extract %[[SLICE_2]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32> +// FULL-UNROLL: %[[ELEM_3:.*]] = vector.extract %[[SLICE_3]]{{\[}}%[[VAL_13]] : index] : f32 from vector<[4]xf32> // FULL-UNROLL: %[[TRANSPOSE_SLICE:.*]] = vector.from_elements %[[ELEM_0]], %[[ELEM_1]], %[[ELEM_2]], %[[ELEM_3]] : vector<4xf32> // FULL-UNROLL: vector.transfer_write %[[TRANSPOSE_SLICE]], %[[DEST]]{{\[}}%[[SLICE_I]], %[[J]]] {in_bounds = [true]} : vector<4xf32>, memref diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 8796f153c4911..7b7f128c1180b 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -191,7 +191,7 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 { // CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: return %[[R]] func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f32 { - %0 = vector.extract %arg0[%id] : f32 from vector<1xf32> + %0 = vector.extract %arg0[%id : index] : f32 from vector<1xf32> return %0: f32 } @@ -202,16 +202,38 @@ func.func @extract_size1_vector_dynamic(%arg0 : vector<1xf32>, %id : index) -> f // CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 // CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 { - %0 = vector.extract %arg0[%id] : f32 from vector<4xf32> + %0 = vector.extract %arg0[%id : index] : f32 from vector<4xf32> return %0: f32 } +// ----- + +// CHECK-LABEL: @extract_i32_index +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @extract_i32_index(%arg0 : vector<4xf32>, %id : i32) -> f32 { + %0 = vector.extract %arg0[%id : i32] : f32 from vector<4xf32> + return %0: f32 +} + +// ----- + +// CHECK-LABEL: @extract_i8_index +// CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i8 +// CHECK: spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i8 +func.func @extract_i8_index(%arg0 : vector<4xf32>, %id : i8) -> f32 { + %0 = vector.extract %arg0[%id : i8] : f32 from vector<4xf32> + return %0: f32 +} + +// ----- + // CHECK-LABEL: @extract_dynamic_cst // CHECK-SAME: %[[V:.*]]: vector<4xf32> // CHECK: spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32> func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 { %idx = arith.constant 1 : index - %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32> + %0 = vector.extract %arg0[%idx : index] : f32 from vector<4xf32> return %0: f32 } @@ -252,7 +274,7 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3 // CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] // CHECK: return %[[R]] func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : index) -> vector<1xf32> { - %1 = vector.insert %arg1, %arg0[%id] : f32 into vector<1xf32> + %1 = vector.insert %arg1, %arg0[%id : index] : f32 into vector<1xf32> return %1 : vector<1xf32> } @@ -263,7 +285,29 @@ func.func @insert_size1_vector_dynamic(%arg0 : vector<1xf32>, %arg1: f32, %id : // CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 // CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> { - %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32> + %0 = vector.insert %val, %arg0[%id : index] : f32 into vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_i32_index +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: i32 +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32 +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32 +func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : i32) -> vector<4xf32> { + %0 = vector.insert %val, %arg0[%id : i32] : f32 into vector<4xf32> + return %0: vector<4xf32> +} + +// ----- + +// CHECK-LABEL: @insert_i8_index +// CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: i8 +// CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i8 +// CHECK: spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i8 +func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : i8) -> vector<4xf32> { + %0 = vector.insert %val, %arg0[%id : i8] : f32 into vector<4xf32> return %0: vector<4xf32> } @@ -274,7 +318,7 @@ func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vect // CHECK: spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32> func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> { %idx = arith.constant 2 : index - %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32> + %0 = vector.insert %val, %arg0[%idx : index] : f32 into vector<4xf32> return %0: vector<4xf32> } diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir index 9000551783576..bac1c1cb5615e 100644 --- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir +++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir @@ -814,12 +814,12 @@ func.func @extract_from_arith_ext(%src: vector<4x[8]xi8>) -> vector<[8]xi32> { // CHECK-LABEL: @non_constant_extract_from_arith_ext( // CHECK-SAME: %[[SRC:[a-z0-9]+]]: vector<4x[8]xi8>, // CHECK-SAME: %[[DIM:[a-z0-9]+]]: index -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]]] : vector<[8]xi8> from vector<4x[8]xi8> +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]] : index] : vector<[8]xi8> from vector<4x[8]xi8> // CHECK: %[[EXTEND:.*]] = arith.extui %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32> // CHECK: return %[[EXTEND]] func.func @non_constant_extract_from_arith_ext(%src: vector<4x[8]xi8>, %dim: index) -> vector<[8]xi32> { %0 = arith.extui %src : vector<4x[8]xi8> to vector<4x[8]xi32> - %1 = vector.extract %0[%dim] : vector<[8]xi32> from vector<4x[8]xi32> + %1 = vector.extract %0[%dim : index] : vector<[8]xi32> from vector<4x[8]xi32> return %1 : vector<[8]xi32> } diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index 458906a187982..61b6981b194a6 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -179,10 +179,10 @@ func.func @transfer_write_f16_scalable_16x8(%dest: memref, %vec: vector // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] { - // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16> + // CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]] : index] : vector<[8]xf16> from vector<[8]x[8]xf16> // CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref // CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16> + // CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]] : index] : vector<[8]xf16> from vector<[8]x[8]xf16> // CHECK-NEXT: vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref // CHECK-NEXT: } // CHECK-NEXT: return @@ -224,20 +224,20 @@ func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref, %dim0: // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1> // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { - // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1> + // CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]] : index] : vector<[8]xi1> from vector<[8]x[8]xi1> // CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1> - // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1> - // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1> + // CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]] : index] : vector<[8]xi1> from vector<[8]x[8]xi1> // CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1> - // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1> - // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: } %c0 = arith.constant 0 : index @@ -313,16 +313,16 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref, %dest: me // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref, vector<[4]x[4]xf32> // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref, vector<[4]x[4]xf32> // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] { - // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index - // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32> + // CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]] : index] : vector<[4]xf32> from vector<[4]x[4]xf32> // CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref // CHECK-NEXT: } // CHECK-NEXT: return @@ -399,7 +399,7 @@ func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: ind // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1> // CHECK-NEXT: return %[[EXTRACT]] %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1> - %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> + %extract = vector.extract %mask[%index : index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> return %extract : vector<[4]x[4]xi1> } diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 4e1035e038ca5..1f077409a6c66 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -734,7 +734,7 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func.func @hoist_vector_broadcasts // CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> { -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]] : index] : vector<4xf32> from vector<3x4xf32> // CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} { // CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32> // CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32> @@ -744,7 +744,7 @@ module attributes {transform.with_named_sequence} { func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> { %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> { - %extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32> + %extract = vector.extract %iarg[%pos : index] : vector<4xf32> from vector<3x4xf32> %use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32> %broadcast = vector.broadcast %use : vector<4xf32> to vector<3x4xf32> scf.yield %broadcast : vector<3x4xf32> diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir index fbebb97a11983..fe108e47d5dd3 100644 --- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir @@ -88,7 +88,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): %0 = transform.param.constant 2 : i64 -> !transform.param // expected-error@below {{expected ']' in dynamic index list}} - // expected-error@below {{custom op 'transform.structured.vectorize' expected SSA value or integer}} + // expected-error@below {{custom op 'transform.structured.vectorize' expected a valid list of SSA values or integers}} transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param, 2] : !transform.any_op, !transform.param } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 5ae769090dac6..db15a0562ef4e 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -126,7 +126,7 @@ func.func @extract_from_create_mask_dynamic_position(%dim0: index, %index: index %mask = vector.create_mask %c3, %c4, %dim0 : vector<4x4x6xi1> // CHECK: vector.create_mask %[[DIM0]] : vector<6xi1> // CHECK-NOT: vector.extract - %extract = vector.extract %mask[2, %index] : vector<6xi1> from vector<4x4x6xi1> + %extract = vector.extract %mask[2, %index : index] : vector<6xi1> from vector<4x4x6xi1> return %extract : vector<6xi1> } @@ -140,7 +140,7 @@ func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %in %mask = vector.create_mask %c1, %c0, %dim0 : vector<1x4x6xi1> // CHECK: arith.constant dense : vector<6xi1> // CHECK-NOT: vector.extract - %extract = vector.extract %mask[0, %index] : vector<6xi1> from vector<1x4x6xi1> + %extract = vector.extract %mask[0, %index : index] : vector<6xi1> from vector<1x4x6xi1> return %extract : vector<6xi1> } @@ -153,8 +153,8 @@ func.func @extract_from_create_mask_dynamic_position_unknown(%dim0: index, %inde %mask = vector.create_mask %c2, %dim0 : vector<4x6xi1> // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[DIM0]] : vector<4x6xi1> - // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]]] : vector<6xi1> from vector<4x6xi1> - %extract = vector.extract %mask[%index] : vector<6xi1> from vector<4x6xi1> + // CHECK-NEXT: vector.extract %[[MASK]][%[[INDEX]] : index] : vector<6xi1> from vector<4x6xi1> + %extract = vector.extract %mask[%index : index] : vector<6xi1> from vector<4x6xi1> return %extract : vector<6xi1> } @@ -167,8 +167,8 @@ func.func @extract_from_create_mask_mixed_position_unknown(%dim0: index, %index0 %mask = vector.create_mask %c2, %c2, %dim0 : vector<2x4x4xi1> // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C2]], %[[C2]], %[[DIM0]] : vector<2x4x4xi1> - // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]]] : vector<4xi1> from vector<2x4x4xi1> - %extract = vector.extract %mask[1, %index0] : vector<4xi1> from vector<2x4x4xi1> + // CHECK-NEXT: vector.extract %[[MASK]][1, %[[INDEX]] : index] : vector<4xi1> from vector<2x4x4xi1> + %extract = vector.extract %mask[1, %index0 : index] : vector<4xi1> from vector<2x4x4xi1> return %extract : vector<4xi1> } @@ -1918,10 +1918,10 @@ func.func @extract_insert_chain(%a: vector<2x16xf32>, %b: vector<12x8x16xf32>, % // CHECK-LABEL: extract_from_extract_chain_should_not_fold_dynamic_extracts // CHECK-SAME: (%[[VEC:.*]]: vector<2x4xf32>, %[[IDX:.*]]: index) -// CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]]] : vector<4xf32> from vector<2x4xf32> +// CHECK: %[[A:.*]] = vector.extract %[[VEC]][%[[IDX]] : index] : vector<4xf32> from vector<2x4xf32> // CHECK: %[[B:.*]] = vector.extract %[[A]][1] : f32 from vector<4xf32> func.func @extract_from_extract_chain_should_not_fold_dynamic_extracts(%v: vector<2x4xf32>, %index: index) -> f32 { - %0 = vector.extract %v[%index] : vector<4xf32> from vector<2x4xf32> + %0 = vector.extract %v[%index : index] : vector<4xf32> from vector<2x4xf32> %1 = vector.extract %0[1] : f32 from vector<4xf32> return %1 : f32 } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index d591c60acb64e..90a71b8e52425 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -148,6 +148,39 @@ func.func @extract_vector_type(%arg0: index) { %1 = vector.extract %arg0[] : index from index } +// ----- +func.func @extract_mixed_index_types(%arg0 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { + // expected-error@+2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}} + // expected-note@-2 {{prior use here}} + %1 = vector.extract %arg0[%i32_idx, %i8_idx : i8] : f32 from vector<8x16xf32> +} + +// ----- +func.func @extract_index_vals_no_type(%arg0 : vector<8xf32>, + %i32_idx: i32) { + // expected-error@+2 {{expected a type for dynamic indices}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.extract %arg0[%i32_idx] : f32 from vector<8x16xf32> +} + +// ----- +func.func @extract_index_vals_multiple_types(%arg0 : vector<8xf32>, + %i8_idx : i8, + %i32_idx : i32) { + // expected-error@+2 {{expected single type}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.extract %arg0[%i8_idx, %i32_idx : i8, i32] : f32 from vector<8x16xf32> +} + +// ----- +func.func @extract_index_consts_type(%arg0 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { + // expected-error@+2 {{'vector.extract' expected no type for constant indices}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.extract %arg0[5, 3 : index] : f32 from vector<8x16xf32> +} + // ----- func.func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) { @@ -271,6 +304,38 @@ func.func @insert_0d(%a: f32, %b: vector) { %1 = vector.insert %a, %b[0] : f32 into vector } +// ----- +func.func @insert_mixed_index_types(%arg0 : f32, %arg1 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { + // expected-error@+2 {{use of value '%i32_idx' expects different type than prior uses: 'i8' vs 'i32'}} + // expected-note@-2 {{prior use here}} + %1 = vector.insert %arg0, %arg1[%i32_idx, %i8_idx : i8] : f32 into vector<8x16xf32> +} + +// ----- +func.func @insert_index_vals_no_type(%arg0 : f32, %arg1 : vector<8xf32>, + %i32_idx: i32) { + // expected-error@+2 {{expected a type for dynamic indices}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.insert %arg0, %arg1[%i32_idx] : f32 into vector<8x16xf32> +} + +// ----- +func.func @insert_index_vals_multiple_types(%arg0 : f32, %arg1 : vector<8xf32>, + %i8_idx : i8, %i32_idx : i32) { + // expected-error@+2 {{expected single type}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.insert %arg0, %arg1[%i8_idx, %i32_idx : i8, i32] : f32 into vector<8x16xf32> +} + +// ----- +func.func @insert_index_consts_type(%arg0 : f32, %arg1 : vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8) { + // expected-error@+2 {{'vector.insert' expected no type for constant indices}} + // expected-error@+1 {{expected a valid list of SSA values or integers}} + %1 = vector.insert %arg0, %arg1[5, 3 : index] : f32 into vector<8x16xf32> +} + // ----- func.func @outerproduct_num_operands(%arg0: f32) { diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 3baacba9b6124..5cc2ba366febc 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -222,17 +222,33 @@ func.func @extract_const_idx(%arg0: vector<4x8x16xf32>) // CHECK-LABEL: @extract_val_idx // CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[IDX:.+]]: index -func.func @extract_val_idx(%arg0: vector<4x8x16xf32>, %idx: index) - -> (vector<8x16xf32>, vector<16xf32>, f32) { - // CHECK: vector.extract %[[VEC]][%[[IDX]]] : vector<8x16xf32> from vector<4x8x16xf32> - %0 = vector.extract %arg0[%idx] : vector<8x16xf32> from vector<4x8x16xf32> - // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]]] : vector<16xf32> from vector<4x8x16xf32> - %1 = vector.extract %arg0[%idx, %idx] : vector<16xf32> from vector<4x8x16xf32> - // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]]] : f32 from vector<4x8x16xf32> - %2 = vector.extract %arg0[%idx, 5, %idx] : f32 from vector<4x8x16xf32> +func.func @extract_index_as_index(%arg0: vector<4x8x16xf32>, %idx: index) + -> (vector<8x16xf32>, vector<16xf32>, f32) { + // CHECK: vector.extract %[[VEC]][%[[IDX]] : index] : vector<8x16xf32> from vector<4x8x16xf32> + %0 = vector.extract %arg0[%idx : index] : vector<8x16xf32> from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], %[[IDX]] : index] : vector<16xf32> from vector<4x8x16xf32> + %1 = vector.extract %arg0[%idx, %idx : index] : vector<16xf32> from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[IDX]], 5, %[[IDX]] : index] : f32 from vector<4x8x16xf32> + %2 = vector.extract %arg0[%idx, 5, %idx : index] : f32 from vector<4x8x16xf32> return %0, %1, %2 : vector<8x16xf32>, vector<16xf32>, f32 } +// CHECK-LABEL: @extract_val_int +// CHECK-SAME: %[[VEC:.+]]: vector<4x8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8, %[[I1_IDX:.+]]: i1 +func.func @extract_index_as_int(%arg0: vector<4x8x16xf32>, %i32_idx: i32, + %i8_idx: i8, %i1_idx: i1) + -> (vector<8x16xf32>, vector<16xf32>, f32, vector<16xf32>) { + // CHECK: vector.extract %[[VEC]][%[[I32_IDX]] : i32] : vector<8x16xf32> from vector<4x8x16xf32> + %0 = vector.extract %arg0[%i32_idx : i32] : vector<8x16xf32> from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> from vector<4x8x16xf32> + %1 = vector.extract %arg0[%i8_idx, %i8_idx : i8] : vector<16xf32> from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 from vector<4x8x16xf32> + %2 = vector.extract %arg0[%i8_idx, 5, %i8_idx : i8] : f32 from vector<4x8x16xf32> + // CHECK-NEXT: vector.extract %[[VEC]][%[[I1_IDX]], 2 : i1] : vector<16xf32> from vector<4x8x16xf32> + %3 = vector.extract %arg0[%i1_idx, 2 : i1] : vector<16xf32> from vector<4x8x16xf32> + return %0, %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32, vector<16xf32> +} + // CHECK-LABEL: @extract_0d func.func @extract_0d(%a: vector) -> f32 { // CHECK-NEXT: vector.extract %{{.*}}[] : f32 from vector @@ -272,17 +288,33 @@ func.func @insert_const_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, // CHECK-LABEL: @insert_val_idx // CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[IDX:.+]]: index -func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, - %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { - // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]]] : vector<8x16xf32> into vector<4x8x16xf32> - %0 = vector.insert %c, %res[%idx] : vector<8x16xf32> into vector<4x8x16xf32> - // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]]] : vector<16xf32> into vector<4x8x16xf32> - %1 = vector.insert %b, %res[%idx, %idx] : vector<16xf32> into vector<4x8x16xf32> - // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]]] : f32 into vector<4x8x16xf32> - %2 = vector.insert %a, %res[%idx, 5, %idx] : f32 into vector<4x8x16xf32> +func.func @insert_index_as_index(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, + %idx: index, %res: vector<4x8x16xf32>) -> vector<4x8x16xf32> { + // CHECK: vector.insert %[[C]], %{{.*}}[%[[IDX]] : index] : vector<8x16xf32> into vector<4x8x16xf32> + %0 = vector.insert %c, %res[%idx : index] : vector<8x16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[B]], %{{.*}}[%[[IDX]], %[[IDX]] : index] : vector<16xf32> into vector<4x8x16xf32> + %1 = vector.insert %b, %res[%idx, %idx : index] : vector<16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[A]], %{{.*}}[%[[IDX]], 5, %[[IDX]] : index] : f32 into vector<4x8x16xf32> + %2 = vector.insert %a, %res[%idx, 5, %idx : index] : f32 into vector<4x8x16xf32> return %2 : vector<4x8x16xf32> } +// CHECK-LABEL: @insert_val_int +// CHECK-SAME: %[[A:.+]]: f32, %[[B:.+]]: vector<16xf32>, %[[C:.+]]: vector<8x16xf32>, %[[I32_IDX:.+]]: i32, %[[I8_IDX:.+]]: i8, %[[I1_IDX:.+]]: i1 +func.func @insert_index_as_int(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>, + %i32_idx: i32, %i8_idx: i8, %i1_idx: i1, %res: vector<4x8x16xf32>) + -> (vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>) { + // CHECK: vector.insert %[[C]], %{{.*}}[%[[I32_IDX]] : i32] : vector<8x16xf32> into vector<4x8x16xf32> + %0 = vector.insert %c, %res[%i32_idx : i32] : vector<8x16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[B]], %{{.*}}[%[[I8_IDX]], %[[I8_IDX]] : i8] : vector<16xf32> into vector<4x8x16xf32> + %1 = vector.insert %b, %res[%i8_idx, %i8_idx : i8] : vector<16xf32> into vector<4x8x16xf32> + // CHECK: vector.insert %[[A]], %{{.*}}[%[[I8_IDX]], 5, %[[I8_IDX]] : i8] : f32 into vector<4x8x16xf32> + %2 = vector.insert %a, %res[%i8_idx, 5, %i8_idx : i8] : f32 into vector<4x8x16xf32> + // CHECK-NEXT: vector.insert %[[B]], %{{.*}}[%[[I1_IDX]], 2 : i1] : vector<16xf32> into vector<4x8x16xf32> + %3 = vector.insert %b, %res[%i1_idx, 2 : i1] : vector<16xf32> into vector<4x8x16xf32> + return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32>, vector<4x8x16xf32> +} + // CHECK-LABEL: @insert_0d func.func @insert_0d(%a: f32, %b: vector, %c: vector<2x3xf32>) -> (vector, vector<2x3xf32>) { // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 0cecaddc5733e..4bc84fcc9c31f 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -91,13 +91,13 @@ func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector // CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2> // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]] : index] : i2 from vector<8xi2> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index -// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]] : index] : i2 from vector<8xi2> // ----- @@ -119,13 +119,13 @@ func.func @vector_load_i2_dynamic_indexing_mixed(%idx: index) -> vector<3xi2> { // CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2> // CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]] : index] : i2 from vector<8xi2> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index -// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]] : index] : i2 from vector<8xi2> // ----- @@ -147,13 +147,13 @@ func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index) // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2> // CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]] : index] : i2 from vector<8xi2> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index -// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]] : index] : i2 from vector<8xi2> // ----- @@ -176,10 +176,10 @@ func.func @vector_transfer_read_i2_dynamic_indexing_mixed(%idx1: index) -> vecto // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2> // CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2> -// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]] : index] : i2 from vector<8xi2> // CHECK: %[[C1:.+]] = arith.constant 1 : index // CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index -// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]] : index] : i2 from vector<8xi2> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index -// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2> +// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]] : index] : i2 from vector<8xi2>