Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions llvm/lib/Target/AIE/AIELegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ bool AIELegalizerHelper::legalizeG_FPTRUNC(LegalizerHelper &Helper,

bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper,
MachineInstr &MI) const {
const AIEBaseInstrInfo *II = ST.getInstrInfo();
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();

Expand All @@ -1206,6 +1207,68 @@ bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper,
LLT DstTy = MRI.getType(DstReg);
LLT SrcTy = MRI.getType(SrcReg);

/* Vectors
VDst = G_FPEXT VSrc
converts to
ZeroVec = G_AIE_BROADCAST_VECTOR VSrc
VShuffleLow = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 18
VShuffleHigh = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 19
VShuffleLow = G_BITCAST VShuffleLow
VShuffleHigh = G_BITCAST VShuffleHigh
VDst = G_CONCAT_VECTORS VShuffle, VShuffleHigh
*/
if (DstTy.isVector() && SrcTy.isVector()) {
// Extract type information
auto DstElementType = DstTy.getElementType();
auto SrcNumElements = SrcTy.getNumElements();
// Create constants for shuffle modes
Register Mode18 = MIRBuilder.buildConstant(S32, 18).getReg(0);
Register Mode19 = MIRBuilder.buildConstant(S32, 19).getReg(0);
Register Zero = MIRBuilder.buildConstant(S32, 0).getReg(0);
// Get the instructions
const unsigned BroadcastOpc = II->getGenericBroadcastVectorOpcode();
const unsigned VShuffleOpc = II->getGenericShuffleVectorOpcode();

// Step 1: Create a zero vector using broadcast
Register ZeroVec =
MIRBuilder.buildInstr(BroadcastOpc, {SrcTy}, {Zero}).getReg(0);
// Step 2: Create VSHUFFLE for lower 512 bits (mode 18)
Register VShuffleLow =
MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode18})
.getReg(0);
// Step 3: Create VSHUFFLE for high 512 bits (mode 19)
Register VShuffleHigh =
MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode19})
.getReg(0);
// Step 4: bitcast VShuffleLow and VShuffleHigh
// Example: <32xs16> -> <16xs32>
LLT CastToNewTy =
LLT::vector(ElementCount::getFixed(SrcNumElements / 2), DstElementType);
if (CastToNewTy.getSizeInBits() !=
MRI.getType(VShuffleLow).getSizeInBits() ||
CastToNewTy.getSizeInBits() !=
MRI.getType(VShuffleHigh).getSizeInBits()) {
llvm::errs()
<< "Error: Size mismatch in vector bitcast for G_FPEXT. Expected: "
<< CastToNewTy.getSizeInBits()
<< " bits, got: " << MRI.getType(VShuffleLow).getSizeInBits()
<< " and " << MRI.getType(VShuffleHigh).getSizeInBits() << " bits\n";
return false;
}
auto VShuffleLowCast =
MIRBuilder.buildCast(CastToNewTy, VShuffleLow).getReg(0);
auto VShuffleHighCast =
MIRBuilder.buildCast(CastToNewTy, VShuffleHigh).getReg(0);
// Step 5: Concatenate the two src vectors into dst vector
MIRBuilder.buildConcatVectors(DstReg, {VShuffleLowCast, VShuffleHighCast});
// Possibly above line might be wrong, not tested enough.
// MIRBuilder.buildConcatVectors(DstReg, {VShuffleHighCast, VShuffleLowCast});

MI.eraseFromParent();
return true;
}

// Scalars
// We only handle bfloat16 to single precision conversion
if (DstTy != LLT::scalar(32) || SrcTy != LLT::scalar(16))
return false;
Expand Down Expand Up @@ -1300,6 +1363,12 @@ bool AIELegalizerHelper::legalizeG_FMUL(LegalizerHelper &Helper,
MI.eraseFromParent();
return true;
}
bool isBF16Vector(const LLT Ty) {
return Ty.isVector() && Ty.getScalarSizeInBits() == 16;
}
bool isF32Vector(const LLT Ty) {
return Ty.isVector() && Ty.getScalarSizeInBits() == 32;
}

