@@ -11,6 +11,7 @@ SPDX-License-Identifier: MIT
11
11
#include < cstddef>
12
12
13
13
#include " common/LLVMWarningsPush.hpp"
14
+ #include < llvm/ADT/SmallVector.h>
14
15
#include < llvm/IR/BasicBlock.h>
15
16
#include < llvm/IR/Instructions.h>
16
17
#include < llvm/Support/raw_ostream.h>
@@ -90,16 +91,74 @@ MergeScalarPhisPass::MergeScalarPhisPass() : FunctionPass(ID) {
90
91
initializeMergeScalarPhisPassPass (*PassRegistry::getPassRegistry ());
91
92
}
92
93
93
- void MergeScalarPhisPass::cleanUpIR () {
94
+ SmallVector<PHINode*, 8 > MergeScalarPhisPass::getDuplicates (PHINode *PN, SmallVector<PHINode *, 8 >& PhiNodesToErase) {
95
+ BasicBlock *BB = PN->getParent ();
96
+ SmallVector<PHINode*, 8 > DuplicatePhis;
97
+
98
+ // Check for identical PHI nodes in the same basic block
99
+ for (auto I = std::next (BasicBlock::iterator (PN)), E = BB->end (); I != E; ++I) {
100
+ if (!isa<PHINode>(&*I))
101
+ break ;
102
+
103
+ PHINode *OtherPN = cast<PHINode>(&*I);
104
+
105
+ // Skip comparing with itself
106
+ if (PN == OtherPN)
107
+ continue ;
108
+
109
+ // Skip PHI nodes that are already marked for deletion
110
+ if (std::find (PhiNodesToErase.begin (), PhiNodesToErase.end (), OtherPN) != PhiNodesToErase.end ())
111
+ continue ;
112
+
113
+ // Check if PHI nodes are identical
114
+ if (PN->getNumIncomingValues () == OtherPN->getNumIncomingValues ()) {
115
+ bool AreIdentical = true ;
116
+
117
+ for (unsigned K = 0 ; K < PN->getNumIncomingValues (); ++K) {
118
+ if (PN->getIncomingValue (K) != OtherPN->getIncomingValue (K) ||
119
+ PN->getIncomingBlock (K) != OtherPN->getIncomingBlock (K)) {
120
+ AreIdentical = false ;
121
+ break ;
122
+ }
123
+ }
124
+
125
+ if (AreIdentical)
126
+ DuplicatePhis.push_back (OtherPN);
127
+ }
128
+ }
129
+
130
+ return DuplicatePhis;
131
+ }
132
+
133
+ void MergeScalarPhisPass::cleanUpIR (Function *F) {
94
134
for (auto *Phi : PhiNodesToRemove)
95
135
Phi->eraseFromParent ();
96
136
97
137
for (auto *EEI : ExtrElementsToRemove)
98
138
if (EEI->getNumUses () == 0 )
99
139
EEI->eraseFromParent ();
140
+
141
+ SmallVector<PHINode *, 8 > PhiNodesToErase;
142
+ for (auto &BB : *F) {
143
+ for (auto I = BB.begin (), E = BB.end (); I != E;) {
144
+ if (!isa<PHINode>(I))
145
+ break ;
146
+
147
+ PHINode *PN = cast<PHINode>(&*I++);
148
+ SmallVector<PHINode*, 8 > Duplicates = getDuplicates (PN, PhiNodesToErase);
149
+
150
+ for (PHINode *D : Duplicates) {
151
+ D->replaceAllUsesWith (PN);
152
+ PhiNodesToErase.push_back (D);
153
+ }
154
+ }
155
+ }
156
+
157
+ for (auto *PN : PhiNodesToErase)
158
+ PN->eraseFromParent ();
100
159
}
101
160
102
- bool MergeScalarPhisPass::makeChanges () {
161
+ bool MergeScalarPhisPass::makeChanges (Function *F ) {
103
162
bool Changed = VectorToPhiNodesMap.size () > 0 ;
104
163
105
164
for (const auto &Entry : VectorToPhiNodesMap) {
@@ -115,52 +174,79 @@ bool MergeScalarPhisPass::makeChanges() {
115
174
116
175
for (unsigned i = 0 ; i < NumIncValues; ++i) {
117
176
Value *Incoming = FirstPN->getIncomingValue (i);
177
+ if (isa<Constant>(Incoming) && cast<Constant>(Incoming)->isZeroValue ()) {
178
+ NewPhi->addIncoming (ConstantAggregateZero::get (VectorType), FirstPN->getIncomingBlock (i));
179
+ continue ;
180
+ }
118
181
auto *EEI = cast<ExtractElementInst>(Incoming);
119
182
NewPhi->addIncoming (EEI->getVectorOperand (), FirstPN->getIncomingBlock (i));
120
183
}
121
184
122
185
BasicBlock *BB = FirstPN->getParent ();
123
186
Builder.SetInsertPoint (BB->getFirstNonPHI ());
124
187
for (auto *PN : Entry.second ) {
125
- auto *EEI = cast<ExtractElementInst>(PN->getIncomingValue (0 ));
188
+ // Find an incoming ExtractElementInst to get the index value.
189
+ ExtractElementInst *EEI = nullptr ;
190
+ for (unsigned i = 0 ; i < PN->getNumIncomingValues (); ++i) {
191
+ EEI = dyn_cast<ExtractElementInst>(PN->getIncomingValue (i));
192
+ if (EEI)
193
+ break ;
194
+ }
195
+
126
196
auto *CI = cast<ConstantInt>(EEI->getIndexOperand ());
127
197
auto *NewEEI =
128
198
cast<ExtractElementInst>(Builder.CreateExtractElement (NewPhi, CI->getZExtValue (), " extract_merged" ));
199
+
129
200
if (EEI->getDebugLoc ())
130
201
NewEEI->setDebugLoc (EEI->getDebugLoc ());
131
202
132
203
PN->replaceAllUsesWith (NewEEI);
133
204
}
134
205
}
135
206
136
- cleanUpIR ();
207
+ cleanUpIR (F );
137
208
138
209
return Changed;
139
210
}
140
211
212
+ bool MergeScalarPhisPass::isIncomingValueZero (PHINode *pPN, unsigned IncomingIndex) {
213
+ Value *pIncVal = pPN->getIncomingValue (IncomingIndex);
214
+ return isa<Constant>(pIncVal) && cast<Constant>(pIncVal)->isZeroValue ();
215
+ }
216
+
217
+ Value *MergeScalarPhisPass::getVectorOperandForPhiNode (PHINode *PN, unsigned IncomingIndex) {
218
+ Value *IncVal = PN->getIncomingValue (IncomingIndex);
219
+ auto *EEI = dyn_cast<ExtractElementInst>(IncVal);
220
+ return EEI ? EEI->getVectorOperand () : nullptr ;
221
+ }
222
+
223
+ ExtractElementInst *MergeScalarPhisPass::getEEIFromPhi (PHINode *PN, unsigned IncomingIndex) {
224
+ Value *IncVal = PN->getIncomingValue (IncomingIndex);
225
+ return dyn_cast<ExtractElementInst>(IncVal);
226
+ }
227
+
141
228
// Collect all PHI nodes to VectorToPhiNodesMap. Use incoming vector value
142
229
// from incoming index 0 as a key for the group of PHI nodes. After collecting
143
230
// all PHI nodes, filter out phi nodes using conditions for incoming values with
144
231
// indices not equal to 0.
145
232
//
146
233
// Conditions for PHI nodes to be merged into a single vector PHI node:
147
- // Condition 1: All PHI node incoming values should be ExtractElementInsts (EEIs).
234
+ // Condition 1: At least one PHI node incoming value should be ExtractElementInsts (EEIs), other can be zeros .
148
235
// Condition 2: EEIs vector operands should have FixedVectorType.
149
236
// Condition 3: EEIs should have the same vector type.
150
237
// Condition 4: EEIs should have the same constant index value.
151
- // Condition 5: All incoming EEIs should be used only once.
152
- // Condition 6: Number of PHI nodes in a group should be equal to the vector
153
- // size.
238
+ // Condition 6: Number of PHI nodes in a group should be equal to the vector size.
154
239
void MergeScalarPhisPass::collectPhiNodes (Function &F) {
155
- auto getVectorOperandForPhiNode = [](PHINode *PN, unsigned IncomingIndex) -> Value * {
156
- Value *IncVal = PN->getIncomingValue (IncomingIndex);
157
- auto *EEI = cast<ExtractElementInst>(IncVal);
158
- return EEI->getVectorOperand ();
159
- };
160
-
161
- auto getEEIFromPhi = [](PHINode *PN, unsigned IncomingIndex) -> ExtractElementInst * {
162
- Value *IncVal = PN->getIncomingValue (IncomingIndex);
163
- return dyn_cast<ExtractElementInst>(IncVal);
240
+ auto getFirstVectorIncomingValForPhiNode = [](PHINode *PN) -> Value * {
241
+ for (unsigned i = 0 ; i < PN->getNumIncomingValues (); ++i) {
242
+ Value *IncVal = PN->getIncomingValue (i);
243
+ if (isa<Constant>(IncVal) && cast<Constant>(IncVal)->isZeroValue ())
244
+ continue ;
245
+
246
+ auto *EEI = cast<ExtractElementInst>(IncVal);
247
+ return EEI->getVectorOperand ();
248
+ }
249
+ return nullptr ;
164
250
};
165
251
166
252
clearContainers ();
@@ -176,12 +262,10 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
176
262
bool Check = true ;
177
263
Type *SavedType = nullptr ;
178
264
int IndexValue = 0 ;
179
-
180
- // Skip PHI nodes with less than 2 incoming values.
181
- // if (PN->getNumIncomingValues() < 2)
182
- // continue;
183
-
184
265
for (unsigned i = 0 ; i < PN->getNumIncomingValues (); ++i) {
266
+ if (isIncomingValueZero (PN, i))
267
+ continue ;
268
+
185
269
// Check Condition 1
186
270
ExtractElementInst *EEI = getEEIFromPhi (PN, i);
187
271
if (!EEI) {
@@ -197,7 +281,7 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
197
281
198
282
auto *CurrType = EEI->getVectorOperand ()->getType ();
199
283
int CurrIndexValue = CI->getZExtValue ();
200
- if (i == 0 ) {
284
+ if (!SavedType ) {
201
285
SavedType = CurrType;
202
286
IndexValue = CurrIndexValue;
203
287
}
@@ -218,41 +302,55 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
218
302
Check = false ;
219
303
break ;
220
304
}
221
-
222
- // Check Condition 5
223
- if (!EEI->getSingleUndroppableUse ()) {
224
- Check = false ;
225
- break ;
226
- }
227
305
}
228
306
229
307
if (Check) {
230
308
// Using vector operand of the first incoming EEI to define a key for
231
309
// the group of PHI nodes.
232
- Value *FirstEEIVectorOp = getVectorOperandForPhiNode (PN, 0 );
233
- auto CurNumIncValues = PN->getNumIncomingValues ();
310
+ Value *FirstVectorOp = getFirstVectorIncomingValForPhiNode (PN);
311
+ unsigned CurNumIncValues = PN->getNumIncomingValues ();
234
312
235
313
// All PHI nodes corresponding to the same vector value should be in one
236
314
// basic block.
237
- if (VectorToPhiNodesMap.find (FirstEEIVectorOp ) != VectorToPhiNodesMap.end ()) {
238
- if (VectorToPhiNodesMap[FirstEEIVectorOp ][0 ]->getParent () != PN->getParent ()) {
315
+ if (VectorToPhiNodesMap.find (FirstVectorOp ) != VectorToPhiNodesMap.end ()) {
316
+ if (VectorToPhiNodesMap[FirstVectorOp ][0 ]->getParent () != PN->getParent ())
239
317
continue ;
240
- }
241
318
242
319
// All PN in the group should have the same number of incoming values.
243
- if (VectorToPhiNodesMap[FirstEEIVectorOp ][0 ]->getNumIncomingValues () != CurNumIncValues) {
320
+ if (VectorToPhiNodesMap[FirstVectorOp ][0 ]->getNumIncomingValues () != CurNumIncValues)
244
321
continue ;
245
- }
246
322
}
247
-
248
- VectorToPhiNodesMap[FirstEEIVectorOp].push_back (PN);
323
+ VectorToPhiNodesMap[FirstVectorOp].push_back (PN);
249
324
}
250
325
}
251
326
}
252
327
328
+ filterOutUnexpectedIncomingConstants ();
329
+
253
330
// Filter out PHI nodes that do not meet the conditions 6.
254
331
// Filter out some suspicious cases (e.g. when the EEIs for a particular PHI
255
332
// node group and a particular index are not in the same base block).
333
+ filterOutUnvectorizedPhis ();
334
+
335
+
336
+ for (auto &Entry : VectorToPhiNodesMap) {
337
+ for (auto *PN : Entry.second ) {
338
+ PhiNodesToRemove.insert (PN);
339
+ for (unsigned I = 0 ; I < PN->getNumIncomingValues (); ++I) {
340
+ ExtractElementInst *EEI = getEEIFromPhi (PN, I);
341
+ // Skip zeros incoming values.
342
+ if (EEI && EEI->hasOneUser ()) {
343
+ ExtrElementsToRemove.insert (EEI);
344
+ }
345
+ }
346
+ }
347
+ }
348
+ }
349
+
350
+ // Check that if at least one phi node incoming value is zero for a
351
+ // specific index, then all other incoming values for that index should
352
+ // be zeros as well.
353
+ void MergeScalarPhisPass::filterOutUnexpectedIncomingConstants () {
256
354
for (auto It = VectorToPhiNodesMap.begin (); It != VectorToPhiNodesMap.end ();) {
257
355
auto &PhiNodes = It->second ;
258
356
auto *FirstPhiNode = PhiNodes[0 ];
@@ -263,8 +361,47 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
263
361
continue ;
264
362
}
265
363
266
- // Check Condition 6
267
- Type *VType = getVectorOperandForPhiNode (FirstPhiNode, 0 )->getType ();
364
+ bool NeedToBreak = false ;
365
+ size_t NumIncomingValues = FirstPhiNode->getNumIncomingValues ();
366
+ for (unsigned Index = 0 ; Index < NumIncomingValues; ++Index) {
367
+ bool ExpectZeroVal = false ;
368
+ for (size_t P = 0 ; P < PhiNodes.size (); ++P) {
369
+ if (isIncomingValueZero (PhiNodes[P], Index)) {
370
+ if (!P) {
371
+ ExpectZeroVal = true ;
372
+ continue ;
373
+ }
374
+
375
+ if (!ExpectZeroVal) {
376
+ It = VectorToPhiNodesMap.erase (It);
377
+ NeedToBreak = true ;
378
+ break ;
379
+ }
380
+ } else {
381
+ if (ExpectZeroVal) {
382
+ It = VectorToPhiNodesMap.erase (It);
383
+ NeedToBreak = true ;
384
+ break ;
385
+ }
386
+ }
387
+ }
388
+
389
+ if (NeedToBreak)
390
+ break ;
391
+ }
392
+
393
+ if (!NeedToBreak)
394
+ ++It;
395
+ }
396
+ }
397
+
398
+ void MergeScalarPhisPass::filterOutUnvectorizedPhis () {
399
+ for (auto It = VectorToPhiNodesMap.begin (); It != VectorToPhiNodesMap.end ();) {
400
+ Type *VType = It->first ->getType ();
401
+ auto &PhiNodes = It->second ;
402
+ auto *FirstPhiNode = PhiNodes[0 ];
403
+
404
+ // Check Condition 6: Number of PHI nodes in a group should be equal to the vector size.
268
405
size_t NumElements = cast<VectorType>(VType)->getElementCount ().getFixedValue ();
269
406
if (NumElements != PhiNodes.size ()) {
270
407
It = VectorToPhiNodesMap.erase (It);
@@ -273,14 +410,25 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
273
410
274
411
bool NeedToBreak = false ;
275
412
size_t NumIncomingValues = FirstPhiNode->getNumIncomingValues ();
276
-
277
413
for (unsigned Index = 0 ; Index < NumIncomingValues; ++Index) {
278
414
Value *VOp = getVectorOperandForPhiNode (FirstPhiNode, Index);
279
- BasicBlock *EEIBB = getEEIFromPhi (FirstPhiNode, Index)->getParent ();
415
+ ExtractElementInst *EEI = getEEIFromPhi (FirstPhiNode, Index);
416
+
417
+ if (!VOp || !EEI)
418
+ continue ;
280
419
420
+ BasicBlock *EEIBB = EEI->getParent ();
281
421
for (size_t P = 0 ; P < PhiNodes.size (); ++P) {
282
422
Value *CurrVOp = getVectorOperandForPhiNode (PhiNodes[P], Index);
283
- BasicBlock *CurrEEIBB = getEEIFromPhi (PhiNodes[P], Index)->getParent ();
423
+ ExtractElementInst *CurrEEI = getEEIFromPhi (PhiNodes[P], Index);
424
+
425
+ if (!CurrVOp || !CurrEEI) {
426
+ It = VectorToPhiNodesMap.erase (It);
427
+ NeedToBreak = true ;
428
+ break ;
429
+ }
430
+
431
+ BasicBlock *CurrEEIBB = CurrEEI->getParent ();
284
432
285
433
// Check that all incoming values for a specific index in the PHI nodes
286
434
// group were extracted from the same vector.
@@ -297,20 +445,13 @@ void MergeScalarPhisPass::collectPhiNodes(Function &F) {
297
445
break ;
298
446
}
299
447
}
300
- }
301
448
302
- if (! NeedToBreak) {
303
- ++It ;
449
+ if (NeedToBreak)
450
+ break ;
304
451
}
305
- }
306
452
307
- for (auto &Entry : VectorToPhiNodesMap) {
308
- for (auto *PN : Entry.second ) {
309
- PhiNodesToRemove.insert (PN);
310
- for (unsigned I = 0 ; I < PN->getNumIncomingValues (); ++I) {
311
- ExtrElementsToRemove.insert (getEEIFromPhi (PN, I));
312
- }
313
- }
453
+ if (!NeedToBreak)
454
+ ++It;
314
455
}
315
456
}
316
457
@@ -329,7 +470,7 @@ bool MergeScalarPhisPass::runOnFunction(Function &F) {
329
470
330
471
// Optimize the function until optimization patterns can be found.
331
472
while (VectorToPhiNodesMap.size ()) {
332
- Changed |= makeChanges ();
473
+ Changed |= makeChanges (&F );
333
474
collectPhiNodes (F);
334
475
}
335
476
0 commit comments