Skip to content

Commit 726af32

Browse files
phoebewangtstellar
authored andcommitted
[X86][BF16] Fix 2 crashes with vector broadcast
Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D151808 (cherry picked from commit 801dd88)
1 parent 4fd1b86 commit 726af32

File tree

3 files changed

+74
-5
lines changed

3 files changed

+74
-5
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,6 +2195,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
21952195
setOperationAction(ISD::FMUL, VT, Expand);
21962196
setOperationAction(ISD::FDIV, VT, Expand);
21972197
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
2198+
setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
21982199
}
21992200
addLegalFPImmediate(APFloat::getZero(APFloat::BFloat()));
22002201
}
@@ -2207,6 +2208,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
22072208
setOperationAction(ISD::FMUL, MVT::v32bf16, Expand);
22082209
setOperationAction(ISD::FDIV, MVT::v32bf16, Expand);
22092210
setOperationAction(ISD::BUILD_VECTOR, MVT::v32bf16, Custom);
2211+
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v32bf16, Custom);
22102212
}
22112213

22122214
if (!Subtarget.useSoftFloat() && Subtarget.hasVLX()) {
@@ -18773,11 +18775,11 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
1877318775
return DAG.getBitcast(VT, DAG.getVectorShuffle(FpVT, DL, V1, V2, Mask));
1877418776
}
1877518777

18776-
if (VT == MVT::v16f16) {
18777-
V1 = DAG.getBitcast(MVT::v16i16, V1);
18778-
V2 = DAG.getBitcast(MVT::v16i16, V2);
18779-
return DAG.getBitcast(MVT::v16f16,
18780-
DAG.getVectorShuffle(MVT::v16i16, DL, V1, V2, Mask));
18778+
if (VT == MVT::v16f16 || VT.getVectorElementType() == MVT::bf16) {
18779+
MVT IVT = VT.changeVectorElementTypeToInteger();
18780+
V1 = DAG.getBitcast(IVT, V1);
18781+
V2 = DAG.getBitcast(IVT, V2);
18782+
return DAG.getBitcast(VT, DAG.getVectorShuffle(IVT, DL, V1, V2, Mask));
1878118783
}
1878218784

1878318785
switch (VT.SimpleTy) {

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12969,6 +12969,27 @@ let Predicates = [HasBF16, HasVLX] in {
1296912969
(VCVTNEPS2BF16Z256rr VR256X:$src)>;
1297012970
def : Pat<(v8bf16 (int_x86_vcvtneps2bf16256 (loadv8f32 addr:$src))),
1297112971
(VCVTNEPS2BF16Z256rm addr:$src)>;
12972+
12973+
def : Pat<(v8bf16 (X86VBroadcastld16 addr:$src)),
12974+
(VPBROADCASTWZ128rm addr:$src)>;
12975+
def : Pat<(v16bf16 (X86VBroadcastld16 addr:$src)),
12976+
(VPBROADCASTWZ256rm addr:$src)>;
12977+
12978+
def : Pat<(v8bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
12979+
(VPBROADCASTWZ128rr VR128X:$src)>;
12980+
def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
12981+
(VPBROADCASTWZ256rr VR128X:$src)>;
12982+
12983+
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
12984+
}
12985+
12986+
let Predicates = [HasBF16] in {
12987+
def : Pat<(v32bf16 (X86VBroadcastld16 addr:$src)),
12988+
(VPBROADCASTWZrm addr:$src)>;
12989+
12990+
def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
12991+
(VPBROADCASTWZrr VR128X:$src)>;
12992+
// TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
1297212993
}
1297312994

1297412995
let Constraints = "$src1 = $dst" in {

llvm/test/CodeGen/X86/avx512bf16-vl-intrinsics.ll

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,49 @@ entry:
356356
%2 = select <4 x i1> %1, <4 x float> %0, <4 x float> %E
357357
ret <4 x float> %2
358358
}
359+
360+
define <16 x i16> @test_no_vbroadcast1() {
361+
; CHECK-LABEL: test_no_vbroadcast1:
362+
; CHECK: # %bb.0: # %entry
363+
; CHECK-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
364+
; CHECK-NEXT: vpbroadcastw %xmm0, %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0xc0]
365+
; CHECK-NEXT: ret{{[l|q]}} # encoding: [0xc3]
366+
entry:
367+
%0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> poison, <8 x bfloat> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
368+
%1 = bitcast <8 x bfloat> %0 to <8 x i16>
369+
%2 = shufflevector <8 x i16> %1, <8 x i16> undef, <16 x i32> zeroinitializer
370+
ret <16 x i16> %2
371+
}
372+
373+
;; FIXME: This should generate the same output as above, but let's fix the crash first.
374+
define <16 x bfloat> @test_no_vbroadcast2() nounwind {
375+
; X86-LABEL: test_no_vbroadcast2:
376+
; X86: # %bb.0: # %entry
377+
; X86-NEXT: pushl %ebp # encoding: [0x55]
378+
; X86-NEXT: movl %esp, %ebp # encoding: [0x89,0xe5]
379+
; X86-NEXT: andl $-32, %esp # encoding: [0x83,0xe4,0xe0]
380+
; X86-NEXT: subl $64, %esp # encoding: [0x83,0xec,0x40]
381+
; X86-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
382+
; X86-NEXT: vmovaps %xmm0, (%esp) # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x29,0x04,0x24]
383+
; X86-NEXT: vpbroadcastw (%esp), %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0x04,0x24]
384+
; X86-NEXT: movl %ebp, %esp # encoding: [0x89,0xec]
385+
; X86-NEXT: popl %ebp # encoding: [0x5d]
386+
; X86-NEXT: retl # encoding: [0xc3]
387+
;
388+
; X64-LABEL: test_no_vbroadcast2:
389+
; X64: # %bb.0: # %entry
390+
; X64-NEXT: pushq %rbp # encoding: [0x55]
391+
; X64-NEXT: movq %rsp, %rbp # encoding: [0x48,0x89,0xe5]
392+
; X64-NEXT: andq $-32, %rsp # encoding: [0x48,0x83,0xe4,0xe0]
393+
; X64-NEXT: subq $64, %rsp # encoding: [0x48,0x83,0xec,0x40]
394+
; X64-NEXT: vcvtneps2bf16 %xmm0, %xmm0 # encoding: [0x62,0xf2,0x7e,0x08,0x72,0xc0]
395+
; X64-NEXT: vmovaps %xmm0, (%rsp) # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x29,0x04,0x24]
396+
; X64-NEXT: vpbroadcastw (%rsp), %ymm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x7d,0x79,0x04,0x24]
397+
; X64-NEXT: movq %rbp, %rsp # encoding: [0x48,0x89,0xec]
398+
; X64-NEXT: popq %rbp # encoding: [0x5d]
399+
; X64-NEXT: retq # encoding: [0xc3]
400+
entry:
401+
%0 = tail call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> poison, <8 x bfloat> zeroinitializer, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
402+
%1 = shufflevector <8 x bfloat> %0, <8 x bfloat> undef, <16 x i32> zeroinitializer
403+
ret <16 x bfloat> %1
404+
}

0 commit comments

Comments
 (0)