-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[VectorCombine] Scalarize extracts of ZExt if profitable. #142976
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,6 +123,7 @@ class VectorCombine { | |
bool foldBinopOfReductions(Instruction &I); | ||
bool foldSingleElementStore(Instruction &I); | ||
bool scalarizeLoadExtract(Instruction &I); | ||
bool scalarizeExtExtract(Instruction &I); | ||
bool foldConcatOfBoolMasks(Instruction &I); | ||
bool foldPermuteOfBinops(Instruction &I); | ||
bool foldShuffleOfBinops(Instruction &I); | ||
|
@@ -1774,6 +1775,73 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { | |
return true; | ||
} | ||
|
||
bool VectorCombine::scalarizeExtExtract(Instruction &I) { | ||
auto *Ext = dyn_cast<ZExtInst>(&I); | ||
if (!Ext) | ||
return false; | ||
|
||
// Try to convert a vector zext feeding only extracts to a set of scalar | ||
// (Src << ExtIdx *Size) & (Size -1) | ||
// if profitable . | ||
auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType()); | ||
if (!SrcTy) | ||
return false; | ||
auto *DstTy = cast<FixedVectorType>(Ext->getType()); | ||
|
||
Type *ScalarDstTy = DstTy->getElementType(); | ||
if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy)) | ||
return false; | ||
|
||
InstructionCost VectorCost = | ||
TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy, | ||
TTI::CastContextHint::None, CostKind, Ext); | ||
unsigned ExtCnt = 0; | ||
bool ExtLane0 = false; | ||
for (User *U : Ext->users()) { | ||
const APInt *Idx; | ||
if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx)))) | ||
return false; | ||
if (cast<Instruction>(U)->use_empty()) | ||
continue; | ||
ExtCnt += 1; | ||
ExtLane0 |= Idx->isZero(); | ||
VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy, | ||
CostKind, Idx->getZExtValue(), U); | ||
} | ||
|
||
InstructionCost ScalarCost = | ||
ExtCnt * TTI.getArithmeticInstrCost( | ||
Instruction::And, ScalarDstTy, CostKind, | ||
{TTI::OK_AnyValue, TTI::OP_None}, | ||
{TTI::OK_NonUniformConstantValue, TTI::OP_None}) + | ||
(ExtCnt - ExtLane0) * | ||
TTI.getArithmeticInstrCost( | ||
Instruction::LShr, ScalarDstTy, CostKind, | ||
{TTI::OK_AnyValue, TTI::OP_None}, | ||
{TTI::OK_NonUniformConstantValue, TTI::OP_None}); | ||
if (ScalarCost > VectorCost) | ||
return false; | ||
|
||
Value *ScalarV = Ext->getOperand(0); | ||
if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast<Instruction>(ScalarV), | ||
&DT)) | ||
ScalarV = Builder.CreateFreeze(ScalarV); | ||
ScalarV = Builder.CreateBitCast( | ||
ScalarV, | ||
IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where do you account for the cost of the bitcast (vec->int transfer)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I assumed they are basically always free? Is there a dedicated hook to check their cost? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know x86 can struggle with gpr<->simd move throughput (although I wouldn't be surprised if the cost tables are missing coverage). getCastInstrCost should be able to handle bitcasts - by default I think they are treated as free. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes that's a good point. Some tests aren't transformed after the change. The motivating case has a load feeding the zext, so there's no need for the bitcast; planning to handle this as follow-up
artagnon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType()); | ||
uint64_t EltBitMask = (1ull << SrcEltSizeInBits) - 1; | ||
for (User *U : Ext->users()) { | ||
auto *Extract = cast<ExtractElementInst>(U); | ||
uint64_t Idx = | ||
cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue(); | ||
Value *LShr = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits); | ||
Value *And = Builder.CreateAnd(LShr, EltBitMask); | ||
U->replaceAllUsesWith(And); | ||
} | ||
return true; | ||
} | ||
|
||
/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" | ||
/// to "(bitcast (concat X, Y))" | ||
/// where X/Y are bitcasted from i1 mask vectors. | ||
|
@@ -3662,6 +3730,7 @@ bool VectorCombine::run() { | |
if (IsVectorType) { | ||
MadeChange |= scalarizeOpOrCmp(I); | ||
MadeChange |= scalarizeLoadExtract(I); | ||
MadeChange |= scalarizeExtExtract(I); | ||
MadeChange |= scalarizeVPIntrinsic(I); | ||
MadeChange |= foldInterleaveIntrinsics(I); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.