Skip to content

Add DoNotPoisonEltMask to several SimplifyDemanded function in TargetLowering #145903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: users/bjope/demandedbits_bitcast_1
Choose a base branch
from

Conversation

bjope
Copy link
Collaborator

@bjope bjope commented Jun 26, 2025

This is currently several commits in one PR. Should perhaps be splitted in several pull requests.

bjope added 4 commits June 26, 2025 16:08
The fix for #138513 resulted in a number of regressions due to
the need to demand elements corresponding to bits used by bitcasts
even if those bits weren't used. Problem was that if we did not
demand those elements the calls to SimplifyDemandedVectorElts
could end up turning those unused elements in to poison, making
the bitcast result poison.

This patch is trying to avoid such regressions by adding a new
element mask ('DoNotPoisonEltMask') to SimplifyDemandedVectorElts
that identify elements that aren't really demanded, but they must
not be made more poisonous during simplifications.
…rTargetNode

Add DoNotPoisonEltMask to SimplifyDemandedVectorEltsForTargetNode
and try to handle it for a number of X86 opcodes. In some situations
we just fallback and assume that the DoNotPoisonEltMask elements
are demanded.

Goal is to reduce amount of regressions after fix of #138513.
…its/VectorElts

Add DoNotPoisonEltMask to SimplifyMultipleUseDemandedBits and
SimplifyMultipleUseDemandedVectorElts.

Goal is to reduce amount of regressions after fix of #138513.
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-backend-x86

Author: Björn Pettersson (bjope)

Changes

This is currently several commits in one PR. Should perhaps be splitted in several pull requests.


Patch is 1.15 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145903.diff

112 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+24-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+225-125)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+117-84)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+9-12)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i1.ll (+278-324)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i16.ll (+154-184)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i8.ll (+445-517)
  • (modified) llvm/test/CodeGen/AMDGPU/shift-i128.ll (+7-9)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v2bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v3bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v4bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v2f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v3f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v4f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/ARM/fpclamptosat_vec.ll (+210-234)
  • (modified) llvm/test/CodeGen/Thumb2/mve-fpclamptosat_vec.ll (+38-42)
  • (modified) llvm/test/CodeGen/Thumb2/mve-gather-ind8-unscaled.ll (+5)
  • (modified) llvm/test/CodeGen/Thumb2/mve-laneinterleaving.ll (+41-45)
  • (modified) llvm/test/CodeGen/Thumb2/mve-pred-ext.ll (-1)
  • (modified) llvm/test/CodeGen/Thumb2/mve-satmul-loops.ll (+95-118)
  • (modified) llvm/test/CodeGen/Thumb2/mve-scatter-ind8-unscaled.ll (+7-2)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-addpred.ll (+4-4)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll (+4-4)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vst3.ll (+19-19)
  • (modified) llvm/test/CodeGen/X86/avx512-intrinsics-fast-isel.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/avx512-intrinsics-upgrade.ll (+14-32)
  • (modified) llvm/test/CodeGen/X86/avx512fp16-mov.ll (+16-16)
  • (modified) llvm/test/CodeGen/X86/avx512vl-intrinsics-upgrade.ll (+14-38)
  • (modified) llvm/test/CodeGen/X86/avx512vl-vec-masked-cmp.ll (+30-60)
  • (modified) llvm/test/CodeGen/X86/bitcast-and-setcc-128.ll (+12-16)
  • (modified) llvm/test/CodeGen/X86/bitcast-setcc-128.ll (+5-7)
  • (modified) llvm/test/CodeGen/X86/bitcast-vector-bool.ll (-4)
  • (modified) llvm/test/CodeGen/X86/buildvec-widen-dotproduct.ll (+15-19)
  • (modified) llvm/test/CodeGen/X86/combine-pmuldq.ll (+18-69)
  • (modified) llvm/test/CodeGen/X86/combine-sdiv.ll (+19-20)
  • (modified) llvm/test/CodeGen/X86/combine-sra.ll (+31-32)
  • (modified) llvm/test/CodeGen/X86/combine-udiv.ll (+1-4)
  • (modified) llvm/test/CodeGen/X86/extractelement-load.ll (+5-4)
  • (modified) llvm/test/CodeGen/X86/f16c-intrinsics-fast-isel.ll (-4)
  • (modified) llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll (+45-44)
  • (modified) llvm/test/CodeGen/X86/gfni-funnel-shifts.ll (+150-150)
  • (modified) llvm/test/CodeGen/X86/half.ll (+3-3)
  • (modified) llvm/test/CodeGen/X86/hoist-and-by-const-from-shl-in-eqcmp-zero.ll (+28-20)
  • (modified) llvm/test/CodeGen/X86/known-never-zero.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/known-pow2.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/known-signbits-shl.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/known-signbits-vector.ll (+7-23)
  • (modified) llvm/test/CodeGen/X86/masked_store.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/movmsk-cmp.ll (+8-35)
  • (modified) llvm/test/CodeGen/X86/mulvi32.ll (+7-8)
  • (modified) llvm/test/CodeGen/X86/omit-urem-of-power-of-two-or-zero-when-comparing-with-zero.ll (+11-14)
  • (modified) llvm/test/CodeGen/X86/pmul.ll (+66-74)
  • (modified) llvm/test/CodeGen/X86/pmulh.ll (+2-4)
  • (modified) llvm/test/CodeGen/X86/pr107423.ll (+13-13)
  • (modified) llvm/test/CodeGen/X86/pr35918.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/pr41619.ll (-2)
  • (modified) llvm/test/CodeGen/X86/pr42727.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr45563-2.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr45833.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr77459.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/psubus.ll (+123-132)
  • (modified) llvm/test/CodeGen/X86/rotate-extract-vector.ll (-2)
  • (modified) llvm/test/CodeGen/X86/sadd_sat_vec.ll (+93-114)
  • (modified) llvm/test/CodeGen/X86/sat-add.ll (+9-12)
  • (modified) llvm/test/CodeGen/X86/sdiv-exact.ll (+13-17)
  • (modified) llvm/test/CodeGen/X86/shrink_vmul.ll (+20-14)
  • (modified) llvm/test/CodeGen/X86/srem-seteq-vec-nonsplat.ll (+9-11)
  • (modified) llvm/test/CodeGen/X86/sshl_sat_vec.ll (+38-38)
  • (modified) llvm/test/CodeGen/X86/ssub_sat_vec.ll (+123-151)
  • (modified) llvm/test/CodeGen/X86/test-shrink-bug.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/udiv-exact.ll (+13-17)
  • (modified) llvm/test/CodeGen/X86/urem-seteq-illegal-types.ll (+6-4)
  • (modified) llvm/test/CodeGen/X86/urem-seteq-vec-nonsplat.ll (+122-172)
  • (modified) llvm/test/CodeGen/X86/ushl_sat_vec.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vec_minmax_sint.ll (+66-92)
  • (modified) llvm/test/CodeGen/X86/vec_minmax_uint.ll (+66-92)
  • (modified) llvm/test/CodeGen/X86/vec_smulo.ll (+32-38)
  • (modified) llvm/test/CodeGen/X86/vec_umulo.ll (+60-70)
  • (modified) llvm/test/CodeGen/X86/vector-compare-all_of.ll (+15-21)
  • (modified) llvm/test/CodeGen/X86/vector-compare-any_of.ll (+15-21)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-128.ll (+139-159)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-256.ll (+47-48)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-rot-128.ll (+37-56)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-rot-256.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-128.ll (+27-33)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-256.ll (+19-20)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-rot-128.ll (+28-33)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-rot-256.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll (+839-863)
  • (modified) llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-8.ll (+214-222)
  • (modified) llvm/test/CodeGen/X86/vector-mul.ll (+34-34)
  • (modified) llvm/test/CodeGen/X86/vector-pcmp.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-fmaximum.ll (+81-82)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-mul.ll (+57-113)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-smax.ll (+69-99)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-smin.ll (+65-96)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-umax.ll (+69-99)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-umin.ll (+65-96)
  • (modified) llvm/test/CodeGen/X86/vector-rotate-128.ll (+37-56)
  • (modified) llvm/test/CodeGen/X86/vector-rotate-256.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-128.ll (+24-28)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-256.ll (+20-22)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-sub128.ll (+48-56)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-avx.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-ssse3.ll (+6-3)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining.ll (+3-6)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-packus.ll (+806-945)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-ssat.ll (+609-734)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-usat.ll (+368-398)
  • (modified) llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll (+5-4)
  • (modified) llvm/test/CodeGen/X86/vselect.ll (+18-8)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 727526055e592..191f4cc78fcc5 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4187,6 +4187,16 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// More limited version of SimplifyDemandedBits that can be used to "look
   /// through" ops that don't contribute to the DemandedBits/DemandedElts -
   /// bitwise ops etc.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in the \p DoNotPoisonEltMask is set.
+  SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
+                                          const APInt &DemandedElts,
+                                          const APInt &DoNotPoisonEltMask,
+                                          SelectionDAG &DAG,
+                                          unsigned Depth = 0) const;
+
+  /// Helper wrapper around SimplifyMultipleUseDemandedBits, with
+  /// DoNotPoisonEltMask being set to zero.
   SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
                                           const APInt &DemandedElts,
                                           SelectionDAG &DAG,
@@ -4202,6 +4212,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bits from only some vector elements.
   SDValue SimplifyMultipleUseDemandedVectorElts(SDValue Op,
                                                 const APInt &DemandedElts,
+                                                const APInt &DoNotPoisonEltMask,
                                                 SelectionDAG &DAG,
                                                 unsigned Depth = 0) const;
 
@@ -4219,6 +4230,15 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   ///    results of this function, because simply replacing TLO.Old
   ///    with TLO.New will be incorrect when this parameter is true and TLO.Old
   ///    has multiple uses.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in \p DoNotPoisonEltMask is set.
+  bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
+                                  const APInt &DoNotPoisonEltMask,
+                                  APInt &KnownUndef, APInt &KnownZero,
+                                  TargetLoweringOpt &TLO, unsigned Depth = 0,
+                                  bool AssumeSingleUse = false) const;
+  /// Version of SimplifyDemandedVectorElts without the DoNotPoisonEltMask
+  /// argument. All undemanded elements can be turned into poison.
   bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
                                   APInt &KnownUndef, APInt &KnownZero,
                                   TargetLoweringOpt &TLO, unsigned Depth = 0,
@@ -4303,8 +4323,9 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// (used to simplify the caller). The KnownUndef/Zero elements may only be
   /// accurate for those bits in the DemandedMask.
   virtual bool SimplifyDemandedVectorEltsForTargetNode(
-      SDValue Op, const APInt &DemandedElts, APInt &KnownUndef,
-      APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth = 0) const;
+      SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+      APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO,
+      unsigned Depth = 0) const;
 
   /// Attempt to simplify any target nodes based on the demanded bits/elts,
   /// returning true on success. Otherwise, analyze the
