Skip to content

Commit 45a01df

Browse files
committed
[AMD][ROCm] Enable BF16 and fixes review's comment
1 parent 110d6dd commit 45a01df

File tree

6 files changed

+131
-30
lines changed

6 files changed

+131
-30
lines changed

csrc/fp_quantizer/fp_quantize.cpp renamed to csrc/fp_quantizer/fp_quantize_api.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
#if defined(__HIP_PLATFORM_AMD__)
1313
#include <hip/hip_fp16.h>
14+
#if BF16_AVAILABLE
15+
#include <hip/hip_bf16.h>
16+
#endif
1417
#endif
1518

1619
#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \

csrc/includes/conversion_utils.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ DS_D_INLINE __half to(__half val)
5959
{
6060
return val;
6161
}
62+
6263
#ifdef BF16_AVAILABLE
6364
template <>
6465
DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val)
@@ -363,42 +364,74 @@ DS_D_INLINE __nv_bfloat16 to(float val)
363364
template <>
364365
DS_D_INLINE __nv_bfloat16 to(int64_t val)
365366
{
367+
#ifdef __HIP_PLATFORM_AMD__
368+
return __double2bfloat16(__ll2double_rn(val));
369+
#else
366370
return __ll2bfloat16_rn(val);
371+
#endif
367372
}
368373
template <>
369374
DS_D_INLINE __nv_bfloat16 to(int32_t val)
370375
{
376+
#ifdef __HIP_PLATFORM_AMD__
377+
return __float2bfloat16(__int2float_rn(val));
378+
#else
371379
return __int2bfloat16_rn(val);
380+
#endif
372381
}
373382
template <>
374383
DS_D_INLINE __nv_bfloat16 to(int16_t val)
375384
{
385+
#ifdef __HIP_PLATFORM_AMD__
386+
return __float2bfloat16(__int2float_rn(val));
387+
#else
376388
return __short2bfloat16_rn(val);
389+
#endif
377390
}
378391
template <>
379392
DS_D_INLINE __nv_bfloat16 to(int8_t val)
380393
{
394+
#ifdef __HIP_PLATFORM_AMD__
395+
return __float2bfloat16(__int2float_rn(val));
396+
#else
381397
return __int2bfloat16_rn(val);
398+
#endif
382399
}
383400
template <>
384401
DS_D_INLINE __nv_bfloat16 to(uint64_t val)
385402
{
403+
#ifdef __HIP_PLATFORM_AMD__
404+
return __double2bfloat16(__ull2double_rn(val));
405+
#else
386406
return __ull2bfloat16_rn(val);
407+
#endif
387408
}
388409
template <>
389410
DS_D_INLINE __nv_bfloat16 to(uint32_t val)
390411
{
412+
#ifdef __HIP_PLATFORM_AMD__
413+
return __float2bfloat16(__uint2float_rn(val));
414+
#else
391415
return __uint2bfloat16_rn(val);
416+
#endif
392417
}
393418
template <>
394419
DS_D_INLINE __nv_bfloat16 to(uint16_t val)
395420
{
421+
#ifdef __HIP_PLATFORM_AMD__
422+
return __float2bfloat16(__uint2float_rn(val));
423+
#else
396424
return __ushort2bfloat16_rn(val);
425+
#endif
397426
}
398427
template <>
399428
DS_D_INLINE __nv_bfloat16 to(uint8_t val)
400429
{
430+
#ifdef __HIP_PLATFORM_AMD__
431+
return __float2bfloat16(__uint2float_rn(val));
432+
#else
401433
return __uint2bfloat16_rn(val);
434+
#endif
402435
}
403436
#endif
404437

@@ -412,7 +445,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val)
412445
template <>
413446
DS_D_INLINE __nv_bfloat162 to(float val)
414447
{
448+
#ifdef __HIP_PLATFORM_AMD__
449+
return __bfloat162bfloat162(__float2bfloat16(val));
450+
#else
415451
return __float2bfloat162_rn(val);
452+
#endif
416453
}
417454
template <>
418455
DS_D_INLINE __nv_bfloat162 to(__half2 val)
@@ -444,7 +481,11 @@ DS_D_INLINE int64_t to(__half val)
444481
template <>
445482
DS_D_INLINE int64_t to(__nv_bfloat16 val)
446483
{
484+
#ifdef __HIP_PLATFORM_AMD__
485+
return __float2ll_rn(__bfloat162float(val));
486+
#else
447487
return __bfloat162ll_rn(val);
488+
#endif
448489
}
449490
#endif
450491

@@ -471,7 +512,11 @@ DS_D_INLINE int32_t to(__half val)
471512
template <>
472513
DS_D_INLINE int32_t to(__nv_bfloat16 val)
473514
{
515+
#ifdef __HIP_PLATFORM_AMD__
516+
return __float2int_rn(__bfloat162float(val));
517+
#else
474518
return __bfloat162int_rn(val);
519+
#endif
475520
}
476521
#endif
477522

