Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c131985
[AMD][ROCm] Improve support of AMD
k-artem Jul 15, 2025
4490ea5
[AMD][ROCm] Fixes review comments
k-artem Jul 25, 2025
77a7e06
[AMD][ROCm] Fixes review comments
k-artem Aug 3, 2025
110d6dd
Merge branch 'master' into improve_support_of_amd_hardware
sfc-gh-truwase Aug 16, 2025
0946828
[AMD][ROCm] Enable BF16 and fixes review's comment
k-artem Aug 18, 2025
c75a4b4
Merge branch 'master' into improve_support_of_amd_hardware
sfc-gh-truwase Aug 19, 2025
f9934bb
Merge branch 'master' into improve_support_of_amd_hardware
sfc-gh-truwase Aug 20, 2025
2d16fb1
Merge branch 'master' into improve_support_of_amd_hardware
loadams Aug 20, 2025
47cb5cc
Merge branch 'master' into improve_support_of_amd_hardware
loadams Aug 20, 2025
a23815a
[AMD][ROCm] Fix format
k-artem Aug 21, 2025
234920e
Merge branch 'master' into improve_support_of_amd_hardware
loadams Aug 28, 2025
4eade1e
Merge branch 'master' into improve_support_of_amd_hardware
loadams Sep 2, 2025
4904d94
Fix BF16 support for AMD
k-artem Oct 13, 2025
4a1d7b7
Remove unnecessary changes
k-artem Oct 13, 2025
7389a8f
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 13, 2025
2b14460
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 13, 2025
2428cb7
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 22, 2025
ab1af24
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 22, 2025
427071c
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 23, 2025
f08fe18
Merge branch 'master' into improve_support_of_amd_hardware
k-artem Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/fp_quantizer/fp_quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// DeepSpeed Team

#include <stdexcept>
#include "context.h"
#include "fp_context.h"
#include "fp_quantize.h"
#include "memory_access_utils.h"
#include "reduction_utils.h"
Expand Down
12 changes: 10 additions & 2 deletions csrc/fp_quantizer/includes/fp_quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@
#include <stdint.h>

#include <cuda_fp16.h>

#ifdef BF16_AVAILABLE
// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to
// <hip/hip_bfloat16.h> after hipifying), because this header is pulled into .cpp translation units
// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration
// instead, see code block below
#if defined(BF16_AVAILABLE)
#if !defined(__HIP_PLATFORM_AMD__)
#include <cuda_bf16.h>
#else
struct __hip_bfloat16;
#endif
#endif

#include <cuda_runtime_api.h>
#include <stdio.h>

Expand Down
68 changes: 68 additions & 0 deletions csrc/includes/conversion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,42 +363,74 @@ DS_D_INLINE __nv_bfloat16 to(float val)
template <>
DS_D_INLINE __nv_bfloat16 to(int64_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __double2bfloat16(__ll2double_rn(val));
#else
return __ll2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int32_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));
#else
return __int2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int16_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));
#else
return __short2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(int8_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__int2float_rn(val));
#else
return __int2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint64_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __double2bfloat16(__ull2double_rn(val));
#else
return __ull2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint32_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __uint2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint16_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __ushort2bfloat16_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint8_t val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2bfloat16(__uint2float_rn(val));
#else
return __uint2bfloat16_rn(val);
#endif
}
#endif

Expand All @@ -412,7 +444,11 @@ DS_D_INLINE __nv_bfloat162 to(float2 val)
template <>
DS_D_INLINE __nv_bfloat162 to(float val)
{
#ifdef __HIP_PLATFORM_AMD__
return __bfloat162bfloat162(__float2bfloat16(val));
#else
return __float2bfloat162_rn(val);
#endif
}
template <>
DS_D_INLINE __nv_bfloat162 to(__half2 val)
Expand Down Expand Up @@ -444,7 +480,11 @@ DS_D_INLINE int64_t to(__half val)
template <>
DS_D_INLINE int64_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2ll_rn(__bfloat162float(val));
#else
return __bfloat162ll_rn(val);
#endif
}
#endif

Expand All @@ -471,7 +511,11 @@ DS_D_INLINE int32_t to(__half val)
template <>
DS_D_INLINE int32_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -498,7 +542,11 @@ DS_D_INLINE int16_t to(__half val)
template <>
DS_D_INLINE int16_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -525,7 +573,11 @@ DS_D_INLINE int8_t to(__half val)
template <>
DS_D_INLINE int8_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2int_rn(__bfloat162float(val));
#else
return __bfloat162int_rn(val);
#endif
}
#endif

Expand All @@ -552,7 +604,11 @@ DS_D_INLINE uint64_t to(__half val)
template <>
DS_D_INLINE uint64_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2ull_rn(__bfloat162float(val));
#else
return __bfloat162ull_rn(val);
#endif
}
#endif

Expand All @@ -579,7 +635,11 @@ DS_D_INLINE uint32_t to(__half val)
template <>
DS_D_INLINE uint32_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand All @@ -606,7 +666,11 @@ DS_D_INLINE uint16_t to(__half val)
template <>
DS_D_INLINE uint16_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand All @@ -633,7 +697,11 @@ DS_D_INLINE uint8_t to(__half val)
template <>
DS_D_INLINE uint8_t to(__nv_bfloat16 val)
{
#ifdef __HIP_PLATFORM_AMD__
return __float2uint_rn(__bfloat162float(val));
#else
return __bfloat162uint_rn(val);
#endif
}
#endif

Expand Down
10 changes: 8 additions & 2 deletions csrc/includes/ds_kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,21 @@ used throughout the codebase.
#include <cuda.h>
#include <cuda_fp16.h>

