Skip to content

OpenCL: add tiled mul_mat_f16_f32 #14535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ set(GGML_OPENCL_KERNELS
tanh
pad
repeat
mul_mat_f16_f32
)

foreach (K ${GGML_OPENCL_KERNELS})
Expand Down
82 changes: 82 additions & 0 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ struct ggml_backend_opencl_context {
cl_program program_mul_mv_f16_f32;
cl_program program_mul_mv_f32_f32;
cl_program program_mul;
cl_program program_mul_mat_f16_f32_tiled;
cl_program program_div;
cl_program program_sub;
cl_program program_norm;
Expand Down Expand Up @@ -419,6 +420,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mat_f16_f32_1row;
cl_kernel kernel_mul_mat_f16_f32;
cl_kernel kernel_mul_mat_f16_f32_l4;
cl_kernel kernel_mul_mat_f16_f32_tiled;
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
Expand Down Expand Up @@ -1000,6 +1002,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}

// mul_mat_f16_f32_tiled
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mat_f16_f32.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mat_f16_f32.cl");
#endif
backend_ctx->program_mul_mat_f16_f32_tiled =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);

CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err));
GGML_LOG_CONT(".");
}

// mul
{
#ifdef GGML_OPENCL_EMBED_KERNELS
Expand Down Expand Up @@ -4742,6 +4760,58 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
}

static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;

cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;

const int M = src0->ne[1];
const int N = src1->ne[1];
const int K = src0->ne[0];

cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled;

CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd));

// Tiling parameters. These need to be tuned for optimal performance.
// They must match the #defines in the kernel mul_mat_f16_f32.cl.
//
// OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN.
// TPWM / TPWN: Threads per Work-group. This is the work-group size.
// OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements.
//
// The following relationships must hold:
// OPWM = TPWM * OPTM
// OPWN = TPWN * OPTN
//
const int OPWM = 64;
const int OPWN = 64;
const int TPWM = 16;
const int TPWN = 8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think OPWM, OPWN, TPWM, TPWN, OPTM, OPTN are related, e.g., TPWM can be calculated from OPWM and OPTM. I wonder if it is possible to do calculation. Or maybe add a comment about how they are related.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are all mathematically related. keeping them as explicit compile-time constants is generally better because it allows the compiler to perform aggressive, hardware-specific optimizations. This enables full loop unrolling to eliminate expensive branching and allows accumulator arrays to be allocated directly into the fastest registers, optimizations that are impossible if these values are calculated as runtime variables. I will add comments about the relationship

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would constexpr (as opposed to const) work for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr is a C++ feature and wouldn't apply here, as the OpenCL kernel is compiled separately at runtime using the C-like OpenCL language


size_t local_work_size[2] = { TPWM, TPWN };
size_t global_work_size[2] = {
(size_t) ((M + OPWM - 1) / OPWM) * TPWM,
(size_t) ((N + OPWN - 1) / OPWN) * TPWN,
};

backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
}

static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
Expand All @@ -4755,6 +4825,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co

ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;

if (src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32 &&
src0->ne[1] > 32 && // M > 32
src1->ne[1] > 32 && // N > 32
src0->ne[0] > 32 && // K > 32
src0->ne[2] == 1 && src0->ne[3] == 1 &&
src1->ne[2] == 1 && src1->ne[3] == 1 &&
ggml_is_contiguous(src0) && ggml_is_contiguous(src1) &&
backend_ctx->kernel_mul_mat_f16_f32_tiled != NULL) {
ggml_cl_mul_mat_f16_f32_tiled(backend, src0, src1, dst);
return;
}

ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
Expand Down
130 changes: 130 additions & 0 deletions ggml/src/ggml-opencl/kernels/mul_mat_f16_f32.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable

#if defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#else
#define REQD_SUBGROUP_SIZE_128
#endif

#define OPWM 64
#define OPWN 64
#define CPWK 8
#define OPTM 4
#define OPTN 8

#define WG_M (OPWM / OPTM)
#define WG_N (OPWN / OPTN)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WG_M and WG_N seem to be workgroup size - can they be replaced with get_local_size()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested replacing the macros with get_local_size(), but it resulted in a significant performance regression (~17%). Using compile-time constants is critical here, as it allows the compiler to fully unroll the inner loops and pre-calculate memory address offsets, an optimization that is lost when WG_M becomes a runtime value

#define VEC_K (CPWK / 4)

REQD_SUBGROUP_SIZE_128
__kernel void mul_mat_f16_f32(
const int M, const int N, const int K,
__global const void* A_void, ulong A_offset,
__global const void* B_void, ulong B_offset,
__global void* C_void, ulong C_offset) {

__global const half* A = (__global const half* )((__global const char*)A_void + A_offset);
__global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
__global float* C = (__global float*)((__global char*)C_void + C_offset);

const int lidm = get_local_id(0);
const int lidn = get_local_id(1);
const int lid = lidn * WG_M + lidm;

const int offsetM = get_group_id(0) * OPWM;
const int offsetN = get_group_id(1) * OPWN;

__local half4 Alocal[OPWM][VEC_K];
__local float4 Blocal[OPWN][VEC_K];

float sum[OPTM][OPTN];

for (int wm = 0; wm < OPTM; wm++) {
for (int wn = 0; wn < OPTN; wn++) {
sum[wm][wn] = 0.0f;
}
}

const int numTiles = (K + CPWK - 1) / CPWK;

const int load_row_a = lid % OPWM;
const int load_vec_k_a = lid / OPWM;
const int global_row_a = offsetM + load_row_a;

const int load_row_b = lid % OPWN;
const int load_vec_k_b = lid / OPWN;
const int global_row_b = offsetN + load_row_b;

for (int t = 0; t < numTiles; t++) {
const int k_start = t * CPWK;
const int k_vec_start_a = k_start + load_vec_k_a * 4;
const int k_vec_start_b = k_start + load_vec_k_b * 4;

if (global_row_a < M && k_vec_start_a < K) {
if (k_vec_start_a + 3 < K) {
Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
} else {
half4 tempA = (half4)(0.0h);
if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
Alocal[load_row_a][load_vec_k_a] = tempA;
}
} else {
Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
}

if (global_row_b < N && k_vec_start_b < K) {
if (k_vec_start_b + 3 < K) {
Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
} else {
float4 tempB = (float4)(0.0f);
if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
Blocal[load_row_b][load_vec_k_b] = tempB;
}
} else {
Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
}

barrier(CLK_LOCAL_MEM_FENCE);

#pragma unroll
for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
float4 a_fvecs[OPTM];
int current_row_a = lidm;
for (int wm = 0; wm < OPTM; wm++) {
a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
current_row_a += WG_M;
}

float4 b_fvecs[OPTN];
int current_row_b = lidn;
for (int wn = 0; wn < OPTN; wn++) {
b_fvecs[wn] = Blocal[current_row_b][k_vec];
current_row_b += WG_N;
}

for (int wm = 0; wm < OPTM; wm++) {
for (int wn = 0; wn < OPTN; wn++) {
sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}

for (int wm = 0; wm < OPTM; wm++) {
int globalRow = offsetM + lidm + wm * WG_M;
if (globalRow < M) {
for (int wn = 0; wn < OPTN; wn++) {
int globalCol = offsetN + lidn + wn * WG_N;
if (globalCol < N) {
C[globalCol * M + globalRow] = sum[wm][wn];
}
}
}
}
}
Loading