Skip to content

Commit c7b6b71

Browse files
committed
[VectorCombine] Scalarize extracts of ZExt if profitable.
Add a new scalarization transform that tries to convert extracts of a vector ZExt to a set of scalar shift and mask operations. This can be profitable if the cost of extracting is the same or higher than the cost of 2 scalar ops. This is the case on AArch64 for example. For AArch64,this shows up in a number of workloads, including av1aom, gmsh, minizinc and astc-encoder.
1 parent 6ba1955 commit c7b6b71

File tree

2 files changed

+155
-24
lines changed

2 files changed

+155
-24
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class VectorCombine {
120120
bool foldBinopOfReductions(Instruction &I);
121121
bool foldSingleElementStore(Instruction &I);
122122
bool scalarizeLoadExtract(Instruction &I);
123+
bool scalarizeExtExtract(Instruction &I);
123124
bool foldConcatOfBoolMasks(Instruction &I);
124125
bool foldPermuteOfBinops(Instruction &I);
125126
bool foldShuffleOfBinops(Instruction &I);
@@ -1710,6 +1711,72 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
17101711
return true;
17111712
}
17121713

1714+
bool VectorCombine::scalarizeExtExtract(Instruction &I) {
1715+
if (!match(&I, m_ZExt(m_Value())))
1716+
return false;
1717+
1718+
// Try to convert a vector zext feeding only extracts to a set of scalar (Src
1719+
// << ExtIdx *Size) & (Size -1), if profitable.
1720+
auto *Ext = cast<ZExtInst>(&I);
1721+
auto *SrcTy = cast<FixedVectorType>(Ext->getOperand(0)->getType());
1722+
auto *DstTy = cast<FixedVectorType>(Ext->getType());
1723+
1724+
if (DL->getTypeSizeInBits(SrcTy) !=
1725+
DL->getTypeSizeInBits(DstTy->getElementType()))
1726+
return false;
1727+
1728+
InstructionCost VectorCost = TTI.getCastInstrCost(
1729+
Instruction::ZExt, DstTy, SrcTy, TTI::CastContextHint::None, CostKind);
1730+
unsigned ExtCnt = 0;
1731+
bool ExtLane0 = false;
1732+
for (User *U : Ext->users()) {
1733+
const APInt *Idx;
1734+
if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
1735+
return false;
1736+
if (cast<Instruction>(U)->use_empty())
1737+
continue;
1738+
ExtCnt += 1;
1739+
ExtLane0 |= Idx->isZero();
1740+
VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
1741+
CostKind, Idx->getZExtValue(), U);
1742+
}
1743+
1744+
Type *ScalarDstTy = DstTy->getElementType();
1745+
InstructionCost ScalarCost =
1746+
ExtCnt * TTI.getArithmeticInstrCost(
1747+
Instruction::And, ScalarDstTy, CostKind,
1748+
{TTI::OK_AnyValue, TTI::OP_None},
1749+
{TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
1750+
(ExtCnt - ExtLane0) *
1751+
TTI.getArithmeticInstrCost(
1752+
1753+
Instruction::LShr, ScalarDstTy, CostKind,
1754+
{TTI::OK_AnyValue, TTI::OP_None},
1755+
{TTI::OK_NonUniformConstantValue, TTI::OP_None});
1756+
if (ScalarCost > VectorCost)
1757+
return false;
1758+
1759+
Value *ScalarV = Ext->getOperand(0);
1760+
if (!isGuaranteedNotToBePoison(ScalarV, &AC))
1761+
ScalarV = Builder.CreateFreeze(ScalarV);
1762+
ScalarV = Builder.CreateBitCast(
1763+
ScalarV,
1764+
IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
1765+
unsigned SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
1766+
Value *EltBitMask =
1767+
ConstantInt::get(ScalarV->getType(), (1ull << SrcEltSizeInBits) - 1);
1768+
for (auto *U : to_vector(Ext->users())) {
1769+
auto *Extract = cast<ExtractElementInst>(U);
1770+
unsigned Idx =
1771+
cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
1772+
auto *S = Builder.CreateLShr(
1773+
ScalarV, ConstantInt::get(ScalarV->getType(), Idx * SrcEltSizeInBits));
1774+
auto *A = Builder.CreateAnd(S, EltBitMask);
1775+
U->replaceAllUsesWith(A);
1776+
}
1777+
return true;
1778+
}
1779+
17131780
/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
17141781
/// to "(bitcast (concat X, Y))"
17151782
/// where X/Y are bitcasted from i1 mask vectors.
@@ -3576,6 +3643,7 @@ bool VectorCombine::run() {
35763643
if (IsVectorType) {
35773644
MadeChange |= scalarizeOpOrCmp(I);
35783645
MadeChange |= scalarizeLoadExtract(I);
3646+
MadeChange |= scalarizeExtExtract(I);
35793647
MadeChange |= scalarizeVPIntrinsic(I);
35803648
MadeChange |= foldInterleaveIntrinsics(I);
35813649
}

llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,25 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) {
99
; CHECK-LABEL: define void @zext_v4i8_all_lanes_used(
1010
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
1111
; CHECK-NEXT: [[ENTRY:.*:]]
12+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
13+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
14+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
15+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
16+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
17+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
18+
; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
19+
; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
20+
; CHECK-NEXT: [[TMP8:%.*]] = lshr i32 [[TMP1]], 0
21+
; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP8]], 255
1222
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
1323
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
1424
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
1525
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
1626
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
17-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
18-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
19-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
20-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
27+
; CHECK-NEXT: call void @use.i32(i32 [[TMP9]])
28+
; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
29+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
30+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
2131
; CHECK-NEXT: ret void
2232
;
2333
entry:
@@ -68,13 +78,21 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) {
6878
; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1(
6979
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
7080
; CHECK-NEXT: [[ENTRY:.*:]]
81+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
82+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
83+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
84+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
85+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16
86+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
87+
; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8
88+
; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
7189
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
7290
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
7391
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
7492
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
75-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
76-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
77-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
93+
; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
94+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
95+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
7896
; CHECK-NEXT: ret void
7997
;
8098
entry:
@@ -93,13 +111,21 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) {
93111
; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2(
94112
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
95113
; CHECK-NEXT: [[ENTRY:.*:]]
114+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
115+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
116+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
117+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
118+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
119+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
120+
; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 0
121+
; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255
96122
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
97123
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
98124
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
99125
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
100-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
101-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
102-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
126+
; CHECK-NEXT: call void @use.i32(i32 [[TMP7]])
127+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
128+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
103129
; CHECK-NEXT: ret void
104130
;
105131
entry:
@@ -118,11 +144,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) {
118144
; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1(
119145
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
120146
; CHECK-NEXT: [[ENTRY:.*:]]
147+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
148+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
149+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
150+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
151+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8
152+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
121153
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
122154
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
123155
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
124-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
125-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
156+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
157+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
126158
; CHECK-NEXT: ret void
127159
;
128160
entry:
@@ -139,11 +171,17 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) {
139171
; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2(
140172
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
141173
; CHECK-NEXT: [[ENTRY:.*:]]
174+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
175+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
176+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16
177+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
178+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 0
179+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
142180
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
143181
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
144182
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
145-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
146-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
183+
; CHECK-NEXT: call void @use.i32(i32 [[TMP5]])
184+
; CHECK-NEXT: call void @use.i32(i32 [[TMP3]])
147185
; CHECK-NEXT: ret void
148186
;
149187
entry:
@@ -160,15 +198,24 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) {
160198
; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef(
161199
; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) {
162200
; CHECK-NEXT: [[ENTRY:.*:]]
201+
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32
202+
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 24
203+
; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], 255
204+
; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP0]], 16
205+
; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 255
206+
; CHECK-NEXT: [[TMP5:%.*]] = lshr i32 [[TMP0]], 8
207+
; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP5]], 255
208+
; CHECK-NEXT: [[TMP7:%.*]] = lshr i32 [[TMP0]], 0
209+
; CHECK-NEXT: [[TMP8:%.*]] = and i32 [[TMP7]], 255
163210
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
164211
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0
165212
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1
166213
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2
167214
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3
168-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]])
169-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]])
170-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]])
171-
; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]])
215+
; CHECK-NEXT: call void @use.i32(i32 [[TMP8]])
216+
; CHECK-NEXT: call void @use.i32(i32 [[TMP6]])
217+
; CHECK-NEXT: call void @use.i32(i32 [[TMP4]])
218+
; CHECK-NEXT: call void @use.i32(i32 [[TMP2]])
172219
; CHECK-NEXT: ret void
173220
;
174221
entry:
@@ -221,15 +268,25 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) {
221268
; CHECK-LABEL: define void @zext_v4i16_all_lanes_used(
222269
; CHECK-SAME: <4 x i16> [[SRC:%.*]]) {
223270
; CHECK-NEXT: [[ENTRY:.*:]]
271+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i16> [[SRC]]
272+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64
273+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 48
274+
; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 65535
275+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 32
276+
; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 65535
277+
; CHECK-NEXT: [[TMP6:%.*]] = lshr i64 [[TMP1]], 16
278+
; CHECK-NEXT: [[TMP7:%.*]] = and i64 [[TMP6]], 65535
279+
; CHECK-NEXT: [[TMP8:%.*]] = lshr i64 [[TMP1]], 0
280+
; CHECK-NEXT: [[TMP9:%.*]] = and i64 [[TMP8]], 65535
224281
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64>
225282
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0
226283
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1
227284
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2
228285
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3
229-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]])
230-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]])
231-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_2]])
232-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_3]])
286+
; CHECK-NEXT: call void @use.i64(i64 [[TMP9]])
287+
; CHECK-NEXT: call void @use.i64(i64 [[TMP7]])
288+
; CHECK-NEXT: call void @use.i64(i64 [[TMP5]])
289+
; CHECK-NEXT: call void @use.i64(i64 [[TMP3]])
233290
; CHECK-NEXT: ret void
234291
;
235292
entry:
@@ -250,11 +307,17 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) {
250307
; CHECK-LABEL: define void @zext_v2i32_all_lanes_used(
251308
; CHECK-SAME: <2 x i32> [[SRC:%.*]]) {
252309
; CHECK-NEXT: [[ENTRY:.*:]]
310+
; CHECK-NEXT: [[TMP0:%.*]] = freeze <2 x i32> [[SRC]]
311+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64
312+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 32
313+
; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP2]], 4294967295
314+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 0
315+
; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 4294967295
253316
; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64>
254317
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0
255318
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1
256-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]])
257-
; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]])
319+
; CHECK-NEXT: call void @use.i64(i64 [[TMP5]])
320+
; CHECK-NEXT: call void @use.i64(i64 [[TMP3]])
258321
; CHECK-NEXT: ret void
259322
;
260323
entry:

0 commit comments

Comments
 (0)