@@ -4323,6 +4344,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bitwise ops etc.
   virtual SDValue SimplifyMultipleUseDemandedBitsForTargetNode(
       SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+      const APInt &DoNotPoisonEltMask,
       SelectionDAG &DAG, unsigned Depth) const;
 
   /// Return true if this function can prove that \p Op is never poison
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c8088d68b6f1b..1f53742b2c6f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1457,8 +1457,10 @@ bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
                                              bool AssumeSingleUse) {
   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
   APInt KnownUndef, KnownZero;
-  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
-                                      TLO, 0, AssumeSingleUse))
+  APInt DoNotPoisonElts = APInt::getZero(DemandedElts.getBitWidth());
+  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, DoNotPoisonElts,
+                                      KnownUndef, KnownZero, TLO, 0,
+                                      AssumeSingleUse))
     return false;
 
   // Revisit the node.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 524c97ab3eab8..931154b922924 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -679,7 +679,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
 // TODO: Under what circumstances can we create nodes? Constant folding?
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-    SelectionDAG &DAG, unsigned Depth) const {
+    const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
   // Limit search depth.
@@ -713,9 +713,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
     if (NumSrcEltBits == NumDstEltBits)
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedBits, DemandedElts, DAG, Depth + 1))
+              Src, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
 
+    // FIXME: Handle DoNotPoisonEltMask better?
+
     if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
       unsigned Scale = NumDstEltBits / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
@@ -731,9 +734,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       // destination element, since recursive calls below may turn not demanded
       // elements into poison.
       APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      APInt DoNotPoisonSrcElts =
+          APIntOps::ScaleBitMask(DoNotPoisonEltMask, NumSrcElts);
 
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -743,15 +749,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
+      APInt DoNotPoisonSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != NumElts; ++i)
         if (DemandedElts[i]) {
           unsigned Offset = (i % Scale) * NumDstEltBits;
           DemandedSrcBits.insertBits(DemandedBits, Offset);
           DemandedSrcElts.setBit(i / Scale);
+        } else if (DoNotPoisonEltMask[i]) {
+          DoNotPoisonSrcElts.setBit(i / Scale);
         }
 
+      // FIXME: Handle DoNotPoisonEltMask better?
+
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -759,7 +771,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   }
   case ISD::FREEZE: {
     SDValue N0 = Op.getOperand(0);
-    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
+    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0,
+                                             DemandedElts | DoNotPoisonEltMask,
                                              /*PoisonOnly=*/false))
       return N0;
     break;
@@ -815,12 +828,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
-      unsigned NumSignBits =
-          DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+      unsigned NumSignBits = DAG.ComputeNumSignBits(
+          Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
       if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
         return Op0;
@@ -830,15 +843,15 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SRL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
       // Must already be signbits in DemandedBits bounds, and can't demand any
       // shifted in zeroes.
       if (DemandedBits.countl_zero() >= ShAmt) {
-        unsigned NumSignBits =
-            DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+        unsigned NumSignBits = DAG.ComputeNumSignBits(
+            Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
         if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
           return Op0;
       }
@@ -875,7 +888,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
         shouldRemoveRedundantExtend(Op))
       return Op0;
     // If the input is already sign extended, just drop the extension.
-    unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    unsigned NumSignBits = DAG.ComputeNumSignBits(
+        Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
     if (NumSignBits >= (BitWidth - ExBits + 1))
       return Op0;
     break;
@@ -891,7 +906,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     EVT DstVT = Op.getValueType();
-    if (IsLE && DemandedElts == 1 &&
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    if (IsLE && (DemandedElts | DoNotPoisonEltMask) == 1 &&
         DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
         DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
       return DAG.getBitcast(DstVT, Src);
@@ -906,8 +922,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Vec = Op.getOperand(0);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
     EVT VecVT = Vec.getValueType();
+    // FIXME: Handle DoNotPoisonEltMask better.
     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
-        !DemandedElts[CIdx->getZExtValue()])
+        !DemandedElts[CIdx->getZExtValue()] &&
+        !DoNotPoisonEltMask[CIdx->getZExtValue()])
       return Vec;
     break;
   }
@@ -920,8 +938,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     uint64_t Idx = Op.getConstantOperandVal(2);
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
+    APInt DoNotPoisonSubElts = DoNotPoisonEltMask.extractBits(NumSubElts, Idx);
     // If we don't demand the inserted subvector, return the base vector.
-    if (DemandedSubElts == 0)
+    // FIXME: Handle DoNotPoisonEltMask better.
+    if (DemandedSubElts == 0 && DoNotPoisonSubElts == 0)
       return Vec;
     break;
   }
@@ -934,9 +954,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
     for (unsigned i = 0; i != NumElts; ++i) {
       int M = ShuffleMask[i];
-      if (M < 0 || !DemandedElts[i])
+      if (M < 0 || (!DemandedElts[i] && !DoNotPoisonEltMask[i]))
         continue;
-      AllUndef = false;
+      if (DemandedElts[i])
+        AllUndef = false;
       IdentityLHS &= (M == (int)i);
       IdentityRHS &= ((M - NumElts) == i);
     }
@@ -957,13 +978,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
 
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
       if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
-              Op, DemandedBits, DemandedElts, DAG, Depth))
+              Op, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG, Depth))
         return V;
     break;
   }
   return SDValue();
 }
 
+SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
+    SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+    SelectionDAG &DAG, unsigned Depth) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
+}
+
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
     unsigned Depth) const {
@@ -974,13 +1003,14 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
-                                         Depth);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
 }
 
 SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
-    SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
-    unsigned Depth) const {
+    SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+    SelectionDAG &DAG, unsigned Depth) const {
   APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
                                          Depth);
@@ -2756,34 +2786,51 @@ bool TargetLowering::SimplifyDemandedBits(
                              TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
       }
     }
-
     // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
     // Demand the elt/bit if any of the original elts/bits are demanded.
     if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
       unsigned Scale = BitWidth / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
+      APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != Scale; ++i) {
         unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
         unsigned BitOffset = EltOffset * NumSrcEltBits;
         APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
-        if (!Sub.isZero())
+        if (!Sub.isZero()) {
           DemandedSrcBits |= Sub;
+          for (unsigned j = 0; j != NumElts; ++j)
+            if (DemandedElts[j])
+              DemandedSrcElts.setBit((j * Scale) + i);
+        }
       }
-      // Need to demand all smaller source elements that maps to a demanded
-      // destination element, since recursive calls below may turn not demanded
-      // elements into poison.
-      APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      // Need to "semi demand" all smaller source elements that maps to a
+      // demanded destination element, since recursive calls below may turn not
+      // demanded elements into poison. Instead of demanding such elements we
+      // use a special bitmask to indicate that the recursive calls must not
+      // turn such elements into poison.
+      APInt NoPoisonSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
 
       APInt KnownSrcUndef, KnownSrcZero;
-      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
-                                     KnownSrcZero, TLO, Depth + 1))
+      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, NoPoisonSrcElts,
+                                     KnownSrcUndef, KnownSrcZero, TLO,
+                                     Depth + 1))
         return true;
 
       KnownBits KnownSrcBits;
-      if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
-                               KnownSrcBits, TLO, Depth + 1))
+      if (SimplifyDemandedBits(Src, DemandedSrcBits,
+                               DemandedSrcElts | NoPoisonSrcElts, KnownSrcBits,
+                               TLO, Depth + 1))
         return true;
+
+      if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
+        if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
+                Src, DemandedSrcBits, DemandedSrcElts, NoPoisonSrcElts, TLO.DAG,
+                Depth + 1)) {
+          SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
+          return TLO.CombineTo(Op, NewOp);
+        }
+      }
     } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
       // TODO - bigendian once we have test coverage.
       unsigned Scale = NumSrcEltBits / BitWidth;
@@ -3090,8 +3137,9 @@ bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op,
                         !DCI.isBeforeLegalizeOps());
 
   APInt KnownUndef, KnownZero;
