Skip to content

[WIP] MUSA: enable fastfp16, correct warp reduce impl and perf tuning #12383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions ggml/src/ggml-cuda/__mp.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#pragma once

// #define GGML_PERF_ON

static __device__ void atomicAddUint64(uint64_t *address, uint64_t val) {
atomicAdd((unsigned long long*)address, (unsigned long long)val);
}

#ifdef GGML_PERF_ON
#define GGML_PERF_CLOCK(t) std::chrono::system_clock::time_point t = std::chrono::system_clock::now()
#define GGML_PERF_CLOCK_NOW(t) t = std::chrono::system_clock::now()
#define GGML_PERF_CLOCK_COUNT(t) std::chrono::duration<double>(std::chrono::system_clock::now() - t).count()
#define GGML_PERF_CLOCK_COUNT_ADD(s, t) s += std::chrono::duration<double>(std::chrono::system_clock::now() - t).count()
#define GGML_PERF_GPU_CLOCK(t) uint64_t t = clock64()
#define GGML_PERF_GPU_CLOCK_NOW(t) t = clock64()
#define GGML_PERF_GPU_CLOCK_COUNT(t) clock64() - t
#define GGML_PERF_GPU_CLOCK_COUNT_ADD(s, t) s += (clock64() - t)
#else
#define GGML_PERF_CLOCK(t)
#define GGML_PERF_CLOCK_NOW(t)
#define GGML_PERF_CLOCK_COUNT(t)
#define GGML_PERF_CLOCK_COUNT_ADD(s, t)
#define GGML_PERF_GPU_CLOCK(t)
#define GGML_PERF_GPU_CLOCK_NOW(t)
#define GGML_PERF_GPU_CLOCK_COUNT(t)
#define GGML_PERF_GPU_CLOCK_COUNT_ADD(s, t)
#endif // GGML_PERF_ON


#include "common.cuh"
#include "vecdotq.cuh"
#include <cstdint>

static __device__ uint64_t __ticks_total = 0, __ticks1 = 0, __ticks2 = 0, __ticks3 = 0, __ticks4 = 0, __ticks5 = 0;
static __device__ __forceinline__ float __vec_dot_q4_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {

GGML_PERF_GPU_CLOCK(tick_start);

const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;

int v[2];
int u[2*QR4_K];
float d8[QR4_K];

// iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));

// iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
// iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
// iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
// iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108

GGML_PERF_GPU_CLOCK(_tick1);

const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
v[0] = q4[0];
v[1] = q4[4];

GGML_PERF_GPU_CLOCK(_tick2);

// const uint16_t * scales = (const uint16_t *)bq4_K->scales;
uint16_t scales[K_SCALE_SIZE/2];
uint16_t aux[2];
const int j = bq8_offset/2;
if (j < 2) {
aux[0] = scales[j+0] & 0x3f3f;
aux[1] = scales[j+2] & 0x3f3f;
} else {
aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
}
const uint8_t * sc = (const uint8_t *)aux;
const uint8_t * m = sc + 2;

GGML_PERF_GPU_CLOCK(_tick3);

for (int i = 0; i < QR4_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
d8[i] = __low2float(bq8i->ds);

const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
u[2*i+0] = q8[0];
u[2*i+1] = q8[4];
}
GGML_PERF_GPU_CLOCK(_tick4);

float ret = vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);

GGML_PERF_GPU_CLOCK(tick_end);

// Stats: __ticks_total | __ticks1 | __ticks2 | __ticks3 | __ticks4 | __ticks5
// 161989088 | 5698656 | 1895872 | 68142496 | 14016416 | 72235648
// | 3.52% | 1.17% | 42.07% | 8.65% | 44.59%
// ----------------------------------------------------------------------------------------
// 62014000 | 10536672 | 568288 | 493632 | 1359488 | 49060384
// | 17.00% | 0.91% | 0.80% | 2.19% | 79.11%
// ----------------------------------------------------------------------------------------
#ifdef GGML_PERF_ON
atomicAddUint64(&__ticks1, _tick1 - tick_start);
atomicAddUint64(&__ticks2, _tick2 - _tick1);
atomicAddUint64(&__ticks3, _tick3 - _tick2);
atomicAddUint64(&__ticks4, _tick4 - _tick3);
atomicAddUint64(&__ticks5, tick_end - _tick4);
atomicAddUint64(&__ticks_total, tick_end - tick_start);
printf(">> [dotq] __ticks_total = %12llu, __ticks1 = %12llu, __ticks2 = %12llu, __ticks3 = %12llu, __ticks4 = %12llu, __ticks5 = %12llu\n",
__ticks_total, __ticks1, __ticks2, __ticks3, __ticks4, __ticks5
);
#endif // GGML_PERF_ON

