diff --git a/csrc/fp_quantizer/fp_quantize.cu b/csrc/fp_quantizer/fp_quantize.cu index 66ea7392e011..42a1b63e424b 100644 --- a/csrc/fp_quantizer/fp_quantize.cu +++ b/csrc/fp_quantizer/fp_quantize.cu @@ -4,7 +4,7 @@ // DeepSpeed Team #include -#include "context.h" +#include "fp_context.h" #include "fp_quantize.h" #include "memory_access_utils.h" #include "reduction_utils.h" diff --git a/csrc/fp_quantizer/includes/context.h b/csrc/fp_quantizer/includes/fp_context.h similarity index 100% rename from csrc/fp_quantizer/includes/context.h rename to csrc/fp_quantizer/includes/fp_context.h diff --git a/csrc/fp_quantizer/includes/fp_quantize.h b/csrc/fp_quantizer/includes/fp_quantize.h index 60c75541f603..a15b8ddf5a22 100644 --- a/csrc/fp_quantizer/includes/fp_quantize.h +++ b/csrc/fp_quantizer/includes/fp_quantize.h @@ -9,10 +9,18 @@ #include #include - -#ifdef BF16_AVAILABLE +// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to +// after hipifying), because this header is pulled into .cpp translation units +// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration +// instead, see code block below +#if defined(BF16_AVAILABLE) +#if !defined(__HIP_PLATFORM_AMD__) #include +#else +struct __hip_bfloat16; +#endif #endif + #include #include diff --git a/csrc/includes/conversion_utils.h b/csrc/includes/conversion_utils.h index 3a90a3e91ddf..d6d8f11e0854 100644 --- a/csrc/includes/conversion_utils.h +++ b/csrc/includes/conversion_utils.h @@ -363,42 +363,74 @@ DS_D_INLINE __nv_bfloat16 to(float val) template <> DS_D_INLINE __nv_bfloat16 to(int64_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ll2double_rn(val)); +#else return __ll2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int32_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __int2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int16_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __short2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(int8_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__int2float_rn(val)); +#else return __int2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint64_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __double2bfloat16(__ull2double_rn(val)); +#else return __ull2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint32_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __uint2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint16_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __ushort2bfloat16_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat16 to(uint8_t val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2bfloat16(__uint2float_rn(val)); +#else return __uint2bfloat16_rn(val); +#endif } #endif @@ -412,7 +444,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val) template <> DS_D_INLINE __nv_bfloat162 to(float val) { +#ifdef __HIP_PLATFORM_AMD__ + return __bfloat162bfloat162(__float2bfloat16(val)); +#else return __float2bfloat162_rn(val); +#endif } template <> DS_D_INLINE __nv_bfloat162 to(__half2 val) @@ -444,7 +480,11 @@ DS_D_INLINE int64_t to(__half val) template <> DS_D_INLINE int64_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2ll_rn(__bfloat162float(val)); +#else return __bfloat162ll_rn(val); +#endif } #endif @@ -471,7 +511,11 @@ DS_D_INLINE int32_t to(__half val) template <> DS_D_INLINE int32_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -498,7 +542,11 @@ DS_D_INLINE int16_t to(__half val) template <> DS_D_INLINE int16_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -525,7 +573,11 @@ DS_D_INLINE int8_t to(__half val) template <> DS_D_INLINE int8_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2int_rn(__bfloat162float(val)); +#else return __bfloat162int_rn(val); +#endif } #endif @@ -552,7 +604,11 @@ DS_D_INLINE uint64_t to(__half val) template <> DS_D_INLINE uint64_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2ull_rn(__bfloat162float(val)); +#else return __bfloat162ull_rn(val); +#endif } #endif @@ -579,7 +635,11 @@ DS_D_INLINE uint32_t to(__half val) template <> DS_D_INLINE uint32_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif @@ -606,7 +666,11 @@ DS_D_INLINE uint16_t to(__half val) template <> DS_D_INLINE uint16_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif @@ -633,7 +697,11 @@ DS_D_INLINE uint8_t to(__half val) template <> DS_D_INLINE uint8_t to(__nv_bfloat16 val) { +#ifdef __HIP_PLATFORM_AMD__ + return __float2uint_rn(__bfloat162float(val)); +#else return __bfloat162uint_rn(val); +#endif } #endif diff --git a/csrc/includes/ds_kernel_utils.h b/csrc/includes/ds_kernel_utils.h index f8b16ee6a315..cb8b0b28484e 100644 --- a/csrc/includes/ds_kernel_utils.h +++ b/csrc/includes/ds_kernel_utils.h @@ -13,7 +13,11 @@ used throughout the codebase. #include #include -#ifdef BF16_AVAILABLE +// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to +// after hipifying), because this header is pulled into .cpp translation units +// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration +// instead, see code block below +#if defined(BF16_AVAILABLE) && !defined(__HIP_PLATFORM_AMD__) #include #endif @@ -21,7 +25,9 @@ used throughout the codebase. #define DS_D_INLINE __device__ __forceinline__ #ifdef __HIP_PLATFORM_AMD__ - +#if BF16_AVAILABLE +struct __hip_bfloat16; +#endif // constexpr variant of warpSize for templating constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE; #define HALF_PRECISION_AVAILABLE = 1 diff --git a/csrc/includes/reduction_utils.h b/csrc/includes/reduction_utils.h index eb9afb66a894..68ec106975b6 100644 --- a/csrc/includes/reduction_utils.h +++ b/csrc/includes/reduction_utils.h @@ -9,6 +9,10 @@ #include "ds_kernel_utils.h" #include "memory_access_utils.h" +#if defined(BF16_AVAILABLE) && defined(__HIP_PLATFORM_AMD__) +#include +#endif + namespace cg = cooperative_groups; namespace reduce { @@ -374,7 +378,11 @@ DS_D_INLINE __half init() template <> DS_D_INLINE __nv_bfloat16 init() { +#ifdef __HIP_PLATFORM_AMD__ + constexpr __hip_bfloat16_raw neg_inf = {0xFF80}; +#else constexpr __nv_bfloat16_raw neg_inf = {0xFF80}; +#endif return __nv_bfloat16(neg_inf); } #endif @@ -573,6 +581,24 @@ DS_D_INLINE void _warp(cg::thread_block_tile& warp, T* data) } } +#if defined(__HIP_PLATFORM_AMD__) +template +DS_D_INLINE void _warp_with_type_conversion(cg::thread_block_tile& warp_arg, T* data) +{ + constexpr int elems = sizeof...(Ops); + if constexpr (!(std::is_integral::value || std::is_floating_point::value)) { + float temp_data[elems]; +#pragma unroll + for (int i = 0; i < elems; i++) { temp_data[i] = conversion::to(data[i]); } + _warp(warp_arg, temp_data); +#pragma unroll + for (int i = 0; i < elems; i++) { data[i] = conversion::to(temp_data[i]); } + } else { + _warp(warp_arg, data); + } +} +#endif // defined(__HIP_PLATFORM_AMD__) + /* Implementation for primary block reduction that serves both `block` and `partitioned_block`. @@ -600,7 +626,11 @@ DS_D_INLINE void _block(cg::thread_block& tb, #endif // Always perform warp-scope reduction +#ifdef __HIP_PLATFORM_AMD__ + _warp_with_type_conversion(warp_arg, data); +#else _warp(warp_arg, data); +#endif // If max_warps == 1 let's skip the runtime check if (total_warps != 1) { @@ -624,8 +654,11 @@ DS_D_INLINE void _block(cg::thread_block& tb, } else { init(data); } - +#ifdef __HIP_PLATFORM_AMD__ + _warp_with_type_conversion(warp_arg, data); +#else _warp(warp_arg, data); +#endif #pragma unroll for (int i = 0; i < elems; i++) { diff --git a/op_builder/builder.py b/op_builder/builder.py index eb78c28cfc53..4127249d9e34 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -778,6 +778,7 @@ def nvcc_args(self): '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR ] + self.enable_bf16 = True else: try: nvcc_threads = int(os.getenv("DS_NVCC_THREADS", "")) diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 642aed56a192..3afa74dc31c2 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -75,3 +75,14 @@ def extra_ldflags(self): def include_paths(self): return ['csrc/transformer/inference/includes', 'csrc/includes'] + + def nvcc_args(self): + args = super().nvcc_args() + """BF16 is supported on AMD, but including `cuda_bf16.h` (`` after hipification) + in host-only translation units (*.cpp files) fails because GPU-specific builtins are pulled in with the BF16 type. + This cannot be avoided via forward declarations for this transformer_inference extension, + since `pt_binding.cpp` code explicitly requires the BF16 header, so disable it for now. + """ + if self.is_rocm_pytorch(): + self.enable_bf16 = False + return args diff --git a/tests/unit/ops/fp_quantizer/test_fp_quant.py b/tests/unit/ops/fp_quantizer/test_fp_quant.py index e9baf016310e..0655b0ce26a3 100644 --- a/tests/unit/ops/fp_quantizer/test_fp_quant.py +++ b/tests/unit/ops/fp_quantizer/test_fp_quant.py @@ -57,7 +57,7 @@ def test_fp_quant_meta(dtype): qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size) qtorch_error = (qtorch_out - x).abs().sum() / x.numel() - ds_error = (x_dequantized - x).abs().sum() / x.numel() + ds_error = (x_dequantized - ds_x).abs().sum() / x.numel() assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}" @@ -129,6 +129,6 @@ def test_fp_quant(dtype, q_bits): qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=quant_config.group_size) qtorch_error = (qtorch_out - x).abs().sum() / x.numel() - ds_error = (x_dequantized - x).abs().sum() / x.numel() + ds_error = (x_dequantized - ds_x).abs().sum() / x.numel() assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"