-  bool Simplified =
-      SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  bool Simplified = SimplifyDemandedVectorElts(
+      Op, DemandedElts, DoNotPoisonEltMask, KnownUndef, KnownZero, TLO);
   if (Simplified) {
     DCI.AddToWorklist(Op.getNode());
     DCI.CommitTargetLoweringOpt(TLO);
@@ -3152,6 +3200,16 @@ bool TargetLowering::SimplifyDemandedVectorElts(
     SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
     APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
     bool AssumeSingleUse) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(OriginalDemandedElts.getBitWidth());
+  return SimplifyDemandedVectorElts(Op, OriginalDemandedElts,
+                                    DoNotPoisonEltMask, KnownUndef, KnownZero,
+                                    TLO, Depth, AssumeSingleUse);
+}
+
+bool TargetLowering::SimplifyDemandedVectorElts(
+    SDValue Op, const APInt &OriginalDemandedElts,
+    const APInt &DoNotPoisonEltMask, APInt &KnownUndef, APInt &KnownZero,
+    TargetLoweringOpt &TLO, unsigned Depth, bool AssumeSingleUse) const {
   EVT VT = Op.getValueType();
   unsigned Opcode = Op.getOpcode();
   APInt DemandedElts = OriginalDemandedElts;
@@ -3190,6 +3248,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   if (Depth >= SelectionDAG::MaxRecursionDepth)
     return false;
 
+  APInt DemandedEltsInclDoNotPoison = DemandedElts | DoNotPoisonEltMask;
   SDLoc DL(Op);
   unsigned EltSizeInBits = VT.getScalarSizeInBits();
   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
@@ -3197,10 +3256,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   // Helper for demanding the specified elements and all the bits of both binary
   // operands.
   auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
-    SDValue NewOp0 = S...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Björn Pettersson (bjope)

Changes

This is currently several commits in one PR. Should perhaps be splitted in several pull requests.


Patch is 1.15 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145903.diff

112 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+24-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+225-125)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+117-84)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+9-12)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i1.ll (+278-324)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i16.ll (+154-184)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i8.ll (+445-517)
  • (modified) llvm/test/CodeGen/AMDGPU/shift-i128.ll (+7-9)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v2bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v3bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v4bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v2f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v3f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v4f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/ARM/fpclamptosat_vec.ll (+210-234)
  • (modified) llvm/test/CodeGen/Thumb2/mve-fpclamptosat_vec.ll (+38-42)
  • (modified) llvm/test/CodeGen/Thumb2/mve-gather-ind8-unscaled.ll (+5)
  • (modified) llvm/test/CodeGen/Thumb2/mve-laneinterleaving.ll (+41-45)
  • (modified) llvm/test/CodeGen/Thumb2/mve-pred-ext.ll (-1)
  • (modified) llvm/test/CodeGen/Thumb2/mve-satmul-loops.ll (+95-118)
  • (modified) llvm/test/CodeGen/Thumb2/mve-scatter-ind8-unscaled.ll (+7-2)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-addpred.ll (+4-4)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll (+4-4)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vst3.ll (+19-19)
  • (modified) llvm/test/CodeGen/X86/avx512-intrinsics-fast-isel.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/avx512-intrinsics-upgrade.ll (+14-32)
  • (modified) llvm/test/CodeGen/X86/avx512fp16-mov.ll (+16-16)
  • (modified) llvm/test/CodeGen/X86/avx512vl-intrinsics-upgrade.ll (+14-38)
  • (modified) llvm/test/CodeGen/X86/avx512vl-vec-masked-cmp.ll (+30-60)
  • (modified) llvm/test/CodeGen/X86/bitcast-and-setcc-128.ll (+12-16)
  • (modified) llvm/test/CodeGen/X86/bitcast-setcc-128.ll (+5-7)
  • (modified) llvm/test/CodeGen/X86/bitcast-vector-bool.ll (-4)
  • (modified) llvm/test/CodeGen/X86/buildvec-widen-dotproduct.ll (+15-19)
  • (modified) llvm/test/CodeGen/X86/combine-pmuldq.ll (+18-69)
  • (modified) llvm/test/CodeGen/X86/combine-sdiv.ll (+19-20)
  • (modified) llvm/test/CodeGen/X86/combine-sra.ll (+31-32)
  • (modified) llvm/test/CodeGen/X86/combine-udiv.ll (+1-4)
  • (modified) llvm/test/CodeGen/X86/extractelement-load.ll (+5-4)
  • (modified) llvm/test/CodeGen/X86/f16c-intrinsics-fast-isel.ll (-4)
  • (modified) llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll (+45-44)
  • (modified) llvm/test/CodeGen/X86/gfni-funnel-shifts.ll (+150-150)
  • (modified) llvm/test/CodeGen/X86/half.ll (+3-3)
  • (modified) llvm/test/CodeGen/X86/hoist-and-by-const-from-shl-in-eqcmp-zero.ll (+28-20)
  • (modified) llvm/test/CodeGen/X86/known-never-zero.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/known-pow2.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/known-signbits-shl.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/known-signbits-vector.ll (+7-23)
  • (modified) llvm/test/CodeGen/X86/masked_store.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/movmsk-cmp.ll (+8-35)
  • (modified) llvm/test/CodeGen/X86/mulvi32.ll (+7-8)
  • (modified) llvm/test/CodeGen/X86/omit-urem-of-power-of-two-or-zero-when-comparing-with-zero.ll (+11-14)
  • (modified) llvm/test/CodeGen/X86/pmul.ll (+66-74)
  • (modified) llvm/test/CodeGen/X86/pmulh.ll (+2-4)
  • (modified) llvm/test/CodeGen/X86/pr107423.ll (+13-13)
  • (modified) llvm/test/CodeGen/X86/pr35918.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/pr41619.ll (-2)
  • (modified) llvm/test/CodeGen/X86/pr42727.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr45563-2.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr45833.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr77459.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/psubus.ll (+123-132)
  • (modified) llvm/test/CodeGen/X86/rotate-extract-vector.ll (-2)
  • (modified) llvm/test/CodeGen/X86/sadd_sat_vec.ll (+93-114)
  • (modified) llvm/test/CodeGen/X86/sat-add.ll (+9-12)
  • (modified) llvm/test/CodeGen/X86/sdiv-exact.ll (+13-17)
  • (modified) llvm/test/CodeGen/X86/shrink_vmul.ll (+20-14)
  • (modified) llvm/test/CodeGen/X86/srem-seteq-vec-nonsplat.ll (+9-11)
  • (modified) llvm/test/CodeGen/X86/sshl_sat_vec.ll (+38-38)
  • (modified) llvm/test/CodeGen/X86/ssub_sat_vec.ll (+123-151)
  • (modified) llvm/test/CodeGen/X86/test-shrink-bug.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/udiv-exact.ll (+13-17)
  • (modified) llvm/test/CodeGen/X86/urem-seteq-illegal-types.ll (+6-4)
  • (modified) llvm/test/CodeGen/X86/urem-seteq-vec-nonsplat.ll (+122-172)
  • (modified) llvm/test/CodeGen/X86/ushl_sat_vec.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vec_minmax_sint.ll (+66-92)
  • (modified) llvm/test/CodeGen/X86/vec_minmax_uint.ll (+66-92)
  • (modified) llvm/test/CodeGen/X86/vec_smulo.ll (+32-38)
  • (modified) llvm/test/CodeGen/X86/vec_umulo.ll (+60-70)
  • (modified) llvm/test/CodeGen/X86/vector-compare-all_of.ll (+15-21)
  • (modified) llvm/test/CodeGen/X86/vector-compare-any_of.ll (+15-21)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-128.ll (+139-159)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-256.ll (+47-48)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-rot-128.ll (+37-56)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-rot-256.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-128.ll (+27-33)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-256.ll (+19-20)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-rot-128.ll (+28-33)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-rot-256.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll (+839-863)
  • (modified) llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-8.ll (+214-222)
  • (modified) llvm/test/CodeGen/X86/vector-mul.ll (+34-34)
  • (modified) llvm/test/CodeGen/X86/vector-pcmp.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-fmaximum.ll (+81-82)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-mul.ll (+57-113)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-smax.ll (+69-99)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-smin.ll (+65-96)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-umax.ll (+69-99)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-umin.ll (+65-96)
  • (modified) llvm/test/CodeGen/X86/vector-rotate-128.ll (+37-56)
  • (modified) llvm/test/CodeGen/X86/vector-rotate-256.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-128.ll (+24-28)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-256.ll (+20-22)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-sub128.ll (+48-56)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-avx.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-ssse3.ll (+6-3)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining.ll (+3-6)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-packus.ll (+806-945)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-ssat.ll (+609-734)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-usat.ll (+368-398)
  • (modified) llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll (+5-4)
  • (modified) llvm/test/CodeGen/X86/vselect.ll (+18-8)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 727526055e592..191f4cc78fcc5 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4187,6 +4187,16 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// More limited version of SimplifyDemandedBits that can be used to "look
   /// through" ops that don't contribute to the DemandedBits/DemandedElts -
   /// bitwise ops etc.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in the \p DoNotPoisonEltMask is set.
+  SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
+                                          const APInt &DemandedElts,
+                                          const APInt &DoNotPoisonEltMask,
+                                          SelectionDAG &DAG,
+                                          unsigned Depth = 0) const;
+
+  /// Helper wrapper around SimplifyMultipleUseDemandedBits, with
+  /// DoNotPoisonEltMask being set to zero.
   SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
                                           const APInt &DemandedElts,
                                           SelectionDAG &DAG,
@@ -4202,6 +4212,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bits from only some vector elements.
   SDValue SimplifyMultipleUseDemandedVectorElts(SDValue Op,
                                                 const APInt &DemandedElts,
+                                                const APInt &DoNotPoisonEltMask,
                                                 SelectionDAG &DAG,
                                                 unsigned Depth = 0) const;
 
@@ -4219,6 +4230,15 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   ///    results of this function, because simply replacing TLO.Old
   ///    with TLO.New will be incorrect when this parameter is true and TLO.Old
   ///    has multiple uses.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in \p DoNotPoisonEltMask is set.
+  bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
+                                  const APInt &DoNotPoisonEltMask,
+                                  APInt &KnownUndef, APInt &KnownZero,
+                                  TargetLoweringOpt &TLO, unsigned Depth = 0,
+                                  bool AssumeSingleUse = false) const;
+  /// Version of SimplifyDemandedVectorElts without the DoNotPoisonEltMask
+  /// argument. All undemanded elements can be turned into poison.
   bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
                                   APInt &KnownUndef, APInt &KnownZero,
                                   TargetLoweringOpt &TLO, unsigned Depth = 0,
@@ -4303,8 +4323,9 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// (used to simplify the caller). The KnownUndef/Zero elements may only be
   /// accurate for those bits in the DemandedMask.
   virtual bool SimplifyDemandedVectorEltsForTargetNode(
-      SDValue Op, const APInt &DemandedElts, APInt &KnownUndef,
-      APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth = 0) const;
+      SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+      APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO,
+      unsigned Depth = 0) const;
 
   /// Attempt to simplify any target nodes based on the demanded bits/elts,
   /// returning true on success. Otherwise, analyze the
@@ -4323,6 +4344,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bitwise ops etc.
   virtual SDValue SimplifyMultipleUseDemandedBitsForTargetNode(
       SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+      const APInt &DoNotPoisonEltMask,
       SelectionDAG &DAG, unsigned Depth) const;
 
   /// Return true if this function can prove that \p Op is never poison
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c8088d68b6f1b..1f53742b2c6f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1457,8 +1457,10 @@ bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
                                              bool AssumeSingleUse) {
   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
   APInt KnownUndef, KnownZero;
-  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
-                                      TLO, 0, AssumeSingleUse))
+  APInt DoNotPoisonElts = APInt::getZero(DemandedElts.getBitWidth());
+  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, DoNotPoisonElts,
+                                      KnownUndef, KnownZero, TLO, 0,
+                                      AssumeSingleUse))
     return false;
 
   // Revisit the node.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 524c97ab3eab8..931154b922924 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -679,7 +679,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
 // TODO: Under what circumstances can we create nodes? Constant folding?
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-    SelectionDAG &DAG, unsigned Depth) const {
+    const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
   // Limit search depth.
@@ -713,9 +713,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
     if (NumSrcEltBits == NumDstEltBits)
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedBits, DemandedElts, DAG, Depth + 1))
+              Src, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
 
+    // FIXME: Handle DoNotPoisonEltMask better?
+
     if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
       unsigned Scale = NumDstEltBits / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
@@ -731,9 +734,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       // destination element, since recursive calls below may turn not demanded
       // elements into poison.
       APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      APInt DoNotPoisonSrcElts =
+          APIntOps::ScaleBitMask(DoNotPoisonEltMask, NumSrcElts);
 
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -743,15 +749,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
+      APInt DoNotPoisonSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != NumElts; ++i)
         if (DemandedElts[i]) {
           unsigned Offset = (i % Scale) * NumDstEltBits;
           DemandedSrcBits.insertBits(DemandedBits, Offset);
           DemandedSrcElts.setBit(i / Scale);
+        } else if (DoNotPoisonEltMask[i]) {
+          DoNotPoisonSrcElts.setBit(i / Scale);
         }
 
+      // FIXME: Handle DoNotPoisonEltMask better?
+
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -759,7 +771,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   }
   case ISD::FREEZE: {
     SDValue N0 = Op.getOperand(0);
-    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
+    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0,
+                                             DemandedElts | DoNotPoisonEltMask,
                                              /*PoisonOnly=*/false))
       return N0;
     break;
