Skip to content

[VectorCombine] Scalarize vector intrinsics with scalar arguments #146530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Jul 1, 2025

Some intrinsics like llvm.abs or llvm.powi have a scalar argument even when the overloaded type is a vector.
This patch handles these in scalarizeOpOrCmp to allow scalarizing them.

In the test the leftover vector powi isn't folded away to poison, this should be fixed in a separate patch.

Some intrinsics like llvm.abs or llvm.powi have a scalar argument even when the overloaded type is a vector.
This patch handles these in scalarizeOpOrCmp to allow scalarizing them.

In the test the leftover vector powi isn't folded away to poison, this should be fixed in a separate patch.
@llvmbot
Copy link
Member

llvmbot commented Jul 1, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-vectorizers

Author: Luke Lau (lukel97)

Changes

Some intrinsics like llvm.abs or llvm.powi have a scalar argument even when the overloaded type is a vector.
This patch handles these in scalarizeOpOrCmp to allow scalarizing them.

In the test the leftover vector powi isn't folded away to poison, this should be fixed in a separate patch.


Full diff: https://github.com/llvm/llvm-project/pull/146530.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+31-27)
  • (modified) llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll (+3-3)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 19e82099e87f0..9be684a61c8f0 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -22,6 +22,7 @@
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/TargetFolder.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
@@ -1091,12 +1092,14 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
     return false;
 
   // TODO: Allow intrinsics with different argument types
-  // TODO: Allow intrinsics with scalar arguments
-  if (II && (!isTriviallyVectorizable(II->getIntrinsicID()) ||
-             !all_of(II->args(), [&II](Value *Arg) {
-               return Arg->getType() == II->getType();
-             })))
-    return false;
+  if (II) {
+    if (!isTriviallyVectorizable(II->getIntrinsicID()))
+      return false;
+    for (auto [Idx, Arg] : enumerate(II->args()))
+      if (Arg->getType() != II->getType() &&
+          !isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, &TTI))
+        return false;
+  }
 
   // Do not convert the vector condition of a vector select into a scalar
   // condition. That may cause problems for codegen because of differences in
@@ -1109,19 +1112,18 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
 
   // Match constant vectors or scalars being inserted into constant vectors:
   // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
-  SmallVector<Constant *> VecCs;
-  SmallVector<Value *> ScalarOps;
+  SmallVector<Value *> VecCs, ScalarOps;
   std::optional<uint64_t> Index;
 
   auto Ops = II ? II->args() : I.operands();
-  for (Value *Op : Ops) {
+  for (auto [OpNum, Op] : enumerate(Ops)) {
     Constant *VecC;
     Value *V;
     uint64_t InsIdx = 0;
-    VectorType *OpTy = cast<VectorType>(Op->getType());
-    if (match(Op, m_InsertElt(m_Constant(VecC), m_Value(V),
-                              m_ConstantInt(InsIdx)))) {
+    if (match(Op.get(), m_InsertElt(m_Constant(VecC), m_Value(V),
+                                    m_ConstantInt(InsIdx)))) {
       // Bail if any inserts are out of bounds.
+      VectorType *OpTy = cast<VectorType>(Op->getType());
       if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
         return false;
       // All inserts must have the same index.
@@ -1132,7 +1134,11 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
         return false;
       VecCs.push_back(VecC);
       ScalarOps.push_back(V);
-    } else if (match(Op, m_Constant(VecC))) {
+    } else if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(),
+                                                        OpNum, &TTI)) {
+      VecCs.push_back(Op.get());
+      ScalarOps.push_back(Op.get());
+    } else if (match(Op.get(), m_Constant(VecC))) {
       VecCs.push_back(VecC);
       ScalarOps.push_back(nullptr);
     } else {
@@ -1176,16 +1182,17 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
   // Fold the vector constants in the original vectors into a new base vector to
   // get more accurate cost modelling.
   Value *NewVecC = nullptr;
+  TargetFolder Folder(*DL);
   if (CI)
-    NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0],
-                                              VecCs[1], *DL);
+    NewVecC = Folder.FoldCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
   else if (UO)
-    NewVecC = ConstantFoldUnaryOpOperand(Opcode, VecCs[0], *DL);
+    NewVecC =
+        Folder.FoldUnOpFMF(UO->getOpcode(), VecCs[0], UO->getFastMathFlags());
   else if (BO)