@@ -498,7 +543,11 @@ DS_D_INLINE int16_t to(__half val)
498543
template <>
499544
DS_D_INLINE int16_t to(__nv_bfloat16 val)
500545
{
546+
#ifdef __HIP_PLATFORM_AMD__
547+
return __float2int_rn(__bfloat162float(val));
548+
#else
501549
return __bfloat162int_rn(val);
550+
#endif
502551
}
503552
#endif
504553

@@ -525,7 +574,11 @@ DS_D_INLINE int8_t to(__half val)
525574
template <>
526575
DS_D_INLINE int8_t to(__nv_bfloat16 val)
527576
{
577+
#ifdef __HIP_PLATFORM_AMD__
578+
return __float2int_rn(__bfloat162float(val));
579+
#else
528580
return __bfloat162int_rn(val);
581+
#endif
529582
}
530583
#endif
531584

@@ -552,7 +605,11 @@ DS_D_INLINE uint64_t to(__half val)
552605
template <>
553606
DS_D_INLINE uint64_t to(__nv_bfloat16 val)
554607
{
608+
#ifdef __HIP_PLATFORM_AMD__
609+
return __float2ull_rn(__bfloat162float(val));
610+
#else
555611
return __bfloat162ull_rn(val);
612+
#endif
556613
}
557614
#endif
558615

@@ -579,7 +636,11 @@ DS_D_INLINE uint32_t to(__half val)
579636
template <>
580637
DS_D_INLINE uint32_t to(__nv_bfloat16 val)
581638
{
639+
#ifdef __HIP_PLATFORM_AMD__
640+
return __float2uint_rn(__bfloat162float(val));
641+
#else
582642
return __bfloat162uint_rn(val);
643+
#endif
583644
}
584645
#endif
585646

@@ -606,7 +667,11 @@ DS_D_INLINE uint16_t to(__half val)
606667
template <>
607668
DS_D_INLINE uint16_t to(__nv_bfloat16 val)
608669
{
670+
#ifdef __HIP_PLATFORM_AMD__
671+
return __float2uint_rn(__bfloat162float(val));
672+
#else
609673
return __bfloat162uint_rn(val);
674+
#endif
610675
}
611676
#endif
612677

@@ -633,7 +698,11 @@ DS_D_INLINE uint8_t to(__half val)
633698
template <>
634699
DS_D_INLINE uint8_t to(__nv_bfloat16 val)
635700
{
701+
#ifdef __HIP_PLATFORM_AMD__
702+
return __float2uint_rn(__bfloat162float(val));
703+
#else
636704
return __bfloat162uint_rn(val);
705+
#endif
637706
}
638707
#endif
639708

csrc/includes/reduction_utils.h

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#include "ds_kernel_utils.h"
1010
#include "memory_access_utils.h"
1111

12+
#if defined(BF16_AVAILABLE) && defined(__HIP_PLATFORM_AMD__)
13+
#include <hip/hip_bfloat16.h>
14+
#endif
15+
1216
namespace cg = cooperative_groups;
1317