@@ -815,12 +828,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
-      unsigned NumSignBits =
-          DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+      unsigned NumSignBits = DAG.ComputeNumSignBits(
+          Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
       if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
         return Op0;
@@ -830,15 +843,15 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SRL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
       // Must already be signbits in DemandedBits bounds, and can't demand any
       // shifted in zeroes.
       if (DemandedBits.countl_zero() >= ShAmt) {
-        unsigned NumSignBits =
-            DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+        unsigned NumSignBits = DAG.ComputeNumSignBits(
+            Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
         if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
           return Op0;
       }
@@ -875,7 +888,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
         shouldRemoveRedundantExtend(Op))
       return Op0;
     // If the input is already sign extended, just drop the extension.
-    unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    unsigned NumSignBits = DAG.ComputeNumSignBits(
+        Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
     if (NumSignBits >= (BitWidth - ExBits + 1))
       return Op0;
     break;
@@ -891,7 +906,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     EVT DstVT = Op.getValueType();
-    if (IsLE && DemandedElts == 1 &&
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    if (IsLE && (DemandedElts | DoNotPoisonEltMask) == 1 &&
         DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
         DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
       return DAG.getBitcast(DstVT, Src);
@@ -906,8 +922,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Vec = Op.getOperand(0);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
     EVT VecVT = Vec.getValueType();
+    // FIXME: Handle DoNotPoisonEltMask better.
     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
-        !DemandedElts[CIdx->getZExtValue()])
+        !DemandedElts[CIdx->getZExtValue()] &&
+        !DoNotPoisonEltMask[CIdx->getZExtValue()])
       return Vec;
     break;
   }
@@ -920,8 +938,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     uint64_t Idx = Op.getConstantOperandVal(2);
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
+    APInt DoNotPoisonSubElts = DoNotPoisonEltMask.extractBits(NumSubElts, Idx);
     // If we don't demand the inserted subvector, return the base vector.
-    if (DemandedSubElts == 0)
+    // FIXME: Handle DoNotPoisonEltMask better.
+    if (DemandedSubElts == 0 && DoNotPoisonSubElts == 0)
       return Vec;
     break;
   }
@@ -934,9 +954,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
     for (unsigned i = 0; i != NumElts; ++i) {
       int M = ShuffleMask[i];
-      if (M < 0 || !DemandedElts[i])
+      if (M < 0 || (!DemandedElts[i] && !DoNotPoisonEltMask[i]))
         continue;
-      AllUndef = false;
+      if (DemandedElts[i])
+        AllUndef = false;
       IdentityLHS &= (M == (int)i);
       IdentityRHS &= ((M - NumElts) == i);
     }
@@ -957,13 +978,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
 
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
       if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
-              Op, DemandedBits, DemandedElts, DAG, Depth))
+              Op, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG, Depth))
         return V;
     break;
   }
   return SDValue();
 }
 
+SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
+    SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+    SelectionDAG &DAG, unsigned Depth) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
+}
+
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
     unsigned Depth) const {
@@ -974,13 +1003,14 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
-                                         Depth);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
 }
 
 SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
-    SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
-    unsigned Depth) const {
+    SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+    SelectionDAG &DAG, unsigned Depth) const {
   APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
                                          Depth);
@@ -2756,34 +2786,51 @@ bool TargetLowering::SimplifyDemandedBits(
                              TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
       }
     }
-
     // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
     // Demand the elt/bit if any of the original elts/bits are demanded.
     if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
       unsigned Scale = BitWidth / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
+      APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != Scale; ++i) {
         unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
         unsigned BitOffset = EltOffset * NumSrcEltBits;
         APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
-        if (!Sub.isZero())
+        if (!Sub.isZero()) {
           DemandedSrcBits |= Sub;
+          for (unsigned j = 0; j != NumElts; ++j)
+            if (DemandedElts[j])
+              DemandedSrcElts.setBit((j * Scale) + i);
+        }
       }
-      // Need to demand all smaller source elements that maps to a demanded
-      // destination element, since recursive calls below may turn not demanded
-      // elements into poison.
-      APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      // Need to "semi demand" all smaller source elements that maps to a
+      // demanded destination element, since recursive calls below may turn not
+      // demanded elements into poison. Instead of demanding such elements we
+      // use a special bitmask to indicate that the recursive calls must not
+      // turn such elements into poison.
+      APInt NoPoisonSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
 
       APInt KnownSrcUndef, KnownSrcZero;
-      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
-                                     KnownSrcZero, TLO, Depth + 1))
+      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, NoPoisonSrcElts,
+                                     KnownSrcUndef, KnownSrcZero, TLO,
+                                     Depth + 1))
         return true;
 
       KnownBits KnownSrcBits;
-      if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
-                               KnownSrcBits, TLO, Depth + 1))
+      if (SimplifyDemandedBits(Src, DemandedSrcBits,
+                               DemandedSrcElts | NoPoisonSrcElts, KnownSrcBits,
+                               TLO, Depth + 1))
         return true;
+
+      if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
+        if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
+                Src, DemandedSrcBits, DemandedSrcElts, NoPoisonSrcElts, TLO.DAG,
+                Depth + 1)) {
+          SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
+          return TLO.CombineTo(Op, NewOp);
+        }
+      }
     } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
       // TODO - bigendian once we have test coverage.
       unsigned Scale = NumSrcEltBits / BitWidth;
@@ -3090,8 +3137,9 @@ bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op,
                         !DCI.isBeforeLegalizeOps());
 
   APInt KnownUndef, KnownZero;
-  bool Simplified =
-      SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  bool Simplified = SimplifyDemandedVectorElts(
+      Op, DemandedElts, DoNotPoisonEltMask, KnownUndef, KnownZero, TLO);
   if (Simplified) {
     DCI.AddToWorklist(Op.getNode());
     DCI.CommitTargetLoweringOpt(TLO);
@@ -3152,6 +3200,16 @@ bool TargetLowering::SimplifyDemandedVectorElts(
     SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
     APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
     bool AssumeSingleUse) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(OriginalDemandedElts.getBitWidth());
+  return SimplifyDemandedVectorElts(Op, OriginalDemandedElts,
+                                    DoNotPoisonEltMask, KnownUndef, KnownZero,
+                                    TLO, Depth, AssumeSingleUse);
+}
+
+bool TargetLowering::SimplifyDemandedVectorElts(
+    SDValue Op, const APInt &OriginalDemandedElts,
+    const APInt &DoNotPoisonEltMask, APInt &KnownUndef, APInt &KnownZero,
+    TargetLoweringOpt &TLO, unsigned Depth, bool AssumeSingleUse) const {
   EVT VT = Op.getValueType();
   unsigned Opcode = Op.getOpcode();
   APInt DemandedElts = OriginalDemandedElts;
@@ -3190,6 +3248,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   if (Depth >= SelectionDAG::MaxRecursionDepth)
     return false;
 
+  APInt DemandedEltsInclDoNotPoison = DemandedElts | DoNotPoisonEltMask;
   SDLoc DL(Op);
   unsigned EltSizeInBits = VT.getScalarSizeInBits();
   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
@@ -3197,10 +3256,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   // Helper for demanding the specified elements and all the bits of both binary
   // operands.
   auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
-    SDValue NewOp0 = S...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Björn Pettersson (bjope)

Changes

This is currently several commits in one PR. Should perhaps be splitted in several pull requests.


Patch is 1.15 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145903.diff

112 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+24-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+225-125)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+117-84)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+9-12)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i1.ll (+278-324)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i16.ll (+154-184)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i8.ll (+445-517)
  • (modified) llvm/test/CodeGen/AMDGPU/shift-i128.ll (+7-9)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v2bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v3bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v4bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v2f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v3f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v4f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/ARM/fpclamptosat_vec.ll (+210-234)
  • (modified) llvm/test/CodeGen/Thumb2/mve-fpclamptosat_vec.ll (+38-42)
  • (modified) llvm/test/CodeGen/Thumb2/mve-gather-ind8-unscaled.ll (+5)
  • (modified) llvm/test/CodeGen/Thumb2/mve-laneinterleaving.ll (+41-45)
  • (modified) llvm/test/CodeGen/Thumb2/mve-pred-ext.ll (-1)
  • (modified) llvm/test/CodeGen/Thumb2/mve-satmul-loops.ll (+95-118)
  • (modified) llvm/test/CodeGen/Thumb2/mve-scatter-ind8-unscaled.ll (+7-2)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-addpred.ll (+4-4)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll (+4-4)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vst3.ll (+19-19)
  • (modified) llvm/test/CodeGen/X86/avx512-intrinsics-fast-isel.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/avx512-intrinsics-upgrade.ll (+14-32)
  • (modified) llvm/test/CodeGen/X86/avx512fp16-mov.ll (+16-16)
  • (modified) llvm/test/CodeGen/X86/avx512vl-intrinsics-upgrade.ll (+14-38)
  • (modified) llvm/test/CodeGen/X86/avx512vl-vec-masked-cmp.ll (+30-60)
  • (modified) llvm/test/CodeGen/X86/bitcast-and-setcc-128.ll (+12-16)
  • (modified) llvm/test/CodeGen/X86/bitcast-setcc-128.ll (+5-7)
  • (modified) llvm/test/CodeGen/X86/bitcast-vector-bool.ll (-4)
  • (modified) llvm/test/CodeGen/X86/buildvec-widen-dotproduct.ll (+15-19)
  • (modified) llvm/test/CodeGen/X86/combine-pmuldq.ll (+18-69)
  • (modified) llvm/test/CodeGen/X86/combine-sdiv.ll (+19-20)
  • (modified) llvm/test/CodeGen/X86/combine-sra.ll (+31-32)
  • (modified) llvm/test/CodeGen/X86/combine-udiv.ll (+1-4)
  • (modified) llvm/test/CodeGen/X86/extractelement-load.ll (+5-4)
  • (modified) llvm/test/CodeGen/X86/f16c-intrinsics-fast-isel.ll (-4)
  • (modified) llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll (+45-44)
  • (modified) llvm/test/CodeGen/X86/gfni-funnel-shifts.ll (+150-150)
  • (modified) llvm/test/CodeGen/X86/half.ll (+3-3)
  • (modified) llvm/test/CodeGen/X86/hoist-and-by-const-from-shl-in-eqcmp-zero.ll (+28-20)
  • (modified) llvm/test/CodeGen/X86/known-never-zero.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/known-pow2.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/known-signbits-shl.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/known-signbits-vector.ll (+7-23)
  • (modified) llvm/test/CodeGen/X86/masked_store.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/movmsk-cmp.ll (+8-35)
  • (modified) llvm/test/CodeGen/X86/mulvi32.ll (+7-8)
  • (modified) llvm/test/CodeGen/X86/omit-urem-of-power-of-two-or-zero-when-comparing-with-zero.ll (+11-14)
  • (modified) llvm/test/CodeGen/X86/pmul.ll (+66-74)
  • (modified) llvm/test/CodeGen/X86/pmulh.ll (+2-4)
  • (modified) llvm/test/CodeGen/X86/pr107423.ll (+13-13)
  • (modified) llvm/test/CodeGen/X86/pr35918.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/pr41619.ll (-2)
  • (modified) llvm/test/CodeGen/X86/pr42727.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr45563-2.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr45833.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr77459.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/psubus.ll (+123-132)
  • (modified) llvm/test/CodeGen/X86/rotate-extract-vector.ll (-2)
  • (modified) llvm/test/CodeGen/X86/sadd_sat_vec.ll (+93-114)
  • (modified) llvm/test/CodeGen/X86/sat-add.ll (+9-12)
  • (modified) llvm/test/CodeGen/X86/sdiv-exact.ll (+13-17)
  • (modified) llvm/test/CodeGen/X86/shrink_vmul.ll (+20-14)
  • (modified) llvm/test/CodeGen/X86/srem-seteq-vec-nonsplat.ll (+9-11)
  • (modified) llvm/test/CodeGen/X86/sshl_sat_vec.ll (+38-38)
  • (modified) llvm/test/CodeGen/X86/ssub_sat_vec.ll (+123-151)
  • (modified) llvm/test/CodeGen/X86/test-shrink-bug.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/udiv-exact.ll (+13-17)
  • (modified) llvm/test/CodeGen/X86/urem-seteq-illegal-types.ll (+6-4)
  • (modified) llvm/test/CodeGen/X86/urem-seteq-vec-nonsplat.ll (+122-172)
  • (modified) llvm/test/CodeGen/X86/ushl_sat_vec.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vec_minmax_sint.ll (+66-92)
  • (modified) llvm/test/CodeGen/X86/vec_minmax_uint.ll (+66-92)
  • (modified) llvm/test/CodeGen/X86/vec_smulo.ll (+32-38)
  • (modified) llvm/test/CodeGen/X86/vec_umulo.ll (+60-70)
  • (modified) llvm/test/CodeGen/X86/vector-compare-all_of.ll (+15-21)
  • (modified) llvm/test/CodeGen/X86/vector-compare-any_of.ll (+15-21)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-128.ll (+139-159)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-256.ll (+47-48)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-rot-128.ll (+37-56)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-rot-256.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-128.ll (+27-33)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-256.ll (+19-20)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-rot-128.ll (+28-33)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-rot-256.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll (+839-863)
  • (modified) llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-8.ll (+214-222)
  • (modified) llvm/test/CodeGen/X86/vector-mul.ll (+34-34)
  • (modified) llvm/test/CodeGen/X86/vector-pcmp.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-fmaximum.ll (+81-82)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-mul.ll (+57-113)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-smax.ll (+69-99)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-smin.ll (+65-96)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-umax.ll (+69-99)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-umin.ll (+65-96)
  • (modified) llvm/test/CodeGen/X86/vector-rotate-128.ll (+37-56)
  • (modified) llvm/test/CodeGen/X86/vector-rotate-256.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-128.ll (+24-28)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-256.ll (+20-22)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-sub128.ll (+48-56)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-avx.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-ssse3.ll (+6-3)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining.ll (+3-6)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-packus.ll (+806-945)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-ssat.ll (+609-734)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-usat.ll (+368-398)
  • (modified) llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll (+5-4)
  • (modified) llvm/test/CodeGen/X86/vselect.ll (+18-8)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 727526055e592..191f4cc78fcc5 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4187,6 +4187,16 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// More limited version of SimplifyDemandedBits that can be used to "look
   /// through" ops that don't contribute to the DemandedBits/DemandedElts -
   /// bitwise ops etc.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in the \p DoNotPoisonEltMask is set.