return ret;
}
66 changes: 62 additions & 4 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#define GGML_USE_MUSA

#include "ggml.h"
#include "ggml-cuda.h"
Expand Down Expand Up @@ -223,7 +224,11 @@ static bool fast_fp16_available(const int cc) {

// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fast_fp16_hardware_available(const int cc) {
#ifdef GGML_USE_MUSA
return true;
#else
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
#endif // GGML_USE_MUSA
}

// Any FP16 tensor core instructions are available for ggml code.
Expand Down Expand Up @@ -254,6 +259,8 @@ static bool cp_async_available(const int cc) {
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
#elif defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY2
return 128;
#else
return 32;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
Expand Down Expand Up @@ -284,22 +291,30 @@ static __device__ void no_device_code(

template<int width = WARP_SIZE>
static __device__ __forceinline__ int warp_reduce_sum(int x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
return __reduce_add_sync(0xffffffff, x);
#else
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
#ifdef GGML_USE_MUSA
x += musa_shfl_xor_sync<int, width>(x, offset);
#else
x += __shfl_xor_sync(0xffffffff, x, offset, width);
#endif // GGML_USE_MUSA
}
return x;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
#ifdef GGML_USE_MUSA
x += musa_shfl_xor_sync<float, width>(x, offset);
#else
x += __shfl_xor_sync(0xffffffff, x, offset, width);
#endif // GGML_USE_MUSA
}
return x;
}
Expand All @@ -308,8 +323,13 @@ template<int width = WARP_SIZE>
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
#ifdef GGML_USE_MUSA
a.x += musa_shfl_xor_sync<float, width>(a.x, offset);
a.y += musa_shfl_xor_sync<float, width>(a.y, offset);
#else
a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
#endif // GGML_USE_MUSA
}
return a;
}
Expand All @@ -319,7 +339,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#ifdef FP16_AVAILABLE
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
#ifdef GGML_USE_MUSA
a = __hadd2(a, musa_shfl_xor_sync<half2, width>(a, offset));
#else
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
#endif // GGML_USE_MUSA
}
return a;

Expand All @@ -333,7 +357,11 @@ template<int width = WARP_SIZE>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
#ifdef GGML_USE_MUSA
x = fmaxf(x, musa_shfl_xor_sync<float, width>(x, offset));
#else
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
#endif // GGML_USE_MUSA
}
return x;
}
Expand Down Expand Up @@ -373,16 +401,46 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal

template<int width = WARP_SIZE>
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) || defined(GGML_USE_MUSA)
#pragma unroll
for (int offset = width/2; offset > 0; offset >>= 1) {
#ifdef GGML_USE_MUSA
x = ggml_cuda_hmax2(x, musa_shfl_xor_sync<half2, width>(x, offset));
#else
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
#endif // GGML_USE_MUSA
}
return x;
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) || defined(GGML_USE_MUSA)
}

template<int width = WARP_SIZE, int qk_size>
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#ifdef GGML_USE_MUSA
#pragma unroll
for (int offset = qk_size/2; offset > 0; offset >>= 1) {
x += musa_shfl_xor_sync<float, width>(x, offset);
}
return x;
#else
return warp_reduce_sum<width>(x);
#endif // GGML_USE_MUSA
}

template<int width = WARP_SIZE, int qk_size>
static __device__ __forceinline__ float warp_reduce_max(float x) {
#ifdef GGML_USE_MUSA
#pragma unroll
for (int offset = qk_size/2; offset > 0; offset >>= 1) {
x = fmaxf(x, musa_shfl_xor_sync<float, width>(x, offset));
}
return x;
#else
return warp_reduce_max<width>(x);
#endif // GGML_USE_MUSA
}

#if CUDART_VERSION < CUDART_HMASK
Expand Down
45 changes: 39 additions & 6 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml-cuda/__mp.cuh"
#include "ggml.h"

