diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 862ed7bae1fbb..d48ff4705b2dc 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -398,6 +398,18 @@ std::optional vector::getConstantVscaleMultiplier(Value value) { return {}; } +/// Converts an IntegerAttr to have the specified type if needed. +/// This handles cases where constant attributes have a different type than the +/// target element type. If the input attribute is not an IntegerAttr or already +/// has the correct type, returns it unchanged. +static Attribute convertIntegerAttr(Attribute attr, Type expectedType) { + if (auto intAttr = mlir::dyn_cast(attr)) { + if (intAttr.getType() != expectedType) + return IntegerAttr::get(expectedType, intAttr.getInt()); + } + return attr; +} + //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -2459,8 +2471,37 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) { return {}; } +/// Fold vector.from_elements to a constant when all operands are constants. +/// Example: +/// %c1 = arith.constant 1 : i32 +/// %c2 = arith.constant 2 : i32 +/// %v = vector.from_elements %c1, %c2 : vector<2xi32> +/// => +/// %v = arith.constant dense<[1, 2]> : vector<2xi32> +/// +static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, + ArrayRef elements) { + if (llvm::any_of(elements, [](Attribute attr) { return !attr; })) + return {}; + + auto destVecType = fromElementsOp.getDest().getType(); + auto destEltType = destVecType.getElementType(); + // Constant attributes might have a different type than the return type. + // Convert them before creating the dense elements attribute. + auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) { + return convertIntegerAttr(attr, destEltType); + }); + + return DenseElementsAttr::get(destVecType, convertedElements); +} + OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { - return foldFromElementsToElements(*this); + if (auto res = foldFromElementsToElements(*this)) + return res; + if (auto res = foldFromElementsToConstant(*this, adaptor.getElements())) + return res; + + return {}; } /// Rewrite a vector.from_elements into a vector.splat if all elements are the @@ -3322,17 +3363,6 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, /// Converts the expected type to an IntegerAttr if there's /// a mismatch. - auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute { - if (auto intAttr = mlir::dyn_cast(attr)) { - if (intAttr.getType() != expectedType) - return IntegerAttr::get(expectedType, intAttr.getInt()); - } - return attr; - }; - - // The `convertIntegerAttr` method specifically handles the case - // for `llvm.mlir.constant` which can hold an attribute with a - // different type than the return type. if (auto denseSource = llvm::dyn_cast(srcAttr)) { for (auto value : denseSource.getValues()) insertedValues.push_back(convertIntegerAttr(value, destEltType)); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 65b73375831da..0282e9cac5e02 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3075,6 +3075,33 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x // ----- +// CHECK-LABEL: func @from_elements_all_elements_constant( +func.func @from_elements_all_elements_constant() -> vector<2x2xi32> { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %c3_i32 = arith.constant 3 : i32 + // CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32> + %res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32> + // CHECK: return %[[RES]] + return %res : vector<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @from_elements_partial_elements_constant( +// CHECK-SAME: %[[A:.*]]: f32 +func.func @from_elements_partial_elements_constant(%arg0: f32) -> vector<2xf32> { + // CHECK: %[[C:.*]] = arith.constant 1.000000e+00 : f32 + %c = arith.constant 1.0 : f32 + // CHECK: %[[RES:.*]] = vector.from_elements %[[A]], %[[C]] : vector<2xf32> + %res = vector.from_elements %arg0, %c : vector<2xf32> + // CHECK: return %[[RES]] + return %res : vector<2xf32> +} + +// ----- + // CHECK-LABEL: func @vector_insert_const_regression( // CHECK: llvm.mlir.undef // CHECK: vector.insert