diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index 19e82099e87f0..7e4fdb22342c2 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -19,9 +19,9 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" @@ -1091,12 +1091,14 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) { return false; // TODO: Allow intrinsics with different argument types - // TODO: Allow intrinsics with scalar arguments - if (II && (!isTriviallyVectorizable(II->getIntrinsicID()) || - !all_of(II->args(), [&II](Value *Arg) { - return Arg->getType() == II->getType(); - }))) - return false; + if (II) { + if (!isTriviallyVectorizable(II->getIntrinsicID())) + return false; + for (auto [Idx, Arg] : enumerate(II->args())) + if (Arg->getType() != II->getType() && + !isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, &TTI)) + return false; + } // Do not convert the vector condition of a vector select into a scalar // condition. That may cause problems for codegen because of differences in @@ -1109,19 +1111,18 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) { // Match constant vectors or scalars being inserted into constant vectors: // vec_op [VecC0 | (inselt VecC0, V0, Index)], ... - SmallVector VecCs; - SmallVector ScalarOps; + SmallVector VecCs, ScalarOps; std::optional Index; auto Ops = II ? II->args() : I.operands(); - for (Value *Op : Ops) { + for (auto [OpNum, Op] : enumerate(Ops)) { Constant *VecC; Value *V; uint64_t InsIdx = 0; - VectorType *OpTy = cast(Op->getType()); - if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V), - m_ConstantInt(InsIdx)))) { + if (match(Op.get(), m_InsertElt(m_Constant(VecC), m_Value(V), + m_ConstantInt(InsIdx)))) { // Bail if any inserts are out of bounds. + VectorType *OpTy = cast(Op->getType()); if (OpTy->getElementCount().getKnownMinValue() <= InsIdx) return false; // All inserts must have the same index. @@ -1132,7 +1133,11 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) { return false; VecCs.push_back(VecC); ScalarOps.push_back(V); - } else if (match(Op, m_Constant(VecC))) { + } else if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), + OpNum, &TTI)) { + VecCs.push_back(Op.get()); + ScalarOps.push_back(Op.get()); + } else if (match(Op.get(), m_Constant(VecC))) { VecCs.push_back(VecC); ScalarOps.push_back(nullptr); } else { @@ -1176,16 +1181,17 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) { // Fold the vector constants in the original vectors into a new base vector to // get more accurate cost modelling. Value *NewVecC = nullptr; + TargetFolder Folder(*DL); if (CI) - NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0], - VecCs[1], *DL); + NewVecC = Folder.FoldCmp(CI->getPredicate(), VecCs[0], VecCs[1]); else if (UO) - NewVecC = ConstantFoldUnaryOpOperand(Opcode, VecCs[0], *DL); + NewVecC = + Folder.FoldUnOpFMF(UO->getOpcode(), VecCs[0], UO->getFastMathFlags()); else if (BO) - NewVecC = ConstantFoldBinaryOpOperands(Opcode, VecCs[0], VecCs[1], *DL); + NewVecC = Folder.FoldBinOp(BO->getOpcode(), VecCs[0], VecCs[1]); else if (II->arg_size() == 2) - NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0], - VecCs[1], II->getType(), II); + NewVecC = Folder.FoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0], + VecCs[1], II->getType(), &I); // Get cost estimate for the insert element. This cost will factor into // both sequences. @@ -1193,8 +1199,9 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) { InstructionCost NewCost = ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, *Index, NewVecC); - for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) { - if (!Scalar) + for (auto [Idx, Op, VecC, Scalar] : enumerate(Ops, VecCs, ScalarOps)) { + if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg( + II->getIntrinsicID(), Idx, &TTI))) continue; InstructionCost InsertCost = TTI.getVectorInstrCost( Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar); @@ -1238,16 +1245,12 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) { // Create a new base vector if the constant folding failed. if (!NewVecC) { - SmallVector VecCValues; - VecCValues.reserve(VecCs.size()); - append_range(VecCValues, VecCs); if (CI) NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]); else if (UO || BO) - NewVecC = Builder.CreateNAryOp(Opcode, VecCValues); + NewVecC = Builder.CreateNAryOp(Opcode, VecCs); else - NewVecC = - Builder.CreateIntrinsic(VecTy, II->getIntrinsicID(), VecCValues); + NewVecC = Builder.CreateIntrinsic(VecTy, II->getIntrinsicID(), VecCs); } Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index); replaceValue(I, *Insert); diff --git a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll index 58b7f8de004d0..9e43a28bf1e59 100644 --- a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll +++ b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll @@ -152,12 +152,12 @@ define @fma_scalable(float %x, float %y, float %z) { ret %v } -; TODO: We should be able to scalarize this if we preserve the scalar argument. define <4 x float> @scalar_argument(float %x) { ; CHECK-LABEL: define <4 x float> @scalar_argument( ; CHECK-SAME: float [[X:%.*]]) { -; CHECK-NEXT: [[X_INSERT:%.*]] = insertelement <4 x float> poison, float [[X]], i32 0 -; CHECK-NEXT: [[V:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> [[X_INSERT]], i32 42) +; CHECK-NEXT: [[V_SCALAR:%.*]] = call float @llvm.powi.f32.i32(float [[X]], i32 42) +; CHECK-NEXT: [[TMP1:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> poison, i32 42) +; CHECK-NEXT: [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0 ; CHECK-NEXT: ret <4 x float> [[V]] ; %x.insert = insertelement <4 x float> poison, float %x, i32 0