|
| 1 | +#ifndef _hip_compat_cuh |
| 2 | +#define _hip_compat_cuh |
| 3 | + |
| 4 | +// Workaround for a bug in hipamd, backported from upstream. |
| 5 | +__device__ __forceinline__ __half __compat_hrcp(__half x) { |
| 6 | + return __half_raw{ |
| 7 | + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; |
| 8 | +} |
| 9 | + |
| 10 | +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { |
| 11 | + return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), |
| 12 | + static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; |
| 13 | +} |
| 14 | + |
| 15 | +#define hrcp __compat_hrcp |
| 16 | +#define h2rcp __compat_h2rcp |
| 17 | + |
| 18 | +// Workaround for hipify_python using rocblas instead of hipblas. |
| 19 | +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, |
| 20 | + hipblasOperation_t transA, |
| 21 | + hipblasOperation_t transB, |
| 22 | + int m, |
| 23 | + int n, |
| 24 | + int k, |
| 25 | + const half* alpha, |
| 26 | + const half* AP, |
| 27 | + int lda, |
| 28 | + const half* BP, |
| 29 | + int ldb, |
| 30 | + const half* beta, |
| 31 | + half* CP, |
| 32 | + int ldc) { |
| 33 | + return hipblasHgemm(handle, transA, transB, m, n, k, |
| 34 | + reinterpret_cast<const hipblasHalf *>(alpha), |
| 35 | + reinterpret_cast<const hipblasHalf *>(AP), lda, |
| 36 | + reinterpret_cast<const hipblasHalf *>(BP), ldb, |
| 37 | + reinterpret_cast<const hipblasHalf *>(beta), |
| 38 | + reinterpret_cast<hipblasHalf *>(CP), ldc); |
| 39 | +} |
| 40 | + |
| 41 | +#define rocblas_handle hipblasHandle_t |
| 42 | +#define rocblas_operation_none HIPBLAS_OP_N |
| 43 | +#define rocblas_hgemm __compat_hipblasHgemm |
| 44 | + |
| 45 | +#endif |
0 commit comments