Skip to content

[VectorCombine] Scalarize extracts of ZExt if profitable. #142976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

fhahn
Copy link
Contributor

@fhahn fhahn commented Jun 5, 2025

Add a new scalarization transform that tries to convert extracts of a vector ZExt to a set of scalar shift and mask operations. This can be profitable if the cost of extracting is the same or higher than the cost of 2 scalar ops. This is the case on AArch64 for example.

For AArch64,this shows up in a number of workloads, including av1aom, gmsh, minizinc and astc-encoder.

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: Florian Hahn (fhahn)

Changes

Add a new scalarization transform that tries to convert extracts of a vector ZExt to a set of scalar shift and mask operations. This can be profitable if the cost of extracting is the same or higher than the cost of 2 scalar ops. This is the case on AArch64 for example.

For AArch64,this shows up in a number of workloads, including av1aom, gmsh, minizinc and astc-encoder.


Full diff: https://github.com/llvm/llvm-project/pull/142976.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+68)
  • (modified) llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll (+87-24)
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 52cb1dbb33b86..85375654fbc19 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -120,6 +120,7 @@ class VectorCombine {
   bool foldBinopOfReductions(Instruction &I);
   bool foldSingleElementStore(Instruction &I);
   bool scalarizeLoadExtract(Instruction &I);
+  bool scalarizeExtExtract(Instruction &I);
   bool foldConcatOfBoolMasks(Instruction &I);
   bool foldPermuteOfBinops(Instruction &I);
   bool foldShuffleOfBinops(Instruction &I);
@@ -1710,6 +1711,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
   return true;
 }
 