bool AIELegalizerHelper::legalizeG_FADD_G_FSUB(LegalizerHelper &Helper,
MachineInstr &MI) const {
Expand All @@ -1309,6 +1378,138 @@ bool AIELegalizerHelper::legalizeG_FADD_G_FSUB(LegalizerHelper &Helper,
const Register DstReg = MI.getOperand(0).getReg();
Register SrcLHS = MI.getOperand(1).getReg();
Register SrcRHS = MI.getOperand(2).getReg();
const LLT SrcLHSTy = MRI.getType(SrcLHS);
const LLT SrcRHSTy = MRI.getType(SrcRHS);

// Can be combined with the bf16 vector case
if (isF32Vector(SrcLHSTy) && isF32Vector(SrcRHSTy)) {
// vector should be of size 32 asssert
assert(SrcLHSTy.getNumElements() == 32 && SrcRHSTy.getNumElements() == 32 &&
"Expected vector of size 32, type(f32) for inputs of G_FADD/G_FSUB");

// // Step 1: Convert bf16 vectors to f32 vectors using FPExt
const LLT F32VecTy = SrcLHSTy;
// LLT::fixed_vector(SrcLHSTy.getNumElements(), LLT::scalar(32));
// Register SrcLHSF32 = MRI.createGenericVirtualRegister(F32VecTy);
// Register SrcRHSF32 = MRI.createGenericVirtualRegister(F32VecTy);
// MIRBuilder.buildFPExt(SrcLHSF32, SrcLHS);
// MIRBuilder.buildFPExt(SrcRHSF32, SrcRHS);

// Step 2: Input is going to be <32 x bf16> pad it to <64 x f32> for AIE2P
// as AccV64S32 is legal on AIE2P.
if (ST.isAIE2P()) {
const Register UndefVec = MIRBuilder.buildUndef(F32VecTy).getReg(0);
const Register ConcatLHS = MRI.createGenericVirtualRegister(V64FP32);
const Register ConcatRHS = MRI.createGenericVirtualRegister(V64FP32);
MIRBuilder.buildConcatVectors(ConcatLHS, {SrcLHS, UndefVec});
MIRBuilder.buildConcatVectors(ConcatRHS, {SrcRHS, UndefVec});
SrcLHS = ConcatLHS;
SrcRHS = ConcatRHS;
}

// Step 3: Perform the floating point operation
Register Res = MIRBuilder
.buildInstr(MI.getOpcode(), {MRI.getType(SrcLHS)},
{SrcLHS, SrcRHS})
.getReg(0);

// Step 4: Handle accumulator conversion based on target
if (ST.isAIE2()) {
Res = MIRBuilder.buildBitcast(V8ACC64, Res).getReg(0);
} else if (ST.isAIE2P()) {
// Unmerge to get 2 vectors of <32xf32> as FADD/FSUB was done on <64xf32>
SmallVector<Register, 2> UnmergedRegs;
const auto Unmerge = MIRBuilder.buildUnmerge(F32VecTy, Res);
getUnmergeResults(UnmergedRegs, *Unmerge);
Res = UnmergedRegs[0]; // Take the first <32xf32> vector, other half is
// just zeros.
}

// // Step 5: Convert back to bf16 using the truncation intrinsic
// const int VecSize = MRI.getType(Res).getSizeInBits();
// const LLT DstLLT = ST.isAIE2P() ? V32BF16 : V16BF16;
// Res = MIRBuilder
// .buildIntrinsic(getFpTrunc32ToBF16IntrID(ST, VecSize), {DstLLT},
// true, false)
// .addUse(Res)
// .getReg(0);

// // Handle AIE2 padding
// if (ST.isAIE2()) {
// Res = emitPadUndefVector(MRI, MIRBuilder, V32BF16, Res);
// }

MIRBuilder.buildCopy(DstReg, Res);

MI.eraseFromParent();
return true;
}

// Handle bf16 vectors code assumes the input is <32 x bf16>, the
// LegalizerInfo makes sure that the input is either padded or unmerged to <32
// x bf16>.
if (isBF16Vector(SrcLHSTy) && isBF16Vector(SrcRHSTy)) {
// vector should be of size 32 asssert
assert(SrcLHSTy.getNumElements() == 32 && SrcRHSTy.getNumElements() == 32 &&
"Expected vector of size 32 for inputs of G_FADD/G_FSUB");

// Step 1: Convert bf16 vectors to f32 vectors using FPExt
const LLT F32VecTy =
LLT::fixed_vector(SrcLHSTy.getNumElements(), LLT::scalar(32));
Register SrcLHSF32 = MRI.createGenericVirtualRegister(F32VecTy);
Register SrcRHSF32 = MRI.createGenericVirtualRegister(F32VecTy);
MIRBuilder.buildFPExt(SrcLHSF32, SrcLHS);
MIRBuilder.buildFPExt(SrcRHSF32, SrcRHS);

// Step 2: Input is going to be <32 x bf16> pad it to <64 x f32> for AIE2P
// as AccV64S32 is legal on AIE2P.
if (ST.isAIE2P()) {
const Register UndefVec = MIRBuilder.buildUndef(F32VecTy).getReg(0);
const Register ConcatLHS = MRI.createGenericVirtualRegister(V64FP32);
const Register ConcatRHS = MRI.createGenericVirtualRegister(V64FP32);
MIRBuilder.buildConcatVectors(ConcatLHS, {SrcLHSF32, UndefVec});
MIRBuilder.buildConcatVectors(ConcatRHS, {SrcRHSF32, UndefVec});
SrcLHSF32 = ConcatLHS;
SrcRHSF32 = ConcatRHS;
}

// Step 3: Perform the floating point operation
Register Res = MIRBuilder
.buildInstr(MI.getOpcode(), {MRI.getType(SrcLHSF32)},
{SrcLHSF32, SrcRHSF32})
.getReg(0);

// Step 4: Handle accumulator conversion based on target
if (ST.isAIE2()) {
Res = MIRBuilder.buildBitcast(V8ACC64, Res).getReg(0);
} else if (ST.isAIE2P()) {
// Unmerge to get 2 vectors of <32xf32> as FADD/FSUB was done on <64xf32>
SmallVector<Register, 2> UnmergedRegs;
const auto Unmerge = MIRBuilder.buildUnmerge(F32VecTy, Res);
getUnmergeResults(UnmergedRegs, *Unmerge);
Res = UnmergedRegs[0]; // Take the first <32xf32> vector, other half is
// just zeros.
}

// Step 5: Convert back to bf16 using the truncation intrinsic
const int VecSize = MRI.getType(Res).getSizeInBits();
const LLT DstLLT = ST.isAIE2P() ? V32BF16 : V16BF16;
Res = MIRBuilder
.buildIntrinsic(getFpTrunc32ToBF16IntrID(ST, VecSize), {DstLLT},
true, false)
.addUse(Res)
.getReg(0);

// Handle AIE2 padding
if (ST.isAIE2()) {
Res = emitPadUndefVector(MRI, MIRBuilder, V32BF16, Res);
}

MIRBuilder.buildCopy(DstReg, Res);

MI.eraseFromParent();
return true;
}

assert(MRI.getType(DstReg) == LLT::scalar(16) &&
"Expected bfloat16 type in custom legalization.");
Expand Down
62 changes: 61 additions & 1 deletion llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ static LegalityPredicate isValidVectorAIEP(const unsigned TypeIdx) {
};
}

// `V2 = G_FPEXT V1` on vectors is valid iff:
// - V1 and V2 are floating-point vectors
// - V2 is wider than V1 for total vector sizes
// - Number of elements of both vectors are same
// - Size of Element of V2 = 2 * Size of Element of V1
static LegalityPredicate isValidVectorFPEXT(const unsigned TypeIdx_dst,
const unsigned TypeIdx_src) {
return [=](const LegalityQuery &Query) {
const LLT DstTy = Query.Types[TypeIdx_dst];
const LLT SrcTy = Query.Types[TypeIdx_src];
if (DstTy.isVector() && SrcTy.isVector()) {
auto DstElementCount = DstTy.getElementCount();
auto SrcElementCount = SrcTy.getElementCount();
auto DstElementType = DstTy.getElementType();
auto SrcElementType = SrcTy.getElementType();
auto DstElementSize = DstElementType.getSizeInBits();
auto SrcElementSize = SrcElementType.getSizeInBits();
return DstTy.getSizeInBits() > SrcTy.getSizeInBits() &&
DstElementCount == SrcElementCount &&
(DstElementSize == (SrcElementSize * 2));
}
return false;
};
}

static LegalityPredicate
negatePredicate(const std::function<bool(const LegalityQuery &)> &Func) {
return [=](const LegalityQuery &Query) { return !Func(Query); };
Expand Down Expand Up @@ -219,6 +244,13 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)
getActionDefinitionsBuilder(G_FPEXT)
.libcallFor({{S64, S32}})
.customFor({{S32, S16}})
// Add support for vector types
// Extend vectors to have at least 512-bits
.clampMinNumElements(1, S8, 64)
.clampMinNumElements(1, S16, 32)
.clampMinNumElements(1, S32, 16)
.customIf(isValidVectorFPEXT(0 /* Dst */, 1 /* Src */))
// .customFor({{V32S32, V32S16}})
.narrowScalarFor({{S64, S16}}, llvm::LegalizeMutations::changeTo(0, S32));

getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
Expand All @@ -241,7 +273,35 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)

