@@ -31,9 +31,11 @@ using namespace std;
3131#define CUDART_INF_FP16 __ushort_as_half ((unsigned short )0x7C00U)
3232#endif
3333
34+ #if __CUDA_ARCH__ >= 800
3435#ifndef CUDART_INF_BF16
3536#define CUDART_INF_BF16 __ushort_as_bfloat16 ((unsigned short )0x7F80U)
3637#endif
38+ #endif
3739
3840constexpr int32_t BITS_PER_BLOCK = 32 ;
3941constexpr int32_t THREADS_PER_THREAD_BLOCK = 256 ;
@@ -50,11 +52,13 @@ __device__ __half NegativeInfinity<__half>()
5052 return -CUDART_INF_FP16;
5153}
5254
55+ #if __CUDA_ARCH__ >= 800
5356template <>
5457__device__ __nv_bfloat16 NegativeInfinity<__nv_bfloat16>()
5558{
5659 return -CUDART_INF_BF16;
5760}
61+ #endif
5862
5963template <typename T, typename PackedT>
6064__device__ PackedT PackedNegativeInfinity ()
@@ -217,13 +221,15 @@ void ApplyTokenBitmaskInplace(Tensor logits, Tensor bitmask, std::optional<Tenso
217221 logits.data <half_t >(), bitmask.data <int32_t >(), indices_ptr, vocab_size, 0 , 0 , num_rows);
218222 break ;
219223 }
224+ #if __CUDA_ARCH__ >= 800
220225 case kBfloat16 : {
221226 ApplyTokenBitmaskInplaceDispatchToPackedT (
222227 logits.data <bfloat16_t >(), bitmask.data <int32_t >(), indices_ptr, vocab_size, 0 , 0 , num_rows);
223228 break ;
224229 }
230+ #endif
225231 default :
226- TM_CHECK (false ) << " logits dtype must be float, half or bfloat16." ;
232+ TM_CHECK (false ) << " logits dtype must be float, float16 or bfloat16." ;
227233 break ;
228234 }
229235}
0 commit comments