Skip to content

[VPlan] Emit VPVectorEndPointerRecipe for reverse interleave pointer adjustment #144864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7764,8 +7764,9 @@ VPRecipeBuilder::tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands,
(CM.foldTailByMasking() || !GEP || !GEP->isInBounds())
? GEPNoWrapFlags::none()
: GEPNoWrapFlags::inBounds();
VectorPtr = new VPVectorEndPointerRecipe(
Ptr, &Plan.getVF(), getLoadStoreType(I), Flags, I->getDebugLoc());
VectorPtr =
new VPVectorEndPointerRecipe(Ptr, &Plan.getVF(), getLoadStoreType(I),
/*Stride*/ -1, Flags, I->getDebugLoc());
} else {
VectorPtr = new VPVectorPointerRecipe(Ptr, getLoadStoreType(I),
GEP ? GEP->getNoWrapFlags()
Expand Down
14 changes: 10 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1706,17 +1706,22 @@ class VPWidenGEPRecipe : public VPRecipeWithIRFlags {

/// A recipe to compute a pointer to the last element of each part of a widened
/// memory access for widened memory accesses of IndexedTy. Used for
/// VPWidenMemoryRecipes that are reversed.
/// VPWidenMemoryRecipes or VPInterleaveRecipes that are reversed.
class VPVectorEndPointerRecipe : public VPRecipeWithIRFlags,
public VPUnrollPartAccessor<2> {
Type *IndexedTy;

/// The constant stride of the pointer computed by this recipe.
int64_t Stride;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


public:
VPVectorEndPointerRecipe(VPValue *Ptr, VPValue *VF, Type *IndexedTy,
GEPNoWrapFlags GEPFlags, DebugLoc DL)
int64_t Stride, GEPNoWrapFlags GEPFlags, DebugLoc DL)
: VPRecipeWithIRFlags(VPDef::VPVectorEndPointerSC,
ArrayRef<VPValue *>({Ptr, VF}), GEPFlags, DL),
IndexedTy(IndexedTy) {}
IndexedTy(IndexedTy), Stride(Stride) {
assert(Stride != 0 && "Stride cannot be zero");
}

VP_CLASSOF_IMPL(VPDef::VPVectorEndPointerSC)

Expand Down Expand Up @@ -1748,7 +1753,8 @@ class VPVectorEndPointerRecipe : public VPRecipeWithIRFlags,

VPVectorEndPointerRecipe *clone() override {
return new VPVectorEndPointerRecipe(getOperand(0), getVFValue(), IndexedTy,
getGEPNoWrapFlags(), getDebugLoc());
Stride, getGEPNoWrapFlags(),
getDebugLoc());
}

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
Expand Down
38 changes: 11 additions & 27 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2326,31 +2326,34 @@ void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent,
}
#endif

static Type *getGEPIndexTy(bool IsScalable, bool IsReverse,
static Type *getGEPIndexTy(bool IsScalable, bool IsReverse, bool IsUnitStride,
unsigned CurrentPart, IRBuilderBase &Builder) {
// Use i32 for the gep index type when the value is constant,
// or query DataLayout for a more suitable index type otherwise.
const DataLayout &DL = Builder.GetInsertBlock()->getDataLayout();
return IsScalable && (IsReverse || CurrentPart > 0)
return !IsUnitStride || (IsScalable && (IsReverse || CurrentPart > 0))
? DL.getIndexType(Builder.getPtrTy(0))
: Builder.getInt32Ty();
}

void VPVectorEndPointerRecipe::execute(VPTransformState &State) {
auto &Builder = State.Builder;
unsigned CurrentPart = getUnrollPart(*this);
bool IsUnitStride = Stride == 1 || Stride == -1;
Type *IndexTy = getGEPIndexTy(State.VF.isScalable(), /*IsReverse*/ true,
CurrentPart, Builder);
IsUnitStride, CurrentPart, Builder);

// The wide store needs to start at the last vector element.
Value *RunTimeVF = State.get(getVFValue(), VPLane(0));
if (IndexTy != RunTimeVF->getType())
RunTimeVF = Builder.CreateZExtOrTrunc(RunTimeVF, IndexTy);
// NumElt = -CurrentPart * RunTimeVF
// NumElt = Stride * CurrentPart * RunTimeVF
Value *NumElt = Builder.CreateMul(
ConstantInt::get(IndexTy, -(int64_t)CurrentPart), RunTimeVF);
// LastLane = 1 - RunTimeVF
Value *LastLane = Builder.CreateSub(ConstantInt::get(IndexTy, 1), RunTimeVF);
ConstantInt::get(IndexTy, Stride * (int64_t)CurrentPart), RunTimeVF);
// LastLane = Stride * (RunTimeVF - 1)
Value *LastLane = Builder.CreateSub(RunTimeVF, ConstantInt::get(IndexTy, 1));
if (Stride != 1)
LastLane = Builder.CreateMul(ConstantInt::get(IndexTy, Stride), LastLane);
Value *Ptr = State.get(getOperand(0), VPLane(0));
Value *ResultPtr =
Builder.CreateGEP(IndexedTy, Ptr, NumElt, "", getGEPNoWrapFlags());
Expand All @@ -2375,7 +2378,7 @@ void VPVectorPointerRecipe::execute(VPTransformState &State) {
auto &Builder = State.Builder;
unsigned CurrentPart = getUnrollPart(*this);
Type *IndexTy = getGEPIndexTy(State.VF.isScalable(), /*IsReverse*/ false,
CurrentPart, Builder);
/*IsUnitStride*/ true, CurrentPart, Builder);
Value *Ptr = State.get(getOperand(0), VPLane(0));

Value *Increment = createStepForVF(Builder, IndexTy, State.VF, CurrentPart);
Expand Down Expand Up @@ -3341,25 +3344,6 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
if (auto *I = dyn_cast<Instruction>(ResAddr))
State.setDebugLocFrom(I->getDebugLoc());

// If the group is reverse, adjust the index to refer to the last vector lane
// instead of the first. We adjust the index from the first vector lane,
// rather than directly getting the pointer for lane VF - 1, because the
// pointer operand of the interleaved access is supposed to be uniform.
if (Group->isReverse()) {
Value *RuntimeVF =
getRuntimeVF(State.Builder, State.Builder.getInt32Ty(), State.VF);
Value *Index =
State.Builder.CreateSub(RuntimeVF, State.Builder.getInt32(1));
Index = State.Builder.CreateMul(Index,
State.Builder.getInt32(Group->getFactor()));
Index = State.Builder.CreateNeg(Index);

bool InBounds = false;
if (auto *Gep = dyn_cast<GetElementPtrInst>(ResAddr->stripPointerCasts()))
InBounds = Gep->isInBounds();
ResAddr = State.Builder.CreateGEP(ScalarTy, ResAddr, Index, "", InBounds);
}

State.setDebugLocFrom(getDebugLoc());
Value *PoisonVec = PoisonValue::get(VecTy);

Expand Down
29 changes: 22 additions & 7 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2482,23 +2482,23 @@ void VPlanTransforms::createInterleaveGroups(
auto *InsertPos =
cast<VPWidenMemoryRecipe>(RecipeBuilder.getRecipe(IRInsertPos));

bool InBounds = false;
if (auto *Gep = dyn_cast<GetElementPtrInst>(
getLoadStorePointerOperand(IRInsertPos)->stripPointerCasts()))
InBounds = Gep->isInBounds();

// Get or create the start address for the interleave group.
auto *Start =
cast<VPWidenMemoryRecipe>(RecipeBuilder.getRecipe(IG->getMember(0)));
VPValue *Addr = Start->getAddr();
VPRecipeBase *AddrDef = Addr->getDefiningRecipe();
if (AddrDef && !VPDT.properlyDominates(AddrDef, InsertPos)) {
// TODO: Hoist Addr's defining recipe (and any operands as needed) to
// InsertPos or sink loads above zero members to join it.
bool InBounds = false;
if (auto *Gep = dyn_cast<GetElementPtrInst>(
getLoadStorePointerOperand(IRInsertPos)->stripPointerCasts()))
InBounds = Gep->isInBounds();

// We cannot re-use the address of member zero because it does not
// dominate the insert position. Instead, use the address of the insert
// position and create a PtrAdd adjusting it to the address of member
// zero.
// TODO: Hoist Addr's defining recipe (and any operands as needed) to
// InsertPos or sink loads above zero members to join it.
assert(IG->getIndex(IRInsertPos) != 0 &&
"index of insert position shouldn't be zero");
auto &DL = IRInsertPos->getDataLayout();
Expand All @@ -2512,6 +2512,21 @@ void VPlanTransforms::createInterleaveGroups(
Addr = InBounds ? B.createInBoundsPtrAdd(InsertPos->getAddr(), OffsetVPV)
: B.createPtrAdd(InsertPos->getAddr(), OffsetVPV);
}
// If the group is reverse, adjust the index to refer to the last vector
// lane instead of the first. We adjust the index from the first vector
// lane, rather than directly getting the pointer for lane VF - 1, because
// the pointer operand of the interleaved access is supposed to be uniform.
if (IG->isReverse()) {
auto *GEP = dyn_cast<GetElementPtrInst>(
getLoadStorePointerOperand(IRInsertPos)->stripPointerCasts());
Comment on lines +2520 to +2521
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to share the logic to determine inbounds from the GEP with similar code above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto *ReversePtr = new VPVectorEndPointerRecipe(
Addr, &Plan.getVF(), getLoadStoreType(IRInsertPos),
-(int64_t)IG->getFactor(),
InBounds ? GEPNoWrapFlags::inBounds() : GEPNoWrapFlags::none(),
InsertPos->getDebugLoc());
ReversePtr->insertBefore(InsertPos);
Addr = ReversePtr;
}
auto *VPIG = new VPInterleaveRecipe(IG, Addr, StoredValues,
InsertPos->getMask(), NeedsMaskForGaps, InsertPos->getDebugLoc());
VPIG->insertBefore(InsertPos);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,8 @@ define void @test_reversed_load2_store2(ptr noalias nocapture readonly %A, ptr n
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 4 x i32> [ [[INDUCTION]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = sub i64 1023, [[INDEX]]
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT_ST2:%.*]], ptr [[A:%.*]], i64 [[OFFSET_IDX]], i32 0
; CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.vscale.i32()
; CHECK-NEXT: [[TMP6:%.*]] = shl nuw nsw i32 [[TMP5]], 3
; CHECK-NEXT: [[TMP7:%.*]] = sub nsw i32 2, [[TMP6]]
; CHECK-NEXT: [[TMP8:%.*]] = sext i32 [[TMP7]] to i64
; CHECK-NEXT: [[TMP6:%.*]] = shl nuw nsw i64 [[TMP0]], 3
; CHECK-NEXT: [[TMP8:%.*]] = sub nsw i64 2, [[TMP6]]
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i64 [[TMP8]]
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <vscale x 8 x i32>, ptr [[TMP9]], align 4
; CHECK-NEXT: [[STRIDED_VEC:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> [[WIDE_VEC]])
Expand All @@ -381,10 +379,8 @@ define void @test_reversed_load2_store2(ptr noalias nocapture readonly %A, ptr n
; CHECK-NEXT: [[TMP12:%.*]] = add nsw <vscale x 4 x i32> [[REVERSE]], [[VEC_IND]]
; CHECK-NEXT: [[TMP13:%.*]] = sub nsw <vscale x 4 x i32> [[REVERSE1]], [[VEC_IND]]
; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [[STRUCT_ST2]], ptr [[B:%.*]], i64 [[OFFSET_IDX]], i32 0
; CHECK-NEXT: [[TMP15:%.*]] = call i32 @llvm.vscale.i32()
; CHECK-NEXT: [[TMP16:%.*]] = shl nuw nsw i32 [[TMP15]], 3
; CHECK-NEXT: [[TMP17:%.*]] = sub nsw i32 2, [[TMP16]]
; CHECK-NEXT: [[TMP18:%.*]] = sext i32 [[TMP17]] to i64
; CHECK-NEXT: [[TMP15:%.*]] = shl nuw nsw i64 [[TMP0]], 3
; CHECK-NEXT: [[TMP18:%.*]] = sub nsw i64 2, [[TMP15]]
Comment on lines +382 to +383
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like previously the index types for interleave groups were i32 but now it's determined by datalayout because of the changes in getGEPIndexTy. I don't have much of an opinion on this, was there a reason why it was needed for this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original implementation uses i64 as the index type in the GEP.
I think one reason getGEPIndexTy can't simply choose i32 is due to scalable VF in this case;
another reason is that we compute NumElt = Stride * CurrentPart * RunTimeVF, and if the stride is not unit stride, then we have to use the type determined by datalayout.

; CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds i32, ptr [[TMP14]], i64 [[TMP18]]
; CHECK-NEXT: [[REVERSE2:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP12]])
; CHECK-NEXT: [[REVERSE3:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP13]])
Expand Down Expand Up @@ -1577,10 +1573,8 @@ define void @interleave_deinterleave_reverse(ptr noalias nocapture readonly %A,
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 4 x i32> [ [[INDUCTION]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = sub i64 1023, [[INDEX]]
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[STRUCT_XYZT:%.*]], ptr [[A:%.*]], i64 [[OFFSET_IDX]], i32 0
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vscale.i32()
; CHECK-NEXT: [[TMP7:%.*]] = shl nuw nsw i32 [[TMP6]], 4
; CHECK-NEXT: [[TMP8:%.*]] = sub nsw i32 4, [[TMP7]]
; CHECK-NEXT: [[TMP9:%.*]] = sext i32 [[TMP8]] to i64
; CHECK-NEXT: [[TMP6:%.*]] = shl nuw nsw i64 [[TMP0]], 4
; CHECK-NEXT: [[TMP9:%.*]] = sub nsw i64 4, [[TMP6]]
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[TMP5]], i64 [[TMP9]]
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <vscale x 16 x i32>, ptr [[TMP10]], align 4
; CHECK-NEXT: [[STRIDED_VEC:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave4.nxv16i32(<vscale x 16 x i32> [[WIDE_VEC]])
Expand All @@ -1597,10 +1591,8 @@ define void @interleave_deinterleave_reverse(ptr noalias nocapture readonly %A,
; CHECK-NEXT: [[TMP19:%.*]] = mul nsw <vscale x 4 x i32> [[REVERSE4]], [[VEC_IND]]
; CHECK-NEXT: [[TMP20:%.*]] = shl nuw nsw <vscale x 4 x i32> [[REVERSE5]], [[VEC_IND]]
; CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [[STRUCT_XYZT]], ptr [[B:%.*]], i64 [[OFFSET_IDX]], i32 0
; CHECK-NEXT: [[TMP22:%.*]] = call i32 @llvm.vscale.i32()
; CHECK-NEXT: [[TMP23:%.*]] = shl nuw nsw i32 [[TMP22]], 4
; CHECK-NEXT: [[TMP24:%.*]] = sub nsw i32 4, [[TMP23]]
; CHECK-NEXT: [[TMP25:%.*]] = sext i32 [[TMP24]] to i64
; CHECK-NEXT: [[TMP22:%.*]] = shl nuw nsw i64 [[TMP0]], 4
; CHECK-NEXT: [[TMP25:%.*]] = sub nsw i64 4, [[TMP22]]
; CHECK-NEXT: [[TMP26:%.*]] = getelementptr inbounds i32, ptr [[TMP21]], i64 [[TMP25]]
; CHECK-NEXT: [[REVERSE6:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP17]])
; CHECK-NEXT: [[REVERSE7:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP18]])
Expand Down
Loading
Loading