Skip to content

Commit

Permalink
[AIE2P] Fix combine G_SHUFFLE_VECTOR into G_AIE_VSEL
Browse files Browse the repository at this point in the history
  • Loading branch information
katerynamuts committed Jan 29, 2025
1 parent 9b6736e commit 0f3e26a
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 92 deletions.
7 changes: 3 additions & 4 deletions llvm/lib/Target/AIE/AIECombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ def combine_vector_shuffle_broadcast : GICombineRule<
[{ return matchShuffleToBroadcast(*${root}, MRI, ${matchinfo}); }]),
(apply [{ applySplatVector(*${root}, MRI, B, ${matchinfo}); }])>;

def combine_vsel_matchdata: GIDefMatchData<"std::tuple<Register, Register, Register, uint64_t>">;
def combine_vector_shuffle_vsel : GICombineRule<
(defs root:$root, combine_vsel_matchdata:$matchinfo),
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_SHUFFLE_VECTOR): $root,
[{ return matchShuffleToVSel(*${root}, MRI, ${matchinfo}); }]),
(apply [{ applyVSel(*${root}, MRI, B, ${matchinfo}); }])>;
[{ return matchShuffleToVSel(*${root}, MRI, (const AIEBaseInstrInfo &)B.getTII(), ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def combine_shuffle_to_vextbcst : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
Expand Down
58 changes: 27 additions & 31 deletions llvm/lib/Target/AIE/AIECombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegionInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
Expand Down Expand Up @@ -1266,17 +1267,6 @@ void llvm::applyPadVector(MachineInstr &MI, MachineRegisterInfo &MRI,
MI.eraseFromParent();
}

void llvm::applyVSel(
MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
std::tuple<Register, Register, Register, uint64_t> &MatchInfo) {
B.setInstrAndDebugLoc(MI);
const AIEBaseInstrInfo &AIETII = (const AIEBaseInstrInfo &)B.getTII();
auto [DstVecReg, Src1Reg, Src2Reg, Mask] = MatchInfo;
B.buildInstr(AIETII.getGenericVSelOpcode(), {DstVecReg},
{Src1Reg, Src2Reg, Mask});
MI.eraseFromParent();
}

/// Match something like this:
/// %68:_(s32) = G_CONSTANT i32 0
/// %93:_(s32) = G_CONSTANT i32 1
Expand Down Expand Up @@ -1785,9 +1775,9 @@ bool llvm::matchShuffleToBroadcast(MachineInstr &MI, MachineRegisterInfo &MRI,
return true;
}

bool llvm::matchShuffleToVSel(
MachineInstr &MI, MachineRegisterInfo &MRI,
std::tuple<Register, Register, Register, uint64_t> &MatchInfo) {
bool llvm::matchShuffleToVSel(MachineInstr &MI, MachineRegisterInfo &MRI,
const AIEBaseInstrInfo &TII,
BuildFnTy &MatchInfo) {
assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
const Register DstReg = MI.getOperand(0).getReg();
const Register Src1Reg = MI.getOperand(1).getReg();
Expand All @@ -1796,7 +1786,8 @@ bool llvm::matchShuffleToVSel(

const LLT DstTy = MRI.getType(DstReg);
const LLT Src1Ty = MRI.getType(Src1Reg);
if (Src1Ty.getSizeInBits() != 512)
if (Src1Ty.getSizeInBits() != 512 ||
Src1Ty.getElementType() == LLT::scalar(64))
return false;

const unsigned NumDstElems = DstTy.getNumElements();
Expand All @@ -1805,30 +1796,35 @@ bool llvm::matchShuffleToVSel(
return false;

// Check that the shuffle mask can be converted into VSel mask:
// The mask contains only -1
if (std::all_of(Mask.begin(), Mask.end(),
[&](int Value) { return Value == -1; })) {
return false;
}

// 1. The shuffle mask doesn't contain indices that correspond to the same
// index in Src1 and Src2, i.e., for each i only the i-th element from Src1 or
// the i-th element from Src2 is used.
// 2. The mask indices modulo the number of elements are in strictly ascending
// order.
int PrevIdx = Mask[0] % NumSrcElems;
const size_t NumElems = Mask.size();
for (unsigned I = 1; I < NumElems; I++) {
int CurrIdx = Mask[I] % NumSrcElems;
if (CurrIdx <= PrevIdx)
return false;
}

// Create the mask
unsigned long long DstMask = 0;
for (unsigned I = 0; I < NumElems; I++) {
uint64_t DstMask = 0;
const size_t NumMaskElems = Mask.size();
for (unsigned I = 0; I < NumMaskElems; I++) {
int Idx = Mask[I];
if (Idx >= (int)NumSrcElems) {
unsigned long long ElemMask = 1 << I;
DstMask |= ElemMask;
}
if (Idx == -1 || Idx == (int)I)
continue;

if ((unsigned)Idx == I + NumSrcElems)
DstMask |= uint64_t(1) << I;
else
return false;
}

MatchInfo = std::make_tuple(DstReg, Src1Reg, Src2Reg, DstMask);
MatchInfo = [=, &TII](MachineIRBuilder &B) {
MachineInstrBuilder MaskReg = B.buildConstant(LLT::scalar(32), DstMask);
const unsigned VSelOpc = TII.getGenericVSelOpcode();
B.buildInstr(VSelOpc, {DstReg}, {Src1Reg, Src2Reg, MaskReg});
};
return true;
}

Expand Down
7 changes: 2 additions & 5 deletions llvm/lib/Target/AIE/AIECombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@ bool matchShuffleToBroadcast(MachineInstr &MI, MachineRegisterInfo &MRI,
std::pair<Register, Register> &MatchInfo);
/// Combine G_SHUFFLE_VECTOR(G_BUILD_VECTOR (VAL, UNDEF, ...), mask<0,0,...>)
/// idiom into G_AIE_VSEL
bool matchShuffleToVSel(
MachineInstr &MI, MachineRegisterInfo &MRI,
std::tuple<Register, Register, Register, uint64_t> &MatchInfo);
bool matchShuffleToVSel(MachineInstr &MI, MachineRegisterInfo &MRI,
const AIEBaseInstrInfo &TII, BuildFnTy &MatchInfo);
/// Combine a shuffle vector with a mask that extracts the only element from
/// the first source vector and broadcasts it.
bool matchShuffleToExtractBroadcast(MachineInstr &MI, MachineRegisterInfo &MRI,
Expand Down Expand Up @@ -166,8 +165,6 @@ bool matchConcatPadVector(MachineInstr &MI, MachineRegisterInfo &MRI,
Register &MatchedInputVector);
void applyPadVector(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B, Register MatchedInputVector);
void applyVSel(MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
std::tuple<Register, Register, Register, uint64_t> &MatchInfo);
bool tryToCombineVectorShiftsByZero(MachineInstr &MI, MachineRegisterInfo &MRI);

bool matchExtractConcat(MachineInstr &MI, MachineRegisterInfo &MRI,
Expand Down
Loading

0 comments on commit 0f3e26a

Please sign in to comment.