-
Notifications
You must be signed in to change notification settings - Fork 12.3k
OpenCL: add tiled mul_mat_f16_f32 #14535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,7 @@ set(GGML_OPENCL_KERNELS | |
tanh | ||
pad | ||
repeat | ||
mul_mat_f16_f32 | ||
) | ||
|
||
foreach (K ${GGML_OPENCL_KERNELS}) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -367,6 +367,7 @@ struct ggml_backend_opencl_context { | |
cl_program program_mul_mv_f16_f32; | ||
cl_program program_mul_mv_f32_f32; | ||
cl_program program_mul; | ||
cl_program program_mul_mat_f16_f32_tiled; | ||
cl_program program_div; | ||
cl_program program_sub; | ||
cl_program program_norm; | ||
|
@@ -419,6 +420,7 @@ struct ggml_backend_opencl_context { | |
cl_kernel kernel_mul_mat_f16_f32_1row; | ||
cl_kernel kernel_mul_mat_f16_f32; | ||
cl_kernel kernel_mul_mat_f16_f32_l4; | ||
cl_kernel kernel_mul_mat_f16_f32_tiled; | ||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; | ||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; | ||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; | ||
|
@@ -1000,6 +1002,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve | |
GGML_LOG_CONT("."); | ||
} | ||
|
||
// mul_mat_f16_f32_tiled | ||
{ | ||
#ifdef GGML_OPENCL_EMBED_KERNELS | ||
const std::string kernel_src { | ||
#include "mul_mat_f16_f32.cl.h" | ||
}; | ||
#else | ||
const std::string kernel_src = read_file("mul_mat_f16_f32.cl"); | ||
#endif | ||
backend_ctx->program_mul_mat_f16_f32_tiled = | ||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); | ||
|
||
CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err)); | ||
GGML_LOG_CONT("."); | ||
} | ||
|
||
// mul | ||
{ | ||
#ifdef GGML_OPENCL_EMBED_KERNELS | ||
|
@@ -4742,6 +4760,58 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor | |
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); | ||
} | ||
|
||
static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||
|
||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; | ||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; | ||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; | ||
|
||
cl_ulong offset0 = extra0->offset + src0->view_offs; | ||
cl_ulong offset1 = extra1->offset + src1->view_offs; | ||
cl_ulong offsetd = extrad->offset + dst->view_offs; | ||
|
||
const int M = src0->ne[1]; | ||
const int N = src1->ne[1]; | ||
const int K = src0->ne[0]; | ||
|
||
cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled; | ||
|
||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M)); | ||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N)); | ||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K)); | ||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device)); | ||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0)); | ||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); | ||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1)); | ||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); | ||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); | ||
|
||
// Tiling parameters. These need to be tuned for optimal performance. | ||
// They must match the #defines in the kernel mul_mat_f16_f32.cl. | ||
// | ||
// OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN. | ||
// TPWM / TPWN: Threads per Work-group. This is the work-group size. | ||
// OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements. | ||
// | ||
// The following relationships must hold: | ||
// OPWM = TPWM * OPTM | ||
// OPWN = TPWN * OPTN | ||
// | ||
const int OPWM = 64; | ||
const int OPWN = 64; | ||
const int TPWM = 16; | ||
const int TPWN = 8; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are all mathematically related. keeping them as explicit compile-time constants is generally better because it allows the compiler to perform aggressive, hardware-specific optimizations. This enables full loop unrolling to eliminate expensive branching and allows accumulator arrays to be allocated directly into the fastest registers, optimizations that are impossible if these values are calculated as runtime variables. I will add comments about the relationship There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
size_t local_work_size[2] = { TPWM, TPWN }; | ||
size_t global_work_size[2] = { | ||
(size_t) ((M + OPWM - 1) / OPWM) * TPWM, | ||
(size_t) ((N + OPWN - 1) / OPWN) * TPWN, | ||
}; | ||
|
||
backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); | ||
} | ||
|
||
static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||
GGML_ASSERT(src0); | ||
GGML_ASSERT(src0->extra); | ||
|
@@ -4755,6 +4825,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co | |
|
||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; | ||
|
||
if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 && | ||
src0->ne[1] > 32 && // M > 32 | ||
src1->ne[1] > 32 && // N > 32 | ||
src0->ne[0] > 32 && // K > 32 | ||
src0->ne[2] == 1 && src0->ne[3] == 1 && | ||
src1->ne[2] == 1 && src1->ne[3] == 1 && | ||
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && | ||
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) { | ||
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst); | ||
return; | ||
} | ||
|
||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; | ||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; | ||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||
|
||
#if defined(cl_qcom_reqd_sub_group_size) | ||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable | ||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) | ||
#else | ||
#define REQD_SUBGROUP_SIZE_128 | ||
#endif | ||
|
||
#define OPWM 64 | ||
#define OPWN 64 | ||
#define CPWK 8 | ||
#define OPTM 4 | ||
#define OPTN 8 | ||
|
||
#define WG_M (OPWM / OPTM) | ||
#define WG_N (OPWN / OPTN) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested replacing the macros with get_local_size(), but it resulted in a significant performance regression (~17%). Using compile-time constants is critical here, as it allows the compiler to fully unroll the inner loops and pre-calculate memory address offsets, an optimization that is lost when WG_M becomes a runtime value |
||
#define VEC_K (CPWK / 4) | ||
|
||
REQD_SUBGROUP_SIZE_128 | ||
__kernel void mul_mat_f16_f32( | ||
const int M, const int N, const int K, | ||
__global const void* A_void, ulong A_offset, | ||
__global const void* B_void, ulong B_offset, | ||
__global void* C_void, ulong C_offset) { | ||
|
||
__global const half* A = (__global const half* )((__global const char*)A_void + A_offset); | ||
__global const float* B = (__global const float*)((__global const char*)B_void + B_offset); | ||
__global float* C = (__global float*)((__global char*)C_void + C_offset); | ||
|
||
const int lidm = get_local_id(0); | ||
const int lidn = get_local_id(1); | ||
const int lid = lidn * WG_M + lidm; | ||
|
||
const int offsetM = get_group_id(0) * OPWM; | ||
const int offsetN = get_group_id(1) * OPWN; | ||
|
||
__local half4 Alocal[OPWM][VEC_K]; | ||
__local float4 Blocal[OPWN][VEC_K]; | ||
|
||
float sum[OPTM][OPTN]; | ||
|
||
for (int wm = 0; wm < OPTM; wm++) { | ||
for (int wn = 0; wn < OPTN; wn++) { | ||
sum[wm][wn] = 0.0f; | ||
} | ||
} | ||
|
||
const int numTiles = (K + CPWK - 1) / CPWK; | ||
|
||
const int load_row_a = lid % OPWM; | ||
const int load_vec_k_a = lid / OPWM; | ||
const int global_row_a = offsetM + load_row_a; | ||
|
||
const int load_row_b = lid % OPWN; | ||
const int load_vec_k_b = lid / OPWN; | ||
const int global_row_b = offsetN + load_row_b; | ||
|
||
for (int t = 0; t < numTiles; t++) { | ||
const int k_start = t * CPWK; | ||
const int k_vec_start_a = k_start + load_vec_k_a * 4; | ||
const int k_vec_start_b = k_start + load_vec_k_b * 4; | ||
|
||
if (global_row_a < M && k_vec_start_a < K) { | ||
if (k_vec_start_a + 3 < K) { | ||
Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a); | ||
} else { | ||
half4 tempA = (half4)(0.0h); | ||
if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a]; | ||
if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1]; | ||
if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2]; | ||
Alocal[load_row_a][load_vec_k_a] = tempA; | ||
} | ||
} else { | ||
Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h); | ||
} | ||
|
||
if (global_row_b < N && k_vec_start_b < K) { | ||
if (k_vec_start_b + 3 < K) { | ||
Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b); | ||
} else { | ||
float4 tempB = (float4)(0.0f); | ||
if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b]; | ||
if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1]; | ||
if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2]; | ||
Blocal[load_row_b][load_vec_k_b] = tempB; | ||
} | ||
} else { | ||
Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f); | ||
} | ||
|
||
barrier(CLK_LOCAL_MEM_FENCE); | ||
|
||
#pragma unroll | ||
for (int k_vec = 0; k_vec < VEC_K; k_vec++) { | ||
float4 a_fvecs[OPTM]; | ||
int current_row_a = lidm; | ||
for (int wm = 0; wm < OPTM; wm++) { | ||
a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]); | ||
current_row_a += WG_M; | ||
} | ||
|
||
float4 b_fvecs[OPTN]; | ||
int current_row_b = lidn; | ||
for (int wn = 0; wn < OPTN; wn++) { | ||
b_fvecs[wn] = Blocal[current_row_b][k_vec]; | ||
current_row_b += WG_N; | ||
} | ||
|
||
for (int wm = 0; wm < OPTM; wm++) { | ||
for (int wn = 0; wn < OPTN; wn++) { | ||
sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]); | ||
} | ||
} | ||
} | ||
barrier(CLK_LOCAL_MEM_FENCE); | ||
} | ||
|
||
for (int wm = 0; wm < OPTM; wm++) { | ||
int globalRow = offsetM + lidm + wm * WG_M; | ||
if (globalRow < M) { | ||
for (int wn = 0; wn < OPTN; wn++) { | ||
int globalCol = offsetN + lidn + wn * WG_N; | ||
if (globalCol < N) { | ||
C[globalCol * M + globalRow] = sum[wm][wn]; | ||
} | ||
} | ||
} | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.