-    NewVecC = ConstantFoldBinaryOpOperands(Opcode, VecCs[0], VecCs[1], *DL);
+    NewVecC = Folder.FoldBinOp(BO->getOpcode(), VecCs[0], VecCs[1]);
   else if (II->arg_size() == 2)
-    NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
-                                          VecCs[1], II->getType(), II);
+    NewVecC = Folder.FoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
+                                         VecCs[1], II->getType(), &I);
 
   // Get cost estimate for the insert element. This cost will factor into
   // both sequences.
@@ -1193,8 +1200,9 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
   InstructionCost NewCost =
       ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
                                             CostKind, *Index, NewVecC);
-  for (auto [Op, VecC, Scalar] : zip(Ops, VecCs, ScalarOps)) {
-    if (!Scalar)
+  for (auto [Idx, Op, VecC, Scalar] : enumerate(Ops, VecCs, ScalarOps)) {
+    if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
+                              II->getIntrinsicID(), Idx, &TTI)))
       continue;
     InstructionCost InsertCost = TTI.getVectorInstrCost(
         Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
@@ -1238,16 +1246,12 @@ bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
 
   // Create a new base vector if the constant folding failed.
   if (!NewVecC) {
-    SmallVector<Value *> VecCValues;
-    VecCValues.reserve(VecCs.size());
-    append_range(VecCValues, VecCs);
     if (CI)
       NewVecC = Builder.CreateCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
     else if (UO || BO)
-      NewVecC = Builder.CreateNAryOp(Opcode, VecCValues);
+      NewVecC = Builder.CreateNAryOp(Opcode, VecCs);
     else
-      NewVecC =
-          Builder.CreateIntrinsic(VecTy, II->getIntrinsicID(), VecCValues);
+      NewVecC = Builder.CreateIntrinsic(VecTy, II->getIntrinsicID(), VecCs);
   }
   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
   replaceValue(I, *Insert);
diff --git a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
index 58b7f8de004d0..9e43a28bf1e59 100644
--- a/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
+++ b/llvm/test/Transforms/VectorCombine/intrinsic-scalarize.ll
@@ -152,12 +152,12 @@ define <vscale x 4 x float> @fma_scalable(float %x, float %y, float %z) {
   ret <vscale x 4 x float> %v
 }
 
-; TODO: We should be able to scalarize this if we preserve the scalar argument.
 define <4 x float> @scalar_argument(float %x) {
 ; CHECK-LABEL: define <4 x float> @scalar_argument(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[X_INSERT:%.*]] = insertelement <4 x float> poison, float [[X]], i32 0
-; CHECK-NEXT:    [[V:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> [[X_INSERT]], i32 42)
+; CHECK-NEXT:    [[V_SCALAR:%.*]] = call float @llvm.powi.f32.i32(float [[X]], i32 42)
+; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> poison, i32 42)
+; CHECK-NEXT:    [[V:%.*]] = insertelement <4 x float> [[TMP1]], float [[V_SCALAR]], i64 0
 ; CHECK-NEXT:    ret <4 x float> [[V]]
 ;
   %x.insert = insertelement <4 x float> poison, float %x, i32 0

Comment on lines +1185 to +1195
TargetFolder Folder(*DL);
if (CI)
NewVecC = ConstantFoldCompareInstOperands(CI->getPredicate(), VecCs[0],
VecCs[1], *DL);
NewVecC = Folder.FoldCmp(CI->getPredicate(), VecCs[0], VecCs[1]);
else if (UO)
NewVecC = ConstantFoldUnaryOpOperand(Opcode, VecCs[0], *DL);
NewVecC =
Folder.FoldUnOpFMF(UO->getOpcode(), VecCs[0], UO->getFastMathFlags());
else if (BO)
NewVecC = ConstantFoldBinaryOpOperands(Opcode, VecCs[0], VecCs[1], *DL);
NewVecC = Folder.FoldBinOp(BO->getOpcode(), VecCs[0], VecCs[1]);
else if (II->arg_size() == 2)
NewVecC = ConstantFoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
VecCs[1], II->getType(), II);
NewVecC = Folder.FoldBinaryIntrinsic(II->getIntrinsicID(), VecCs[0],
VecCs[1], II->getType(), &I);
Copy link
Contributor Author

@lukel97 lukel97 Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've switched the API here to use ConstantFolder which basically calls the same underlying ConstantFoldFoo functions, the only difference being that it also checks to see if the arguments are constant which allows us to make VecCs Values instead of Constants.

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@lukel97 lukel97 merged commit 7931a8f into llvm:main Jul 2, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants