-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[VectorCombine] New folding pattern for extract/binop/shuffle chains #145232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,6 +130,7 @@ class VectorCombine { | |
bool foldShuffleOfIntrinsics(Instruction &I); | ||
bool foldShuffleToIdentity(Instruction &I); | ||
bool foldShuffleFromReductions(Instruction &I); | ||
bool foldShuffleChainsToReduce(Instruction &I); | ||
bool foldCastFromReductions(Instruction &I); | ||
bool foldSelectShuffle(Instruction &I, bool FromReduction = false); | ||
bool foldInterleaveIntrinsics(Instruction &I); | ||
|
@@ -2988,6 +2989,240 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) { | |
return foldSelectShuffle(*Shuffle, true); | ||
} | ||
|
||
bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) { | ||
auto *EEI = dyn_cast<ExtractElementInst>(&I); | ||
if (!EEI) | ||
return false; | ||
|
||
std::queue<Value *> InstWorklist; | ||
Value *InitEEV = nullptr; | ||
|
||
unsigned int CommonCallOp = 0, CommonBinOp = 0; | ||
|
||
bool IsFirstCallOrBinInst = true; | ||
bool ShouldBeCallOrBinInst = true; | ||
|
||
SmallVector<Value *, 3> PrevVecV(3, nullptr); | ||
int64_t ShuffleMaskHalf = -1, ExpectedShuffleMaskHalf = 1; | ||
int64_t VecSize = -1; | ||
|
||
Value *VecOp; | ||
if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero()))) | ||
return false; | ||
|
||
auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType()); | ||
if (!FVT) | ||
return false; | ||
|
||
VecSize = FVT->getNumElements(); | ||
if (VecSize < 2 || (VecSize % 2) != 0) | ||
return false; | ||
|
||
ShuffleMaskHalf = 1; | ||
PrevVecV[2] = VecOp; | ||
InitEEV = EEI; | ||
|
||
InstWorklist.push(PrevVecV[2]); | ||
|
||
while (!InstWorklist.empty()) { | ||
Value *V = InstWorklist.front(); | ||
InstWorklist.pop(); | ||
|
||
auto *CI = dyn_cast<Instruction>(V); | ||
if (!CI) | ||
return false; | ||
|
||
if (auto *CallI = dyn_cast<CallInst>(CI)) { | ||
if (!ShouldBeCallOrBinInst || !PrevVecV[2]) | ||
return false; | ||
|
||
if (!IsFirstCallOrBinInst && | ||
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; })) | ||
return false; | ||
|
||
if (CallI != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0])) | ||
return false; | ||
IsFirstCallOrBinInst = false; | ||
|
||
auto *II = dyn_cast<IntrinsicInst>(CallI); | ||
if (!II) | ||
return false; | ||
|
||
if (!CommonCallOp) | ||
CommonCallOp = II->getIntrinsicID(); | ||
if (II->getIntrinsicID() != CommonCallOp) | ||
return false; | ||
|
||
switch (II->getIntrinsicID()) { | ||
case Intrinsic::umin: | ||
case Intrinsic::umax: | ||
case Intrinsic::smin: | ||
case Intrinsic::smax: { | ||
auto *Op0 = CallI->getOperand(0); | ||
auto *Op1 = CallI->getOperand(1); | ||
PrevVecV[0] = Op0; | ||
PrevVecV[1] = Op1; | ||
break; | ||
} | ||
default: | ||
return false; | ||
} | ||
ShouldBeCallOrBinInst ^= 1; | ||
|
||
if (!isa<ShuffleVectorInst>(PrevVecV[1])) | ||
std::swap(PrevVecV[0], PrevVecV[1]); | ||
InstWorklist.push(PrevVecV[1]); | ||
InstWorklist.push(PrevVecV[0]); | ||
} else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) { | ||
if (!ShouldBeCallOrBinInst || !PrevVecV[2]) | ||
return false; | ||
|
||
if (!IsFirstCallOrBinInst && | ||
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; })) | ||
return false; | ||
|
||
if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0])) | ||
return false; | ||
IsFirstCallOrBinInst = false; | ||
|
||
if (!CommonBinOp) | ||
CommonBinOp = CI->getOpcode(); | ||
if (CI->getOpcode() != CommonBinOp) | ||
return false; | ||
|
||
switch (CI->getOpcode()) { | ||
case BinaryOperator::Add: | ||
case BinaryOperator::Mul: | ||
case BinaryOperator::Or: | ||
case BinaryOperator::And: | ||
case BinaryOperator::Xor: { | ||
auto *Op0 = BinOp->getOperand(0); | ||
auto *Op1 = BinOp->getOperand(1); | ||
PrevVecV[0] = Op0; | ||
PrevVecV[1] = Op1; | ||
break; | ||
} | ||
default: | ||
return false; | ||
} | ||
ShouldBeCallOrBinInst ^= 1; | ||
|
||
if (!isa<ShuffleVectorInst>(PrevVecV[1])) | ||
std::swap(PrevVecV[0], PrevVecV[1]); | ||
InstWorklist.push(PrevVecV[1]); | ||
InstWorklist.push(PrevVecV[0]); | ||
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) { | ||
if (ShouldBeCallOrBinInst || | ||
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; })) | ||
return false; | ||
|
||
if (SVInst != PrevVecV[1]) | ||
return false; | ||
|
||
auto *ShuffleVec = SVInst->getOperand(0); | ||
if (!ShuffleVec || ShuffleVec != PrevVecV[0]) | ||
return false; | ||
|
||
SmallVector<int> CurMask; | ||
SVInst->getShuffleMask(CurMask); | ||
|
||
if (ShuffleMaskHalf != ExpectedShuffleMaskHalf) | ||
return false; | ||
ExpectedShuffleMaskHalf *= 2; | ||
|
||
for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) { | ||
if (Mask < ShuffleMaskHalf && CurMask[Mask] != ShuffleMaskHalf + Mask) | ||
return false; | ||
if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1) | ||
return false; | ||
} | ||
ShuffleMaskHalf *= 2; | ||
if (ExpectedShuffleMaskHalf == VecSize) | ||
break; | ||
ShouldBeCallOrBinInst ^= 1; | ||
} else { | ||
return false; | ||
} | ||
} | ||
|
||
if (ShouldBeCallOrBinInst) | ||
return false; | ||
|
||
Rajveer100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert(VecSize != -1 && ExpectedShuffleMaskHalf == VecSize && | ||
"Expected Match for Vector Size and Mask Half"); | ||
|
||
Value *FinalVecV = PrevVecV[0]; | ||
auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType()); | ||
|
||
if (!InitEEV || !FinalVecV) | ||
return false; | ||
|
||
assert(FinalVecVTy && "Expected non-null value for Vector Type"); | ||
|
||
Intrinsic::ID ReducedOp = 0; | ||
if (CommonCallOp) { | ||
switch (CommonCallOp) { | ||
case Intrinsic::umin: | ||
ReducedOp = Intrinsic::vector_reduce_umin; | ||
break; | ||
case Intrinsic::umax: | ||
ReducedOp = Intrinsic::vector_reduce_umax; | ||
break; | ||
case Intrinsic::smin: | ||
ReducedOp = Intrinsic::vector_reduce_smin; | ||
break; | ||
case Intrinsic::smax: | ||
ReducedOp = Intrinsic::vector_reduce_smax; | ||
break; | ||
default: | ||
return false; | ||
} | ||
} else if (CommonBinOp) { | ||
switch (CommonBinOp) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
case BinaryOperator::Add: | ||
ReducedOp = Intrinsic::vector_reduce_add; | ||
break; | ||
case BinaryOperator::Mul: | ||
ReducedOp = Intrinsic::vector_reduce_mul; | ||
break; | ||
case BinaryOperator::Or: | ||
ReducedOp = Intrinsic::vector_reduce_or; | ||
break; | ||
case BinaryOperator::And: | ||
ReducedOp = Intrinsic::vector_reduce_and; | ||
break; | ||
case BinaryOperator::Xor: | ||
ReducedOp = Intrinsic::vector_reduce_xor; | ||
break; | ||
default: | ||
return false; | ||
} | ||
} | ||
|
||
InstructionCost OrigCost = 0; | ||
unsigned int NumLevels = Log2_64(VecSize); | ||
|
||
for (unsigned int Level = 0; Level < NumLevels; ++Level) { | ||
OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, | ||
Rajveer100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
FinalVecVTy, FinalVecVTy); | ||
OrigCost += TTI.getArithmeticInstrCost(Instruction::ICmp, FinalVecVTy); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use the real binop/intrinsic to compute the cost. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding this, I was actually wondering from the beginning if cost analysis is even worth it, since isn't it always cheaper to replace multiple combo operations with a single one? In fact, if we consider the original ones like you suggest, it will be even costlier when the sum adds up make it more biased towards the single reduce. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Generally it is true. But the single reduce may be expanded later if it is not supported by the target. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case, I think there might be an option to check the target support for reduce, since cost analysis wouldn't tell us that, it would just always be true that the new cost is cheaper? |
||
} | ||
OrigCost += TTI.getVectorInstrCost(Instruction::ExtractElement, FinalVecVTy, | ||
CostKind, 0); | ||
|
||
IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV}); | ||
InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind); | ||
|
||
if (NewCost >= OrigCost) | ||
return false; | ||
|
||
auto *ReducedResult = | ||
Builder.CreateIntrinsic(ReducedOp, {FinalVecV->getType()}, {FinalVecV}); | ||
replaceValue(*InitEEV, *ReducedResult); | ||
|
||
return true; | ||
} | ||
|
||
/// Determine if its more efficient to fold: | ||
/// reduce(trunc(x)) -> trunc(reduce(x)). | ||
/// reduce(sext(x)) -> sext(reduce(x)). | ||
|
@@ -3705,6 +3940,9 @@ bool VectorCombine::run() { | |
MadeChange |= foldShuffleFromReductions(I); | ||
MadeChange |= foldCastFromReductions(I); | ||
break; | ||
case Instruction::ExtractElement: | ||
MadeChange |= foldShuffleChainsToReduce(I); | ||
break; | ||
case Instruction::ICmp: | ||
case Instruction::FCmp: | ||
MadeChange |= foldExtractExtract(I); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it work for non-power-of-2 vector types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment it works for even integers (not limited to powers of 2). For odd, I will need to check again for a suitable pattern and resolve the parity.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually looking again, yeah, only powers of 2, since lets say I had:
14 -> 7 (odd, here the partition might not work.
Let me see.