getActionDefinitionsBuilder({G_FADD, G_FSUB})
.legalFor({AccV64S32})
.customFor({S16})
// Handle custom bf16/f32 case for both scalar and vector types
.customFor({S16, V32S16, V32S32})
// Convert smaller than <32 x f32/bf16> to legal sizes, doesn't change types
.moreElementsIf(
[=](const LegalityQuery &Query) {
const LLT &Ty = Query.Types[0];
return Ty.isVector() &&
(Ty.getScalarSizeInBits() == 32 ||
Ty.getScalarSizeInBits() == 16) &&
Ty.getNumElements() <= 32;
},
[=](const LegalityQuery &Query) {
if (Query.Types[0].getScalarSizeInBits() == 32) {
// Note: Can cause slowdown as BUILD_VECTOR adds scalars
return std::make_pair(0, LLT::fixed_vector(64, S32));
} else {
return std::make_pair(0, LLT::fixed_vector(32, S16));
}
})
// Converts <64xbf16> into 2 chunks of <32xbf16>
.fewerElementsIf(
[=](const LegalityQuery &Query) {
const LLT &Ty = Query.Types[0];
return Ty.isVector() && (Ty.getScalarSizeInBits() == 16) &&
Ty.getNumElements() == 64;
},
[=](const LegalityQuery &Query) {
return std::make_pair(0, LLT::fixed_vector(32, S16));
})
.libcallFor({S32, S64});