+  SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
+                                          const APInt &DemandedElts,
+                                          const APInt &DoNotPoisonEltMask,
+                                          SelectionDAG &DAG,
+                                          unsigned Depth = 0) const;
+
+  /// Helper wrapper around SimplifyMultipleUseDemandedBits, with
+  /// DoNotPoisonEltMask being set to zero.
   SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
                                           const APInt &DemandedElts,
                                           SelectionDAG &DAG,
@@ -4202,6 +4212,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bits from only some vector elements.
   SDValue SimplifyMultipleUseDemandedVectorElts(SDValue Op,
                                                 const APInt &DemandedElts,
+                                                const APInt &DoNotPoisonEltMask,
                                                 SelectionDAG &DAG,
                                                 unsigned Depth = 0) const;
 
@@ -4219,6 +4230,15 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   ///    results of this function, because simply replacing TLO.Old
   ///    with TLO.New will be incorrect when this parameter is true and TLO.Old
   ///    has multiple uses.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in \p DoNotPoisonEltMask is set.
+  bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
+                                  const APInt &DoNotPoisonEltMask,
+                                  APInt &KnownUndef, APInt &KnownZero,
+                                  TargetLoweringOpt &TLO, unsigned Depth = 0,
+                                  bool AssumeSingleUse = false) const;
+  /// Version of SimplifyDemandedVectorElts without the DoNotPoisonEltMask
+  /// argument. All undemanded elements can be turned into poison.
   bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
                                   APInt &KnownUndef, APInt &KnownZero,
                                   TargetLoweringOpt &TLO, unsigned Depth = 0,
@@ -4303,8 +4323,9 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// (used to simplify the caller). The KnownUndef/Zero elements may only be
   /// accurate for those bits in the DemandedMask.
   virtual bool SimplifyDemandedVectorEltsForTargetNode(
-      SDValue Op, const APInt &DemandedElts, APInt &KnownUndef,
-      APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth = 0) const;
+      SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+      APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO,
+      unsigned Depth = 0) const;
 
   /// Attempt to simplify any target nodes based on the demanded bits/elts,
   /// returning true on success. Otherwise, analyze the
@@ -4323,6 +4344,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bitwise ops etc.
   virtual SDValue SimplifyMultipleUseDemandedBitsForTargetNode(
       SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+      const APInt &DoNotPoisonEltMask,
       SelectionDAG &DAG, unsigned Depth) const;
 
   /// Return true if this function can prove that \p Op is never poison
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c8088d68b6f1b..1f53742b2c6f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1457,8 +1457,10 @@ bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
                                              bool AssumeSingleUse) {
   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
   APInt KnownUndef, KnownZero;
-  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
-                                      TLO, 0, AssumeSingleUse))
+  APInt DoNotPoisonElts = APInt::getZero(DemandedElts.getBitWidth());
+  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, DoNotPoisonElts,
+                                      KnownUndef, KnownZero, TLO, 0,
+                                      AssumeSingleUse))
     return false;
 
   // Revisit the node.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 524c97ab3eab8..931154b922924 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -679,7 +679,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
 // TODO: Under what circumstances can we create nodes? Constant folding?
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-    SelectionDAG &DAG, unsigned Depth) const {
+    const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
   // Limit search depth.
@@ -713,9 +713,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
     if (NumSrcEltBits == NumDstEltBits)
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedBits, DemandedElts, DAG, Depth + 1))
+              Src, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
 
+    // FIXME: Handle DoNotPoisonEltMask better?
+
     if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
       unsigned Scale = NumDstEltBits / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
@@ -731,9 +734,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       // destination element, since recursive calls below may turn not demanded
       // elements into poison.
       APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      APInt DoNotPoisonSrcElts =
+          APIntOps::ScaleBitMask(DoNotPoisonEltMask, NumSrcElts);
 
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -743,15 +749,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
+      APInt DoNotPoisonSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != NumElts; ++i)
         if (DemandedElts[i]) {
           unsigned Offset = (i % Scale) * NumDstEltBits;
           DemandedSrcBits.insertBits(DemandedBits, Offset);
           DemandedSrcElts.setBit(i / Scale);
+        } else if (DoNotPoisonEltMask[i]) {
+          DoNotPoisonSrcElts.setBit(i / Scale);
         }
 
+      // FIXME: Handle DoNotPoisonEltMask better?
+
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -759,7 +771,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   }
   case ISD::FREEZE: {
     SDValue N0 = Op.getOperand(0);
-    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
+    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0,
+                                             DemandedElts | DoNotPoisonEltMask,
                                              /*PoisonOnly=*/false))
       return N0;
     break;
@@ -815,12 +828,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
-      unsigned NumSignBits =
-          DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+      unsigned NumSignBits = DAG.ComputeNumSignBits(
+          Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
       if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
         return Op0;
@@ -830,15 +843,15 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SRL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
       // Must already be signbits in DemandedBits bounds, and can't demand any
       // shifted in zeroes.
       if (DemandedBits.countl_zero() >= ShAmt) {
-        unsigned NumSignBits =
-            DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+        unsigned NumSignBits = DAG.ComputeNumSignBits(
+            Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
         if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
           return Op0;
       }
@@ -875,7 +888,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
         shouldRemoveRedundantExtend(Op))
       return Op0;
     // If the input is already sign extended, just drop the extension.
-    unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    unsigned NumSignBits = DAG.ComputeNumSignBits(
+        Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
     if (NumSignBits >= (BitWidth - ExBits + 1))
       return Op0;
     break;
@@ -891,7 +906,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     EVT DstVT = Op.getValueType();
-    if (IsLE && DemandedElts == 1 &&
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    if (IsLE && (DemandedElts | DoNotPoisonEltMask) == 1 &&
         DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
         DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
       return DAG.getBitcast(DstVT, Src);
@@ -906,8 +922,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Vec = Op.getOperand(0);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
     EVT VecVT = Vec.getValueType();
+    // FIXME: Handle DoNotPoisonEltMask better.
     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
-        !DemandedElts[CIdx->getZExtValue()])
+        !DemandedElts[CIdx->getZExtValue()] &&
+        !DoNotPoisonEltMask[CIdx->getZExtValue()])
       return Vec;
     break;
   }
@@ -920,8 +938,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     uint64_t Idx = Op.getConstantOperandVal(2);
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
+    APInt DoNotPoisonSubElts = DoNotPoisonEltMask.extractBits(NumSubElts, Idx);
     // If we don't demand the inserted subvector, return the base vector.
-    if (DemandedSubElts == 0)
+    // FIXME: Handle DoNotPoisonEltMask better.
+    if (DemandedSubElts == 0 && DoNotPoisonSubElts == 0)
       return Vec;
     break;
   }
@@ -934,9 +954,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
     for (unsigned i = 0; i != NumElts; ++i) {
       int M = ShuffleMask[i];
-      if (M < 0 || !DemandedElts[i])
+      if (M < 0 || (!DemandedElts[i] && !DoNotPoisonEltMask[i]))
         continue;
-      AllUndef = false;
+      if (DemandedElts[i])
+        AllUndef = false;
       IdentityLHS &= (M == (int)i);
       IdentityRHS &= ((M - NumElts) == i);
     }
@@ -957,13 +978,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
 
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
       if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
-              Op, DemandedBits, DemandedElts, DAG, Depth))
+              Op, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG, Depth))
         return V;
     break;
   }
   return SDValue();
 }
 
+SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
+    SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+    SelectionDAG &DAG, unsigned Depth) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
+}
+
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
     unsigned Depth) const {
@@ -974,13 +1003,14 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
-                                         Depth);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
 }
 
 SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
-    SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
-    unsigned Depth) const {
+    SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+    SelectionDAG &DAG, unsigned Depth) const {
   APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
                                          Depth);
@@ -2756,34 +2786,51 @@ bool TargetLowering::SimplifyDemandedBits(
                              TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
       }
     }
-
     // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
     // Demand the elt/bit if any of the original elts/bits are demanded.
     if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
       unsigned Scale = BitWidth / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
+      APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != Scale; ++i) {
         unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
         unsigned BitOffset = EltOffset * NumSrcEltBits;
         APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
-        if (!Sub.isZero())
+        if (!Sub.isZero()) {
           DemandedSrcBits |= Sub;
+          for (unsigned j = 0; j != NumElts; ++j)
+            if (DemandedElts[j])
+              DemandedSrcElts.setBit((j * Scale) + i);
+        }
       }
-      // Need to demand all smaller source elements that maps to a demanded
-      // destination element, since recursive calls below may turn not demanded
-      // elements into poison.
-      APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      // Need to "semi demand" all smaller source elements that maps to a
+      // demanded destination element, since recursive calls below may turn not
+      // demanded elements into poison. Instead of demanding such elements we
+      // use a special bitmask to indicate that the recursive calls must not
+      // turn such elements into poison.
+      APInt NoPoisonSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
 
       APInt KnownSrcUndef, KnownSrcZero;