+bool VectorCombine::scalarizeExtExtract(Instruction &I) {
+  if (!match(&I, m_ZExt(m_Value())))
+    return false;
+
+  // Try to convert a vector zext feeding only extracts to a set of scalar (Src
+  // << ExtIdx *Size) & (Size -1), if profitable.
+  auto *Ext = cast<ZExtInst>(&I);
+  auto *SrcTy = cast<FixedVectorType>(Ext->getOperand(0)->getType());
+  auto *DstTy = cast<FixedVectorType>(Ext->getType());
+
+  if (DL->getTypeSizeInBits(SrcTy) !=
+      DL->getTypeSizeInBits(DstTy->getElementType()))
+    return false;
+
+  InstructionCost VectorCost = TTI.getCastInstrCost(
+      Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
+  unsigned ExtCnt = 0;
+  bool ExtLane0 = false;
+  for (User *U : Ext->users()) {
+    const APInt *Idx;
+    if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
+      return false;
+    if (cast<Instruction>(U)->use_empty())
+      continue;
+    ExtCnt += 1;
+    ExtLane0 |= Idx->isZero();
+    VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
+                                         CostKind, Idx->getZExtValue(), U);
+  }
+
+  Type *ScalarDstTy = DstTy->getElementType();
+  InstructionCost ScalarCost =
+      ExtCnt * TTI.getArithmeticInstrCost(
+                   Instruction::And, ScalarDstTy, CostKind,
+                   {TTI::OK_AnyValue, TTI::OP_None},
+                   {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
+      (ExtCnt - ExtLane0) *
+          TTI.getArithmeticInstrCost(
+
+              Instruction::LShr, ScalarDstTy, CostKind,
+              {TTI::OK_AnyValue, TTI::OP_None},
+              {TTI::OK_NonUniformConstantValue, TTI::OP_None});
+  if (ScalarCost > VectorCost)
+    return false;
+
+  Value *ScalarV = Ext->getOperand(0);
+  if (!isGuaranteedNotToBePoison(ScalarV, &AC))
+    ScalarV = Builder.CreateFreeze(ScalarV);
+  ScalarV = Builder.CreateBitCast(
+      ScalarV,
+      IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
+  unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
+  Value *EltBitMask =
+      ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1);
+  for (auto *U : to_vector(Ext->users())) {
+    auto *Extract = cast<ExtractElementInst>(U);
+    unsigned Idx =
+        cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
+    auto *S = Builder.CreateLShr(
+        ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits));
+    auto *A = Builder.CreateAnd(S, EltBitMask);
+    U->replaceAllUsesWith(A);
+  }
+  return true;
+}
+
 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
 /// to "(bitcast (concat X, Y))"
 /// where X/Y are bitcasted from i1 mask vectors.
@@ -3576,6 +3643,7 @@ bool VectorCombine::run() {
     if (IsVectorType) {
       MadeChange |= scalarizeOpOrCmp(I);
       MadeChange |= scalarizeLoadExtract(I);
+      MadeChange |= scalarizeExtExtract(I);
       MadeChange |= scalarizeVPIntrinsic(I);
       MadeChange |= foldInterleaveIntrinsics(I);
     }
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
index 09c03991ad7c3..23538589ae32c 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll
@@ -9,15 +9,25 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
+; CHECK-NEXT:    [[TMP8:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = and i32 [[TMP8]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP9]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -68,13 +78,21 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -93,13 +111,21 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = and i32 [[TMP6]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP7]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -118,11 +144,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -139,11 +171,17 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) {
 ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2(
 ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = and i32 [[TMP2]], 255
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP5]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -160,15 +198,24 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) {
 ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef(
 ; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[TMP0]], 24
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 255
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i32 [[TMP0]], 16
+; CHECK-NEXT:    [[TMP4:%.*]] = and i32 [[TMP3]], 255
+; CHECK-NEXT:    [[TMP5:%.*]] = lshr i32 [[TMP0]], 8
+; CHECK-NEXT:    [[TMP6:%.*]] = and i32 [[TMP5]], 255
+; CHECK-NEXT:    [[TMP7:%.*]] = lshr i32 [[TMP0]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = and i32 [[TMP7]], 255
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i32(i32 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP8]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP6]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP4]])
+; CHECK-NEXT:    call void @use.i32(i32 [[TMP2]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -221,15 +268,25 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) {
 ; CHECK-LABEL: define void @zext_v4i16_all_lanes_used(
 ; CHECK-SAME: <4 x i16> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <4 x i16> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i64 [[TMP1]], 48
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 65535
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT:    [[TMP5:%.*]] = and i64 [[TMP4]], 65535
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i64 [[TMP1]], 16
+; CHECK-NEXT:    [[TMP7:%.*]] = and i64 [[TMP6]], 65535
+; CHECK-NEXT:    [[TMP8:%.*]] = lshr i64 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = and i64 [[TMP8]], 65535
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1
 ; CHECK-NEXT:    [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2
 ; CHECK-NEXT:    [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_1]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_2]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_3]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP9]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP7]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -250,11 +307,17 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) {
 ; CHECK-LABEL: define void @zext_v2i32_all_lanes_used(
 ; CHECK-SAME: <2 x i32> [[SRC:%.*]]) {
 ; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze <2 x i32> [[SRC]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i64 [[TMP1]], 32
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 4294967295
+; CHECK-NEXT:    [[TMP4:%.*]] = lshr i64 [[TMP1]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = and i64 [[TMP4]], 4294967295
 ; CHECK-NEXT:    [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64>
 ; CHECK-NEXT:    [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0
 ; CHECK-NEXT:    [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_0]])
-; CHECK-NEXT:    call void @use.i64(i64 [[EXT_1]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP5]])
+; CHECK-NEXT:    call void @use.i64(i64 [[TMP3]])
 ; CHECK-NEXT:    ret void
 ;
 entry:

ScalarV = Builder.CreateFreeze(ScalarV);
ScalarV = Builder.CreateBitCast(
ScalarV,
IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you account for the cost of the bitcast (vec->int transfer)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I assumed they are basically always free? Is there a dedicated hook to check their cost?

Copy link
Collaborator

@RKSimon RKSimon Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know x86 can struggle with gpr<->simd move throughput (although I wouldn't be surprised if the cost tables are missing coverage).

getCastInstrCost should be able to handle bitcasts - by default I think they are treated as free.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes that's a good point. Some tests aren't transformed after the change. The motivating case has a load feeding the zext, so there's no need for the bitcast; planning to handle this as follow-up

return false;

InstructionCost VectorCost = TTI.getCastInstrCost(
Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the Instruction *I op since you have it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done thanks

@fhahn fhahn force-pushed the vector-combine-scalarize-extracts-of-zext branch from 681cd15 to 5daa841 Compare June 5, 2025 18:02
@fhahn fhahn force-pushed the vector-combine-scalarize-extracts-of-zext branch 2 times, most recently from 30e314f to 09c6867 Compare June 20, 2025 16:27
@fhahn
Copy link
Contributor Author

fhahn commented Jun 20, 2025

ping :)

@RKSimon RKSimon added the llvm::vectorcombine Cost-based vector combine pass label Jun 23, 2025
fhahn added 3 commits June 30, 2025 14:50
Add a new scalarization transform that tries to convert extracts of
a vector ZExt to a set of scalar shift and mask operations. This can be
profitable if the cost of extracting is the same or higher than the cost
of 2 scalar ops. This is the case on AArch64 for example.

For AArch64,this shows up in a number of workloads, including av1aom, gmsh,
minizinc and astc-encoder.
@fhahn fhahn force-pushed the vector-combine-scalarize-extracts-of-zext branch from 09c6867 to 75ed96d Compare June 30, 2025 13:50
Copy link
Contributor

@artagnon artagnon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Some suggestions follow.

Comment on lines +1828 to +1829
unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
unsigned EltBitMask = (1ull << SrcEltSizeInBits) - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
unsigned EltBitMask = (1ull << SrcEltSizeInBits) - 1;
uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
uint64_t EltBitMask = (1ull << SrcEltSizeInBits) - 1;

Comment on lines +1834 to +1836
Value *S = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits);
Value *A = Builder.CreateAnd(S, EltBitMask);
U->replaceAllUsesWith(A);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Value *S = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits);
Value *A = Builder.CreateAnd(S, EltBitMask);
U->replaceAllUsesWith(A);
Value *LShr = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits);
Value *And = Builder.CreateAnd(LShr, EltBitMask);
U->replaceAllUsesWith(And);

Comment on lines +333 to +334
define void @zext_nv4i8_all_lanes_used(<vscale x 4 x i8> %src) {
; CHECK-LABEL: define void @zext_nv4i8_all_lanes_used(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
define void @zext_nv4i8_all_lanes_used(<vscale x 4 x i8> %src) {
; CHECK-LABEL: define void @zext_nv4i8_all_lanes_used(
define void @zext_nxv4i8_all_lanes_used(<vscale x 4 x i8> %src) {
; CHECK-LABEL: define void @zext_nxv4i8_all_lanes_used(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants