Skip to content

Commit 5c3d679

Browse files
authored
[mlir][x86vector] AVX Convert/Broadcast F16 to F32 instructions (llvm#137917)
Adds AVX broadcast and conversion from F16 to packed F32 (similar to PR: llvm#136830). The instructions that are added: - VBCSTNESH2PS - VCVTNEEPH2PS - VCVTNEOPH2PS
1 parent 249d949 commit 5c3d679

File tree

6 files changed

+250
-65
lines changed

6 files changed

+250
-65
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -408,34 +408,41 @@ def DotOp : AVX_LowOp<"dot", [Pure,
408408
}];
409409
}
410410

411-
412411
//----------------------------------------------------------------------------//
413-
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
412+
// AVX: Convert BF16/F16 to F32 and broadcast into packed F32
414413
//----------------------------------------------------------------------------//
415414

416-
def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
415+
def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
417416
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
418-
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
417+
let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
419418
let description = [{
420419
#### From the Intel Intrinsics Guide:
421420

422-
Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
423-
memory locations starting at location `__A` to packed single-precision
424-
(32-bit) floating-point elements, and store the results in `dst`.
421+
Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations
422+
starting at location `__A` to a single-precision (32-bit) floating-point,
423+
broadcast it to packed single-precision (32-bit) floating-point elements,
424+
and store the results in `dst`.
425425

426426
Example:
427427
```mlir
428-
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
428+
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
429+
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
429430
```
430431
}];
431-
let arguments = (ins AnyMemRef:$a);
432+
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
432433
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
433434
let assemblyFormat =
434435
"$a attr-dict`:` type($a)`->` type($dst)";
435436

436437
let extraClassDefinition = [{
437438
std::string $cppClass::getIntrinsicName() {
438-
std::string intr = "llvm.x86.vcvtneebf162ps";
439+
auto elementType =
440+
getA().getType().getElementType();
441+
std::string intr = "llvm.x86.";
442+
if (elementType.isBF16())
443+
intr += "vbcstnebf162ps";
444+
if (elementType.isF16())
445+
intr += "vbcstnesh2ps";
439446
VectorType vecType = getDst().getType();
440447
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
441448
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -447,31 +454,43 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
447454
let extraClassDeclaration = [{
448455
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
449456
}];
457+
450458
}
451459

452-
def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
460+
//------------------------------------------------------------------------------//
461+
// AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
462+
//------------------------------------------------------------------------------//
463+
464+
def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>,
453465
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
454-
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
466+
let summary = "AVX: Convert packed BF16/F16 even-indexed elements into packed F32 Data.";
455467
let description = [{
456468
#### From the Intel Intrinsics Guide:
457469

458-
Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
470+
Convert packed BF16 or F16 (16-bit) floating-point even-indexed elements stored at
459471
memory locations starting at location `__A` to packed single-precision
460472
(32-bit) floating-point elements, and store the results in `dst`.
461473

462474
Example:
463475
```mlir
464-
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
476+
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
477+
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
465478
```
466479
}];
467-
let arguments = (ins AnyMemRef:$a);
480+
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
468481
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
469482
let assemblyFormat =
470483
"$a attr-dict`:` type($a)`->` type($dst)";
471484

472485
let extraClassDefinition = [{
473486
std::string $cppClass::getIntrinsicName() {
474-
std::string intr = "llvm.x86.vcvtneobf162ps";
487+
auto elementType =
488+
getA().getType().getElementType();
489+
std::string intr = "llvm.x86.";
490+
if (elementType.isBF16())
491+
intr += "vcvtneebf162ps";
492+
if (elementType.isF16())
493+
intr += "vcvtneeph2ps";
475494
VectorType vecType = getDst().getType();
476495
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
477496
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -485,34 +504,36 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
485504
}];
486505
}
487506

488-
//----------------------------------------------------------------------------//
489-
// AVX: Convert BF16 to F32 and broadcast into packed F32
490-
//----------------------------------------------------------------------------//
491-
492-
def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
507+
def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>,
493508
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
494-
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
509+
let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
495510
let description = [{
496511
#### From the Intel Intrinsics Guide:
497512

498-
Convert scalar BF16 (16-bit) floating-point element stored at memory locations
499-
starting at location `__A` to a single-precision (32-bit) floating-point,
500-
broadcast it to packed single-precision (32-bit) floating-point elements,
501-
and store the results in `dst`.
513+
Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at
514+
memory locations starting at location `__A` to packed single-precision
515+
(32-bit) floating-point elements, and store the results in `dst`.
502516

503517
Example:
504518
```mlir
505-
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
519+
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
520+
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
506521
```
507522
}];
508-
let arguments = (ins AnyMemRef:$a);
523+
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
509524
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
510525
let assemblyFormat =
511526
"$a attr-dict`:` type($a)`->` type($dst)";
512527

513528
let extraClassDefinition = [{
514529
std::string $cppClass::getIntrinsicName() {
515-
std::string intr = "llvm.x86.vbcstnebf162ps";
530+
auto elementType =
531+
getA().getType().getElementType();
532+
std::string intr = "llvm.x86.";
533+
if (elementType.isBF16())
534+
intr += "vcvtneobf162ps";
535+
if (elementType.isF16())
536+
intr += "vcvtneoph2ps";
516537
VectorType vecType = getDst().getType();
517538
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
518539
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -521,10 +542,8 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
521542
}
522543
}];
523544

524-
let extraClassDeclaration = [{
545+
let extraClassDeclaration = [{
525546
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
526547
}];
527-
528548
}
529-
530549
#endif // X86VECTOR_OPS

mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,17 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
9595
return operands;
9696
}
9797

98-
SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
98+
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
9999
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
100100
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
101101
}
102102

103-
SmallVector<Value>
104-
x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
103+
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
105104
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
106105
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
107106
}
108107

