Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions llvm/lib/Target/AIE/AIELegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ bool AIELegalizerHelper::legalizeG_FPTRUNC(LegalizerHelper &Helper,

bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper,
MachineInstr &MI) const {
const AIEBaseInstrInfo *II = ST.getInstrInfo();
MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
MachineRegisterInfo &MRI = *MIRBuilder.getMRI();

Expand All @@ -1206,6 +1207,67 @@ bool AIELegalizerHelper::legalizeG_FPEXT(LegalizerHelper &Helper,
LLT DstTy = MRI.getType(DstReg);
LLT SrcTy = MRI.getType(SrcReg);

// Vectors
/*
VDst = G_FPEXT VSrc
converts to
ZeroVec = G_AIE_BROADCAST_VECTOR VSrc
VShuffleLow = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 2
VShuffleHigh = G_AIE_SHUFFLE_VECTOR ZeroVec, VSrc, 3
VShuffleLow = G_BITCAST VShuffleLow
VShuffleHigh = G_BITCAST VShuffleHigh
VDst = G_CONCAT_VECTORS VShuffleLow, VShuffleHigh
*/
if (DstTy.isVector() && SrcTy.isVector()) {
// Extract type information
auto DstElementType = DstTy.getElementType();
auto SrcNumElements = SrcTy.getNumElements();
// Create constants for shuffle modes
Register Mode2 = MIRBuilder.buildConstant(S32, 2).getReg(0);
Register Mode3 = MIRBuilder.buildConstant(S32, 3).getReg(0);
Register Zero = MIRBuilder.buildConstant(S32, 0).getReg(0);
// Get the instructions
const unsigned BroadcastOpc = II->getGenericBroadcastVectorOpcode();
const unsigned VShuffleOpc = II->getGenericShuffleVectorOpcode();

// Step 1: Create a zero vector using broadcast
Register ZeroVec =
MIRBuilder.buildInstr(BroadcastOpc, {SrcTy}, {Zero}).getReg(0);
// Step 2: Create VSHUFFLE for lower 512 bits (mode 2)
Register VShuffleLow =
MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode2})
.getReg(0);
// Step 3: Create VSHUFFLE for high 512 bits (mode 3)
Register VShuffleHigh =
MIRBuilder.buildInstr(VShuffleOpc, {SrcTy}, {ZeroVec, SrcReg, Mode3})
.getReg(0);
// Step 4: bitcast VShuffleLow and VShuffleHigh
// Example: <32xs16> -> <16xs32>
LLT CastToNewTy =
LLT::vector(ElementCount::getFixed(SrcNumElements / 2), DstElementType);
if (CastToNewTy.getSizeInBits() !=
MRI.getType(VShuffleLow).getSizeInBits() ||
CastToNewTy.getSizeInBits() !=
MRI.getType(VShuffleHigh).getSizeInBits()) {
llvm::errs()
<< "Error: Size mismatch in vector bitcast for G_FPEXT. Expected: "
<< CastToNewTy.getSizeInBits()
<< " bits, got: " << MRI.getType(VShuffleLow).getSizeInBits()
<< " and " << MRI.getType(VShuffleHigh).getSizeInBits() << " bits\n";
return false;
}
auto VShuffleLowCast =
MIRBuilder.buildCast(CastToNewTy, VShuffleLow).getReg(0);
auto VShuffleHighCast =
MIRBuilder.buildCast(CastToNewTy, VShuffleHigh).getReg(0);
// Step 5: Concatenate the two src vectors into dst vector
MIRBuilder.buildConcatVectors(DstReg, {VShuffleLowCast, VShuffleHighCast});

MI.eraseFromParent();
return true;
}

// Scalars
// We only handle bfloat16 to single precision conversion
if (DstTy != LLT::scalar(32) || SrcTy != LLT::scalar(16))
return false;
Expand Down
32 changes: 32 additions & 0 deletions llvm/lib/Target/AIE/aie2p/AIE2PLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ static LegalityPredicate isValidVectorAIEP(const unsigned TypeIdx) {
};
}

// `V2 = G_FPEXT V1` on vectors is valid iff:
// - V1 and V2 are floating-point vectors
// - V2 is wider than V1 for total vector sizes
// - Number of elements of both vectors are same
// - Size of Element of V2 = 2 * Size of Element of V1
static LegalityPredicate isValidVectorFPEXT(const unsigned TypeIdx_dst,
const unsigned TypeIdx_src) {
return [=](const LegalityQuery &Query) {
const LLT DstTy = Query.Types[TypeIdx_dst];
const LLT SrcTy = Query.Types[TypeIdx_src];
if (DstTy.isVector() && SrcTy.isVector()) {
auto DstElementCount = DstTy.getElementCount();
auto SrcElementCount = SrcTy.getElementCount();
auto DstElementType = DstTy.getElementType();
auto SrcElementType = SrcTy.getElementType();
auto DstElementSize = DstElementType.getSizeInBits();
auto SrcElementSize = SrcElementType.getSizeInBits();
return DstTy.getSizeInBits() > SrcTy.getSizeInBits() &&
DstElementCount == SrcElementCount &&
(DstElementSize == (SrcElementSize * 2));
}
return false;
};
}

