Skip to content

Commit 82a7975

Browse files
ppogotovigcbot
authored andcommitted
Support zero incoming values in MergeScalarPhisPass.
Add support to MergeScalarPhisPass for vectorizing phi instructions with constant zero inсoming values.
1 parent 6cdf4a7 commit 82a7975

File tree

4 files changed

+509
-59
lines changed

4 files changed

+509
-59
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/MergeScalarPhisPass/MergeScalarPhisPass.cpp

Lines changed: 196 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ SPDX-License-Identifier: MIT
1111
#include <cstddef>
1212

1313
#include "common/LLVMWarningsPush.hpp"
14+
#include <llvm/ADT/SmallVector.h>
1415
#include <llvm/IR/BasicBlock.h>
1516
#include <llvm/IR/Instructions.h>
1617
#include <llvm/Support/raw_ostream.h>
@@ -90,16 +91,74 @@ MergeScalarPhisPass::MergeScalarPhisPass() : FunctionPass(ID) {
9091
initializeMergeScalarPhisPassPass(*PassRegistry::getPassRegistry());
9192
}
9293

93-
void MergeScalarPhisPass::cleanUpIR() {
94+
SmallVector<PHINode*, 8> MergeScalarPhisPass::getDuplicates(PHINode *PN, SmallVector<PHINode *, 8>& PhiNodesToErase) {
95+
BasicBlock *BB = PN->getParent();
96+
SmallVector<PHINode*, 8> DuplicatePhis;
97+
98+
// Check for identical PHI nodes in the same basic block
99+
for (auto I = std::next(BasicBlock::iterator(PN)), E = BB->end(); I != E; ++I) {
100+
if (!isa<PHINode>(&*I))
101+
break;
102+
103+
PHINode *OtherPN = cast<PHINode>(&*I);
104+
105+
// Skip comparing with itself
106+
if (PN == OtherPN)
107+
continue;
108+
109+
// Skip PHI nodes that are already marked for deletion
110+
if (std::find(PhiNodesToErase.begin(), PhiNodesToErase.end(), OtherPN) != PhiNodesToErase.end())
111+
continue;
112+
113+
// Check if PHI nodes are identical
114+
if (PN->getNumIncomingValues() == OtherPN->getNumIncomingValues()) {
115+
bool AreIdentical = true;
116+
117+
for (unsigned K = 0; K < PN->getNumIncomingValues(); ++K) {
118+
if (PN->getIncomingValue(K) != OtherPN->getIncomingValue(K) ||
119+
PN->getIncomingBlock(K) != OtherPN->getIncomingBlock(K)) {
120+
AreIdentical = false;
121+
break;
122+
}
123+
}
124+
125+
if (AreIdentical)
126+
DuplicatePhis.push_back(OtherPN);
127+
}
128+
}
129+
130+
return DuplicatePhis;
131+
}
132+
133+
void MergeScalarPhisPass::cleanUpIR(Function *F) {
94134
for (auto *Phi : PhiNodesToRemove)
95135
Phi->eraseFromParent();
96136

97137
for (auto *EEI : ExtrElementsToRemove)
98138
if (EEI->getNumUses() == 0)
99139
EEI->eraseFromParent();
140+
141+
SmallVector<PHINode *, 8> PhiNodesToErase;
142+
for (auto &BB : *F) {
143+
for (auto I = BB.begin(), E = BB.end(); I != E;) {
144+
if (!isa<PHINode>(I))
145+
break;
146+
147+
PHINode *PN = cast<PHINode>(&*I++);
148+
SmallVector<PHINode*, 8> Duplicates = getDuplicates(PN, PhiNodesToErase);
149+
150+
for (PHINode *D : Duplicates) {
151+
D->replaceAllUsesWith(PN);
152+
PhiNodesToErase.push_back(D);
153+
}
154+
}
155+
}
156+
157+
for (auto *PN : PhiNodesToErase)
158+
PN->eraseFromParent();
100159
}
101160

102-
bool MergeScalarPhisPass::makeChanges() {
161+
bool MergeScalarPhisPass::makeChanges(Function *F) {
103162
bool Changed = VectorToPhiNodesMap.size() > 0;
104163

105164
for (const auto &Entry : VectorToPhiNodesMap) {
@@ -115,52 +174,79 @@ bool MergeScalarPhisPass::makeChanges() {
115174

116175
for (unsigned i = 0; i < NumIncValues; ++i) {
117176
Value *Incoming = FirstPN->getIncomingValue(i);
177+
if (isa<Constant>(Incoming) && cast<Constant>(Incoming)->isZeroValue()) {
178+
NewPhi->addIncoming(ConstantAggregateZero::get(VectorType), FirstPN->getIncomingBlock(i));
179+
continue;
180+
}
118181
auto *EEI = cast<ExtractElementInst>(Incoming);
119182
NewPhi->addIncoming(EEI->getVectorOperand(), FirstPN->getIncomingBlock(i));
120183
}
121184

122185
BasicBlock *BB = FirstPN->getParent();
123186
Builder.SetInsertPoint(BB->getFirstNonPHI());
124187
for (auto *PN : Entry.second) {
125-
auto *EEI = cast<ExtractElementInst>(PN->getIncomingValue(0));
188+
// Find an incoming ExtractElementInst to get the index value.
189+
ExtractElementInst *EEI = nullptr;
190+
for (unsigned i = 0; i < PN->getNumIncomingValues(); ++i) {
191+
EEI = dyn_cast<ExtractElementInst>(PN->getIncomingValue(i));
192+
if (EEI)
193+
break;
194+
}
195+
126196
auto *CI = cast<ConstantInt>(EEI->getIndexOperand());
127197
auto *NewEEI =
128198
cast<ExtractElementInst>(Builder.CreateExtractElement(NewPhi, CI->getZExtValue(), "extract_merged"));
199+
129200
if (EEI->getDebugLoc())
130201
NewEEI->setDebugLoc(EEI->getDebugLoc());
131202

132203
PN->replaceAllUsesWith(NewEEI);
133204
}
134205
}
135206

136-
cleanUpIR();
207+
cleanUpIR(F);
137208

138209
return Changed;
139210
}
140211

212+
bool MergeScalarPhisPass::isIncomingValueZero(PHINode *pPN, unsigned IncomingIndex) {
213+
Value *pIncVal = pPN->getIncomingValue(IncomingIndex);
214+
return isa<Constant>(pIncVal) && cast<Constant>(pIncVal)->isZeroValue();
215+
}
216+
217+
Value *MergeScalarPhisPass::getVectorOperandForPhiNode(PHINode *PN, unsigned IncomingIndex) {
218+
Value *IncVal = PN->getIncomingValue(IncomingIndex);
219+
auto *EEI = dyn_cast<ExtractElementInst>(IncVal);
220+
return EEI ? EEI->getVectorOperand() : nullptr;
221+
}
222+
223+
ExtractElementInst *MergeScalarPhisPass::getEEIFromPhi(PHINode *PN, unsigned IncomingIndex) {
224+
Value *IncVal = PN->getIncomingValue(IncomingIndex);
225+
return dyn_cast<ExtractElementInst>(IncVal);
226+
}
227+
141228
// Collect all PHI nodes to VectorToPhiNodesMap. Use incoming vector value
142229
// from incoming index 0 as a key for the group of PHI nodes. After collecting
143230
// all PHI nodes, filter out phi nodes using conditions for incoming values with
144231
// indices not equal to 0.
145232
//
146233
// Conditions for PHI nodes to be merged into a single vector PHI node:
147-
// Condition 1: All PHI node incoming values should be ExtractElementInsts (EEIs).
234+
// Condition 1: At least one PHI node incoming value should be ExtractElementInsts (EEIs), other can be zeros.
148235
// Condition 2: EEIs vector operands should have FixedVectorType.
149236
// Condition 3: EEIs should have the same vector type.
150237
// Condition 4: EEIs should have the same constant index value.
151-
// Condition 5: All incoming EEIs should be used only once.
152-
// Condition 6: Number of PHI nodes in a group should be equal to the vector
153-
// size.
238+
// Condition 6: Number of PHI nodes in a group should be equal to the vector size.
154239
void MergeScalarPhisPass::collectPhiNodes(Function &F) {
155-
auto getVectorOperandForPhiNode = [](PHINode *PN, unsigned IncomingIndex) -> Value * {
156-
Value *IncVal = PN->getIncomingValue(IncomingIndex);
157-
auto *EEI = cast<ExtractElementInst>(IncVal);
158-
return EEI->getVectorOperand();
159-
};
160-
161-
auto getEEIFromPhi = [](PHINode *PN, unsigned IncomingIndex) -> ExtractElementInst * {
162-
Value *IncVal = PN->getIncomingValue(IncomingIndex);
163-
return dyn_cast<ExtractElementInst>(IncVal);
240+
auto getFirstVectorIncomingValForPhiNode = [](PHINode *PN) -> Value * {
241+
for (unsigned i = 0; i < PN->getNumIncomingValues(); ++i) {
242+
Value *IncVal = PN->getIncomingValue(i);
243+
if (isa<Constant>(IncVal) && cast<Constant>(IncVal)->isZeroValue())
244+
continue;
245+
246+
auto *EEI = cast<ExtractElementInst>(IncVal);
247+
return EEI->getVectorOperand();
248+
}
249+
return nullptr;
164250
};
165251

166252
clearContainers();
@@ -176,12 +262,10 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
176262
bool Check = true;
177263
Type *SavedType = nullptr;
178264
int IndexValue = 0;
179-
180-
// Skip PHI nodes with less than 2 incoming values.
181-
// if (PN->getNumIncomingValues() < 2)
182-
// continue;
183-
184265
for (unsigned i = 0; i < PN->getNumIncomingValues(); ++i) {
266+
if (isIncomingValueZero(PN, i))
267+
continue;
268+
185269
// Check Condition 1
186270
ExtractElementInst *EEI = getEEIFromPhi(PN, i);
187271
if (!EEI) {
@@ -197,7 +281,7 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
197281

198282
auto *CurrType = EEI->getVectorOperand()->getType();
199283
int CurrIndexValue = CI->getZExtValue();
200-
if (i == 0) {
284+
if (!SavedType) {
201285
SavedType = CurrType;
202286
IndexValue = CurrIndexValue;
203287
}
@@ -218,41 +302,55 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
218302
Check = false;
219303
break;
220304
}
221-
222-
// Check Condition 5
223-
if (!EEI->getSingleUndroppableUse()) {
224-
Check = false;
225-
break;
226-
}
227305
}
228306

229307
if (Check) {
230308
// Using vector operand of the first incoming EEI to define a key for
231309
// the group of PHI nodes.
232-
Value *FirstEEIVectorOp = getVectorOperandForPhiNode(PN, 0);
233-
auto CurNumIncValues = PN->getNumIncomingValues();
310+
Value *FirstVectorOp = getFirstVectorIncomingValForPhiNode(PN);
311+
unsigned CurNumIncValues = PN->getNumIncomingValues();
234312

235313
// All PHI nodes corresponding to the same vector value should be in one
236314
// basic block.
237-
if (VectorToPhiNodesMap.find(FirstEEIVectorOp) != VectorToPhiNodesMap.end()) {
238-
if (VectorToPhiNodesMap[FirstEEIVectorOp][0]->getParent() != PN->getParent()) {
315+
if (VectorToPhiNodesMap.find(FirstVectorOp) != VectorToPhiNodesMap.end()) {
316+
if (VectorToPhiNodesMap[FirstVectorOp][0]->getParent() != PN->getParent())
239317
continue;
240-
}
241318

242319
// All PN in the group should have the same number of incoming values.
243-
if (VectorToPhiNodesMap[FirstEEIVectorOp][0]->getNumIncomingValues() != CurNumIncValues) {
320+
if (VectorToPhiNodesMap[FirstVectorOp][0]->getNumIncomingValues() != CurNumIncValues)
244321
continue;
245-
}
246322
}
247-
248-
VectorToPhiNodesMap[FirstEEIVectorOp].push_back(PN);
323+
VectorToPhiNodesMap[FirstVectorOp].push_back(PN);
249324
}
250325
}
251326
}
252327

328+
filterOutUnexpectedIncomingConstants();
329+
253330
// Filter out PHI nodes that do not meet the conditions 6.
254331
// Filter out some suspicious cases (e.g. when the EEIs for a particular PHI
255332
// node group and a particular index are not in the same base block).
333+
filterOutUnvectorizedPhis();
334+
335+
336+
for (auto &Entry : VectorToPhiNodesMap) {
337+
for (auto *PN : Entry.second) {
338+
PhiNodesToRemove.insert(PN);
339+
for (unsigned I = 0; I < PN->getNumIncomingValues(); ++I) {
340+
ExtractElementInst *EEI = getEEIFromPhi(PN, I);
341+
// Skip zeros incoming values.
342+
if (EEI && EEI->hasOneUser()) {
343+
ExtrElementsToRemove.insert(EEI);
344+
}
345+
}
346+
}
347+
}
348+
}
349+
350+
// Check that if at least one phi node incoming value is zero for a
351+
// specific index, then all other incoming values for that index should
352+
// be zeros as well.
353+
void MergeScalarPhisPass::filterOutUnexpectedIncomingConstants() {
256354
for (auto It = VectorToPhiNodesMap.begin(); It != VectorToPhiNodesMap.end();) {
257355
auto &PhiNodes = It->second;
258356
auto *FirstPhiNode = PhiNodes[0];
@@ -263,8 +361,47 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
263361
continue;
264362
}
265363

266-
// Check Condition 6
267-
Type *VType = getVectorOperandForPhiNode(FirstPhiNode, 0)->getType();
364+
bool NeedToBreak = false;
365+
size_t NumIncomingValues = FirstPhiNode->getNumIncomingValues();
366+
for (unsigned Index = 0; Index < NumIncomingValues; ++Index) {
367+
bool ExpectZeroVal = false;
368+
for (size_t P = 0; P < PhiNodes.size(); ++P) {
369+
if (isIncomingValueZero(PhiNodes[P], Index)) {
370+
if (!P) {
371+
ExpectZeroVal = true;
372+
continue;
373+
}
374+
375+
if (!ExpectZeroVal) {
376+
It = VectorToPhiNodesMap.erase(It);
377+
NeedToBreak = true;
378+
break;
379+
}
380+
} else {
381+
if (ExpectZeroVal) {
382+
It = VectorToPhiNodesMap.erase(It);
383+
NeedToBreak = true;
384+
break;
385+
}
386+
}
387+
}
388+
389+
if (NeedToBreak)
390+
break;
391+
}
392+
393+
if (!NeedToBreak)
394+
++It;
395+
}
396+
}
397+
398+
void MergeScalarPhisPass::filterOutUnvectorizedPhis() {
399+
for (auto It = VectorToPhiNodesMap.begin(); It != VectorToPhiNodesMap.end();) {
400+
Type *VType = It->first->getType();
401+
auto &PhiNodes = It->second;
402+
auto *FirstPhiNode = PhiNodes[0];
403+
404+
// Check Condition 6: Number of PHI nodes in a group should be equal to the vector size.
268405
size_t NumElements = cast<VectorType>(VType)->getElementCount().getFixedValue();
269406
if (NumElements != PhiNodes.size()) {
270407
It = VectorToPhiNodesMap.erase(It);
@@ -273,14 +410,25 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
273410

274411
bool NeedToBreak = false;
275412
size_t NumIncomingValues = FirstPhiNode->getNumIncomingValues();
276-
277413
for (unsigned Index = 0; Index < NumIncomingValues; ++Index) {
278414
Value *VOp = getVectorOperandForPhiNode(FirstPhiNode, Index);
279-
BasicBlock *EEIBB = getEEIFromPhi(FirstPhiNode, Index)->getParent();
415+
ExtractElementInst *EEI = getEEIFromPhi(FirstPhiNode, Index);
416+
417+
if (!VOp || !EEI)
418+
continue;
280419

420+
BasicBlock *EEIBB = EEI->getParent();
281421
for (size_t P = 0; P < PhiNodes.size(); ++P) {
282422
Value *CurrVOp = getVectorOperandForPhiNode(PhiNodes[P], Index);
283-
BasicBlock *CurrEEIBB = getEEIFromPhi(PhiNodes[P], Index)->getParent();
423+
ExtractElementInst *CurrEEI = getEEIFromPhi(PhiNodes[P], Index);
424+
425+
if (!CurrVOp || !CurrEEI) {
426+
It = VectorToPhiNodesMap.erase(It);
427+
NeedToBreak = true;
428+
break;
429+
}
430+
431+
BasicBlock *CurrEEIBB = CurrEEI->getParent();
284432

285433
// Check that all incoming values for a specific index in the PHI nodes
286434
// group were extracted from the same vector.
@@ -297,20 +445,13 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
297445
break;
298446
}
299447
}
300-
}
301448

302-
if (!NeedToBreak) {
303-
++It;
449+
if (NeedToBreak)
450+
break;
304451
}
305-
}
306452

307-
for (auto &Entry : VectorToPhiNodesMap) {
308-
for (auto *PN : Entry.second) {
309-
PhiNodesToRemove.insert(PN);
310-
for (unsigned I = 0; I < PN->getNumIncomingValues(); ++I) {
311-
ExtrElementsToRemove.insert(getEEIFromPhi(PN, I));
312-
}
313-
}
453+
if (!NeedToBreak)
454+
++It;
314455
}
315456
}
316457

@@ -329,7 +470,7 @@ bool MergeScalarPhisPass::runOnFunction(Function &F) {
329470

330471
// Optimize the function until optimization patterns can be found.
331472
while (VectorToPhiNodesMap.size()) {
332-
Changed |= makeChanges();
473+
Changed |= makeChanges(&F);
333474
collectPhiNodes(F);
334475
}
335476

0 commit comments

Comments
 (0)