-      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
-                                     KnownSrcZero, TLO, Depth + 1))
+      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, NoPoisonSrcElts,
+                                     KnownSrcUndef, KnownSrcZero, TLO,
+                                     Depth + 1))
         return true;
 
       KnownBits KnownSrcBits;
-      if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
-                               KnownSrcBits, TLO, Depth + 1))
+      if (SimplifyDemandedBits(Src, DemandedSrcBits,
+                               DemandedSrcElts | NoPoisonSrcElts, KnownSrcBits,
+                               TLO, Depth + 1))
         return true;
+
+      if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
+        if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
+                Src, DemandedSrcBits, DemandedSrcElts, NoPoisonSrcElts, TLO.DAG,
+                Depth + 1)) {
+          SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
+          return TLO.CombineTo(Op, NewOp);
+        }
+      }
     } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
       // TODO - bigendian once we have test coverage.
       unsigned Scale = NumSrcEltBits / BitWidth;
@@ -3090,8 +3137,9 @@ bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op,
                         !DCI.isBeforeLegalizeOps());
 
   APInt KnownUndef, KnownZero;
-  bool Simplified =
-      SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  bool Simplified = SimplifyDemandedVectorElts(
+      Op, DemandedElts, DoNotPoisonEltMask, KnownUndef, KnownZero, TLO);
   if (Simplified) {
     DCI.AddToWorklist(Op.getNode());
     DCI.CommitTargetLoweringOpt(TLO);
@@ -3152,6 +3200,16 @@ bool TargetLowering::SimplifyDemandedVectorElts(
     SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
     APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
     bool AssumeSingleUse) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(OriginalDemandedElts.getBitWidth());
+  return SimplifyDemandedVectorElts(Op, OriginalDemandedElts,
+                                    DoNotPoisonEltMask, KnownUndef, KnownZero,
+                                    TLO, Depth, AssumeSingleUse);
+}
+
+bool TargetLowering::SimplifyDemandedVectorElts(
+    SDValue Op, const APInt &OriginalDemandedElts,
+    const APInt &DoNotPoisonEltMask, APInt &KnownUndef, APInt &KnownZero,
+    TargetLoweringOpt &TLO, unsigned Depth, bool AssumeSingleUse) const {
   EVT VT = Op.getValueType();
   unsigned Opcode = Op.getOpcode();
   APInt DemandedElts = OriginalDemandedElts;
@@ -3190,6 +3248,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   if (Depth >= SelectionDAG::MaxRecursionDepth)
     return false;
 
+  APInt DemandedEltsInclDoNotPoison = DemandedElts | DoNotPoisonEltMask;
   SDLoc DL(Op);
   unsigned EltSizeInBits = VT.getScalarSizeInBits();
   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
@@ -3197,10 +3256,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   // Helper for demanding the specified elements and all the bits of both binary
   // operands.
   auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
-    SDValue NewOp0 = S...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-backend-arm

Author: Björn Pettersson (bjope)

Changes

This is currently several commits in one PR. Should perhaps be splitted in several pull requests.


Patch is 1.15 MiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145903.diff

112 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+24-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+4-2)
  • (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+225-125)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+117-84)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.h (+9-12)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i1.ll (+278-324)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i16.ll (+154-184)
  • (modified) llvm/test/CodeGen/AMDGPU/load-constant-i8.ll (+445-517)
  • (modified) llvm/test/CodeGen/AMDGPU/shift-i128.ll (+7-9)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v2bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v3bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3bf16.v4bf16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v2f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v3f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/AMDGPU/shufflevector.v3f16.v4f16.ll (+15-30)
  • (modified) llvm/test/CodeGen/ARM/fpclamptosat_vec.ll (+210-234)
  • (modified) llvm/test/CodeGen/Thumb2/mve-fpclamptosat_vec.ll (+38-42)
  • (modified) llvm/test/CodeGen/Thumb2/mve-gather-ind8-unscaled.ll (+5)
  • (modified) llvm/test/CodeGen/Thumb2/mve-laneinterleaving.ll (+41-45)
  • (modified) llvm/test/CodeGen/Thumb2/mve-pred-ext.ll (-1)
  • (modified) llvm/test/CodeGen/Thumb2/mve-satmul-loops.ll (+95-118)
  • (modified) llvm/test/CodeGen/Thumb2/mve-scatter-ind8-unscaled.ll (+7-2)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-addpred.ll (+4-4)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vecreduce-mlapred.ll (+4-4)
  • (modified) llvm/test/CodeGen/Thumb2/mve-vst3.ll (+19-19)
  • (modified) llvm/test/CodeGen/X86/avx512-intrinsics-fast-isel.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/avx512-intrinsics-upgrade.ll (+14-32)
  • (modified) llvm/test/CodeGen/X86/avx512fp16-mov.ll (+16-16)
  • (modified) llvm/test/CodeGen/X86/avx512vl-intrinsics-upgrade.ll (+14-38)
  • (modified) llvm/test/CodeGen/X86/avx512vl-vec-masked-cmp.ll (+30-60)
  • (modified) llvm/test/CodeGen/X86/bitcast-and-setcc-128.ll (+12-16)
  • (modified) llvm/test/CodeGen/X86/bitcast-setcc-128.ll (+5-7)
  • (modified) llvm/test/CodeGen/X86/bitcast-vector-bool.ll (-4)
  • (modified) llvm/test/CodeGen/X86/buildvec-widen-dotproduct.ll (+15-19)
  • (modified) llvm/test/CodeGen/X86/combine-pmuldq.ll (+18-69)
  • (modified) llvm/test/CodeGen/X86/combine-sdiv.ll (+19-20)
  • (modified) llvm/test/CodeGen/X86/combine-sra.ll (+31-32)
  • (modified) llvm/test/CodeGen/X86/combine-udiv.ll (+1-4)
  • (modified) llvm/test/CodeGen/X86/extractelement-load.ll (+5-4)
  • (modified) llvm/test/CodeGen/X86/f16c-intrinsics-fast-isel.ll (-4)
  • (modified) llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll (+45-44)
  • (modified) llvm/test/CodeGen/X86/gfni-funnel-shifts.ll (+150-150)
  • (modified) llvm/test/CodeGen/X86/half.ll (+3-3)
  • (modified) llvm/test/CodeGen/X86/hoist-and-by-const-from-shl-in-eqcmp-zero.ll (+28-20)
  • (modified) llvm/test/CodeGen/X86/known-never-zero.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/known-pow2.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/known-signbits-shl.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/known-signbits-vector.ll (+7-23)
  • (modified) llvm/test/CodeGen/X86/masked_store.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/movmsk-cmp.ll (+8-35)
  • (modified) llvm/test/CodeGen/X86/mulvi32.ll (+7-8)
  • (modified) llvm/test/CodeGen/X86/omit-urem-of-power-of-two-or-zero-when-comparing-with-zero.ll (+11-14)
  • (modified) llvm/test/CodeGen/X86/pmul.ll (+66-74)
  • (modified) llvm/test/CodeGen/X86/pmulh.ll (+2-4)
  • (modified) llvm/test/CodeGen/X86/pr107423.ll (+13-13)
  • (modified) llvm/test/CodeGen/X86/pr35918.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/pr41619.ll (-2)
  • (modified) llvm/test/CodeGen/X86/pr42727.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr45563-2.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr45833.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/pr77459.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/psubus.ll (+123-132)
  • (modified) llvm/test/CodeGen/X86/rotate-extract-vector.ll (-2)
  • (modified) llvm/test/CodeGen/X86/sadd_sat_vec.ll (+93-114)
  • (modified) llvm/test/CodeGen/X86/sat-add.ll (+9-12)
  • (modified) llvm/test/CodeGen/X86/sdiv-exact.ll (+13-17)
  • (modified) llvm/test/CodeGen/X86/shrink_vmul.ll (+20-14)
  • (modified) llvm/test/CodeGen/X86/srem-seteq-vec-nonsplat.ll (+9-11)
  • (modified) llvm/test/CodeGen/X86/sshl_sat_vec.ll (+38-38)
  • (modified) llvm/test/CodeGen/X86/ssub_sat_vec.ll (+123-151)
  • (modified) llvm/test/CodeGen/X86/test-shrink-bug.ll (+1-1)
  • (modified) llvm/test/CodeGen/X86/udiv-exact.ll (+13-17)
  • (modified) llvm/test/CodeGen/X86/urem-seteq-illegal-types.ll (+6-4)
  • (modified) llvm/test/CodeGen/X86/urem-seteq-vec-nonsplat.ll (+122-172)
  • (modified) llvm/test/CodeGen/X86/ushl_sat_vec.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vec_minmax_sint.ll (+66-92)
  • (modified) llvm/test/CodeGen/X86/vec_minmax_uint.ll (+66-92)
  • (modified) llvm/test/CodeGen/X86/vec_smulo.ll (+32-38)
  • (modified) llvm/test/CodeGen/X86/vec_umulo.ll (+60-70)
  • (modified) llvm/test/CodeGen/X86/vector-compare-all_of.ll (+15-21)
  • (modified) llvm/test/CodeGen/X86/vector-compare-any_of.ll (+15-21)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-128.ll (+139-159)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-256.ll (+47-48)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-rot-128.ll (+37-56)
  • (modified) llvm/test/CodeGen/X86/vector-fshl-rot-256.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-128.ll (+27-33)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-256.ll (+19-20)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-rot-128.ll (+28-33)
  • (modified) llvm/test/CodeGen/X86/vector-fshr-rot-256.ll (+2-2)
  • (modified) llvm/test/CodeGen/X86/vector-interleaved-load-i16-stride-7.ll (+839-863)
  • (modified) llvm/test/CodeGen/X86/vector-interleaved-store-i8-stride-8.ll (+214-222)
  • (modified) llvm/test/CodeGen/X86/vector-mul.ll (+34-34)
  • (modified) llvm/test/CodeGen/X86/vector-pcmp.ll (+1-2)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-fmaximum.ll (+81-82)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-mul.ll (+57-113)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-smax.ll (+69-99)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-smin.ll (+65-96)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-umax.ll (+69-99)
  • (modified) llvm/test/CodeGen/X86/vector-reduce-umin.ll (+65-96)
  • (modified) llvm/test/CodeGen/X86/vector-rotate-128.ll (+37-56)
  • (modified) llvm/test/CodeGen/X86/vector-rotate-256.ll (+14-15)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-128.ll (+24-28)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-256.ll (+20-22)
  • (modified) llvm/test/CodeGen/X86/vector-shift-shl-sub128.ll (+48-56)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-avx.ll (+6-6)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining-ssse3.ll (+6-3)
  • (modified) llvm/test/CodeGen/X86/vector-shuffle-combining.ll (+3-6)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-packus.ll (+806-945)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-ssat.ll (+609-734)
  • (modified) llvm/test/CodeGen/X86/vector-trunc-usat.ll (+368-398)
  • (modified) llvm/test/CodeGen/X86/vector_splat-const-shift-of-constmasked.ll (+5-4)
  • (modified) llvm/test/CodeGen/X86/vselect.ll (+18-8)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 727526055e592..191f4cc78fcc5 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4187,6 +4187,16 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// More limited version of SimplifyDemandedBits that can be used to "look
   /// through" ops that don't contribute to the DemandedBits/DemandedElts -
   /// bitwise ops etc.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in the \p DoNotPoisonEltMask is set.
