Skip to content

opencl: mark argsort unsupported if cols exceed workgroup limit #15375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions ggml/src/ggml-opencl/ggml-opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {

cl_int alignment;
size_t max_alloc_size;
size_t max_workgroup_size;
bool fp16_support;
bool has_vector_subgroup_broadcast;
bool disable_fusion;
Expand Down Expand Up @@ -2218,6 +2219,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);

clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);

// Check SVM.
cl_device_svm_capabilities svm_caps;
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0));
Expand Down Expand Up @@ -2533,7 +2537,8 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
}

static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
GGML_UNUSED(dev);
ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx;

switch (op->op) {
case GGML_OP_NONE:
Expand Down Expand Up @@ -2708,8 +2713,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
}
case GGML_OP_IM2COL:
return true;
case GGML_OP_ARGSORT:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ARGSORT: {
cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32;
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);

int cols = 1;
while (cols < op->ne[0]) {
cols *= 2;
}

return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
}
case GGML_OP_SUM_ROWS:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
Expand Down
Loading