From ad99063fb51f3b2ce0819b2878a889e3e5c3fcbb Mon Sep 17 00:00:00 2001 From: Bodhi Hu Date: Fri, 14 Mar 2025 17:28:03 +0800 Subject: [PATCH 1/4] MUSA: enable fastfp16, correct warp reduce impl and other changes --- ggml/src/ggml-cuda/common.cuh | 66 +++++++++++++++++++++++++++-- ggml/src/ggml-cuda/ggml-cuda.cu | 49 +++++++++++++++++++--- ggml/src/ggml-cuda/mmvq.cu | 21 ++++++++++ ggml/src/ggml-cuda/quantize.cu | 5 ++- ggml/src/ggml-cuda/vecdotq.cuh | 26 +++++++++++- ggml/src/ggml-cuda/vendors/musa.h | 2 + ggml/src/ggml-musa/CMakeLists.txt | 7 ++++ ggml/src/ggml-musa/ggml-musa.h | 69 +++++++++++++++++++++++++++++++ 8 files changed, 232 insertions(+), 13 deletions(-) create mode 100644 ggml/src/ggml-musa/ggml-musa.h diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4d4ac47c034e1..663c57102d4a7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1,4 +1,5 @@ #pragma once +#define GGML_USE_MUSA #include "ggml.h" #include "ggml-cuda.h" @@ -223,7 +224,11 @@ static bool fast_fp16_available(const int cc) { // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fast_fp16_hardware_available(const int cc) { +#ifdef GGML_USE_MUSA + return true; +#else return cc >= GGML_CUDA_CC_PASCAL && cc != 610; +#endif // GGML_USE_MUSA } // Any FP16 tensor core instructions are available for ggml code. @@ -254,6 +259,8 @@ static bool cp_async_available(const int cc) { static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) return __AMDGCN_WAVEFRONT_SIZE; +#elif defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY2 + return 128; #else return 32; #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) @@ -284,22 +291,30 @@ static __device__ void no_device_code( template static __device__ __forceinline__ int warp_reduce_sum(int x) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE return __reduce_add_sync(0xffffffff, x); #else #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { +#ifdef GGML_USE_MUSA + x += musa_shfl_xor_sync(x, offset); +#else x += __shfl_xor_sync(0xffffffff, x, offset, width); +#endif // GGML_USE_MUSA } return x; -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE } template static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { +#ifdef GGML_USE_MUSA + x += musa_shfl_xor_sync(x, offset); +#else x += __shfl_xor_sync(0xffffffff, x, offset, width); +#endif // GGML_USE_MUSA } return x; } @@ -308,8 +323,13 @@ template static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { +#ifdef GGML_USE_MUSA + a.x += musa_shfl_xor_sync(a.x, offset); + a.y += musa_shfl_xor_sync(a.y, offset); +#else a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); +#endif // GGML_USE_MUSA } return a; } @@ -319,7 +339,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #ifdef FP16_AVAILABLE #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { +#ifdef GGML_USE_MUSA + a = __hadd2(a, musa_shfl_xor_sync(a, offset)); +#else a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); +#endif // GGML_USE_MUSA } return a; @@ -333,7 +357,11 @@ template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { +#ifdef GGML_USE_MUSA + x = fmaxf(x, musa_shfl_xor_sync(x, offset)); +#else x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width)); +#endif // GGML_USE_MUSA } return x; } @@ -373,16 +401,46 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal template static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) || defined(GGML_USE_MUSA) #pragma unroll for (int offset = width/2; offset > 0; offset >>= 1) { +#ifdef GGML_USE_MUSA + x = ggml_cuda_hmax2(x, musa_shfl_xor_sync(x, offset)); +#else x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width)); +#endif // GGML_USE_MUSA } return x; #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) || defined(GGML_USE_MUSA) +} + +template +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#ifdef GGML_USE_MUSA +#pragma unroll + for (int offset = qk_size/2; offset > 0; offset >>= 1) { + x += musa_shfl_xor_sync(x, offset); + } + return x; +#else + return warp_reduce_sum(x); +#endif // GGML_USE_MUSA +} + +template +static __device__ __forceinline__ float warp_reduce_max(float x) { +#ifdef GGML_USE_MUSA +#pragma unroll + for (int offset = qk_size/2; offset > 0; offset >>= 1) { + x = fmaxf(x, musa_shfl_xor_sync(x, offset)); + } + return x; +#else + return warp_reduce_max(x); +#endif // GGML_USE_MUSA } #if CUDART_VERSION < CUDART_HMASK diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 497de37be8210..0fb896f364b44 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1358,11 +1358,22 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } +// CLI: +// ./build_musa/bin/llama-cli -m ~/models/deepseek-r1-7B-Q4_K_M.gguf -ngl 28 -t 8 -p "摩尔线程是一家HQ北京的国产GPU 及 AI 公司,他们正在" -n 10 -no-cnv --cache-type-k q8_0 -fa +static double ticks_total, ticks_quant, ticks_op; +// stats: | ticks_total | ticks_quant | ticks_mul_mat +// base | | | +// | | | +// dnn-shfl | 2.119177099 | 0.019638826 | 2.096463255 +// | | 0.92% | 98.93% static void ggml_cuda_op_mul_mat( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, quantize_cuda_t quantize_src1) { + std::chrono::system_clock::time_point tick_start = std::chrono::system_clock::now(); + std::chrono::system_clock::time_point _tick; + const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; @@ -1504,6 +1515,7 @@ static void ggml_cuda_op_mul_mat( dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1)); } + _tick = std::chrono::system_clock::now(); if (quantize_src1) { size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; if (quantize_src1 == quantize_mmq_q8_1_cuda) { @@ -1516,6 +1528,7 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(cudaGetLastError()); } } + ticks_quant += std::chrono::duration(std::chrono::system_clock::now() - _tick).count(); if (dst_on_device) { dev[id].dst_dd = (float *) dst->data; @@ -1606,20 +1619,24 @@ static void ggml_cuda_op_mul_mat( GGML_ABORT("fatal error"); } + _tick = std::chrono::system_clock::now(); if (quantize_src1 && !src1_is_contiguous) { quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } + ticks_quant += std::chrono::duration((std::chrono::system_clock::now() - _tick)).count(); if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) { CUDA_CHECK(ggml_cuda_cpy_tensor_2d( src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); } + _tick = std::chrono::system_clock::now(); // do the computation op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); + ticks_op += std::chrono::duration((std::chrono::system_clock::now() - _tick)).count(); // copy dst to host or other device if necessary if (!dst_on_device) { @@ -1666,6 +1683,14 @@ static void ggml_cuda_op_mul_mat( } } } + + // ticks_total += std::chrono::duration(std::chrono::system_clock::now() - tick_start).count(); + // FILE *stat_file = fopen("cuda_op_mul_mat_stats.log", "a"); + // fprintf(stat_file, + // ">> ticks_total = %2.9f, ticks_quant = %2.9f, ticks_op = %2.9f\n", + // ticks_total, ticks_quant, ticks_op + // ); + // fclose(stat_file); } static __global__ void k_compute_batched_ptrs( @@ -1884,23 +1909,41 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); + // stats: | mmv_cnt | mm_batched_cublas_cnt | op_mmv_cnt | op_mmvq_cnt | op_mmq_cnt | op_mm_cublas_cnt + // FA=ON: | 0.0K | 0.0K | 0.0K | 7.2K | 0.0K | 0.0K + // FA=OFF: | 1.016K | 0.0K | 0.0K | 8.256K | 0.0K | 0.016K + static int mmv_cnt = 0, mm_batched_cublas_cnt = 0, op_mmv_cnt = 0, op_mmvq_cnt = 0, op_mmq_cnt = 0, op_mm_cublas_cnt = 0; + if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); + mmv_cnt++; } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // general KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); + mm_batched_cublas_cnt++; } else if (use_mul_mat_vec) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); + op_mmv_cnt++; } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); + op_mmvq_cnt++; } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); + op_mmq_cnt++; } else { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); + op_mm_cublas_cnt++; } + + // FILE *file = fopen("cuda_op_stats.log", "a"); + // fprintf(file, + // ">> mmv_cnt = %4.5fK, mm_batched_cublas_cnt = %4.5fK, op_mmv_cnt = %4.5fK, op_mmvq_cnt = %4.5fK, op_mmq_cnt = %4.5fK, op_mm_cublas_cnt = %4.5fK\n", + // mmv_cnt/1000.0, mm_batched_cublas_cnt/1000.0, op_mmv_cnt/1000.0, op_mmvq_cnt/1000.0, op_mmq_cnt/1000.0, op_mm_cublas_cnt/1000.0 + // ); + // fclose(file); } struct mmid_row_mapping { @@ -3008,12 +3051,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } -#ifdef GGML_USE_MUSA - if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 && - !ggml_is_transposed(a) && !ggml_is_transposed(b)) { - return false; - } -#endif // GGML_USE_MUSA switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index a7d518a574ddc..0c0cec85d2896 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -47,6 +47,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { 1; } +<<<<<<< HEAD enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, MMVQ_PARAMETERS_GCN, @@ -126,6 +127,9 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int ta } return 1; } +======= +static __device__ uint64_t ticks_total = 0, ticks_vecdotq = 0, ticks_reduce_sum = 0; +>>>>>>> c9e3fd9c (MUSA: enable fastfp16, correct warp reduce impl and other changes) template // tell the compiler to use as many registers as it wants, see nwarps definition below @@ -134,6 +138,8 @@ static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + uint64_t tick_start = clock64(), _clock, _ticks_vecdotq; + constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; constexpr int vdr = get_vdr_mmvq(type); @@ -155,6 +161,7 @@ static __global__ void mul_mat_vec_q( const block_q8_1 * y = (const block_q8_1 *) vy; + _clock = clock64(); for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx @@ -169,6 +176,7 @@ static __global__ void mul_mat_vec_q( } } } + _ticks_vecdotq = clock64() - _clock; __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { @@ -185,6 +193,7 @@ static __global__ void mul_mat_vec_q( return; } + _clock = clock64(); // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_y; ++j) { @@ -201,6 +210,13 @@ static __global__ void mul_mat_vec_q( dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; } } + + // atomicAddUint64(&ticks_vecdotq, _ticks_vecdotq); + // atomicAddUint64(&ticks_reduce_sum, clock64() - _clock); + // atomicAddUint64(&ticks_total, clock64() - tick_start); + // printf(">> ticks_total = %12llu, ticks_vecdotq = %12llu, ticks_reduce_sum = %12llu\n", + // ticks_total, ticks_vecdotq, ticks_reduce_sum + // ); } static std::pair calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) { @@ -222,6 +238,11 @@ static void mul_mat_vec_q_cuda( const int warp_size = ggml_cuda_info().devices[device].warp_size; const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + // printf( + // ">> nblocks = %5d, nwarps = %5d, ncols_y = %5d, nrows_y = %5d, ncols_x = %5d, nrows_x = %5d, rows_per_cuda_block = %5d\n", + // nblocks, nwarps, ncols_y, nrows_y, ncols_x, nrows_x, rows_per_cuda_block + // ); + switch (ncols_y) { case 1: { diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 1702e4ce2feba..d4445a1ce16d6 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -3,6 +3,7 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + const int warp_size = ggml_cuda_get_physical_warp_size(); if (ix0 >= kx0_padded) { return; @@ -21,8 +22,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest float amax = fabsf(xi); float sum = xi; - amax = warp_reduce_max(amax); - sum = warp_reduce_sum(sum); + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40091a0ef07b4..37a7669ee4c5b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -668,9 +668,21 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } +static __device__ uint64_t __ticks_total = 0, __ticks1 = 0, __ticks2 = 0; +static __device__ void atomicAddUint64(uint64_t *address, uint64_t val) { + atomicAdd((unsigned long long*)address, (unsigned long long)val); +} static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + uint64_t tick_start = clock64(); + + // __shared__ block_q4_K bq4_K_shm; + // if (threadIdx.x == 0) { + // bq4_K_shm = *((const block_q4_K *)vbq + kbx); + // } + // __syncthreads(); + // const block_q4_K * bq4_K = &bq4_K_shm; const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx; int v[2]; @@ -710,8 +722,20 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( u[2*i+0] = q8[0]; u[2*i+1] = q8[4]; } + uint64_t _tick1 = clock64(); + + float ret = vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + + // uint64_t tick_end = clock64(); + + // atomicAddUint64(&__ticks1, _tick1 - tick_start); + // atomicAddUint64(&__ticks2, tick_end - _tick1); + // atomicAddUint64(&__ticks_total, tick_end - tick_start); + // printf(">> __ticks_total = %12llu, __ticks1 = %12llu, __ticks2 = %12llu\n", + // __ticks_total, __ticks1, __ticks2 + // ); - return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + return ret; } static __device__ __forceinline__ float vec_dot_q5_K_q8_1( diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 6cc1b69ee3390..5abf7b5980b5a 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -5,6 +5,8 @@ #include #include #include +#include "ggml-musa/ggml-musa.h" + #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 166970ca6bfb8..740d1e9fb22a8 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -50,6 +50,13 @@ if (MUSAToolkit_FOUND) set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) foreach(SOURCE ${GGML_SOURCES_MUSA}) set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu") + if (GGML_MUSA_CC_SHOW_RES_USAGE) + set(COMPILE_FLAGS "${COMPILE_FLAGS} -resource-usage") + endif() + if (GGML_MUSA_CC_EMIT_IR) + set(MUSA_ARCHITECTURES "22") + set(COMPILE_FLAGS "${COMPILE_FLAGS} -S --cuda-device-only -emit-llvm") + endif() foreach(ARCH ${MUSA_ARCHITECTURES}) set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}") endforeach() diff --git a/ggml/src/ggml-musa/ggml-musa.h b/ggml/src/ggml-musa/ggml-musa.h new file mode 100644 index 0000000000000..13843155772e4 --- /dev/null +++ b/ggml/src/ggml-musa/ggml-musa.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include +#include +#include + +// FULL_MASK +#define MASK_SHFL_128 ((~(128 - 1)) & 0x7f) << 7 | (128 - 1) +#define MASK_SHFL_64 ((~(64 - 1)) & 0x7f) << 7 | (64 - 1) +#define MASK_SHFL_32 ((~(32 - 1)) & 0x7f) << 7 | (32 - 1) +#define MASK_SHFL_16 ((~(16 - 1)) & 0x7f) << 7 | (16 - 1) +#define MASK_SHFL_8 ((~(8 - 1)) & 0x7f) << 7 | (8 - 1) +#define MASK_SHFL_4 ((~(4 - 1)) & 0x7f) << 7 | (4 - 1) +#define MASK_SHFL_2 ((~(2 - 1)) & 0x7f) << 7 | (2 - 1) + +#define MASK_SHFL_UP_128 ((~(128 - 1)) & 0x7f) << 7 +#define MASK_SHFL_UP_64 ((~(64 - 1)) & 0x7f) << 7 +#define MASK_SHFL_UP_32 ((~(32 - 1)) & 0x7f) << 7 +#define MASK_SHFL_UP_16 ((~(16 - 1)) & 0x7f) << 7 +#define MASK_SHFL_UP_8 ((~(8 - 1)) & 0x7f) << 7 +#define MASK_SHFL_UP_4 ((~(4 - 1)) & 0x7f) << 7 +#define MASK_SHFL_UP_2 ((~(2 - 1)) & 0x7f) << 7 + +template +__device__ __forceinline__ T musa_shfl_xor_sync(T val, int lane_mask) { +#if (defined(__MUSA_ARCH__) && (__MUSA_ARCH__ > 220)) // MUSIFY_EXCL_LINE + return __shfl_xor_sync(0xffffffff, val, lane_mask, width); +#elif (defined(__MUSA_ARCH__) && (__MUSA_ARCH__ > 210)) // MUSIFY_EXCL_LINE + static_assert((width >= 2) && (width <= 128) && ((width & (width - 1)) == 0)); + + auto shfl_func = [&](int& var) { + if constexpr (width == 128) { + return __musa_shfl_xor_sync_i32(var, lane_mask & 0x7f, MASK_SHFL_128); + } else if constexpr (width == 64) { + return __musa_shfl_xor_sync_i32(var, lane_mask & 0x3f, MASK_SHFL_64); + } else if constexpr (width == 32) { + return __musa_shfl_xor_sync_i32(var, lane_mask & 0x1f, MASK_SHFL_32); + } else if constexpr (width == 16) { + return __musa_shfl_xor_sync_i32(var, lane_mask & 0xf, MASK_SHFL_16); + } else if constexpr (width == 8) { + return __musa_shfl_xor_sync_i32(var, lane_mask & 0x7, MASK_SHFL_8); + } else if constexpr (width == 4) { + return __musa_shfl_xor_sync_i32(var, lane_mask & 0x3, MASK_SHFL_4); + } else if constexpr (width == 2) { + return __musa_shfl_xor_sync_i32(var, lane_mask & 0x1, MASK_SHFL_2); + } + }; + + if constexpr (sizeof(T) == 4) { + int var = *(reinterpret_cast(&val)); + int ret = shfl_func(var); + return *(reinterpret_cast(&ret)); + } else { + struct __Bits { + int __a, __b; + }; + __Bits __tmp; + memcpy(&__tmp, &val, sizeof(val)); + __tmp.__a = shfl_func(__tmp.__a); + __tmp.__b = shfl_func(__tmp.__b); + int64_t ret = *(reinterpret_cast(&__tmp)); + return *(reinterpret_cast(&ret)); + } +#else + return 0; +#endif +} From 30839a3a5145a5e23622611b1f6514f3b9305827 Mon Sep 17 00:00:00 2001 From: Bodhi Hu Date: Fri, 14 Mar 2025 17:34:11 +0800 Subject: [PATCH 2/4] update comments --- ggml/src/ggml-cuda/ggml-cuda.cu | 7 +------ ggml/src/ggml-cuda/mmvq.cu | 5 +---- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0fb896f364b44..a10de7196d0ac 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1358,14 +1358,9 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync( #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } -// CLI: -// ./build_musa/bin/llama-cli -m ~/models/deepseek-r1-7B-Q4_K_M.gguf -ngl 28 -t 8 -p "摩尔线程是一家HQ北京的国产GPU 及 AI 公司,他们正在" -n 10 -no-cnv --cache-type-k q8_0 -fa static double ticks_total, ticks_quant, ticks_op; // stats: | ticks_total | ticks_quant | ticks_mul_mat -// base | | | -// | | | -// dnn-shfl | 2.119177099 | 0.019638826 | 2.096463255 -// | | 0.92% | 98.93% +// | 2.119177099 | 0.019638826 | 2.096463255 static void ggml_cuda_op_mul_mat( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 0c0cec85d2896..4641781f067a4 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -47,7 +47,6 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { 1; } -<<<<<<< HEAD enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, MMVQ_PARAMETERS_GCN, @@ -127,10 +126,8 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int ta } return 1; } -======= -static __device__ uint64_t ticks_total = 0, ticks_vecdotq = 0, ticks_reduce_sum = 0; ->>>>>>> c9e3fd9c (MUSA: enable fastfp16, correct warp reduce impl and other changes) +static __device__ uint64_t ticks_total = 0, ticks_vecdotq = 0, ticks_reduce_sum = 0; template // tell the compiler to use as many registers as it wants, see nwarps definition below __launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) From fa00f572ce9a5fef594553c2e509f55709c8bf23 Mon Sep 17 00:00:00 2001 From: Bodhi Hu Date: Mon, 17 Mar 2025 19:21:32 +0800 Subject: [PATCH 3/4] update perf profiling --- ggml/src/ggml-cuda/common.cuh | 22 ++++++++++++++++++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 24 ++++++++++++------------ ggml/src/ggml-cuda/mmvq.cu | 21 ++++++++++++++------- ggml/src/ggml-cuda/vecdotq.cuh | 32 +++++++++++++++++++------------- 4 files changed, 67 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 663c57102d4a7..9469921dcdc1a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -845,3 +845,25 @@ struct ggml_backend_cuda_context { return pool(device); } }; + +// #define GGML_PERF_ON + +#ifdef GGML_PERF_ON +#define GGML_PERF_CLOCK(t) std::chrono::system_clock::time_point t = std::chrono::system_clock::now() +#define GGML_PERF_CLOCK_NOW(t) t = std::chrono::system_clock::now() +#define GGML_PERF_CLOCK_COUNT(t) std::chrono::duration(std::chrono::system_clock::now() - t).count() +#define GGML_PERF_CLOCK_COUNT_ADD(s, t) s += std::chrono::duration(std::chrono::system_clock::now() - t).count() +#define GGML_PERF_GPU_CLOCK(t) uint64_t t = clock64() +#define GGML_PERF_GPU_CLOCK_NOW(t) t = clock64() +#define GGML_PERF_GPU_CLOCK_COUNT(t) clock64() - t +#define GGML_PERF_GPU_CLOCK_COUNT_ADD(s, t) s += (clock64() - t) +#else +#define GGML_PERF_CLOCK(t) +#define GGML_PERF_CLOCK_NOW(t) +#define GGML_PERF_CLOCK_COUNT(t) +#define GGML_PERF_CLOCK_COUNT_ADD(s, t) +#define GGML_PERF_GPU_CLOCK(t) +#define GGML_PERF_GPU_CLOCK_NOW(t) +#define GGML_PERF_GPU_CLOCK_COUNT(t) +#define GGML_PERF_GPU_CLOCK_COUNT_ADD(s, t) +#endif // GGML_PERF_ON diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a10de7196d0ac..506b8d1b93465 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1366,8 +1366,8 @@ static void ggml_cuda_op_mul_mat( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, quantize_cuda_t quantize_src1) { - std::chrono::system_clock::time_point tick_start = std::chrono::system_clock::now(); - std::chrono::system_clock::time_point _tick; + GGML_PERF_CLOCK(tick_start); + GGML_PERF_CLOCK(_tick); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -1510,7 +1510,7 @@ static void ggml_cuda_op_mul_mat( dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1)); } - _tick = std::chrono::system_clock::now(); + GGML_PERF_CLOCK_NOW(_tick); if (quantize_src1) { size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; if (quantize_src1 == quantize_mmq_q8_1_cuda) { @@ -1523,7 +1523,7 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(cudaGetLastError()); } } - ticks_quant += std::chrono::duration(std::chrono::system_clock::now() - _tick).count(); + GGML_PERF_CLOCK_COUNT_ADD(ticks_quant, _tick); if (dst_on_device) { dev[id].dst_dd = (float *) dst->data; @@ -1614,24 +1614,24 @@ static void ggml_cuda_op_mul_mat( GGML_ABORT("fatal error"); } - _tick = std::chrono::system_clock::now(); + GGML_PERF_CLOCK_NOW(_tick); if (quantize_src1 && !src1_is_contiguous) { quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } - ticks_quant += std::chrono::duration((std::chrono::system_clock::now() - _tick)).count(); + GGML_PERF_CLOCK_COUNT_ADD(ticks_quant, _tick); if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) { CUDA_CHECK(ggml_cuda_cpy_tensor_2d( src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); } - _tick = std::chrono::system_clock::now(); + GGML_PERF_CLOCK_NOW(_tick); // do the computation op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream); CUDA_CHECK(cudaGetLastError()); - ticks_op += std::chrono::duration((std::chrono::system_clock::now() - _tick)).count(); + GGML_PERF_CLOCK_COUNT_ADD(ticks_op, _tick); // copy dst to host or other device if necessary if (!dst_on_device) { @@ -1679,7 +1679,7 @@ static void ggml_cuda_op_mul_mat( } } - // ticks_total += std::chrono::duration(std::chrono::system_clock::now() - tick_start).count(); + // GGML_PERF_CLOCK_COUNT_ADD(ticks_total, tick_start); // FILE *stat_file = fopen("cuda_op_mul_mat_stats.log", "a"); // fprintf(stat_file, // ">> ticks_total = %2.9f, ticks_quant = %2.9f, ticks_op = %2.9f\n", @@ -1904,9 +1904,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - // stats: | mmv_cnt | mm_batched_cublas_cnt | op_mmv_cnt | op_mmvq_cnt | op_mmq_cnt | op_mm_cublas_cnt - // FA=ON: | 0.0K | 0.0K | 0.0K | 7.2K | 0.0K | 0.0K - // FA=OFF: | 1.016K | 0.0K | 0.0K | 8.256K | 0.0K | 0.016K + // stats(Q4_K/Q8): | mmv_cnt | mm_batched_cublas_cnt | op_mmv_cnt | op_mmvq_cnt | op_mmq_cnt | op_mm_cublas_cnt + // FA=ON | 0.0K | 0.0K | 0.0K | 7.2K | 0.0K | 0.0K + // FA=OFF | 1.016K | 0.0K | 0.0K | 8.256K | 0.0K | 0.016K static int mmv_cnt = 0, mm_batched_cublas_cnt = 0, op_mmv_cnt = 0, op_mmvq_cnt = 0, op_mmq_cnt = 0, op_mm_cublas_cnt = 0; if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 4641781f067a4..b3ed0421f71f7 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -135,7 +135,9 @@ static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { - uint64_t tick_start = clock64(), _clock, _ticks_vecdotq; + GGML_PERF_GPU_CLOCK(tick_start); + GGML_PERF_GPU_CLOCK(_clock); + GGML_PERF_GPU_CLOCK(_ticks_vecdotq); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -158,7 +160,7 @@ static __global__ void mul_mat_vec_q( const block_q8_1 * y = (const block_q8_1 *) vy; - _clock = clock64(); + GGML_PERF_GPU_CLOCK_NOW(_clock); for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx @@ -173,7 +175,7 @@ static __global__ void mul_mat_vec_q( } } } - _ticks_vecdotq = clock64() - _clock; + GGML_PERF_GPU_CLOCK_COUNT_ADD(_ticks_vecdotq, _clock); __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { @@ -190,7 +192,7 @@ static __global__ void mul_mat_vec_q( return; } - _clock = clock64(); + GGML_PERF_GPU_CLOCK_NOW(_clock); // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_y; ++j) { @@ -208,10 +210,15 @@ static __global__ void mul_mat_vec_q( } } + // Stats: ticks_total | ticks_vecdotq | ticks_reduce_sum + // 267342976 | 211809728 | 45522960 + // | 79.23% | 17.03% + // ------------------------------------------------------------------------------------------- + // GGML_PERF_GPU_CLOCK(tick_end); // atomicAddUint64(&ticks_vecdotq, _ticks_vecdotq); - // atomicAddUint64(&ticks_reduce_sum, clock64() - _clock); - // atomicAddUint64(&ticks_total, clock64() - tick_start); - // printf(">> ticks_total = %12llu, ticks_vecdotq = %12llu, ticks_reduce_sum = %12llu\n", + // atomicAddUint64(&ticks_reduce_sum, tick_end - _clock); + // atomicAddUint64(&ticks_total, tick_end - tick_start); + // printf(">> [mmvq] ticks_total = %12llu, ticks_vecdotq = %12llu, ticks_reduce_sum = %12llu\n", // ticks_total, ticks_vecdotq, ticks_reduce_sum // ); } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 37a7669ee4c5b..c1942cb24de08 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -668,21 +668,15 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } -static __device__ uint64_t __ticks_total = 0, __ticks1 = 0, __ticks2 = 0; +static __device__ uint64_t __ticks_total = 0, __ticks1 = 0, __ticks2 = 0, __ticks3 = 0, __ticks4 = 0; static __device__ void atomicAddUint64(uint64_t *address, uint64_t val) { atomicAdd((unsigned long long*)address, (unsigned long long)val); } static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - uint64_t tick_start = clock64(); + GGML_PERF_GPU_CLOCK(tick_start); - // __shared__ block_q4_K bq4_K_shm; - // if (threadIdx.x == 0) { - // bq4_K_shm = *((const block_q4_K *)vbq + kbx); - // } - // __syncthreads(); - // const block_q4_K * bq4_K = &bq4_K_shm; const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx; int v[2]; @@ -697,10 +691,14 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + GGML_PERF_GPU_CLOCK(_tick1); + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); v[0] = q4[0]; v[1] = q4[4]; + GGML_PERF_GPU_CLOCK(_tick2); + const uint16_t * scales = (const uint16_t *)bq4_K->scales; uint16_t aux[2]; const int j = bq8_offset/2; @@ -714,6 +712,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const uint8_t * sc = (const uint8_t *)aux; const uint8_t * m = sc + 2; + GGML_PERF_GPU_CLOCK(_tick3); + for (int i = 0; i < QR4_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; d8[i] = __low2float(bq8i->ds); @@ -722,17 +722,23 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( u[2*i+0] = q8[0]; u[2*i+1] = q8[4]; } - uint64_t _tick1 = clock64(); + GGML_PERF_GPU_CLOCK(_tick4); float ret = vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); - // uint64_t tick_end = clock64(); + GGML_PERF_GPU_CLOCK(tick_end); + // Stats: __ticks_total | __ticks1 | __ticks2 | __ticks3 | __ticks4 | vmmq + // 161989088 | 5698656 | 1895872 | 68142496 | 14016416 | 72235648 + // | 3.52% | 1.17% | 42.07% | 8.65% | 44.59% + // ---------------------------------------------------------------------------------------- // atomicAddUint64(&__ticks1, _tick1 - tick_start); - // atomicAddUint64(&__ticks2, tick_end - _tick1); + // atomicAddUint64(&__ticks2, _tick2 - _tick1); + // atomicAddUint64(&__ticks3, _tick3 - _tick2); + // atomicAddUint64(&__ticks4, _tick4 - _tick3); // atomicAddUint64(&__ticks_total, tick_end - tick_start); - // printf(">> __ticks_total = %12llu, __ticks1 = %12llu, __ticks2 = %12llu\n", - // __ticks_total, __ticks1, __ticks2 + // printf(">> [dotq] __ticks_total = %12llu, __ticks1 = %12llu, __ticks2 = %12llu, __ticks3 = %12llu, __ticks4 = %12llu\n", + // __ticks_total, __ticks1, __ticks2, __ticks3, __ticks4 // ); return ret; From a97b50dd38e0c7ff1d56bf8231e007d878a582bb Mon Sep 17 00:00:00 2001 From: Bodhi Hu Date: Mon, 17 Mar 2025 20:31:06 +0800 Subject: [PATCH 4/4] separate profiling codes --- ggml/src/ggml-cuda/__mp.cuh | 112 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/common.cuh | 22 ------- ggml/src/ggml-cuda/ggml-cuda.cu | 1 + ggml/src/ggml-cuda/mmvq.cu | 4 +- ggml/src/ggml-cuda/vecdotq.cuh | 36 ++-------- 5 files changed, 120 insertions(+), 55 deletions(-) create mode 100644 ggml/src/ggml-cuda/__mp.cuh diff --git a/ggml/src/ggml-cuda/__mp.cuh b/ggml/src/ggml-cuda/__mp.cuh new file mode 100644 index 0000000000000..956c4389d220a --- /dev/null +++ b/ggml/src/ggml-cuda/__mp.cuh @@ -0,0 +1,112 @@ +#pragma once + +// #define GGML_PERF_ON + +static __device__ void atomicAddUint64(uint64_t *address, uint64_t val) { + atomicAdd((unsigned long long*)address, (unsigned long long)val); +} + +#ifdef GGML_PERF_ON +#define GGML_PERF_CLOCK(t) std::chrono::system_clock::time_point t = std::chrono::system_clock::now() +#define GGML_PERF_CLOCK_NOW(t) t = std::chrono::system_clock::now() +#define GGML_PERF_CLOCK_COUNT(t) std::chrono::duration(std::chrono::system_clock::now() - t).count() +#define GGML_PERF_CLOCK_COUNT_ADD(s, t) s += std::chrono::duration(std::chrono::system_clock::now() - t).count() +#define GGML_PERF_GPU_CLOCK(t) uint64_t t = clock64() +#define GGML_PERF_GPU_CLOCK_NOW(t) t = clock64() +#define GGML_PERF_GPU_CLOCK_COUNT(t) clock64() - t +#define GGML_PERF_GPU_CLOCK_COUNT_ADD(s, t) s += (clock64() - t) +#else +#define GGML_PERF_CLOCK(t) +#define GGML_PERF_CLOCK_NOW(t) +#define GGML_PERF_CLOCK_COUNT(t) +#define GGML_PERF_CLOCK_COUNT_ADD(s, t) +#define GGML_PERF_GPU_CLOCK(t) +#define GGML_PERF_GPU_CLOCK_NOW(t) +#define GGML_PERF_GPU_CLOCK_COUNT(t) +#define GGML_PERF_GPU_CLOCK_COUNT_ADD(s, t) +#endif // GGML_PERF_ON + + +#include "common.cuh" +#include "vecdotq.cuh" +#include + +static __device__ uint64_t __ticks_total = 0, __ticks1 = 0, __ticks2 = 0, __ticks3 = 0, __ticks4 = 0, __ticks5 = 0; +static __device__ __forceinline__ float __vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + GGML_PERF_GPU_CLOCK(tick_start); + + const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + GGML_PERF_GPU_CLOCK(_tick1); + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + v[0] = q4[0]; + v[1] = q4[4]; + + GGML_PERF_GPU_CLOCK(_tick2); + + // const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t scales[K_SCALE_SIZE/2]; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + GGML_PERF_GPU_CLOCK(_tick3); + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + GGML_PERF_GPU_CLOCK(_tick4); + + float ret = vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + + GGML_PERF_GPU_CLOCK(tick_end); + + // Stats: __ticks_total | __ticks1 | __ticks2 | __ticks3 | __ticks4 | __ticks5 + // 161989088 | 5698656 | 1895872 | 68142496 | 14016416 | 72235648 + // | 3.52% | 1.17% | 42.07% | 8.65% | 44.59% + // ---------------------------------------------------------------------------------------- + // 62014000 | 10536672 | 568288 | 493632 | 1359488 | 49060384 + // | 17.00% | 0.91% | 0.80% | 2.19% | 79.11% + // ---------------------------------------------------------------------------------------- +#ifdef GGML_PERF_ON + atomicAddUint64(&__ticks1, _tick1 - tick_start); + atomicAddUint64(&__ticks2, _tick2 - _tick1); + atomicAddUint64(&__ticks3, _tick3 - _tick2); + atomicAddUint64(&__ticks4, _tick4 - _tick3); + atomicAddUint64(&__ticks5, tick_end - _tick4); + atomicAddUint64(&__ticks_total, tick_end - tick_start); + printf(">> [dotq] __ticks_total = %12llu, __ticks1 = %12llu, __ticks2 = %12llu, __ticks3 = %12llu, __ticks4 = %12llu, __ticks5 = %12llu\n", + __ticks_total, __ticks1, __ticks2, __ticks3, __ticks4, __ticks5 + ); +#endif // GGML_PERF_ON + + return ret; +} diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9469921dcdc1a..663c57102d4a7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -845,25 +845,3 @@ struct ggml_backend_cuda_context { return pool(device); } }; - -// #define GGML_PERF_ON - -#ifdef GGML_PERF_ON -#define GGML_PERF_CLOCK(t) std::chrono::system_clock::time_point t = std::chrono::system_clock::now() -#define GGML_PERF_CLOCK_NOW(t) t = std::chrono::system_clock::now() -#define GGML_PERF_CLOCK_COUNT(t) std::chrono::duration(std::chrono::system_clock::now() - t).count() -#define GGML_PERF_CLOCK_COUNT_ADD(s, t) s += std::chrono::duration(std::chrono::system_clock::now() - t).count() -#define GGML_PERF_GPU_CLOCK(t) uint64_t t = clock64() -#define GGML_PERF_GPU_CLOCK_NOW(t) t = clock64() -#define GGML_PERF_GPU_CLOCK_COUNT(t) clock64() - t -#define GGML_PERF_GPU_CLOCK_COUNT_ADD(s, t) s += (clock64() - t) -#else -#define GGML_PERF_CLOCK(t) -#define GGML_PERF_CLOCK_NOW(t) -#define GGML_PERF_CLOCK_COUNT(t) -#define GGML_PERF_CLOCK_COUNT_ADD(s, t) -#define GGML_PERF_GPU_CLOCK(t) -#define GGML_PERF_GPU_CLOCK_NOW(t) -#define GGML_PERF_GPU_CLOCK_COUNT(t) -#define GGML_PERF_GPU_CLOCK_COUNT_ADD(s, t) -#endif // GGML_PERF_ON diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 506b8d1b93465..dfca676da6f40 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -38,6 +38,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv6.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/__mp.cuh" #include "ggml.h" #include diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index b3ed0421f71f7..136e9d5308a61 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,5 +1,6 @@ #include "mmvq.cuh" #include "vecdotq.cuh" +#include "__mp.cuh" typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); @@ -11,7 +12,8 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : - type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : + // type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : + type == GGML_TYPE_Q4_K ? __vec_dot_q4_K_q8_1 : type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index c1942cb24de08..ba195e1d100d3 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1,3 +1,5 @@ +#pragma once + #include "common.cuh" #include @@ -668,15 +670,9 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } -static __device__ uint64_t __ticks_total = 0, __ticks1 = 0, __ticks2 = 0, __ticks3 = 0, __ticks4 = 0; -static __device__ void atomicAddUint64(uint64_t *address, uint64_t val) { - atomicAdd((unsigned long long*)address, (unsigned long long)val); -} static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - GGML_PERF_GPU_CLOCK(tick_start); - const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx; int v[2]; @@ -691,14 +687,10 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 - GGML_PERF_GPU_CLOCK(_tick1); - const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); v[0] = q4[0]; v[1] = q4[4]; - GGML_PERF_GPU_CLOCK(_tick2); - const uint16_t * scales = (const uint16_t *)bq4_K->scales; uint16_t aux[2]; const int j = bq8_offset/2; @@ -712,8 +704,6 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const uint8_t * sc = (const uint8_t *)aux; const uint8_t * m = sc + 2; - GGML_PERF_GPU_CLOCK(_tick3); - for (int i = 0; i < QR4_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; d8[i] = __low2float(bq8i->ds); @@ -722,26 +712,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( u[2*i+0] = q8[0]; u[2*i+1] = q8[4]; } - GGML_PERF_GPU_CLOCK(_tick4); - - float ret = vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); - - GGML_PERF_GPU_CLOCK(tick_end); - - // Stats: __ticks_total | __ticks1 | __ticks2 | __ticks3 | __ticks4 | vmmq - // 161989088 | 5698656 | 1895872 | 68142496 | 14016416 | 72235648 - // | 3.52% | 1.17% | 42.07% | 8.65% | 44.59% - // ---------------------------------------------------------------------------------------- - // atomicAddUint64(&__ticks1, _tick1 - tick_start); - // atomicAddUint64(&__ticks2, _tick2 - _tick1); - // atomicAddUint64(&__ticks3, _tick3 - _tick2); - // atomicAddUint64(&__ticks4, _tick4 - _tick3); - // atomicAddUint64(&__ticks_total, tick_end - tick_start); - // printf(">> [dotq] __ticks_total = %12llu, __ticks1 = %12llu, __ticks2 = %12llu, __ticks3 = %12llu, __ticks4 = %12llu\n", - // __ticks_total, __ticks1, __ticks2, __ticks3, __ticks4 - // ); - - return ret; + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); } static __device__ __forceinline__ float vec_dot_q5_K_q8_1(