Skip to content

Commit ed89e3b

Browse files
committed
[DAG] Handle truncated splat in isBoolConstant
This allows truncated splat / buildvector in isBoolConstant, to allow certain not instructions to be recognized post-legalization, and allow vselect to optimize. An override for x86 avx512 predicated vectors is required to avoid an infinite recursion from the code that detects zero vectors. From: ``` // Check if the first operand is all zeros and Cond type is vXi1. // If this an avx512 target we can improve the use of zero masking by // swapping the operands and inverting the condition. ```
1 parent 8d9911e commit ed89e3b

File tree

13 files changed

+2422
-2683
lines changed

13 files changed

+2422
-2683
lines changed

llvm/include/llvm/CodeGen/SelectionDAG.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,8 +2479,7 @@ class SelectionDAG {
24792479

24802480
/// Check if a value \op N is a constant using the target's BooleanContent for
24812481
/// its type.
2482-
LLVM_ABI std::optional<bool>
2483-
isBoolConstant(SDValue N, bool AllowTruncation = false) const;
2482+
LLVM_ABI std::optional<bool> isBoolConstant(SDValue N) const;
24842483

24852484
/// Set CallSiteInfo to be associated with Node.
24862485
void addCallSiteInfo(const SDNode *Node, CallSiteInfo &&CallInfo) {

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4375,6 +4375,8 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
43754375
Op.getOpcode() == ISD::SPLAT_VECTOR_PARTS;
43764376
}
43774377

4378+
virtual bool isTargetCanonicalSelect(SDNode *N) const { return false; }
4379+
43784380
struct DAGCombinerInfo {
43794381
void *DC; // The DAG Combiner object.
43804382
CombineLevel Level;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12967,8 +12967,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
1296712967
return V;
1296812968

1296912969
// vselect (not Cond), N1, N2 -> vselect Cond, N2, N1
12970-
if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
12971-
return DAG.getSelect(DL, VT, F, N2, N1);
12970+
if (!TLI.isTargetCanonicalSelect(N))
12971+
if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false))
12972+
return DAG.getSelect(DL, VT, F, N2, N1);
1297212973

1297312974
// select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
1297412975
if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10349,7 +10349,7 @@ SDValue SelectionDAG::simplifySelect(SDValue Cond, SDValue T, SDValue F) {
1034910349

1035010350
// select true, T, F --> T
1035110351
// select false, T, F --> F
10352-
if (auto C = isBoolConstant(Cond, /*AllowTruncation=*/true))
10352+
if (auto C = isBoolConstant(Cond))
1035310353
return *C ? T : F;
1035410354

1035510355
// select ?, T, T --> T
@@ -13562,13 +13562,14 @@ bool SelectionDAG::isConstantFPBuildVectorOrConstantFP(SDValue N) const {
1356213562
return false;
1356313563
}
1356413564

13565-
std::optional<bool> SelectionDAG::isBoolConstant(SDValue N,
13566-
bool AllowTruncation) const {
13567-
ConstantSDNode *Const = isConstOrConstSplat(N, false, AllowTruncation);
13565+
std::optional<bool> SelectionDAG::isBoolConstant(SDValue N) const {
13566+
ConstantSDNode *Const =
13567+
isConstOrConstSplat(N, false, /*AllowTruncation=*/true);
1356813568
if (!Const)
1356913569
return std::nullopt;
1357013570

13571-
const APInt &CVal = Const->getAPIntValue();
13571+
EVT VT = N->getValueType(0);
13572+
const APInt CVal = Const->getAPIntValue().trunc(VT.getScalarSizeInBits());
1357213573
switch (TLI->getBooleanContents(N.getValueType())) {
1357313574
case TargetLowering::ZeroOrOneBooleanContent:
1357413575
if (CVal.isOne())

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4975,6 +4975,15 @@ X86TargetLowering::getTargetConstantFromLoad(LoadSDNode *LD) const {
49754975
return getTargetConstantFromNode(LD);
49764976
}
49774977

4978+
bool X86TargetLowering::isTargetCanonicalSelect(SDNode *N) const {
4979+
SDValue Cond = N->getOperand(0);
4980+
SDValue RHS = N->getOperand(2);
4981+
EVT CondVT = Cond.getValueType();
4982+
return N->getOpcode() == ISD::VSELECT && Subtarget.hasAVX512() &&
4983+
CondVT.getVectorElementType() == MVT::i1 &&
4984+
ISD::isBuildVectorAllZeros(RHS.getNode());
4985+
}
4986+
49784987
// Extract raw constant bits from constant pools.
49794988
static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
49804989
APInt &UndefElts,

llvm/lib/Target/X86/X86ISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,8 @@ namespace llvm {
13561356
TargetLowering::isTargetCanonicalConstantNode(Op);
13571357
}
13581358

1359+
bool isTargetCanonicalSelect(SDNode *N) const override;
1360+
13591361
const Constant *getTargetConstantFromLoad(LoadSDNode *LD) const override;
13601362

13611363
SDValue unwrapAddress(SDValue N) const override;

0 commit comments

Comments
 (0)