Skip to content

Commit 1a6e78f

Browse files
committed
opencl: mark argsort unsupported if cols exceed workgroup limit
1 parent 4d19698 commit 1a6e78f

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
333333

334334
cl_int alignment;
335335
size_t max_alloc_size;
336+
size_t max_workgroup_size;
336337
bool fp16_support;
337338
bool has_vector_subgroup_broadcast;
338339
bool disable_fusion;
@@ -2218,6 +2219,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
22182219
clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
22192220
GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
22202221

2222+
clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL);
2223+
GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size);
2224+
22212225
// Check SVM.
22222226
cl_device_svm_capabilities svm_caps;
22232227
CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0));
@@ -2533,7 +2537,8 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
25332537
}
25342538

25352539
static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
2536-
GGML_UNUSED(dev);
2540+
ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
2541+
ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx;
25372542

25382543
switch (op->op) {
25392544
case GGML_OP_NONE:
@@ -2708,8 +2713,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
27082713
}
27092714
case GGML_OP_IM2COL:
27102715
return true;
2711-
case GGML_OP_ARGSORT:
2712-
return op->src[0]->type == GGML_TYPE_F32;
2716+
case GGML_OP_ARGSORT: {
2717+
cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32;
2718+
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
2719+
2720+
int cols = 1;
2721+
while (cols < op->ne[0]) {
2722+
cols *= 2;
2723+
}
2724+
2725+
return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
2726+
}
27132727
case GGML_OP_SUM_ROWS:
27142728
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
27152729
case GGML_OP_FLASH_ATTN_EXT:

0 commit comments

Comments
 (0)