diff --git a/CMakeLists.txt b/CMakeLists.txt index 770b4ba30..c10bb6e62 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -115,8 +115,13 @@ if(BUILD_CUDA) message(STATUS "CMake < 3.23.0; determining CUDA architectures supported...") # 11.4+ supports these at a minimum. - set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80 86 87) - set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80) + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS "13.0") + set(CMAKE_CUDA_ARCHITECTURES_ALL 50 52 53 60 61 62 70 72 75 80 86 87) + set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 50 60 70 80) + else() + set(CMAKE_CUDA_ARCHITECTURES_ALL 75 80 86 87) + set(CMAKE_CUDA_ARCHITECTURES_ALL_MAJOR 70 80) + endif() # CUDA 11.8 adds support for Ada and Hopper. if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.8") diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 97b80f050..c988e7c6a 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -416,7 +417,7 @@ __global__ void kQuantizeBlockwise( for (int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cuda::maximum<>{}, valid_items); if (threadIdx.x == 0) { smem_absmax_value[0] = 1.0f / local_abs_max; @@ -1002,12 +1003,12 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b } __syncthreads(); - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cuda::maximum<>{}, valid_items); __syncthreads(); - local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cuda::maximum<>{}, valid_items); if (unorm != NULL) { __syncthreads(); - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cuda::std::plus<>{}, valid_items); } if (threadIdx.x == 0) { @@ -1213,13 +1214,13 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b } __syncthreads(); - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cuda::maximum<>{}, valid_items); if (threadIdx.x == 0) { atomicMax(&new_max1[0], local_max_s1); } if (unorm != NULL) { __syncthreads(); - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cuda::std::plus<>{}, valid_items); if (threadIdx.x == 0) { atomicAdd(&unorm[0], local_unorm); } @@ -1524,11 +1525,11 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise( } // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); - new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cuda::maximum<>{}); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cuda::maximum<>{}); if (OPTIMIZER == ADEMAMIX) { - new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max()); + new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cuda::maximum<>{}); } if (threadIdx.x == 0) { @@ -1737,7 +1738,7 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise( } // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cuda::maximum<>{}); if (threadIdx.x == 0) smem_exchange1[0] = new_local_abs_max1; @@ -1843,7 +1844,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ } // Reduce thread-local absmax across the block. - const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cuda::maximum<>{}, cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = smem_row_absmax = row_absmax; @@ -1898,7 +1899,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ // Reduce thread-local absmax across the block. // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY - const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cuda::maximum<>{}, cols); if (threadIdx.x == 0) { // Save our block's absmax to shared memory for the quantization step. rowStats[row_id] = row_absmax; diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 9c4cab9cc..be2c6c5dc 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -5,6 +5,27 @@ #if BUILD_CUDA #include +#include + +#if CUDART_VERSION >= 13000 +static inline cudaError_t bnb_cudaMemPrefetchAsync(const void* ptr, + size_t bytes, + int device, + cudaStream_t stream) { + cudaMemLocation loc{}; + loc.type = cudaMemLocationTypeDevice; + loc.id = device; + // flags = 0 + return cudaMemPrefetchAsync(ptr, bytes, loc, 0u, stream); +} +#else +static inline cudaError_t bnb_cudaMemPrefetchAsync(const void* ptr, + size_t bytes, + int device, + cudaStream_t stream) { + return cudaMemPrefetchAsync(ptr, bytes, device, stream); +} +#endif #endif #if BUILD_HIP #include @@ -623,7 +644,7 @@ void cprefetch(void* ptr, size_t bytes, int device) { if (hasPrefetch == 0) return; - CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(bnb_cudaMemPrefetchAsync(ptr, bytes, device, 0)); CUDA_CHECK_RETURN(cudaPeekAtLastError()); }