getActionDefinitionsBuilder({G_FDIV, G_FREM})
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fadd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; RUN: llc -mtriple=aie2p -O0 -stop-after=legalizer %s -o - 2>&1 | FileCheck %s
; This test is a carved out test for sending patch upstream from
; iree-amd-aie/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/multi_reduction_to_reduction_sizes_types.mlirUntitled-1.mlir

; Ideally reduction should be as follows(with minor changes for each shape):
; Input1: <32xbf16> and Input2: <32xbf16>
; Extended1<32xf32> = fpext <32xbf16>
; Extended2<32xf32> = fpext <32xbf16>
; Zero<32xf32> = zeroinitializer
; Out1<64xf32> = Concat zero, <Extended1<32xf32>>
; Out2<64xf32> = Concat zero, <Extended2<32xf32>>
; Result<64xf32> = fadd <Out1<64xf32>>, <Out2<64xf32>>
; R1<32xf32>, R2<32xf32> = unmerge <Result<64xf32>>
; R2 is all 0s
; R1<32xbf16> = trunc <R1<32xf32>>

; check the vadd.f
; pad checks
; checks similar to <32xbf16>
; unpad checks
define bfloat @multi_reduction_1d_16_bf16(<16 x bfloat> %0, bfloat %1) {
%3 = call reassoc bfloat @llvm.vector.reduce.fadd.v16bf16(bfloat %1, <16 x bfloat> %0)
ret bfloat %3
}



; CHECK-LABEL: name: multi_reduction_1d_32_bf16
; CHECK: G_CONSTANT i32 0
; CHECK: G_AIE_BROADCAST_VECTOR %{{[0-9]+}}(s32)
; CHECK: G_CONSTANT i32 2
; CHECK: G_CONSTANT i32 3
; CHECK: G_AIE_SHUFFLE_VECTOR %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}(s32)
; CHECK: G_AIE_SHUFFLE_VECTOR %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}(s32)
; CHECK: G_BITCAST %{{[0-9]+}}(<32 x s16>)
; CHECK: G_BITCAST %{{[0-9]+}}(<32 x s16>)
; CHECK: G_CONCAT_VECTORS %{{[0-9]+}}(<16 x s32>), %{{[0-9]+}}(<16 x s32>)
; CHECK: G_IMPLICIT_DEF
; CHECK: G_CONCAT_VECTORS %{{[0-9]+}}(<32 x s32>), %{{[0-9]+}}(<32 x s32>)
; CHECK: G_FADD %{{[0-9]+}}, %{{[0-9]+}}
; CHECK: G_UNMERGE_VALUES %{{[0-9]+}}(<64 x s32>)
; CHECK: G_INTRINSIC_W_SIDE_EFFECTS intrinsic(@llvm.aie2p.v32accfloat.to.v32bf16), %{{[0-9]+}}(<32 x s32>)
define bfloat @multi_reduction_1d_32_bf16(<32 x bfloat> %0, bfloat %1) {
%3 = call reassoc bfloat @llvm.vector.reduce.fadd.v32bf16(bfloat %1, <32 x bfloat> %0)
ret bfloat %3
}

; ; Converted to chunks of <32 x bf16>
; Check if the input is split into 2 chunks of <32 x bf16>
; Check for each chunk similar to <32xbf16> case
; Check if both inputs get concatenated to <64xbf16>

define bfloat @multi_reduction_1d_64_bf16(<64 x bfloat> %0, bfloat %1) {
%3 = call reassoc bfloat @llvm.vector.reduce.fadd.v64bf16(bfloat %1, <64 x bfloat> %0)
ret bfloat %3
}
Loading