@@ -408,34 +408,41 @@ def DotOp : AVX_LowOp<"dot", [Pure,
408
408
}];
409
409
}
410
410
411
-
412
411
//----------------------------------------------------------------------------//
413
- // AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
412
+ // AVX: Convert BF16/F16 to F32 and broadcast into packed F32
414
413
//----------------------------------------------------------------------------//
415
414
416
- def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt .packed.even.indexed.bf16_to_f32 ", [MemoryEffects<[MemRead]>,
415
+ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32 .packed", [MemoryEffects<[MemRead]>,
417
416
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
418
- let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
417
+ let summary = "AVX: Broadcasts BF16/F16 into packed F32 Data.";
419
418
let description = [{
420
419
#### From the Intel Intrinsics Guide:
421
420
422
- Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
423
- memory locations starting at location `__A` to packed single-precision
424
- (32-bit) floating-point elements, and store the results in `dst`.
421
+ Convert scalar BF16 or F16 (16-bit) floating-point element stored at memory locations
422
+ starting at location `__A` to a single-precision (32-bit) floating-point,
423
+ broadcast it to packed single-precision (32-bit) floating-point elements,
424
+ and store the results in `dst`.
425
425
426
426
Example:
427
427
```mlir
428
- %dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
428
+ %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
429
+ %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
429
430
```
430
431
}];
431
- let arguments = (ins AnyMemRef :$a);
432
+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
432
433
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
433
434
let assemblyFormat =
434
435
"$a attr-dict`:` type($a)`->` type($dst)";
435
436
436
437
let extraClassDefinition = [{
437
438
std::string $cppClass::getIntrinsicName() {
438
- std::string intr = "llvm.x86.vcvtneebf162ps";
439
+ auto elementType =
440
+ getA().getType().getElementType();
441
+ std::string intr = "llvm.x86.";
442
+ if (elementType.isBF16())
443
+ intr += "vbcstnebf162ps";
444
+ if (elementType.isF16())
445
+ intr += "vbcstnesh2ps";
439
446
VectorType vecType = getDst().getType();
440
447
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
441
448
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -447,31 +454,43 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
447
454
let extraClassDeclaration = [{
448
455
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
449
456
}];
457
+
450
458
}
451
459
452
- def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
460
+ //------------------------------------------------------------------------------//
461
+ // AVX: Convert packed BF16/F16 even-indexed/odd-indexed elements into packed F32
462
+ //------------------------------------------------------------------------------//
463
+
464
+ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [MemoryEffects<[MemRead]>,
453
465
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
454
- let summary = "AVX: Convert packed BF16 odd -indexed elements into packed F32 Data.";
466
+ let summary = "AVX: Convert packed BF16/F16 even -indexed elements into packed F32 Data.";
455
467
let description = [{
456
468
#### From the Intel Intrinsics Guide:
457
469
458
- Convert packed BF16 (16-bit) floating-point odd -indexed elements stored at
470
+ Convert packed BF16 or F16 (16-bit) floating-point even -indexed elements stored at
459
471
memory locations starting at location `__A` to packed single-precision
460
472
(32-bit) floating-point elements, and store the results in `dst`.
461
473
462
474
Example:
463
475
```mlir
464
- %dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
476
+ %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
477
+ %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
465
478
```
466
479
}];
467
- let arguments = (ins AnyMemRef :$a);
480
+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
468
481
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
469
482
let assemblyFormat =
470
483
"$a attr-dict`:` type($a)`->` type($dst)";
471
484
472
485
let extraClassDefinition = [{
473
486
std::string $cppClass::getIntrinsicName() {
474
- std::string intr = "llvm.x86.vcvtneobf162ps";
487
+ auto elementType =
488
+ getA().getType().getElementType();
489
+ std::string intr = "llvm.x86.";
490
+ if (elementType.isBF16())
491
+ intr += "vcvtneebf162ps";
492
+ if (elementType.isF16())
493
+ intr += "vcvtneeph2ps";
475
494
VectorType vecType = getDst().getType();
476
495
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
477
496
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -485,34 +504,36 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
485
504
}];
486
505
}
487
506
488
- //----------------------------------------------------------------------------//
489
- // AVX: Convert BF16 to F32 and broadcast into packed F32
490
- //----------------------------------------------------------------------------//
491
-
492
- def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
507
+ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [MemoryEffects<[MemRead]>,
493
508
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
494
- let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
509
+ let summary = "AVX: Convert packed BF16/F16 odd-indexed elements into packed F32 Data.";
495
510
let description = [{
496
511
#### From the Intel Intrinsics Guide:
497
512
498
- Convert scalar BF16 (16-bit) floating-point element stored at memory locations
499
- starting at location `__A` to a single-precision (32-bit) floating-point,
500
- broadcast it to packed single-precision (32-bit) floating-point elements,
501
- and store the results in `dst`.
513
+ Convert packed BF16 or F16 (16-bit) floating-point odd-indexed elements stored at
514
+ memory locations starting at location `__A` to packed single-precision
515
+ (32-bit) floating-point elements, and store the results in `dst`.
502
516
503
517
Example:
504
518
```mlir
505
- %dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
519
+ %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
520
+ %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
506
521
```
507
522
}];
508
- let arguments = (ins AnyMemRef :$a);
523
+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
509
524
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
510
525
let assemblyFormat =
511
526
"$a attr-dict`:` type($a)`->` type($dst)";
512
527
513
528
let extraClassDefinition = [{
514
529
std::string $cppClass::getIntrinsicName() {
515
- std::string intr = "llvm.x86.vbcstnebf162ps";
530
+ auto elementType =
531
+ getA().getType().getElementType();
532
+ std::string intr = "llvm.x86.";
533
+ if (elementType.isBF16())
534
+ intr += "vcvtneobf162ps";
535
+ if (elementType.isF16())
536
+ intr += "vcvtneoph2ps";
516
537
VectorType vecType = getDst().getType();
517
538
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
518
539
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
@@ -521,10 +542,8 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[Me
521
542
}
522
543
}];
523
544
524
- let extraClassDeclaration = [{
545
+ let extraClassDeclaration = [{
525
546
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
526
547
}];
527
-
528
548
}
529
-
530
549
#endif // X86VECTOR_OPS
0 commit comments