Skip to content

Commit 0d3c067

Browse files
committed
[VectorCombine] Scalarize extracts of ZExt if profitable.
Add a new scalarization transform that tries to convert extracts of a vector ZExt to a set of scalar shift and mask operations. This can be profitable if the cost of extracting is the same or higher than the cost of 2 scalar ops. This is the case on AArch64 for example. For AArch64,this shows up in a number of workloads, including av1aom, gmsh, minizinc and astc-encoder.
1 parent 619f7af commit 0d3c067

File tree

2 files changed

+155
-24
lines changed

2 files changed

+155
-24
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class VectorCombine {
121121
bool foldBinopOfReductions(Instruction &I);
122122
bool foldSingleElementStore(Instruction &I);
123123
bool scalarizeLoadExtract(Instruction &I);
124+
bool scalarizeExtExtract(Instruction &I);
124125
bool foldConcatOfBoolMasks(Instruction &I);
125126
bool foldPermuteOfBinops(Instruction &I);
126127
bool foldShuffleOfBinops(Instruction &I);
@@ -1770,6 +1771,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
17701771
return true;
17711772
}
17721773

1774+
bool VectorCombine::scalarizeExtExtract(Instruction &I) {
1775+
if (!match(&I, m_ZExt(m_Value())))
1776+
return false;
1777+
1778+
// Try to convert a vector zext feeding only extracts to a set of scalar (Src
1779+
// << ExtIdx *Size) & (Size -1), if profitable.
1780+
auto *Ext = cast<ZExtInst>(&I);
1781+
auto *SrcTy = cast<FixedVectorType>(Ext->getOperand(0)->getType());
1782+
auto *DstTy = cast<FixedVectorType>(Ext->getType());
1783+
1784+
if (DL->getTypeSizeInBits(SrcTy) !=
1785+
DL->getTypeSizeInBits(DstTy->getElementType()))
1786+
return false;
1787+
1788+
InstructionCost VectorCost = TTI.getCastInstrCost(
1789+
Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
1790+
unsigned ExtCnt = 0;
1791+
bool ExtLane0 = false;
1792+
for (User *U : Ext->users()) {
1793+
const APInt *Idx;
1794+
if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
1795+
return false;
1796+
if (cast<Instruction>(U)->use_empty())
1797+
continue;
1798+
ExtCnt += 1;
1799+
ExtLane0 |= Idx->isZero();
1800+
VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
1801+
CostKind, Idx->getZExtValue(), U);
1802+
}
1803+
1804+
Type *ScalarDstTy = DstTy->getElementType();
1805+
InstructionCost ScalarCost =
1806+
ExtCnt * TTI.getArithmeticInstrCost(
1807+
Instruction::And, ScalarDstTy, CostKind,
1808+
{TTI::OK_AnyValue, TTI::OP_None},
1809+
{TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
1810+
(ExtCnt - ExtLane0) *
1811+
TTI.getArithmeticInstrCost(
1812+
1813+
Instruction::LShr, ScalarDstTy, CostKind,
1814+
{TTI::OK_AnyValue, TTI::OP_None},
1815+
{TTI::OK_NonUniformConstantValue, TTI::OP_None});
1816+
if (ScalarCost > VectorCost)
1817+
return false;
1818+
1819+
Value *ScalarV = Ext->getOperand(0);
1820+
if (!isGuaranteedNotToBePoison(ScalarV, &AC))
1821+
ScalarV = Builder.CreateFreeze(ScalarV);
1822+
ScalarV = Builder.CreateBitCast(
1823+
ScalarV,
1824+
IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
1825+
unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
1826+
Value *EltBitMask =
1827+
ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1);
1828+
for (auto *U : to_vector(Ext->users())) {
1829+
auto *Extract = cast<ExtractElementInst>(U);
1830+
unsigned Idx =
1831+
cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
1832+
auto *S = Builder.CreateLShr(
1833+
ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits));
1834+
auto *A = Builder.CreateAnd(S, EltBitMask);
1835+
U->replaceAllUsesWith(A);
1836+
}
1837+
return true;
1838+
}
1839+
17731840
/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
17741841
/// to "(bitcast (concat X, Y))"
17751842
/// where X/Y are bitcasted from i1 mask vectors.
@@ -3655,6 +3722,7 @@ bool VectorCombine::run() {
36553722
if (IsVectorType) {
36563723
MadeChange |= scalarizeOpOrCmp(I);
36573724
MadeChange |= scalarizeLoadExtract(I);
3725+
MadeChange |= scalarizeExtExtract(I);
36583726
MadeChange |= scalarizeVPIntrinsic(I);
36593727
MadeChange |= foldInterleaveIntrinsics(I);
36603728
}

llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,25 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) {
99
; CHECK-LABEL: define void @zext_v4i8_all_lanes_used(
1010
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
1111
; CHECK-NEXT: [[ENTRY:.*:]]
12+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
13+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
14+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
15+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
16+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
17+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
18+
; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
19+
; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
20+
; CHECK-NEXT: [[TMP8:%.*]] = lshr i32 [[TMP1]], 0
21+
; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP8]], 255
1222
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
1323
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
1424
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
1525
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
1626
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
17-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
18-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
19-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
20-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
27+
; CHECK-NEXT: call void @use.i32(i32 [[TMP9]])
28+
; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
29+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
30+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
2131
; CHECK-NEXT: ret void
2232
;
2333
entry:
@@ -68,13 +78,21 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) {
6878
; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1(
6979
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
7080
; CHECK-NEXT: [[ENTRY:.*:]]
81+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
82+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
83+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
84+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
85+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
86+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
87+
; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
88+
; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
7189
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
7290
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
7391
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
7492
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
75-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
76-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
77-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
93+
; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
94+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
95+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
7896
; CHECK-NEXT: ret void
7997
;
8098
entry:
@@ -93,13 +111,21 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) {
93111
; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2(
94112
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
95113
; CHECK-NEXT: [[ENTRY:.*:]]
114+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
115+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
116+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
117+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
118+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
119+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
120+
; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 0
121+
; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
96122
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
97123
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
98124
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
99125
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
100-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
101-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
102-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
126+
; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
127+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
128+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
103129
; CHECK-NEXT: ret void
104130
;
105131
entry:
@@ -118,11 +144,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) {
118144
; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1(
119145
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
120146
; CHECK-NEXT: [[ENTRY:.*:]]
147+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
148+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
149+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
150+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
151+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
152+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
121153
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
122154
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
123155
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
124-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
125-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
156+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
157+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
126158
; CHECK-NEXT: ret void
127159
;
128160
entry:
@@ -139,11 +171,17 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) {
139171
; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2(
140172
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
141173
; CHECK-NEXT: [[ENTRY:.*:]]
174+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
175+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
176+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
177+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
178+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 0
179+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
142180
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
143181
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
144182
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
145-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
146-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
183+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
184+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
147185
; CHECK-NEXT: ret void
148186
;
149187
entry:
@@ -160,15 +198,24 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) {
160198
; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef(
161199
; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) {
162200
; CHECK-NEXT: [[ENTRY:.*:]]
201+
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32
202+
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 24
203+
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 255
204+
; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP0]], 16
205+
; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 255
206+
; CHECK-NEXT: [[TMP5:%.*]] = lshr i32 [[TMP0]], 8
207+
; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP5]], 255
208+
; CHECK-NEXT: [[TMP7:%.*]] = lshr i32 [[TMP0]], 0
209+
; CHECK-NEXT: [[TMP8:%.*]] = and i32 [[TMP7]], 255
163210
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
164211
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
165212
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
166213
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
167214
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
168-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
169-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
170-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
171-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
215+
; CHECK-NEXT: call void @use.i32(i32 [[TMP8]])
216+
; CHECK-NEXT: call void @use.i32(i32 [[TMP6]])
217+
; CHECK-NEXT: call void @use.i32(i32 [[TMP4]])
218+
; CHECK-NEXT: call void @use.i32(i32 [[TMP2]])
172219
; CHECK-NEXT: ret void
173220
;
174221
entry:
@@ -221,15 +268,25 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) {
221268
; CHECK-LABEL: define void @zext_v4i16_all_lanes_used(
222269
; CHECK-SAME: <4 x i16> [[SRC:%.*]]) {
223270
; CHECK-NEXT: [[ENTRY:.*:]]
271+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i16> [[SRC]]
272+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64
273+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 48
274+
; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 65535
275+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 32
276+
; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 65535
277+
; CHECK-NEXT: [[TMP6:%.*]] = lshr i64 [[TMP1]], 16
278+
; CHECK-NEXT: [[TMP7:%.*]] = and i64 [[TMP6]], 65535
279+
; CHECK-NEXT: [[TMP8:%.*]] = lshr i64 [[TMP1]], 0
280+
; CHECK-NEXT: [[TMP9:%.*]] = and i64 [[TMP8]], 65535
224281
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64>
225282
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0
226283
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1
227284
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2
228285
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3
229-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]])
230-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]])
231-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_2]])
232-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_3]])
286+
; CHECK-NEXT: call void @use.i64(i64 [[TMP9]])
287+
; CHECK-NEXT: call void @use.i64(i64 [[TMP7]])
288+
; CHECK-NEXT: call void @use.i64(i64 [[TMP5]])
289+
; CHECK-NEXT: call void @use.i64(i64 [[TMP3]])
233290
; CHECK-NEXT: ret void
234291
;
235292
entry:
@@ -250,11 +307,17 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) {
250307
; CHECK-LABEL: define void @zext_v2i32_all_lanes_used(
251308
; CHECK-SAME: <2 x i32> [[SRC:%.*]]) {
252309
; CHECK-NEXT: [[ENTRY:.*:]]
310+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <2 x i32> [[SRC]]
311+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64
312+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 32
313+
; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 4294967295
314+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 0
315+
; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 4294967295
253316
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64>
254317
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0
255318
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1
256-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]])
257-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]])
319+
; CHECK-NEXT: call void @use.i64(i64 [[TMP5]])
320+
; CHECK-NEXT: call void @use.i64(i64 [[TMP3]])
258321
; CHECK-NEXT: ret void
259322
;
260323
entry:

0 commit comments

Comments
 (0)