1418
namespace reduce {
@@ -374,7 +378,11 @@ DS_D_INLINE __half init<ROpType::Max>()
374378
template <>
375379
DS_D_INLINE __nv_bfloat16 init<ROpType::Max>()
376380
{
381+
#ifdef __HIP_PLATFORM_AMD__
382+
constexpr __hip_bfloat16_raw neg_inf = {0xFF80};
383+
#else
377384
constexpr __nv_bfloat16_raw neg_inf = {0xFF80};
385+
#endif
378386
return __nv_bfloat16(neg_inf);
379387
}
380388
#endif
@@ -526,29 +534,12 @@ here (fold is C++17 only and I don't think helps and recursion feels like
526534
huge overkill that harms readability) that would be wonderful.
527535
*/
528536

529-
template <typename T>
530-
DS_D_INLINE T shfl_xor_helper(cg::thread_block_tile<hw_warp_size>& warp, const T& value, int i)
531-
{
532-
return warp.shfl_xor(value, i);
533-
}
534-
535-
#if defined(__HIP_PLATFORM_AMD__)
536-
template <>
537-
DS_D_INLINE __half shfl_xor_helper<__half>(cg::thread_block_tile<hw_warp_size>& warp,
538-
const __half& value,
539-
int i)
540-
{
541-
float fvalue = __half2float(value);
542-
return __half(warp.shfl_xor(fvalue, i));
543-
}
544-
#endif
545-
546537
template <typename T, ROpType Op, int reduce_width = hw_warp_size>
547538
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
548539
{
549540
#pragma unroll
550541
for (int i = 1; i < reduce_width; i *= 2) {
551-
data[0] = element<Op>(data[0], shfl_xor_helper(warp, data[0], i));
542+
data[0] = element<Op>(data[0], warp.shfl_xor(data[0], i));
552543
}
553544
}
554545

@@ -557,8 +548,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
557548
{
558549
#pragma unroll
559550
for (int i = 1; i < reduce_width; i *= 2) {
560-
data[0] = element<Op1>(data[0], shfl_xor_helper(warp, data[0], i));
561-
data[1] = element<Op2>(data[1], shfl_xor_helper(warp, data[1], i));
551+
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
552+
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
562553
}
563554
}
564555

@@ -567,9 +558,9 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
567558
{
568559
#pragma unroll
569560
for (int i = 1; i < reduce_width; i *= 2) {
570-
data[0] = element<Op1>(data[0], shfl_xor_helper(warp, data[0], i));
571-
data[1] = element<Op2>(data[1], shfl_xor_helper(warp, data[1], i));
572-
data[2] = element<Op3>(data[2], shfl_xor_helper(warp, data[2], i));
561+
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
562+
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
563+
data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
573564
}
574565
}
575566

@@ -583,13 +574,39 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
583574
{
584575
#pragma unroll
585576
for (int i = 1; i < reduce_width; i *= 2) {
586-
data[0] = element<Op1>(data[0], shfl_xor_helper(warp, data[0], i));
587-
data[1] = element<Op2>(data[1], shfl_xor_helper(warp, data[1], i));
588-
data[2] = element<Op3>(data[2], shfl_xor_helper(warp, data[2], i));
589-
data[3] = element<Op4>(data[3], shfl_xor_helper(warp, data[3], i));
577+
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
578+
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
579+
data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
580+
data[3] = element<Op4>(data[3], warp.shfl_xor(data[3], i));
590581
}
591582
}
592583

584+
#if defined(__HIP_PLATFORM_AMD__)
585+
template <int reduce_width, typename T, ROpType... Ops>
586+
DS_D_INLINE void _warp_with_type_conversion(
587+
cg::thread_block_tile<hw_warp_size>& warp_arg,
588+
T* data)
589+
{
590+
constexpr int elems = sizeof...(Ops);
591+
if constexpr (
592+
!(std::is_integral<T>::value || std::is_floating_point<T>::value)
593+
) {
594+
float temp_data[elems];
595+
#pragma unroll
596+
for (int i = 0; i < elems; i++) {
597+
temp_data[i] = conversion::to<float>(data[i]);
598+
}
599+
_warp<float, Ops...>(warp_arg, temp_data);
600+
#pragma unroll
601+
for (int i = 0; i < elems; i++) {
602+
data[i] = conversion::to<T>(temp_data[i]);
603+
}
604+
} else {
605+
_warp<T, Ops...>(warp_arg, data);
606+
}
607+
}
608+
#endif // defined(__HIP_PLATFORM_AMD__)
609+
593610
/*
594611
Implementation for primary block reduction that serves both `block` and
595612
`partitioned_block`.
@@ -617,7 +634,11 @@ DS_D_INLINE void _block(cg::thread_block& tb,
617634
#endif
618635

619636
// Always perform warp-scope reduction
637+
#ifdef __HIP_PLATFORM_AMD__
638+
_warp_with_type_conversion<hw_warp_size, T, Ops...>(warp_arg, data);
639+
#else
620640
_warp<T, Ops...>(warp_arg, data);
641+
#endif
621642

622643
// If max_warps == 1 let's skip the runtime check
623644
if (total_warps != 1) {
@@ -641,8 +662,12 @@ DS_D_INLINE void _block(cg::thread_block& tb,
641662
} else {
642663
init<Ops...>(data);
643664
}
644-
665+
#ifdef __HIP_PLATFORM_AMD__
666+
_warp_with_type_conversion<total_warps, T, Ops...>(warp_arg, data);
667+
#else
645668
_warp<T, Ops..., total_warps>(warp_arg, data);
669+
#endif
670+
646671

647672
#pragma unroll
648673
for (int i = 0; i < elems; i++) {

op_builder/fp_quantizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class FPQuantizerBuilder(CUDAOpBuilder):
1818
def __init__(self, name=None):
1919
name = self.NAME if name is None else name
2020
super().__init__(name=name)
21+
if self.is_rocm_pytorch():
22+
self.enable_bf16 = True
2123

2224
def absolute_name(self):
2325
return f'deepspeed.ops.fp_quantizer.{self.NAME}_op'
@@ -90,7 +92,7 @@ def filter_ccs(self, ccs):
9092
def sources(self):
9193
return [
9294
"csrc/fp_quantizer/fp_quantize.cu",
93-
"csrc/fp_quantizer/fp_quantize.cpp",
95+
"csrc/fp_quantizer/fp_quantize_api.cu",
9496
]
9597

9698
def extra_ldflags(self):

op_builder/transformer_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class InferenceBuilder(CUDAOpBuilder):
1313
def __init__(self, name=None):
1414
name = self.NAME if name is None else name
1515
super().__init__(name=name)
16+
if self.is_rocm_pytorch():
17+
self.enable_bf16 = True
1618

1719
def absolute_name(self):
1820
return f'deepspeed.ops.transformer.inference.{self.NAME}_op'
@@ -55,7 +57,7 @@ def filter_ccs(self, ccs):
5557

5658
def sources(self):
5759
return [
58-
'csrc/transformer/inference/csrc/pt_binding.cpp',
60+
'csrc/transformer/inference/csrc/pt_binding.cu',
5961
'csrc/transformer/inference/csrc/gelu.cu',
6062
'csrc/transformer/inference/csrc/relu.cu',
6163
'csrc/transformer/inference/csrc/layer_norm.cu',

0 commit comments

Comments
 (0)