109-
SmallVector<Value>
110-
x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
108+
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
111109
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
112110
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
113111
}

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
114114

115115
void mlir::configureX86VectorLegalizeForExportTarget(
116116
LLVMConversionTarget &target) {
117-
target.addIllegalOp<
118-
MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
119-
CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
120-
CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
117+
target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
118+
Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op,
119+
CvtPackedEvenIndexedToF32Op, CvtPackedOddIndexedToF32Op,
120+
BcstToPackedF32Op, RsqrtOp, DotOp>();
121121
}

mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
100100
%a: memref<8xbf16>) -> vector<4xf32>
101101
{
102102
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
103-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
103+
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
104104
return %0 : vector<4xf32>
105105
}
106106

@@ -109,7 +109,7 @@ func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
109109
%a: memref<16xbf16>) -> vector<8xf32>
110110
{
111111
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
112-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
112+
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
113113
return %0 : vector<8xf32>
114114
}
115115

@@ -118,7 +118,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
118118
%a: memref<8xbf16>) -> vector<4xf32>
119119
{
120120
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
121-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
121+
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xbf16> -> vector<4xf32>
122122
return %0 : vector<4xf32>
123123
}
124124

@@ -127,7 +127,7 @@ func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
127127
%a: memref<16xbf16>) -> vector<8xf32>
128128
{
129129
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
130-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
130+
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
131131
return %0 : vector<8xf32>
132132
}
133133

@@ -136,7 +136,7 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_128(
136136
%a: memref<1xbf16>) -> vector<4xf32>
137137
{
138138
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
139-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
139+
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
140140
return %0 : vector<4xf32>
141141
}
142142

@@ -145,7 +145,61 @@ func.func @avxbf16_bsct_bf16_to_f32_packed_256(
145145
%a: memref<1xbf16>) -> vector<8xf32>
146146
{
147147
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
148-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
148+
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
149+
return %0 : vector<8xf32>
150+
}
151+
152+
// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_128
153+
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_128(
154+
%a: memref<8xf16>) -> vector<4xf32>
155+
{
156+
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps128"
157+
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
158+
return %0 : vector<4xf32>
159+
}
160+
161+
// CHECK-LABEL: func @avxf16_cvt_packed_even_indexed_f16_to_f32_256
162+
func.func @avxf16_cvt_packed_even_indexed_f16_to_f32_256(
163+
%a: memref<16xf16>) -> vector<8xf32>
164+
{
165+
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneeph2ps256"
166+
%0 = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
167+
return %0 : vector<8xf32>
168+
}
169+
170+
// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128
171+
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_128(
172+
%a: memref<8xf16>) -> vector<4xf32>
173+
{
174+
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps128"
175+
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<8xf16> -> vector<4xf32>
176+
return %0 : vector<4xf32>
177+
}
178+
179+
// CHECK-LABEL: func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256
180+
func.func @avxf16_cvt_packed_odd_indexed_f16_to_f32_256(
181+
%a: memref<16xf16>) -> vector<8xf32>
182+
{
183+
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneoph2ps256"
184+
%0 = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
185+
return %0 : vector<8xf32>
186+
}
187+
188+
// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_128
189+
func.func @avxf16_bsct_f16_to_f32_packed_128(
190+
%a: memref<1xf16>) -> vector<4xf32>
191+
{
192+
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps128"
193+
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<4xf32>
194+
return %0 : vector<4xf32>
195+
}
196+
197+
// CHECK-LABEL: func @avxf16_bsct_f16_to_f32_packed_256
198+
func.func @avxf16_bsct_f16_to_f32_packed_256(
199+
%a: memref<1xf16>) -> vector<8xf32>
200+
{
201+
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnesh2ps256"
202+
%0 = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
149203
return %0 : vector<8xf32>
150204
}
151205

0 commit comments

Comments
 (0)