Skip to content

Commit b184ba5

Browse files
committed
[VectorCombine] New folding pattern for extract/binop/shuffle chains
Resolves #144654 Part of #143088 This adds a new `foldShuffleChainsToReduce` for horizontal reduction of patterns like: ```llvm define i16 @test_reduce_v8i16(<8 x i16> %a0) local_unnamed_addr #0 { %1 = shufflevector <8 x i16> %a0, <8 x i16> poison, <8 x i32> <i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison> %2 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %a0, <8 x i16> %1) %3 = shufflevector <8 x i16> %2, <8 x i16> poison, <8 x i32> <i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison> %4 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %2, <8 x i16> %3) %5 = shufflevector <8 x i16> %4, <8 x i16> poison, <8 x i32> <i32 1, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison> %6 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %4, <8 x i16> %5) %7 = extractelement <8 x i16> %6, i64 0 ret i16 %7 } ``` ...which can be reduced to a llvm.vector.reduce.umin.v8i16(%a0) intrinsic call. Similar transformation for other ops when costs permit to do so.
1 parent d4826cd commit b184ba5

File tree

2 files changed

+187
-0
lines changed

2 files changed

+187
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class VectorCombine {
129129
bool foldShuffleOfIntrinsics(Instruction &I);
130130
bool foldShuffleToIdentity(Instruction &I);
131131
bool foldShuffleFromReductions(Instruction &I);
132+
bool foldShuffleChainsToReduce(Instruction &I);
132133
bool foldCastFromReductions(Instruction &I);
133134
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
134135
bool foldInterleaveIntrinsics(Instruction &I);
@@ -2910,6 +2911,171 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
29102911
return foldSelectShuffle(*Shuffle, true);
29112912
}
29122913

