-
Notifications
You must be signed in to change notification settings - Fork 12.8k
vulkan : support ggml_mean #15393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vulkan : support ggml_mean #15393
Changes from all commits
0012b5c
b6c4a11
43ab427
8ec8ea4
96308e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1014,6 +1014,39 @@ struct vk_op_upscale_push_constants { | |
float sf0; float sf1; float sf2; float sf3; | ||
}; | ||
|
||
struct vk_op_sum_rows_push_constants | ||
{ | ||
uint32_t n_cols; | ||
uint32_t ne01, ne02; | ||
uint32_t nb01, nb02, nb03; | ||
uint32_t nb11, nb12, nb13; | ||
float weight; | ||
uint32_t misalign_offsets; | ||
uint32_t ne0_12mp, ne0_12L; | ||
uint32_t ne0_1mp, ne0_1L; | ||
}; | ||
|
||
vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) { | ||
uint32_t type_size = (uint32_t)ggml_type_size(src->type); | ||
vk_op_sum_rows_push_constants p = {}; | ||
p.n_cols = (uint32_t)n_cols; | ||
p.ne01 = (uint32_t)src->ne[1]; | ||
p.ne02 = (uint32_t)src->ne[2]; | ||
p.nb01 = (uint32_t)src->nb[1] / type_size; | ||
p.nb02 = (uint32_t)src->nb[2] / type_size; | ||
p.nb03 = (uint32_t)src->nb[3] / type_size; | ||
p.nb11 = (uint32_t)dst->nb[1] / type_size; | ||
p.nb12 = (uint32_t)dst->nb[2] / type_size; | ||
p.nb13 = (uint32_t)dst->nb[3] / type_size; | ||
p.weight = 1.0f; | ||
return p; | ||
} | ||
|
||
template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { | ||
init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L); | ||
init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); | ||
} | ||
|
||
// Allow pre-recording command buffers | ||
struct vk_staging_memcpy { | ||
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} | ||
|
@@ -3122,7 +3155,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | |
|
||
ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); | ||
|
||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); | ||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); | ||
|
||
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); | ||
|
||
|
@@ -7207,6 +7240,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const | |
return nullptr; | ||
case GGML_OP_SUM: | ||
case GGML_OP_SUM_ROWS: | ||
case GGML_OP_MEAN: | ||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||
return ctx->device->pipeline_sum_rows_f32; | ||
} | ||
|
@@ -7339,6 +7373,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { | |
case GGML_OP_CONV_2D_DW: | ||
case GGML_OP_IM2COL: | ||
case GGML_OP_SET_ROWS: | ||
case GGML_OP_SUM: | ||
case GGML_OP_SUM_ROWS: | ||
case GGML_OP_MEAN: | ||
return true; | ||
default: | ||
return false; | ||
|
@@ -7373,6 +7410,16 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk | |
GGML_UNUSED(src2); | ||
} | ||
|
||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { | ||
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); | ||
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); | ||
|
||
p.misalign_offsets = (a_offset << 16) | d_offset; | ||
|
||
GGML_UNUSED(src1); | ||
GGML_UNUSED(src2); | ||
} | ||
|
||
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { | ||
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); | ||
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); | ||
|
@@ -7523,10 +7570,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co | |
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); | ||
|
||
if (op_supports_incontiguous) { | ||
x_sz = ggml_nbytes(src0); | ||
y_sz = use_src1 ? ggml_nbytes(src1) : 0; | ||
z_sz = use_src2 ? ggml_nbytes(src2) : 0; | ||
d_sz = ggml_nbytes(dst); | ||
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); | ||
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; | ||
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; | ||
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); | ||
|
||
if (x_buf_offset + x_sz >= d_X->size) { | ||
x_sz = VK_WHOLE_SIZE; | ||
|
@@ -7554,6 +7601,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co | |
case GGML_OP_SOFT_MAX: | ||
case GGML_OP_SOFT_MAX_BACK: | ||
case GGML_OP_SUM_ROWS: | ||
case GGML_OP_MEAN: | ||
case GGML_OP_ARGMAX: | ||
{ | ||
const uint32_t nr = ggml_nrows(src0); | ||
|
@@ -8540,11 +8588,19 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c | |
} | ||
|
||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { | ||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); | ||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); | ||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); | ||
} | ||
|
||
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { | ||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); | ||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); | ||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); | ||
} | ||
|
||
static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { | ||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); | ||
p.weight = 1.0f / (float)src0->ne[0]; | ||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); | ||
} | ||
|
||
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { | ||
|
@@ -9766,6 +9822,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr | |
case GGML_OP_ARGSORT: | ||
case GGML_OP_SUM: | ||
case GGML_OP_SUM_ROWS: | ||
case GGML_OP_MEAN: | ||
case GGML_OP_ARGMAX: | ||
case GGML_OP_COUNT_EQUAL: | ||
case GGML_OP_IM2COL: | ||
|
@@ -9835,6 +9892,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr | |
case GGML_OP_ARGSORT: | ||
case GGML_OP_SUM: | ||
case GGML_OP_SUM_ROWS: | ||
case GGML_OP_MEAN: | ||
case GGML_OP_ARGMAX: | ||
case GGML_OP_COUNT_EQUAL: | ||
case GGML_OP_IM2COL: | ||
|
@@ -10037,6 +10095,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr | |
case GGML_OP_SUM_ROWS: | ||
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); | ||
|
||
break; | ||
case GGML_OP_MEAN: | ||
ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun); | ||
|
||
break; | ||
case GGML_OP_ARGMAX: | ||
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); | ||
|
@@ -10196,6 +10258,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * | |
case GGML_OP_ARGSORT: | ||
case GGML_OP_SUM: | ||
case GGML_OP_SUM_ROWS: | ||
case GGML_OP_MEAN: | ||
case GGML_OP_ARGMAX: | ||
case GGML_OP_COUNT_EQUAL: | ||
case GGML_OP_IM2COL: | ||
|
@@ -11426,8 +11489,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | |
case GGML_OP_DIAG_MASK_INF: | ||
case GGML_OP_SOFT_MAX: | ||
case GGML_OP_SOFT_MAX_BACK: | ||
return true; | ||
case GGML_OP_SUM: | ||
case GGML_OP_SUM_ROWS: | ||
case GGML_OP_MEAN: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a pre-existing bug, but it looks like the sum sum_rows shader assumes the source is contiguous. Would be nice to update the check here, or update the shader to handle it (which would be more involved). |
||
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); | ||
case GGML_OP_ARGMAX: | ||
case GGML_OP_COUNT_EQUAL: | ||
case GGML_OP_IM2COL: | ||
|
@@ -11983,6 +12049,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * | |
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); | ||
} else if (tensor->op == GGML_OP_SUM_ROWS) { | ||
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); | ||
} else if (tensor->op == GGML_OP_MEAN) { | ||
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); | ||
} else if (tensor->op == GGML_OP_ARGMAX) { | ||
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); | ||
} else if (tensor->op == GGML_OP_COUNT_EQUAL) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,59 @@ | ||
#version 450 | ||
|
||
#include "generic_head.comp" | ||
#include "types.comp" | ||
|
||
#extension GL_EXT_control_flow_attributes : enable | ||
|
||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; | ||
|
||
layout (constant_id = 0) const uint BLOCK_SIZE = 32; | ||
|
||
layout (push_constant) uniform parameter | ||
{ | ||
uint n_cols; | ||
uint ne01, ne02; | ||
uint nb01, nb02, nb03; | ||
uint nb11, nb12, nb13; | ||
float weight; | ||
uint misalign_offsets; | ||
uint ne0_12mp, ne0_12L; | ||
uint ne0_1mp, ne0_1L; | ||
} p; | ||
|
||
uint get_aoffset() { return p.misalign_offsets >> 16; } | ||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; } | ||
|
||
// see init_fastdiv_values in ggml-vulkan.cpp | ||
uint fastdiv(uint n, uint mp, uint L) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to unify the multiple copies of these functions, but I can do it in a later change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes it would be good to share this stuff... I wanted to improve it on host side too (eg to make upscale fit better), but I think a separate PR is better at this point |
||
uint msbs, lsbs; | ||
// msbs = mulhi(n, mp) | ||
umulExtended(n, mp, msbs, lsbs); | ||
return (msbs + n) >> L; | ||
} | ||
|
||
|
||
shared FLOAT_TYPE tmp[BLOCK_SIZE]; | ||
|
||
void main() { | ||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; | ||
const uint col = gl_LocalInvocationID.x; | ||
const float weight = p.weight; | ||
|
||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); | ||
const uint i03_offset = i03 * p.ne01*p.ne02; | ||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); | ||
const uint i01 = row - i03_offset - i02*p.ne01; | ||
|
||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; | ||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; | ||
|
||
tmp[col] = FLOAT_TYPE(0.0f); | ||
tmp[col] = FLOAT_TYPE(0.0); | ||
|
||
for (uint i = col; i < p.KX; i += BLOCK_SIZE) { | ||
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]); | ||
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) { | ||
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]); | ||
} | ||
|
||
barrier(); | ||
|
@@ -32,6 +65,6 @@ void main() { | |
} | ||
|
||
if (col == 0) { | ||
data_d[row] = D_TYPE(tmp[0]); | ||
data_d[dst_idx] = D_TYPE(tmp[0] * weight); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was it intentional to include argsort? I haven't looked at the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does GGML_ASSERT(ggml_is_contiguous(dst->src[0])) like the others, so I included it since it was in the same place