diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp index b2fced47b9527..55ee33a178532 100644 --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -123,6 +123,7 @@ class VectorCombine { bool foldBinopOfReductions(Instruction &I); bool foldSingleElementStore(Instruction &I); bool scalarizeLoadExtract(Instruction &I); + bool scalarizeExtExtract(Instruction &I); bool foldConcatOfBoolMasks(Instruction &I); bool foldPermuteOfBinops(Instruction &I); bool foldShuffleOfBinops(Instruction &I); @@ -1774,6 +1775,73 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) { return true; } +bool VectorCombine::scalarizeExtExtract(Instruction &I) { + auto *Ext = dyn_cast(&I); + if (!Ext) + return false; + + // Try to convert a vector zext feeding only extracts to a set of scalar + // (Src << ExtIdx *Size) & (Size -1) + // if profitable . + auto *SrcTy = dyn_cast(Ext->getOperand(0)->getType()); + if (!SrcTy) + return false; + auto *DstTy = cast(Ext->getType()); + + Type *ScalarDstTy = DstTy->getElementType(); + if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy)) + return false; + + InstructionCost VectorCost = + TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy, + TTI::CastContextHint::None, CostKind, Ext); + unsigned ExtCnt = 0; + bool ExtLane0 = false; + for (User *U : Ext->users()) { + const APInt *Idx; + if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx)))) + return false; + if (cast(U)->use_empty()) + continue; + ExtCnt += 1; + ExtLane0 |= Idx->isZero(); + VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy, + CostKind, Idx->getZExtValue(), U); + } + + InstructionCost ScalarCost = + ExtCnt * TTI.getArithmeticInstrCost( + Instruction::And, ScalarDstTy, CostKind, + {TTI::OK_AnyValue, TTI::OP_None}, + {TTI::OK_NonUniformConstantValue, TTI::OP_None}) + + (ExtCnt - ExtLane0) * + TTI.getArithmeticInstrCost( + Instruction::LShr, ScalarDstTy, CostKind, + {TTI::OK_AnyValue, TTI::OP_None}, + {TTI::OK_NonUniformConstantValue, TTI::OP_None}); + if (ScalarCost > VectorCost) + return false; + + Value *ScalarV = Ext->getOperand(0); + if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast(ScalarV), + &DT)) + ScalarV = Builder.CreateFreeze(ScalarV); + ScalarV = Builder.CreateBitCast( + ScalarV, + IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy))); + uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType()); + uint64_t EltBitMask = (1ull << SrcEltSizeInBits) - 1; + for (User *U : Ext->users()) { + auto *Extract = cast(U); + uint64_t Idx = + cast(Extract->getIndexOperand())->getZExtValue(); + Value *LShr = Builder.CreateLShr(ScalarV, Idx * SrcEltSizeInBits); + Value *And = Builder.CreateAnd(LShr, EltBitMask); + U->replaceAllUsesWith(And); + } + return true; +} + /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))" /// to "(bitcast (concat X, Y))" /// where X/Y are bitcasted from i1 mask vectors. @@ -3662,6 +3730,7 @@ bool VectorCombine::run() { if (IsVectorType) { MadeChange |= scalarizeOpOrCmp(I); MadeChange |= scalarizeLoadExtract(I); + MadeChange |= scalarizeExtExtract(I); MadeChange |= scalarizeVPIntrinsic(I); MadeChange |= foldInterleaveIntrinsics(I); } diff --git a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll index 09c03991ad7c3..60700412686ea 100644 --- a/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll +++ b/llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll @@ -9,15 +9,23 @@ define void @zext_v4i8_all_lanes_used(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 +; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8 +; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255 +; CHECK-NEXT: [[TMP9:%.*]] = and i32 [[TMP1]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP9]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -68,13 +76,20 @@ define void @zext_v4i8_3_lanes_used_1(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_1( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 +; CHECK-NEXT: [[TMP6:%.*]] = lshr i32 [[TMP1]], 8 +; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP6]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -93,13 +108,19 @@ define void @zext_v4i8_3_lanes_used_2(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_3_lanes_used_2( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 +; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP1]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP7]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -118,11 +139,17 @@ define void @zext_v4i8_2_lanes_used_1(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_1( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP1]], 8 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: @@ -139,11 +166,16 @@ define void @zext_v4i8_2_lanes_used_2(<4 x i8> %src) { ; CHECK-LABEL: define void @zext_v4i8_2_lanes_used_2( ; CHECK-SAME: <4 x i8> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255 +; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP1]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP5]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP3]]) ; CHECK-NEXT: ret void ; entry: @@ -160,15 +192,22 @@ define void @zext_v4i8_all_lanes_used_noundef(<4 x i8> noundef %src) { ; CHECK-LABEL: define void @zext_v4i8_all_lanes_used_noundef( ; CHECK-SAME: <4 x i8> noundef [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 24 +; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP0]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 255 +; CHECK-NEXT: [[TMP5:%.*]] = lshr i32 [[TMP0]], 8 +; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP5]], 255 +; CHECK-NEXT: [[TMP8:%.*]] = and i32 [[TMP0]], 255 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT9]], i64 2 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) -; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP8]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP6]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP4]]) +; CHECK-NEXT: call void @use.i32(i32 [[TMP1]]) ; CHECK-NEXT: ret void ; entry: @@ -221,15 +260,23 @@ define void @zext_v4i16_all_lanes_used(<4 x i16> %src) { ; CHECK-LABEL: define void @zext_v4i16_all_lanes_used( ; CHECK-SAME: <4 x i16> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i16> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 48 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i64 [[TMP1]], 32 +; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP4]], 65535 +; CHECK-NEXT: [[TMP6:%.*]] = lshr i64 [[TMP1]], 16 +; CHECK-NEXT: [[TMP7:%.*]] = and i64 [[TMP6]], 65535 +; CHECK-NEXT: [[TMP9:%.*]] = and i64 [[TMP1]], 65535 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <4 x i16> [[SRC]] to <4 x i64> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i64> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i64> [[EXT9]], i64 1 ; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i64> [[EXT9]], i64 2 ; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i64> [[EXT9]], i64 3 -; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]]) -; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]]) -; CHECK-NEXT: call void @use.i64(i64 [[EXT_2]]) -; CHECK-NEXT: call void @use.i64(i64 [[EXT_3]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP9]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP7]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP5]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -250,11 +297,15 @@ define void @zext_v2i32_all_lanes_used(<2 x i32> %src) { ; CHECK-LABEL: define void @zext_v2i32_all_lanes_used( ; CHECK-SAME: <2 x i32> [[SRC:%.*]]) { ; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP0:%.*]] = freeze <2 x i32> [[SRC]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i32> [[TMP0]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i64 [[TMP1]], 32 +; CHECK-NEXT: [[TMP5:%.*]] = and i64 [[TMP1]], 4294967295 ; CHECK-NEXT: [[EXT9:%.*]] = zext nneg <2 x i32> [[SRC]] to <2 x i64> ; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <2 x i64> [[EXT9]], i64 0 ; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <2 x i64> [[EXT9]], i64 1 -; CHECK-NEXT: call void @use.i64(i64 [[EXT_0]]) -; CHECK-NEXT: call void @use.i64(i64 [[EXT_1]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP5]]) +; CHECK-NEXT: call void @use.i64(i64 [[TMP2]]) ; CHECK-NEXT: ret void ; entry: @@ -266,3 +317,32 @@ entry: call void @use.i64(i64 %ext.1) ret void } + +define void @zext_nxv4i8_all_lanes_used( %src) { +; CHECK-LABEL: define void @zext_nxv4i8_all_lanes_used( +; CHECK-SAME: [[SRC:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[EXT9:%.*]] = zext nneg [[SRC]] to +; CHECK-NEXT: [[EXT_0:%.*]] = extractelement [[EXT9]], i64 0 +; CHECK-NEXT: [[EXT_1:%.*]] = extractelement [[EXT9]], i64 1 +; CHECK-NEXT: [[EXT_2:%.*]] = extractelement [[EXT9]], i64 2 +; CHECK-NEXT: [[EXT_3:%.*]] = extractelement [[EXT9]], i64 3 +; CHECK-NEXT: call void @use.i32(i32 [[EXT_0]]) +; CHECK-NEXT: call void @use.i32(i32 [[EXT_1]]) +; CHECK-NEXT: call void @use.i32(i32 [[EXT_2]]) +; CHECK-NEXT: call void @use.i32(i32 [[EXT_3]]) +; CHECK-NEXT: ret void +; +entry: + %ext9 = zext nneg %src to + %ext.0 = extractelement %ext9, i64 0 + %ext.1 = extractelement %ext9, i64 1 + %ext.2 = extractelement %ext9, i64 2 + %ext.3 = extractelement %ext9, i64 3 + + call void @use.i32(i32 %ext.0) + call void @use.i32(i32 %ext.1) + call void @use.i32(i32 %ext.2) + call void @use.i32(i32 %ext.3) + ret void +}