+  SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
+                                          const APInt &DemandedElts,
+                                          const APInt &DoNotPoisonEltMask,
+                                          SelectionDAG &DAG,
+                                          unsigned Depth = 0) const;
+
+  /// Helper wrapper around SimplifyMultipleUseDemandedBits, with
+  /// DoNotPoisonEltMask being set to zero.
   SDValue SimplifyMultipleUseDemandedBits(SDValue Op, const APInt &DemandedBits,
                                           const APInt &DemandedElts,
                                           SelectionDAG &DAG,
@@ -4202,6 +4212,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bits from only some vector elements.
   SDValue SimplifyMultipleUseDemandedVectorElts(SDValue Op,
                                                 const APInt &DemandedElts,
+                                                const APInt &DoNotPoisonEltMask,
                                                 SelectionDAG &DAG,
                                                 unsigned Depth = 0) const;
 
@@ -4219,6 +4230,15 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   ///    results of this function, because simply replacing TLO.Old
   ///    with TLO.New will be incorrect when this parameter is true and TLO.Old
   ///    has multiple uses.
+  /// Vector elements that aren't demanded can be turned into poison unless the
+  /// corresponding bit in \p DoNotPoisonEltMask is set.
+  bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
+                                  const APInt &DoNotPoisonEltMask,
+                                  APInt &KnownUndef, APInt &KnownZero,
+                                  TargetLoweringOpt &TLO, unsigned Depth = 0,
+                                  bool AssumeSingleUse = false) const;
+  /// Version of SimplifyDemandedVectorElts without the DoNotPoisonEltMask
+  /// argument. All undemanded elements can be turned into poison.
   bool SimplifyDemandedVectorElts(SDValue Op, const APInt &DemandedEltMask,
                                   APInt &KnownUndef, APInt &KnownZero,
                                   TargetLoweringOpt &TLO, unsigned Depth = 0,
@@ -4303,8 +4323,9 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// (used to simplify the caller). The KnownUndef/Zero elements may only be
   /// accurate for those bits in the DemandedMask.
   virtual bool SimplifyDemandedVectorEltsForTargetNode(
-      SDValue Op, const APInt &DemandedElts, APInt &KnownUndef,
-      APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth = 0) const;
+      SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+      APInt &KnownUndef, APInt &KnownZero, TargetLoweringOpt &TLO,
+      unsigned Depth = 0) const;
 
   /// Attempt to simplify any target nodes based on the demanded bits/elts,
   /// returning true on success. Otherwise, analyze the
@@ -4323,6 +4344,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
   /// bitwise ops etc.
   virtual SDValue SimplifyMultipleUseDemandedBitsForTargetNode(
       SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+      const APInt &DoNotPoisonEltMask,
       SelectionDAG &DAG, unsigned Depth) const;
 
   /// Return true if this function can prove that \p Op is never poison
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index c8088d68b6f1b..1f53742b2c6f7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1457,8 +1457,10 @@ bool DAGCombiner::SimplifyDemandedVectorElts(SDValue Op,
                                              bool AssumeSingleUse) {
   TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations);
   APInt KnownUndef, KnownZero;
-  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero,
-                                      TLO, 0, AssumeSingleUse))
+  APInt DoNotPoisonElts = APInt::getZero(DemandedElts.getBitWidth());
+  if (!TLI.SimplifyDemandedVectorElts(Op, DemandedElts, DoNotPoisonElts,
+                                      KnownUndef, KnownZero, TLO, 0,
+                                      AssumeSingleUse))
     return false;
 
   // Revisit the node.
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 524c97ab3eab8..931154b922924 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -679,7 +679,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op, const APInt &DemandedBits,
 // TODO: Under what circumstances can we create nodes? Constant folding?
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-    SelectionDAG &DAG, unsigned Depth) const {
+    const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const {
   EVT VT = Op.getValueType();
 
   // Limit search depth.
@@ -713,9 +713,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     unsigned NumDstEltBits = DstVT.getScalarSizeInBits();
     if (NumSrcEltBits == NumDstEltBits)
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedBits, DemandedElts, DAG, Depth + 1))
+              Src, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
 
+    // FIXME: Handle DoNotPoisonEltMask better?
+
     if (SrcVT.isVector() && (NumDstEltBits % NumSrcEltBits) == 0) {
       unsigned Scale = NumDstEltBits / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
@@ -731,9 +734,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       // destination element, since recursive calls below may turn not demanded
       // elements into poison.
       APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      APInt DoNotPoisonSrcElts =
+          APIntOps::ScaleBitMask(DoNotPoisonEltMask, NumSrcElts);
 
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -743,15 +749,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
       unsigned NumSrcElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1;
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
       APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
+      APInt DoNotPoisonSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != NumElts; ++i)
         if (DemandedElts[i]) {
           unsigned Offset = (i % Scale) * NumDstEltBits;
           DemandedSrcBits.insertBits(DemandedBits, Offset);
           DemandedSrcElts.setBit(i / Scale);
+        } else if (DoNotPoisonEltMask[i]) {
+          DoNotPoisonSrcElts.setBit(i / Scale);
         }
 
+      // FIXME: Handle DoNotPoisonEltMask better?
+
       if (SDValue V = SimplifyMultipleUseDemandedBits(
-              Src, DemandedSrcBits, DemandedSrcElts, DAG, Depth + 1))
+              Src, DemandedSrcBits, DemandedSrcElts, DoNotPoisonSrcElts, DAG,
+              Depth + 1))
         return DAG.getBitcast(DstVT, V);
     }
 
@@ -759,7 +771,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   }
   case ISD::FREEZE: {
     SDValue N0 = Op.getOperand(0);
-    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0, DemandedElts,
+    if (DAG.isGuaranteedNotToBeUndefOrPoison(N0,
+                                             DemandedElts | DoNotPoisonEltMask,
                                              /*PoisonOnly=*/false))
       return N0;
     break;
@@ -815,12 +828,12 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SHL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
-      unsigned NumSignBits =
-          DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+      unsigned NumSignBits = DAG.ComputeNumSignBits(
+          Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
       unsigned UpperDemandedBits = BitWidth - DemandedBits.countr_zero();
       if (NumSignBits > ShAmt && (NumSignBits - ShAmt) >= (UpperDemandedBits))
         return Op0;
@@ -830,15 +843,15 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   case ISD::SRL: {
     // If we are only demanding sign bits then we can use the shift source
     // directly.
-    if (std::optional<uint64_t> MaxSA =
-            DAG.getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+    if (std::optional<uint64_t> MaxSA = DAG.getValidMaximumShiftAmount(
+            Op, DemandedElts | DoNotPoisonEltMask, Depth + 1)) {
       SDValue Op0 = Op.getOperand(0);
       unsigned ShAmt = *MaxSA;
       // Must already be signbits in DemandedBits bounds, and can't demand any
       // shifted in zeroes.
       if (DemandedBits.countl_zero() >= ShAmt) {
-        unsigned NumSignBits =
-            DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+        unsigned NumSignBits = DAG.ComputeNumSignBits(
+            Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
         if (DemandedBits.countr_zero() >= (BitWidth - NumSignBits))
           return Op0;
       }
@@ -875,7 +888,9 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
         shouldRemoveRedundantExtend(Op))
       return Op0;
     // If the input is already sign extended, just drop the extension.
-    unsigned NumSignBits = DAG.ComputeNumSignBits(Op0, DemandedElts, Depth + 1);
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    unsigned NumSignBits = DAG.ComputeNumSignBits(
+        Op0, DemandedElts | DoNotPoisonEltMask, Depth + 1);
     if (NumSignBits >= (BitWidth - ExBits + 1))
       return Op0;
     break;
@@ -891,7 +906,8 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Src = Op.getOperand(0);
     EVT SrcVT = Src.getValueType();
     EVT DstVT = Op.getValueType();
-    if (IsLE && DemandedElts == 1 &&
+    // FIXME: Can we skip DoNotPoisonEltMask here?
+    if (IsLE && (DemandedElts | DoNotPoisonEltMask) == 1 &&
         DstVT.getSizeInBits() == SrcVT.getSizeInBits() &&
         DemandedBits.getActiveBits() <= SrcVT.getScalarSizeInBits()) {
       return DAG.getBitcast(DstVT, Src);
@@ -906,8 +922,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Vec = Op.getOperand(0);
     auto *CIdx = dyn_cast<ConstantSDNode>(Op.getOperand(2));
     EVT VecVT = Vec.getValueType();
+    // FIXME: Handle DoNotPoisonEltMask better.
     if (CIdx && CIdx->getAPIntValue().ult(VecVT.getVectorNumElements()) &&
-        !DemandedElts[CIdx->getZExtValue()])
+        !DemandedElts[CIdx->getZExtValue()] &&
+        !DoNotPoisonEltMask[CIdx->getZExtValue()])
       return Vec;
     break;
   }
@@ -920,8 +938,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     uint64_t Idx = Op.getConstantOperandVal(2);
     unsigned NumSubElts = Sub.getValueType().getVectorNumElements();
     APInt DemandedSubElts = DemandedElts.extractBits(NumSubElts, Idx);
+    APInt DoNotPoisonSubElts = DoNotPoisonEltMask.extractBits(NumSubElts, Idx);
     // If we don't demand the inserted subvector, return the base vector.
-    if (DemandedSubElts == 0)
+    // FIXME: Handle DoNotPoisonEltMask better.
+    if (DemandedSubElts == 0 && DoNotPoisonSubElts == 0)
       return Vec;
     break;
   }
@@ -934,9 +954,10 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     bool AllUndef = true, IdentityLHS = true, IdentityRHS = true;
     for (unsigned i = 0; i != NumElts; ++i) {
       int M = ShuffleMask[i];
-      if (M < 0 || !DemandedElts[i])
+      if (M < 0 || (!DemandedElts[i] && !DoNotPoisonEltMask[i]))
         continue;
-      AllUndef = false;
+      if (DemandedElts[i])
+        AllUndef = false;
       IdentityLHS &= (M == (int)i);
       IdentityRHS &= ((M - NumElts) == i);
     }
@@ -957,13 +978,21 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
 
     if (Op.getOpcode() >= ISD::BUILTIN_OP_END)
       if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode(
-              Op, DemandedBits, DemandedElts, DAG, Depth))
+              Op, DemandedBits, DemandedElts, DoNotPoisonEltMask, DAG, Depth))
         return V;
     break;
   }
   return SDValue();
 }
 
+SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
+    SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+    SelectionDAG &DAG, unsigned Depth) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
+}
+
 SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
     SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG,
     unsigned Depth) const {
@@ -974,13 +1003,14 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedBits(
   APInt DemandedElts = VT.isFixedLengthVector()
                            ? APInt::getAllOnes(VT.getVectorNumElements())
                            : APInt(1, 1);
-  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
-                                         Depth);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts,
+                                         DoNotPoisonEltMask, DAG, Depth);
 }
 
 SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