2914+
bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
2915+
auto *EEI = dyn_cast<ExtractElementInst>(&I);
2916+
if (!EEI)
2917+
return false;
2918+
2919+
std::queue<Value *> InstWorklist;
2920+
Value *InitEEV = nullptr;
2921+
Intrinsic::ID CommonOp = 0;
2922+
2923+
bool IsFirstEEInst = true, IsFirstCallInst = true;
2924+
bool ShouldBeCallInst = true;
2925+
2926+
SmallVector<Value *, 3> PrevVecV(3, nullptr);
2927+
int64_t ShuffleMaskHalf = -1, ExpectedShuffleMaskHalf = 1;
2928+
int64_t VecSize = -1;
2929+
2930+
InstWorklist.push(EEI);
2931+
2932+
while (!InstWorklist.empty()) {
2933+
Value *V = InstWorklist.front();
2934+
InstWorklist.pop();
2935+
2936+
auto *CI = dyn_cast<Instruction>(V);
2937+
if (!CI)
2938+
return false;
2939+
2940+
if (auto *EEInst = dyn_cast<ExtractElementInst>(CI)) {
2941+
if (!IsFirstEEInst)
2942+
return false;
2943+
IsFirstEEInst = false;
2944+
2945+
auto *VecOp = EEInst->getVectorOperand();
2946+
if (!VecOp)
2947+
return false;
2948+
2949+
auto *FVT = dyn_cast<FixedVectorType>(VecOp->getType());
2950+
if (!FVT)
2951+
return false;
2952+
2953+
VecSize = FVT->getNumElements();
2954+
if (VecSize < 2 || VecSize % 2 != 0)
2955+
return false;
2956+
2957+
auto *IndexOp = EEInst->getIndexOperand();
2958+
if (!IndexOp)
2959+
return false;
2960+
2961+
auto *ConstIndex = dyn_cast<ConstantInt>(IndexOp);
2962+
if (ConstIndex->getValue() != 0)
2963+
return false;
2964+
2965+
ShuffleMaskHalf = 1;
2966+
PrevVecV[2] = VecOp;
2967+
InitEEV = EEInst;
2968+
InstWorklist.push(PrevVecV[2]);
2969+
} else if (auto *CallI = dyn_cast<CallInst>(CI)) {
2970+
if (IsFirstEEInst || !ShouldBeCallInst || !PrevVecV[2])
2971+
return false;
2972+
2973+
if (!IsFirstCallInst &&
2974+
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
2975+
return false;
2976+
2977+
if (CallI != (IsFirstCallInst ? PrevVecV[2] : PrevVecV[0]))
2978+
return false;
2979+
IsFirstCallInst = false;
2980+
2981+
auto *II = dyn_cast<IntrinsicInst>(CallI);
2982+
if (!II)
2983+
return false;
2984+
2985+
if (!CommonOp)
2986+
CommonOp = II->getIntrinsicID();
2987+
if (II->getIntrinsicID() != CommonOp)
2988+
return false;
2989+
2990+
switch (II->getIntrinsicID()) {
2991+
case Intrinsic::umin:
2992+
case Intrinsic::umax:
2993+
case Intrinsic::smin:
2994+
case Intrinsic::smax: {
2995+
auto *Op0 = CallI->getOperand(0);
2996+
auto *Op1 = CallI->getOperand(1);
2997+
PrevVecV[0] = Op0;
2998+
PrevVecV[1] = Op1;
2999+
break;
3000+
}
3001+
default:
3002+
return false;
3003+
}
3004+
ShouldBeCallInst ^= 1;
3005+
3006+
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
3007+
std::swap(PrevVecV[0], PrevVecV[1]);
3008+
InstWorklist.push(PrevVecV[1]);
3009+
InstWorklist.push(PrevVecV[0]);
3010+
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3011+
if (IsFirstEEInst || ShouldBeCallInst ||
3012+
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
3013+
return false;
3014+
3015+
if (SVInst != PrevVecV[1])
3016+
return false;
3017+
3018+
auto *ShuffleVec = SVInst->getOperand(0);
3019+
if (!ShuffleVec || ShuffleVec != PrevVecV[0])
3020+
return false;
3021+
3022+
SmallVector<int> CurMask;
3023+
SVInst->getShuffleMask(CurMask);
3024+
3025+
if (ShuffleMaskHalf != ExpectedShuffleMaskHalf)
3026+
return false;
3027+
ExpectedShuffleMaskHalf *= 2;
3028+
3029+
for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
3030+
if (Mask < ShuffleMaskHalf && CurMask[Mask] != ShuffleMaskHalf + Mask)
3031+
return false;
3032+
if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
3033+
return false;
3034+
}
3035+
ShuffleMaskHalf *= 2;
3036+
if (ExpectedShuffleMaskHalf == VecSize)
3037+
break;
3038+
ShouldBeCallInst ^= 1;
3039+
} else {
3040+
return false;
3041+
}
3042+
}
3043+
3044+
if (IsFirstEEInst || ShouldBeCallInst)
3045+
return false;
3046+
3047+
assert(VecSize != -1 && ExpectedShuffleMaskHalf == VecSize &&
3048+
"Expected Match for Vector Size and Mask Half");
3049+
3050+
Value *FinalVecV = PrevVecV[0];
3051+
if (!InitEEV || !FinalVecV)
3052+
return false;
3053+
3054+
Intrinsic::ID ReducedOp = 0;
3055+
switch (CommonOp) {
3056+
case Intrinsic::umin:
3057+
ReducedOp = Intrinsic::vector_reduce_umin;
3058+
break;
3059+
case Intrinsic::umax:
3060+
ReducedOp = Intrinsic::vector_reduce_umax;
3061+
break;
3062+
case Intrinsic::smin:
3063+
ReducedOp = Intrinsic::vector_reduce_smin;
3064+
break;
3065+
case Intrinsic::smax:
3066+
ReducedOp = Intrinsic::vector_reduce_smax;
3067+
break;
3068+
default:
3069+
return false;
3070+
}
3071+
3072+
auto *ReducedResult =
3073+
Builder.CreateIntrinsic(ReducedOp, {FinalVecV->getType()}, {FinalVecV});
3074+
replaceValue(*InitEEV, *ReducedResult);
3075+
3076+
return true;
3077+
}
3078+
29133079
/// Determine if its more efficient to fold:
29143080
/// reduce(trunc(x)) -> trunc(reduce(x)).
29153081
/// reduce(sext(x)) -> sext(reduce(x)).
@@ -3621,6 +3787,9 @@ bool VectorCombine::run() {
36213787
MadeChange |= foldShuffleFromReductions(I);
36223788
MadeChange |= foldCastFromReductions(I);
36233789
break;
3790+
case Instruction::ExtractElement:
3791+
MadeChange |= foldShuffleChainsToReduce(I);
3792+
break;
36243793
case Instruction::ICmp:
36253794
case Instruction::FCmp:
36263795
MadeChange |= foldExtractExtract(I);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt < %s -passes=vector-combine -S | FileCheck %s
3+
4+
define i16 @test_reduce_v8i16(<8 x i16> %a0) local_unnamed_addr #0 {
5+
; CHECK-LABEL: define i16 @test_reduce_v8i16(
6+
; CHECK-SAME: <8 x i16> [[A0:%.*]]) local_unnamed_addr {
7+
; CHECK-NEXT: [[TMP1:%.*]] = call i16 @llvm.vector.reduce.umin.v8i16(<8 x i16> [[A0]])
8+
; CHECK-NEXT: ret i16 [[TMP1]]
9+
;
10+
%1 = shufflevector <8 x i16> %a0, <8 x i16> poison, <8 x i32> <i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison>
11+
%2 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %a0, <8 x i16> %1)
12+
%3 = shufflevector <8 x i16> %2, <8 x i16> poison, <8 x i32> <i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
13+
%4 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %2, <8 x i16> %3)
14+
%5 = shufflevector <8 x i16> %4, <8 x i16> poison, <8 x i32> <i32 1, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
15+
%6 = tail call <8 x i16> @llvm.umin.v8i16(<8 x i16> %4, <8 x i16> %5)
16+
%7 = extractelement <8 x i16> %6, i64 0
17+
ret i16 %7
18+
}

0 commit comments

Comments
 (0)