#ifdef BF16_AVAILABLE
// Note: BF16 support on AMD but we have to exclude here cuda_bf16.h (which turn to
// <hip/hip_bfloat16.h> after hipifying), because this header is pulled into .cpp translation units
// that are compiled by a host-only compiler, which triggers build errors. Added forward declaration
// instead, see code block below
#if defined(BF16_AVAILABLE) && !defined(__HIP_PLATFORM_AMD__)
#include <cuda_bf16.h>
#endif

#define DS_HD_INLINE __host__ __device__ __forceinline__
#define DS_D_INLINE __device__ __forceinline__

#ifdef __HIP_PLATFORM_AMD__

#if BF16_AVAILABLE
struct __hip_bfloat16;
#endif
// constexpr variant of warpSize for templating
constexpr int hw_warp_size = ROCM_WAVEFRONT_SIZE;
#define HALF_PRECISION_AVAILABLE = 1
Expand Down
35 changes: 34 additions & 1 deletion csrc/includes/reduction_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"

#if defined(BF16_AVAILABLE) && defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_bfloat16.h>
#endif

namespace cg = cooperative_groups;

namespace reduce {
Expand Down Expand Up @@ -374,7 +378,11 @@ DS_D_INLINE __half init<ROpType::Max>()
template <>
DS_D_INLINE __nv_bfloat16 init<ROpType::Max>()
{
#ifdef __HIP_PLATFORM_AMD__
constexpr __hip_bfloat16_raw neg_inf = {0xFF80};
#else
constexpr __nv_bfloat16_raw neg_inf = {0xFF80};
#endif
return __nv_bfloat16(neg_inf);
}
#endif
Expand Down Expand Up @@ -573,6 +581,24 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
}
}

#if defined(__HIP_PLATFORM_AMD__)
template <int reduce_width, typename T, ROpType... Ops>
DS_D_INLINE void _warp_with_type_conversion(cg::thread_block_tile<hw_warp_size>& warp_arg, T* data)
{
constexpr int elems = sizeof...(Ops);
if constexpr (!(std::is_integral<T>::value || std::is_floating_point<T>::value)) {
float temp_data[elems];
#pragma unroll
for (int i = 0; i < elems; i++) { temp_data[i] = conversion::to<float>(data[i]); }
_warp<float, Ops...>(warp_arg, temp_data);
#pragma unroll
for (int i = 0; i < elems; i++) { data[i] = conversion::to<T>(temp_data[i]); }
} else {
_warp<T, Ops...>(warp_arg, data);
}
}
#endif // defined(__HIP_PLATFORM_AMD__)

/*
Implementation for primary block reduction that serves both `block` and
`partitioned_block`.
Expand Down Expand Up @@ -600,7 +626,11 @@ DS_D_INLINE void _block(cg::thread_block& tb,
#endif

// Always perform warp-scope reduction
#ifdef __HIP_PLATFORM_AMD__
_warp_with_type_conversion<hw_warp_size, T, Ops...>(warp_arg, data);
#else
_warp<T, Ops...>(warp_arg, data);
#endif

// If max_warps == 1 let's skip the runtime check
if (total_warps != 1) {
Expand All @@ -624,8 +654,11 @@ DS_D_INLINE void _block(cg::thread_block& tb,
} else {
init<Ops...>(data);
}

#ifdef __HIP_PLATFORM_AMD__
_warp_with_type_conversion<total_warps, T, Ops...>(warp_arg, data);
#else
_warp<T, Ops..., total_warps>(warp_arg, data);
#endif

#pragma unroll
for (int i = 0; i < elems; i++) {
Expand Down
1 change: 1 addition & 0 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,7 @@ def nvcc_args(self):
'-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR,
'-DROCM_VERSION_MINOR=%s' % ROCM_MINOR
]
self.enable_bf16 = True
else:
try:
nvcc_threads = int(os.getenv("DS_NVCC_THREADS", ""))
Expand Down
11 changes: 11 additions & 0 deletions op_builder/transformer_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@ def extra_ldflags(self):

def include_paths(self):
return ['csrc/transformer/inference/includes', 'csrc/includes']

def nvcc_args(self):
args = super().nvcc_args()
"""BF16 is supported on AMD, but including `cuda_bf16.h` (`<hip/hip_bfloat16.h>` after hipification)
in host-only translation units (*.cpp files) fails because GPU-specific builtins are pulled in with the BF16 type.
This cannot be avoided via forward declarations for this transformer_inference extension,
since `pt_binding.cpp` code explicitly requires the BF16 header, so disable it for now.
"""
if self.is_rocm_pytorch():
self.enable_bf16 = False
return args
4 changes: 2 additions & 2 deletions tests/unit/ops/fp_quantizer/test_fp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_fp_quant_meta(dtype):

qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=group_size)
qtorch_error = (qtorch_out - x).abs().sum() / x.numel()
ds_error = (x_dequantized - x).abs().sum() / x.numel()
ds_error = (x_dequantized - ds_x).abs().sum() / x.numel()

assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"

Expand Down Expand Up @@ -129,6 +129,6 @@ def test_fp_quant(dtype, q_bits):
qtorch_out = qtorch_quantize(x, exp_bits=exp_bits, man_bits=man_bits, group_size=quant_config.group_size)

qtorch_error = (qtorch_out - x).abs().sum() / x.numel()
ds_error = (x_dequantized - x).abs().sum() / x.numel()
ds_error = (x_dequantized - ds_x).abs().sum() / x.numel()

assert 0.0004 > abs(qtorch_error.item() - ds_error.item()), f"failed on iteration {i}"