@@ -392,7 +392,8 @@ kernel void kernel_rope_multi_f32(
392392 float attn_factor ,
393393 float beta_fast ,
394394 float beta_slow ,
395- int4 sections
395+ int4 sections ,
396+ int is_imrope
396397) {
397398 src0 = (global void * )((global char * )src0 + offset0 );
398399 src1 = (global int * )((global char * )src1 + offset1 );
@@ -419,17 +420,29 @@ kernel void kernel_rope_multi_f32(
419420 const int sector = (i0 / 2 ) % sect_dims ;
420421 float theta_base = 0.0f ;
421422
422- if (sector < sections .s0 ) {
423- theta_base = pos [i2 ];
424- }
425- else if (sector >= sections .s0 && sector < sec_w ) {
426- theta_base = pos [i2 + ne2 * 1 ];
427- }
428- else if (sector >= sec_w && sector < sec_w + sections .s2 ) {
429- theta_base = pos [i2 + ne2 * 2 ];
430- }
431- else if (sector >= sec_w + sections .s2 ) {
432- theta_base = pos [i2 + ne2 * 3 ];
423+ if (is_imrope ) {
424+ if (sector % 3 == 1 && sector < 3 * sections .s1 ) { // h
425+ theta_base = (float ) pos [i2 + ne02 * 1 ];
426+ } else if (sector % 3 == 2 && sector < 3 * sections .s2 ) { // w
427+ theta_base = (float ) pos [i2 + ne02 * 2 ];
428+ } else if (sector % 3 == 0 && sector < 3 * sections .s0 ) { // t
429+ theta_base = (float ) pos [i2 + ne02 * 0 ];
430+ } else { // e
431+ theta_base = (float ) pos [i2 + ne02 * 3 ];
432+ }
433+ } else {
434+ if (sector < sections .s0 ) {
435+ theta_base = pos [i2 ];
436+ }
437+ else if (sector >= sections .s0 && sector < sec_w ) {
438+ theta_base = pos [i2 + ne2 * 1 ];
439+ }
440+ else if (sector >= sec_w && sector < sec_w + sections .s2 ) {
441+ theta_base = pos [i2 + ne2 * 2 ];
442+ }
443+ else if (sector >= sec_w + sections .s2 ) {
444+ theta_base = pos [i2 + ne2 * 3 ];
445+ }
433446 }
434447
435448 const float theta = theta_base * pow (freq_base , inv_ndims * i0 );
@@ -490,7 +503,8 @@ kernel void kernel_rope_multi_f16(
490503 float attn_factor ,
491504 float beta_fast ,
492505 float beta_slow ,
493- int4 sections
506+ int4 sections ,
507+ int is_imrope
494508) {
495509 src0 = (global void * )((global char * )src0 + offset0 );
496510 src1 = (global int * )((global char * )src1 + offset1 );
@@ -517,17 +531,29 @@ kernel void kernel_rope_multi_f16(
517531 const int sector = (i0 / 2 ) % sect_dims ;
518532 float theta_base = 0.0f ;
519533
520- if (sector < sections .s0 ) {
521- theta_base = pos [i2 ];
522- }
523- else if (sector >= sections .s0 && sector < sec_w ) {
524- theta_base = pos [i2 + ne2 * 1 ];
525- }
526- else if (sector >= sec_w && sector < sec_w + sections .s2 ) {
527- theta_base = pos [i2 + ne2 * 2 ];
528- }
529- else if (sector >= sec_w + sections .s2 ) {
530- theta_base = pos [i2 + ne2 * 3 ];
534+ if (is_imrope ) {
535+ if (sector % 3 == 1 && sector < 3 * sections .s1 ) { // h
536+ theta_base = (float ) pos [i2 + ne02 * 1 ];
537+ } else if (sector % 3 == 2 && sector < 3 * sections .s2 ) { // w
538+ theta_base = (float ) pos [i2 + ne02 * 2 ];
539+ } else if (sector % 3 == 0 && sector < 3 * sections .s0 ) { // t
540+ theta_base = (float ) pos [i2 + ne02 * 0 ];
541+ } else { // e
542+ theta_base = (float ) pos [i2 + ne02 * 3 ];
543+ }
544+ } else {
545+ if (sector < sections .s0 ) {
546+ theta_base = pos [i2 ];
547+ }
548+ else if (sector >= sections .s0 && sector < sec_w ) {
549+ theta_base = pos [i2 + ne2 * 1 ];
550+ }
551+ else if (sector >= sec_w && sector < sec_w + sections .s2 ) {
552+ theta_base = pos [i2 + ne2 * 2 ];
553+ }
554+ else if (sector >= sec_w + sections .s2 ) {
555+ theta_base = pos [i2 + ne2 * 3 ];
556+ }
531557 }
532558
533559 const float theta = theta_base * pow (freq_base , inv_ndims * i0 );
0 commit comments