-    SDValue Op, const APInt &DemandedElts, SelectionDAG &DAG,
-    unsigned Depth) const {
+    SDValue Op, const APInt &DemandedElts, const APInt &DoNotPoisonEltMask,
+    SelectionDAG &DAG, unsigned Depth) const {
   APInt DemandedBits = APInt::getAllOnes(Op.getScalarValueSizeInBits());
   return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG,
                                          Depth);
@@ -2756,34 +2786,51 @@ bool TargetLowering::SimplifyDemandedBits(
                              TLO.DAG.getNode(ISD::SHL, dl, VT, Sign, ShAmt));
       }
     }
-
     // Bitcast from a vector using SimplifyDemanded Bits/VectorElts.
     // Demand the elt/bit if any of the original elts/bits are demanded.
     if (SrcVT.isVector() && (BitWidth % NumSrcEltBits) == 0) {
       unsigned Scale = BitWidth / NumSrcEltBits;
       unsigned NumSrcElts = SrcVT.getVectorNumElements();
       APInt DemandedSrcBits = APInt::getZero(NumSrcEltBits);
+      APInt DemandedSrcElts = APInt::getZero(NumSrcElts);
       for (unsigned i = 0; i != Scale; ++i) {
         unsigned EltOffset = IsLE ? i : (Scale - 1 - i);
         unsigned BitOffset = EltOffset * NumSrcEltBits;
         APInt Sub = DemandedBits.extractBits(NumSrcEltBits, BitOffset);
-        if (!Sub.isZero())
+        if (!Sub.isZero()) {
           DemandedSrcBits |= Sub;
+          for (unsigned j = 0; j != NumElts; ++j)
+            if (DemandedElts[j])
+              DemandedSrcElts.setBit((j * Scale) + i);
+        }
       }
-      // Need to demand all smaller source elements that maps to a demanded
-      // destination element, since recursive calls below may turn not demanded
-      // elements into poison.
-      APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+      // Need to "semi demand" all smaller source elements that maps to a
+      // demanded destination element, since recursive calls below may turn not
+      // demanded elements into poison. Instead of demanding such elements we
+      // use a special bitmask to indicate that the recursive calls must not
+      // turn such elements into poison.
+      APInt NoPoisonSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
 
       APInt KnownSrcUndef, KnownSrcZero;
-      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, KnownSrcUndef,
-                                     KnownSrcZero, TLO, Depth + 1))
+      if (SimplifyDemandedVectorElts(Src, DemandedSrcElts, NoPoisonSrcElts,
+                                     KnownSrcUndef, KnownSrcZero, TLO,
+                                     Depth + 1))
         return true;
 
       KnownBits KnownSrcBits;
-      if (SimplifyDemandedBits(Src, DemandedSrcBits, DemandedSrcElts,
-                               KnownSrcBits, TLO, Depth + 1))
+      if (SimplifyDemandedBits(Src, DemandedSrcBits,
+                               DemandedSrcElts | NoPoisonSrcElts, KnownSrcBits,
+                               TLO, Depth + 1))
         return true;
+
+      if (!DemandedSrcBits.isAllOnes() || !DemandedSrcElts.isAllOnes()) {
+        if (SDValue DemandedSrc = SimplifyMultipleUseDemandedBits(
+                Src, DemandedSrcBits, DemandedSrcElts, NoPoisonSrcElts, TLO.DAG,
+                Depth + 1)) {
+          SDValue NewOp = TLO.DAG.getBitcast(VT, DemandedSrc);
+          return TLO.CombineTo(Op, NewOp);
+        }
+      }
     } else if (IsLE && (NumSrcEltBits % BitWidth) == 0) {
       // TODO - bigendian once we have test coverage.
       unsigned Scale = NumSrcEltBits / BitWidth;
@@ -3090,8 +3137,9 @@ bool TargetLowering::SimplifyDemandedVectorElts(SDValue Op,
                         !DCI.isBeforeLegalizeOps());
 
   APInt KnownUndef, KnownZero;
-  bool Simplified =
-      SimplifyDemandedVectorElts(Op, DemandedElts, KnownUndef, KnownZero, TLO);
+  APInt DoNotPoisonEltMask = APInt::getZero(DemandedElts.getBitWidth());
+  bool Simplified = SimplifyDemandedVectorElts(
+      Op, DemandedElts, DoNotPoisonEltMask, KnownUndef, KnownZero, TLO);
   if (Simplified) {
     DCI.AddToWorklist(Op.getNode());
     DCI.CommitTargetLoweringOpt(TLO);
@@ -3152,6 +3200,16 @@ bool TargetLowering::SimplifyDemandedVectorElts(
     SDValue Op, const APInt &OriginalDemandedElts, APInt &KnownUndef,
     APInt &KnownZero, TargetLoweringOpt &TLO, unsigned Depth,
     bool AssumeSingleUse) const {
+  APInt DoNotPoisonEltMask = APInt::getZero(OriginalDemandedElts.getBitWidth());
+  return SimplifyDemandedVectorElts(Op, OriginalDemandedElts,
+                                    DoNotPoisonEltMask, KnownUndef, KnownZero,
+                                    TLO, Depth, AssumeSingleUse);
+}
+
+bool TargetLowering::SimplifyDemandedVectorElts(
+    SDValue Op, const APInt &OriginalDemandedElts,
+    const APInt &DoNotPoisonEltMask, APInt &KnownUndef, APInt &KnownZero,
+    TargetLoweringOpt &TLO, unsigned Depth, bool AssumeSingleUse) const {
   EVT VT = Op.getValueType();
   unsigned Opcode = Op.getOpcode();
   APInt DemandedElts = OriginalDemandedElts;
@@ -3190,6 +3248,7 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   if (Depth >= SelectionDAG::MaxRecursionDepth)
     return false;
 
+  APInt DemandedEltsInclDoNotPoison = DemandedElts | DoNotPoisonEltMask;
   SDLoc DL(Op);
   unsigned EltSizeInBits = VT.getScalarSizeInBits();
   bool IsLE = TLO.DAG.getDataLayout().isLittleEndian();
@@ -3197,10 +3256,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
   // Helper for demanding the specified elements and all the bits of both binary
   // operands.
   auto SimplifyDemandedVectorEltsBinOp = [&](SDValue Op0, SDValue Op1) {
-    SDValue NewOp0 = S...
[truncated]

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff HEAD~1 HEAD --extensions h,cpp -- llvm/include/llvm/CodeGen/TargetLowering.h llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp llvm/lib/Target/X86/X86ISelLowering.cpp llvm/lib/Target/X86/X86ISelLowering.h
View the diff from clang-format here.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 191f4cc78..5e60ec56f 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4344,8 +4344,7 @@ public:
   /// bitwise ops etc.
   virtual SDValue SimplifyMultipleUseDemandedBitsForTargetNode(
       SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-      const APInt &DoNotPoisonEltMask,
-      SelectionDAG &DAG, unsigned Depth) const;
+      const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const;
 
   /// Return true if this function can prove that \p Op is never poison
   /// and, if \p PoisonOnly is false, does not have undef bits. The DemandedElts
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 931154b92..2ac90f390 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -4055,15 +4055,13 @@ bool TargetLowering::SimplifyDemandedBitsForTargetNode(
           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
          "Should use SimplifyDemandedBits if you don't know whether Op"
          " is a target node!");
-  computeKnownBitsForTargetNode(Op, Known, DemandedElts,
-                                TLO.DAG, Depth);
+  computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
   return false;
 }
 
 SDValue TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-    const APInt &DoNotPoisonEltMask,
-    SelectionDAG &DAG, unsigned Depth) const {
+    const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const {
   assert(
       (Op.getOpcode() >= ISD::BUILTIN_OP_END ||
        Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN ||
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 528e75246..e30a85f76 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -44959,8 +44959,7 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
 
 SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
     SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-    const APInt &DoNotPoisonEltMask,
-    SelectionDAG &DAG, unsigned Depth) const {
+    const APInt &DoNotPoisonEltMask, SelectionDAG &DAG, unsigned Depth) const {
   int NumElts = DemandedElts.getBitWidth();
   unsigned Opc = Op.getOpcode();
   EVT VT = Op.getValueType();
@@ -45009,7 +45008,8 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
     SDValue LHS = Op.getOperand(1);
     SDValue RHS = Op.getOperand(2);
 
-    KnownBits CondKnown = DAG.computeKnownBits(Cond, DemandedElts | DoNotPoisonEltMask, Depth + 1);
+    KnownBits CondKnown = DAG.computeKnownBits(
+        Cond, DemandedElts | DoNotPoisonEltMask, Depth + 1);
     if (CondKnown.isNegative())
       return LHS;
     if (CondKnown.isNonNegative())
@@ -45021,8 +45021,10 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
     SDValue LHS = Op.getOperand(0);
     SDValue RHS = Op.getOperand(1);
 
-    KnownBits LHSKnown = DAG.computeKnownBits(LHS, DemandedElts | DoNotPoisonEltMask, Depth + 1);
-    KnownBits RHSKnown = DAG.computeKnownBits(RHS, DemandedElts | DoNotPoisonEltMask, Depth + 1);
+    KnownBits LHSKnown =
+        DAG.computeKnownBits(LHS, DemandedElts | DoNotPoisonEltMask, Depth + 1);
+    KnownBits RHSKnown =
+        DAG.computeKnownBits(RHS, DemandedElts | DoNotPoisonEltMask, Depth + 1);
 
     // If all of the demanded bits are known 0 on LHS and known 0 on RHS, then
     // the (inverted) LHS bits cannot contribute to the result of the 'andn' in
@@ -45036,9 +45038,9 @@ SDValue X86TargetLowering::SimplifyMultipleUseDemandedBitsForTargetNode(
   APInt ShuffleUndef, ShuffleZero;
   SmallVector<int, 16> ShuffleMask;
   SmallVector<SDValue, 2> ShuffleOps;
-  if (getTargetShuffleInputs(Op, DemandedElts | DoNotPoisonEltMask,
-                             ShuffleOps, ShuffleMask,
-                             ShuffleUndef, ShuffleZero, DAG, Depth, false)) {
+  if (getTargetShuffleInputs(Op, DemandedElts | DoNotPoisonEltMask, ShuffleOps,
+                             ShuffleMask, ShuffleUndef, ShuffleZero, DAG, Depth,
+                             false)) {
     // If all the demanded elts are from one operand and are inline,
     // then we can use the operand directly.
     int NumOps = ShuffleOps.size();
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 81a6e480d..ca1682017 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1322,8 +1322,8 @@ namespace llvm {
 
     SDValue SimplifyMultipleUseDemandedBitsForTargetNode(
         SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
-        const APInt &DoNotPoisonEltMask,
-        SelectionDAG &DAG, unsigned Depth) const override;
+        const APInt &DoNotPoisonEltMask, SelectionDAG &DAG,
+        unsigned Depth) const override;
 
     bool isGuaranteedNotToBeUndefOrPoisonForTargetNode(
         SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants