Skip to content

Commit 58b1ef9

Browse files
committed
[SelectionDAG] Deal with POISON for INSERT_VECTOR_ELT/INSERT_SUBVECTOR (part 1)
As reported in #141034 SelectionDAG::getNode had some unexpected behaviors when trying to create vectors with UNDEF elements. Since we treat both UNDEF and POISON as undefined (when using isUndef()) we can't just fold away INSERT_VECTOR_ELT/INSERT_SUBVECTOR based on isUndef(), as that could make the resulting vector more poisonous. Same kind of bug existed in DAGCombiner::visitINSERT_SUBVECTOR. Here are some examples: This fold was done even if vec[idx] was POISON: INSERT_VECTOR_ELT vec, UNDEF, idx -> vec This fold was done even if any of vec[idx..idx+size] was POISON: INSERT_SUBVECTOR vec, UNDEF, idx -> vec This fold was done even if the elements not extracted from vec could be POISON: sub = EXTRACT_SUBVECTOR vec, idx INSERT_SUBVECTOR UNDEF, sub, idx -> vec With this patch we avoid such folds unless we can prove that the result isn't more poisonous when eliminating the insert. This patch in itself result in some regressions. Goal is to try to deal with those regressions in follow up commits. Fixes #141034
1 parent dac0820 commit 58b1ef9

18 files changed

+766
-248
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23003,6 +23003,7 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
2300323003
auto *IndexC = dyn_cast<ConstantSDNode>(EltNo);
2300423004

2300523005
// Insert into out-of-bounds element is undefined.
23006+
// Code below relies on that we handle this special case early.
2300623007
if (IndexC && VT.isFixedLengthVector() &&
2300723008
IndexC->getZExtValue() >= VT.getVectorNumElements())
2300823009
return DAG.getUNDEF(VT);
@@ -23013,14 +23014,28 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
2301323014
InVec == InVal.getOperand(0) && EltNo == InVal.getOperand(1))
2301423015
return InVec;
2301523016

23016-
if (!IndexC) {
23017-
// If this is variable insert to undef vector, it might be better to splat:
23018-
// inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23019-
if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23020-
return DAG.getSplat(VT, DL, InVal);
23021-
return SDValue();
23017+
// If this is variable insert to undef vector, it might be better to splat:
23018+
// inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
23019+
if (!IndexC && InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
23020+
return DAG.getSplat(VT, DL, InVal);
23021+
23022+
// Try to drop insert of UNDEF/POISON elements. This is also done in getNode,
23023+
// but we also do it as a DAG combine since for example simplifications into
23024+
// SPLAT_VECTOR/BUILD_VECTOR may turn poison elements into undef/zero etc, and
23025+
// then suddenly the InVec is guaranteed to not be poison.
23026+
if (InVal.isUndef()) {
23027+
if (IndexC && VT.isFixedLengthVector()) {
23028+
APInt EltMask = APInt::getOneBitSet(VT.getVectorNumElements(),
23029+
IndexC->getZExtValue());
23030+
if (DAG.isGuaranteedNotToBePoison(InVec, EltMask))
23031+
return InVec;
23032+
}
23033+
return DAG.getFreeze(InVec);
2302223034
}
2302323035

23036+
if (!IndexC)
23037+
return SDValue();
23038+
2302423039
if (VT.isScalableVector())
2302523040
return SDValue();
2302623041

@@ -27453,18 +27468,42 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) {
2745327468
SDValue N2 = N->getOperand(2);
2745427469
uint64_t InsIdx = N->getConstantOperandVal(2);
2745527470

27456-
// If inserting an UNDEF, just return the original vector.
27457-
if (N1.isUndef())
27458-
return N0;
27471+
// If inserting an UNDEF, just return the original vector (unless it makes the
27472+
// result more poisonous).
27473+
if (N1.isUndef()) {
27474+
if (N1.getOpcode() == ISD::POISON)
27475+
return N0;
27476+
if (VT.isFixedLengthVector()) {
27477+
unsigned SubVecNumElts = N1.getValueType().getVectorNumElements();
27478+
APInt EltMask = APInt::getBitsSet(VT.getVectorNumElements(), InsIdx,
27479+
InsIdx + SubVecNumElts);
27480+
if (DAG.isGuaranteedNotToBePoison(N0, EltMask))
27481+
return N0;
27482+
}
27483+
return DAG.getFreeze(N0);
27484+
}
2745927485

27460-
// If this is an insert of an extracted vector into an undef vector, we can
27461-
// just use the input to the extract if the types match, and can simplify
27486+
// If this is an insert of an extracted vector into an undef/poison vector, we
27487+
// can just use the input to the extract if the types match, and can simplify
2746227488
// in some cases even if they don't.
2746327489
if (N0.isUndef() && N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
2746427490
N1.getOperand(1) == N2) {
27491+
EVT N1VT = N1.getValueType();
2746527492
EVT SrcVT = N1.getOperand(0).getValueType();
27466-
if (SrcVT == VT)
27467-
return N1.getOperand(0);
27493+
if (SrcVT == VT) {
27494+
// Need to ensure that result isn't more poisonous if skipping both the
27495+
// extract+insert.
27496+
if (N0.getOpcode() == ISD::POISON)
27497+
return N1.getOperand(0);
27498+
if (VT.isFixedLengthVector() && N1VT.isFixedLengthVector()) {
27499+
unsigned SubVecNumElts = N1VT.getVectorNumElements();
27500+
APInt EltMask = APInt::getBitsSet(VT.getVectorNumElements(), InsIdx,
27501+
InsIdx + SubVecNumElts);
27502+
if (DAG.isGuaranteedNotToBePoison(N1.getOperand(0), ~EltMask))
27503+
return N1.getOperand(0);
27504+
} else if (DAG.isGuaranteedNotToBePoison(N1.getOperand(0)))
27505+
return N1.getOperand(0);
27506+
}
2746827507
// TODO: To remove the zero check, need to adjust the offset to
2746927508
// a multiple of the new src type.
2747027509
if (isNullConstant(N2)) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7919,23 +7919,42 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
79197919
// INSERT_VECTOR_ELT into out-of-bounds element is an UNDEF, except
79207920
// for scalable vectors where we will generate appropriate code to
79217921
// deal with out-of-bounds cases correctly.
7922-
if (N3C && N1.getValueType().isFixedLengthVector() &&
7923-
N3C->getZExtValue() >= N1.getValueType().getVectorNumElements())
7922+
if (N3C && VT.isFixedLengthVector() &&
7923+
N3C->getZExtValue() >= VT.getVectorNumElements())
79247924
return getUNDEF(VT);
79257925

79267926
// Undefined index can be assumed out-of-bounds, so that's UNDEF too.
79277927
if (N3.isUndef())
79287928
return getUNDEF(VT);
79297929

7930-
// If the inserted element is an UNDEF, just use the input vector.
7931-
if (N2.isUndef())
7930+
// If inserting poison, just use the input vector.
7931+
if (N2.getOpcode() == ISD::POISON)
79327932
return N1;
79337933

7934+
// Inserting undef into undef/poison is still undef.
7935+
if (N2.getOpcode() == ISD::UNDEF && N1.isUndef())
7936+
return getUNDEF(VT);
7937+
7938+
// If the inserted element is an UNDEF, just use the input vector.
7939+
// But not if skipping the insert could make the result more poisonous.
7940+
if (N2.isUndef()) {
7941+
if (N3C && VT.isFixedLengthVector()) {
7942+
APInt EltMask =
7943+
APInt::getOneBitSet(VT.getVectorNumElements(), N3C->getZExtValue());
7944+
if (isGuaranteedNotToBePoison(N1, EltMask))
7945+
return N1;
7946+
} else if (isGuaranteedNotToBePoison(N1))
7947+
return N1;
7948+
}
79347949
break;
79357950
}
79367951
case ISD::INSERT_SUBVECTOR: {
7937-
// Inserting undef into undef is still undef.
7938-
if (N1.isUndef() && N2.isUndef())
7952+
// If inserting poison, just use the input vector,
7953+
if (N2.getOpcode() == ISD::POISON)
7954+
return N1;
7955+
7956+
// Inserting undef into undef/poison is still undef.
7957+
if (N2.getOpcode() == ISD::UNDEF && N1.isUndef())
79397958
return getUNDEF(VT);
79407959

79417960
EVT N2VT = N2.getValueType();
@@ -7964,11 +7983,37 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
79647983
if (VT == N2VT)
79657984
return N2;
79667985

7967-
// If this is an insert of an extracted vector into an undef vector, we
7968-
// can just use the input to the extract.
7986+
// If this is an insert of an extracted vector into an undef/poison vector,
7987+
// we can just use the input to the extract. But not if skipping the
7988+
// extract+insert could make the result more poisonous.
79697989
if (N1.isUndef() && N2.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
7970-
N2.getOperand(1) == N3 && N2.getOperand(0).getValueType() == VT)
7971-
return N2.getOperand(0);
7990+
N2.getOperand(1) == N3 && N2.getOperand(0).getValueType() == VT) {
7991+
if (N1.getOpcode() == ISD::POISON)
7992+
return N2.getOperand(0);
7993+
if (VT.isFixedLengthVector() && N2VT.isFixedLengthVector()) {
7994+
unsigned LoBit = N3->getAsZExtVal();
7995+
unsigned HiBit = LoBit + N2VT.getVectorNumElements();
7996+
APInt EltMask =
7997+
APInt::getBitsSet(VT.getVectorNumElements(), LoBit, HiBit);
7998+
if (isGuaranteedNotToBePoison(N2.getOperand(0), ~EltMask))
7999+
return N2.getOperand(0);
8000+
} else if (isGuaranteedNotToBePoison(N2.getOperand(0)))
8001+
return N2.getOperand(0);
8002+
}
8003+
8004+
// If the inserted subvector is UNDEF, just use the input vector.
8005+
// But not if skipping the insert could make the result more poisonous.
8006+
if (N2.isUndef()) {
8007+
if (VT.isFixedLengthVector()) {
8008+
unsigned LoBit = N3->getAsZExtVal();
8009+
unsigned HiBit = LoBit + N2VT.getVectorNumElements();
8010+
APInt EltMask =
8011+
APInt::getBitsSet(VT.getVectorNumElements(), LoBit, HiBit);
8012+
if (isGuaranteedNotToBePoison(N1, EltMask))
8013+
return N1;
8014+
} else if (isGuaranteedNotToBePoison(N1))
8015+
return N1;
8016+
}
79728017
break;
79738018
}
79748019
case ISD::BITCAST:

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3358,8 +3358,8 @@ bool TargetLowering::SimplifyDemandedVectorElts(
33583358
break;
33593359
}
33603360
case ISD::INSERT_SUBVECTOR: {
3361-
// Demand any elements from the subvector and the remainder from the src its
3362-
// inserted into.
3361+
// Demand any elements from the subvector and the remainder from the src it
3362+
// is inserted into.
33633363
SDValue Src = Op.getOperand(0);
33643364
SDValue Sub = Op.getOperand(1);
33653365
uint64_t Idx = Op.getConstantOperandVal(2);
@@ -3368,6 +3368,10 @@ bool TargetLowering::SimplifyDemandedVectorElts(
33683368
APInt DemandedSrcElts = DemandedElts;
33693369
DemandedSrcElts.clearBits(Idx, Idx + NumSubElts);
33703370

3371+
// If none of the sub operand elements are demanded, bypass the insert.
3372+
if (!DemandedSubElts)
3373+
return TLO.CombineTo(Op, Src);
3374+
33713375
APInt SubUndef, SubZero;
33723376
if (SimplifyDemandedVectorElts(Sub, DemandedSubElts, SubUndef, SubZero, TLO,
33733377
Depth + 1))

llvm/test/CodeGen/AArch64/arm64-build-vector.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ define void @widen_f16_build_vector(ptr %addr) {
5757
; CHECK-LABEL: widen_f16_build_vector:
5858
; CHECK: // %bb.0:
5959
; CHECK-NEXT: mov w8, #13294 // =0x33ee
60-
; CHECK-NEXT: movk w8, #13294, lsl #16
61-
; CHECK-NEXT: str w8, [x0]
60+
; CHECK-NEXT: dup v0.4h, w8
61+
; CHECK-NEXT: str s0, [x0]
6262
; CHECK-NEXT: ret
6363
store <2 x half> <half 0xH33EE, half 0xH33EE>, ptr %addr, align 2
6464
ret void

llvm/test/CodeGen/AArch64/concat-vector-add-combine.ll

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,17 +93,14 @@ define i32 @combine_add_8xi32(i32 %a, i32 %b, i32 %c, i32 %d, i32 %e, i32 %f, i3
9393
define i32 @combine_undef_add_8xi32(i32 %a, i32 %b, i32 %c, i32 %d) local_unnamed_addr #0 {
9494
; CHECK-LABEL: combine_undef_add_8xi32:
9595
; CHECK: // %bb.0:
96-
; CHECK-NEXT: fmov s1, w0
97-
; CHECK-NEXT: movi v0.2d, #0000000000000000
98-
; CHECK-NEXT: mov v1.s[1], w1
99-
; CHECK-NEXT: uhadd v0.4h, v0.4h, v0.4h
100-
; CHECK-NEXT: mov v1.s[2], w2
101-
; CHECK-NEXT: mov v1.s[3], w3
102-
; CHECK-NEXT: xtn v2.4h, v1.4s
103-
; CHECK-NEXT: shrn v1.4h, v1.4s, #16
104-
; CHECK-NEXT: uhadd v1.4h, v2.4h, v1.4h
105-
; CHECK-NEXT: mov v1.d[1], v0.d[0]
106-
; CHECK-NEXT: uaddlv s0, v1.8h
96+
; CHECK-NEXT: fmov s0, w0
97+
; CHECK-NEXT: mov v0.s[1], w1
98+
; CHECK-NEXT: mov v0.s[2], w2
99+
; CHECK-NEXT: mov v0.s[3], w3
100+
; CHECK-NEXT: uzp2 v1.8h, v0.8h, v0.8h
101+
; CHECK-NEXT: uzp1 v0.8h, v0.8h, v0.8h
102+
; CHECK-NEXT: uhadd v0.8h, v0.8h, v1.8h
103+
; CHECK-NEXT: uaddlv s0, v0.8h
107104
; CHECK-NEXT: fmov w0, s0
108105
; CHECK-NEXT: ret
109106
%a1 = insertelement <8 x i32> poison, i32 %a, i32 0

0 commit comments

Comments
 (0)