diff --git a/src/opencl.c b/src/opencl.c index 7dc1857..f2adebf 100644 --- a/src/opencl.c +++ b/src/opencl.c @@ -239,6 +239,46 @@ hs_opencl_run( exit(1); }; + /* Query the max work group size. */ + size_t max_wg_size = 0; + err = clGetDeviceInfo(dids[options->device], CL_DEVICE_MAX_WORK_GROUP_SIZE, + sizeof(size_t), &max_wg_size, NULL); + + if (err != CL_SUCCESS) { + printf("failed to get CL_DEVICE_MAX_WORK_GROUP_SIZE: %d\n", err); + free(dids); + exit(1); + } + + /* Query the max work item dimensions. */ + size_t max_wi_dim; + err = clGetDeviceInfo(dids[options->device], + CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS, sizeof(max_wi_dim), &max_wi_dim, NULL); + + if (err != CL_SUCCESS) { + printf("failed to get CL_DEVICE_MAX_WORK_ITEM_DIMENSIONS: %d\n", err); + free(dids); + exit(1); + } + + /* Query the max work item size for each dimension. */ + size_t *max_wi_size = (size_t *)malloc(sizeof(size_t) * max_wi_dim); + err = clGetDeviceInfo(dids[options->device], CL_DEVICE_MAX_WORK_ITEM_SIZES, + sizeof(size_t) * max_wi_dim, max_wi_size, NULL); + + if (err != CL_SUCCESS) { + printf("failed to get CL_DEVICE_MAX_WORK_ITEM_SIZES: %d\n", err); + free(max_wi_size); + free(dids); + exit(1); + } + + /* Calculate the max work items for the device. */ + size_t max_work_items = 1; + for (size_t i = 0; i < max_wi_dim; i++) + max_work_items *= max_wi_size[i]; + + free(max_wi_size); free(dids); /* Create a kernel. */ @@ -259,12 +299,29 @@ hs_opencl_run( exit(1); } - size_t global_size = options->threads; - size_t local_size = options->blocks; + size_t total_work_items = options->threads; + size_t work_group_size = options->blocks; + + if (total_work_items > max_work_items || total_work_items < 1) + total_work_items = max_work_items; + + if (work_group_size > max_wg_size || work_group_size < 1) + work_group_size = max_wg_size; + + /** + * If total_work_items is not divisible by work_group_size, + * we need to lower total_work_items to the next multiple + * of work_group_size. + */ + if (total_work_items % work_group_size != 0) + total_work_items = total_work_items / work_group_size * work_group_size; + + printf("Total work items: %lu\n", total_work_items); + printf("Work group size: %lu\n", work_group_size); /* Enqueue kernel. */ err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, - &global_size, &local_size, 0, NULL, NULL); + &total_work_items, &work_group_size, 0, NULL, NULL); if (err != CL_SUCCESS) { printf("failed to enqueue the kernel: %d\n", err);