static LegalityPredicate
negatePredicate(const std::function<bool(const LegalityQuery &)> &Func) {
return [=](const LegalityQuery &Query) { return !Func(Query); };
Expand Down Expand Up @@ -219,6 +244,13 @@ AIE2PLegalizerInfo::AIE2PLegalizerInfo(const AIE2PSubtarget &ST)
getActionDefinitionsBuilder(G_FPEXT)
.libcallFor({{S64, S32}})
.customFor({{S32, S16}})
// Add support for vector types
// Extend vectors to have at least 512-bits
.clampMinNumElements(1, S8, 64)
.clampMinNumElements(1, S16, 32)
.clampMinNumElements(1, S32, 16)
.customIf(isValidVectorFPEXT(0 /* Dst */, 1 /* Src */))
// .customFor({{V32S32, V32S16}})
.narrowScalarFor({{S64, S16}}, llvm::LegalizeMutations::changeTo(0, S32));

getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
Expand Down
73 changes: 73 additions & 0 deletions llvm/test/CodeGen/AIE/aie2p/GlobalIsel/legalize-vector-fpext.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
; RUN: llc -mtriple=aie2p -O0 -stop-after=legalizer %s -o - 2>&1 | FileCheck %s


; Validates bfloat -> float legalization.
; CHECK-LABEL: name: extend
; CHECK: [[COPY:%[0-9]+]]:_(<32 x s16>) = COPY $x0
; CHECK-NOT: G_SHL
; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 2
; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
; CHECK-NEXT: [[C0:%[0-9]+]]:_(s32) = G_CONSTANT i32 0
; CHECK-NEXT: [[BCAST:%[0-9]+]]:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR [[C0]](s32)
; CHECK-NEXT: [[SHUF1:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], [[COPY]], [[C2]](s32)
; CHECK-NEXT: [[SHUF2:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], [[COPY]], [[C3]](s32)
; CHECK-NEXT: [[BIT1:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF1]](<32 x s16>)
; CHECK-NEXT: [[BIT2:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF2]](<32 x s16>)
; CHECK-NEXT: [[CONCAT:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[BIT1]](<16 x s32>), [[BIT2]](<16 x s32>)

define <32 x float> @extend(bfloat %o, <32 x bfloat> %in) nounwind {
%X = fpext <32 x bfloat> %in to <32 x float>
ret <32 x float> %X
}

; Pads the 17 valid values with undefined values to form a 32 size vector.

; CHECK-LABEL: name: extend_non_power_of_2
; CHECK: [[COPY:%[0-9]+]]:_(<32 x s16>) = COPY $x0
; CHECK-COUNT-17: G_AIE_SEXT_EXTRACT_VECTOR_ELT
; CHECK-COUNT-32: G_AIE_ADD_VECTOR_ELT_HI
; CHECK-NEXT: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 2
; CHECK-NEXT: [[C3:%[0-9]+]]:_(s32) = G_CONSTANT i32 3
; CHECK-NEXT: [[C0:%[0-9]+]]:_(s32) = G_CONSTANT i32 0
; CHECK-NEXT: [[BCAST:%[0-9]+]]:_(<32 x s16>) = G_AIE_BROADCAST_VECTOR [[C0]](s32)
; CHECK-NEXT: [[SHUF1:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], %{{[0-9]+}}, [[C2]](s32)
; CHECK-NEXT: [[SHUF2:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR [[BCAST]], %{{[0-9]+}}, [[C3]](s32)
; CHECK-NEXT: [[BIT1:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF1]](<32 x s16>)
; CHECK-NEXT: [[BIT2:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUF2]](<32 x s16>)
; CHECK-COUNT-17: G_AIE_SEXT_EXTRACT_VECTOR_ELT
; CHECK-COUNT-32: G_AIE_ADD_VECTOR_ELT_HI
; CHECK-NEXT: [[CONCAT:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS %{{[0-9]+}}(<16 x s32>), %{{[0-9]+}}(<16 x s32>)
define <17 x float> @extend_non_power_of_2(<17 x bfloat> %in) nounwind {
%X = fpext <17 x bfloat> %in to <17 x float>
ret <17 x float> %X
}

; Validates if vector size < 256 bits

; CHECK-LABEL: name: fpext_bf16_to_f32
; CHECK: bb.1
; CHECK: [[VEC_CONCAT:%[0-9]+]]:_(<32 x s16>) = G_CONCAT_VECTORS
; CHECK: G_AIE_SEXT_EXTRACT_VECTOR_ELT [[VEC_CONCAT]]
; CHECK: G_AIE_ADD_VECTOR_ELT_HI
; CHECK: [[SHUFFLE_VEC:%[0-9]+]]:_(<32 x s16>) = G_AIE_SHUFFLE_VECTOR
; CHECK-NOT: G_AIE_SHUFFLE_VECTOR
; CHECK: [[BITCAST:%[0-9]+]]:_(<16 x s32>) = G_BITCAST [[SHUFFLE_VEC]]
; CHECK: $x0 = COPY [[BITCAST]]
define <16 x float> @fpext_bf16_to_f32(<16 x bfloat> %in) nounwind {
%X = fpext <16 x bfloat> %in to <16 x float>
ret <16 x float> %X
}

; Validates scalar path
; CHECK-LABEL: name: fpext_scalar_bf16_to_f32
; CHECK: [[COPY:%[0-9]+]]:_(s32) = COPY $r1
; CHECK-NEXT: [[C16:%[0-9]+]]:_(s32) = G_CONSTANT i32 16
; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[COPY]], [[C16]](s32)
; CHECK-NOT: G_AIE_SHUFFLE_VECTOR
; CHECK-NEXT: $r0 = COPY [[SHL]](s32)
; CHECK-NEXT: PseudoRET implicit $lr, implicit $r0

define float @fpext_scalar_bf16_to_f32(bfloat %in) nounwind {
%X = fpext bfloat %in to float
ret float %X
}