@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
333
333
334
334
cl_int alignment;
335
335
size_t max_alloc_size;
336
+ size_t max_workgroup_size;
336
337
bool fp16_support;
337
338
bool has_vector_subgroup_broadcast;
338
339
bool disable_fusion;
@@ -2218,6 +2219,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
2218
2219
clGetDeviceInfo (device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof (size_t ), &backend_ctx->max_alloc_size , NULL );
2219
2220
GGML_LOG_INFO (" ggml_opencl: max mem alloc size: %zu MB\n " , backend_ctx->max_alloc_size /1024 /1024 );
2220
2221
2222
+ clGetDeviceInfo (device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof (size_t ), &backend_ctx->max_workgroup_size , NULL );
2223
+ GGML_LOG_INFO (" ggml_opencl: device max workgroup size: %lu\n " , backend_ctx->max_workgroup_size );
2224
+
2221
2225
// Check SVM.
2222
2226
cl_device_svm_capabilities svm_caps;
2223
2227
CL_CHECK (clGetDeviceInfo (device, CL_DEVICE_SVM_CAPABILITIES, sizeof (cl_device_svm_capabilities), &svm_caps, 0 ));
@@ -2533,7 +2537,8 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
2533
2537
}
2534
2538
2535
2539
static bool ggml_opencl_supports_op (ggml_backend_dev_t dev, const struct ggml_tensor * op) {
2536
- GGML_UNUSED (dev);
2540
+ ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context ;
2541
+ ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx ;
2537
2542
2538
2543
switch (op->op ) {
2539
2544
case GGML_OP_NONE:
@@ -2708,8 +2713,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2708
2713
}
2709
2714
case GGML_OP_IM2COL:
2710
2715
return true ;
2711
- case GGML_OP_ARGSORT:
2712
- return op->src [0 ]->type == GGML_TYPE_F32;
2716
+ case GGML_OP_ARGSORT: {
2717
+ cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32 ;
2718
+ int max_workgroup_size = backend_ctx->get_kernel_workgroup_size (kernel);
2719
+
2720
+ int cols = 1 ;
2721
+ while (cols < op->ne [0 ]) {
2722
+ cols *= 2 ;
2723
+ }
2724
+
2725
+ return cols <= max_workgroup_size && op->src [0 ]->type == GGML_TYPE_F32;
2726
+ }
2713
2727
case GGML_OP_SUM_ROWS:
2714
2728
return op->src [0 ]->type == GGML_TYPE_F32 && ggml_is_contiguous (op->src [0 ]);
2715
2729
case GGML_OP_FLASH_ATTN_EXT:
0 commit comments