#include <algorithm>
Expand Down Expand Up @@ -1358,11 +1359,17 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}

static double ticks_total, ticks_quant, ticks_op;
// stats: | ticks_total | ticks_quant | ticks_mul_mat
// | 2.119177099 | 0.019638826 | 2.096463255
static void ggml_cuda_op_mul_mat(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
quantize_cuda_t quantize_src1) {

GGML_PERF_CLOCK(tick_start);
GGML_PERF_CLOCK(_tick);

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
Expand Down Expand Up @@ -1504,6 +1511,7 @@ static void ggml_cuda_op_mul_mat(
dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
}

GGML_PERF_CLOCK_NOW(_tick);
if (quantize_src1) {
size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
if (quantize_src1 == quantize_mmq_q8_1_cuda) {
Expand All @@ -1516,6 +1524,7 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(cudaGetLastError());
}
}
GGML_PERF_CLOCK_COUNT_ADD(ticks_quant, _tick);

if (dst_on_device) {
dev[id].dst_dd = (float *) dst->data;
Expand Down Expand Up @@ -1606,20 +1615,24 @@ static void ggml_cuda_op_mul_mat(
GGML_ABORT("fatal error");
}

GGML_PERF_CLOCK_NOW(_tick);
if (quantize_src1 && !src1_is_contiguous) {
quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
CUDA_CHECK(cudaGetLastError());
}
GGML_PERF_CLOCK_COUNT_ADD(ticks_quant, _tick);

if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
}

GGML_PERF_CLOCK_NOW(_tick);
// do the computation
op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
GGML_PERF_CLOCK_COUNT_ADD(ticks_op, _tick);

// copy dst to host or other device if necessary
if (!dst_on_device) {
Expand Down Expand Up @@ -1666,6 +1679,14 @@ static void ggml_cuda_op_mul_mat(
}
}
}

// GGML_PERF_CLOCK_COUNT_ADD(ticks_total, tick_start);
// FILE *stat_file = fopen("cuda_op_mul_mat_stats.log", "a");
// fprintf(stat_file,
// ">> ticks_total = %2.9f, ticks_quant = %2.9f, ticks_op = %2.9f\n",
// ticks_total, ticks_quant, ticks_op
// );
// fclose(stat_file);
}

static __global__ void k_compute_batched_ptrs(
Expand Down Expand Up @@ -1884,23 +1905,41 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);

// stats(Q4_K/Q8): | mmv_cnt | mm_batched_cublas_cnt | op_mmv_cnt | op_mmvq_cnt | op_mmq_cnt | op_mm_cublas_cnt
// FA=ON | 0.0K | 0.0K | 0.0K | 7.2K | 0.0K | 0.0K
// FA=OFF | 1.016K | 0.0K | 0.0K | 8.256K | 0.0K | 0.016K
static int mmv_cnt = 0, mm_batched_cublas_cnt = 0, op_mmv_cnt = 0, op_mmvq_cnt = 0, op_mmq_cnt = 0, op_mm_cublas_cnt = 0;

if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
mmv_cnt++;
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// general KQ + KQV multi-batch without FlashAttention
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
mm_batched_cublas_cnt++;
} else if (use_mul_mat_vec) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
op_mmv_cnt++;
} else if (use_mul_mat_vec_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
op_mmvq_cnt++;
} else if (use_mul_mat_q) {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
op_mmq_cnt++;
} else {
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
op_mm_cublas_cnt++;
}

// FILE *file = fopen("cuda_op_stats.log", "a");
// fprintf(file,
// ">> mmv_cnt = %4.5fK, mm_batched_cublas_cnt = %4.5fK, op_mmv_cnt = %4.5fK, op_mmvq_cnt = %4.5fK, op_mmq_cnt = %4.5fK, op_mm_cublas_cnt = %4.5fK\n",
// mmv_cnt/1000.0, mm_batched_cublas_cnt/1000.0, op_mmv_cnt/1000.0, op_mmvq_cnt/1000.0, op_mmq_cnt/1000.0, op_mm_cublas_cnt/1000.0
// );
// fclose(file);
}

struct mmid_row_mapping {
Expand Down Expand Up @@ -3008,12 +3047,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}
#ifdef GGML_USE_MUSA
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
return false;
}
#endif // GGML_USE_MUSA
switch (a->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
Expand Down
Loading
Loading