Skip to content

Commit c5023da

Browse files
authored
opencl: support imrope (ggml-org#16914)
* opencl: support imrope * opencl: fix whitespace
1 parent e7da30b commit c5023da

File tree

2 files changed

+56
-24
lines changed

2 files changed

+56
-24
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8399,6 +8399,7 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
83998399
const bool is_neox = mode & 2;
84008400
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
84018401
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
8402+
const int is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
84028403

84038404
if (is_mrope) {
84048405
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
@@ -8489,9 +8490,14 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const
84898490
CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor));
84908491
CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast));
84918492
CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow));
8493+
// both mrope and vision kernels have sections
84928494
if (is_mrope || is_vision) {
84938495
CL_CHECK(clSetKernelArg(kernel, 33, sizeof(int32_t)*4, &sections));
84948496
}
8497+
// only mrope has is_imrope
8498+
if (is_mrope && !is_vision) {
8499+
CL_CHECK(clSetKernelArg(kernel, 34, sizeof(int), &is_imrope));
8500+
}
84958501

84968502
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
84978503
size_t local_work_size[] = {(size_t)nth, 1, 1};

ggml/src/ggml-opencl/kernels/rope.cl

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32(
392392
float attn_factor,
393393
float beta_fast,
394394
float beta_slow,
395-
int4 sections
395+
int4 sections,
396+
int is_imrope
396397
) {
397398
src0 = (global void*)((global char*)src0 + offset0);
398399
src1 = (global int*)((global char*)src1 + offset1);
@@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32(
419420
const int sector = (i0 / 2) % sect_dims;
420421
float theta_base = 0.0f;
421422

422-
if (sector < sections.s0) {
423-
theta_base = pos[i2];
424-
}
425-
else if (sector >= sections.s0 && sector < sec_w) {
426-
theta_base = pos[i2 + ne2 * 1];
427-
}
428-
else if (sector >= sec_w && sector < sec_w + sections.s2) {
429-
theta_base = pos[i2 + ne2 * 2];
430-
}
431-
else if (sector >= sec_w + sections.s2) {
432-
theta_base = pos[i2 + ne2 * 3];
423+
if (is_imrope) {
424+
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
425+
theta_base = (float) pos[i2 + ne02 * 1];
426+
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
427+
theta_base = (float) pos[i2 + ne02 * 2];
428+
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
429+
theta_base = (float) pos[i2 + ne02 * 0];
430+
} else { // e
431+
theta_base = (float) pos[i2 + ne02 * 3];
432+
}
433+
} else {
434+
if (sector < sections.s0) {
435+
theta_base = pos[i2];
436+
}
437+
else if (sector >= sections.s0 && sector < sec_w) {
438+
theta_base = pos[i2 + ne2 * 1];
439+
}
440+
else if (sector >= sec_w && sector < sec_w + sections.s2) {
441+
theta_base = pos[i2 + ne2 * 2];
442+
}
443+
else if (sector >= sec_w + sections.s2) {
444+
theta_base = pos[i2 + ne2 * 3];
445+
}
433446
}
434447

435448
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
@@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16(
490503
float attn_factor,
491504
float beta_fast,
492505
float beta_slow,
493-
int4 sections
506+
int4 sections,
507+
int is_imrope
494508
) {
495509
src0 = (global void*)((global char*)src0 + offset0);
496510
src1 = (global int*)((global char*)src1 + offset1);
@@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16(
517531
const int sector = (i0 / 2) % sect_dims;
518532
float theta_base = 0.0f;
519533

520-
if (sector < sections.s0) {
521-
theta_base = pos[i2];
522-
}
523-
else if (sector >= sections.s0 && sector < sec_w) {
524-
theta_base = pos[i2 + ne2 * 1];
525-
}
526-
else if (sector >= sec_w && sector < sec_w + sections.s2) {
527-
theta_base = pos[i2 + ne2 * 2];
528-
}
529-
else if (sector >= sec_w + sections.s2) {
530-
theta_base = pos[i2 + ne2 * 3];
534+
if (is_imrope) {
535+
if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
536+
theta_base = (float) pos[i2 + ne02 * 1];
537+
} else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
538+
theta_base = (float) pos[i2 + ne02 * 2];
539+
} else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
540+
theta_base = (float) pos[i2 + ne02 * 0];
541+
} else { // e
542+
theta_base = (float) pos[i2 + ne02 * 3];
543+
}
544+
} else {
545+
if (sector < sections.s0) {
546+
theta_base = pos[i2];
547+
}
548+
else if (sector >= sections.s0 && sector < sec_w) {
549+
theta_base = pos[i2 + ne2 * 1];
550+
}
551+
else if (sector >= sec_w && sector < sec_w + sections.s2) {
552+
theta_base = pos[i2 + ne2 * 2];
553+
}
554+
else if (sector >= sec_w + sections.s2) {
555+
theta_base = pos[i2 + ne2 * 3];
556+
}
531557
}
532558

533559
const float theta = theta_base * pow(freq_base, inv_ndims*i0);

0 commit comments

Comments
 (0)