-
Notifications
You must be signed in to change notification settings - Fork 12.8k
vulkan : support ggml_mean #15393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
vulkan : support ggml_mean #15393
Conversation
@@ -11428,6 +11441,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | |||
case GGML_OP_SOFT_MAX_BACK: | |||
case GGML_OP_SUM: | |||
case GGML_OP_SUM_ROWS: | |||
case GGML_OP_MEAN: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a pre-existing bug, but it looks like the sum sum_rows shader assumes the source is contiguous. Would be nice to update the check here, or update the shader to handle it (which would be more involved).
I added support for views and non-contiguous source. It does affect performance slightly for the test with small workload. While testing this I also stumbled upon a bug (I think) where the sub-buffer size doesn't account for misalign offsets. The buffer range passed to the shader ends up being too small and few elements at the end are cut off. See the last commit for the fix. I'd also like to push a backend test that uses slice/permute, but at least cuda and sycl backends (and maybe others) would fail this. They have asserts for contiguous source. Updated numbers:
|
Thanks, this is a nice improvement. I think you're right about the misalignment bug. If you update the supports_op callback for other backends to check ggml_is_contiguous(src0), it will make them skip the new tests as unsupported. I think your updated shader still requires ggml_is_contiguous_rows(src0) in supports_op. |
Hm, it does respect |
I think you're right and I just misread the code. |
* cuda : require contiguous src for SUM_ROWS, MEAN support * sycl : require contiguous src for SUM, SUM_ROWS, ARGSORT support
@@ -4391,10 +4391,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g | |||
return true; | |||
case GGML_OP_UPSCALE: | |||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; | |||
case GGML_OP_POOL_2D: | |||
case GGML_OP_SUM: | |||
case GGML_OP_SUM_ROWS: | |||
case GGML_OP_ARGSORT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was it intentional to include argsort? I haven't looked at the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it does GGML_ASSERT(ggml_is_contiguous(dst->src[0])) like the others, so I included it since it was in the same place
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
@@ -8540,11 +8589,20 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c | |||
} | |||
|
|||
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { | |||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); | |||
vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); | |||
p.nb00 = 1; // treat src0 as flattened 1D tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? Wouldn't it already be 1 for contiguous rows?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote it with expectation to make it work with non-contiguous rows. But since I can't easily test it and don't have a use case for it either, I will just add a contiguous rows requirement, and remove p.nb00
. Better than code that pretends it works without having tested it.
uint get_doffset() { return p.misalign_offsets & 0xFFFF; } | ||
|
||
// see init_fastdiv_values in ggml-vulkan.cpp | ||
uint fastdiv(uint n, uint mp, uint L) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to unify the multiple copies of these functions, but I can do it in a later change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it would be good to share this stuff... I wanted to improve it on host side too (eg to make upscale fit better), but I think a separate PR is better at this point
Adds support for
GGML_OP_MEAN
in Vulkan backend.It reuses the
sum_rows
kernel, which also affectssum
. There's an additional multiply with push constant now after the reduction. From what I can see it doesn't noticeably affect performance of those operations, let me know if there's something else I should check.master
PR