@@ -424,22 +424,23 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
424424 // 1x16 integer accumulator
425425 vint32m2_t sumi = __riscv_vmv_v_x_i32m2 (0 .0f , 16 );
426426
427- // Load `b_ptr`.
428- const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2 ((const uint8_t *)b_ptr[l].qs , 16 );
429- const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2 (values, __riscv_vand_vx_u8mf2 (b_0_packed, 0xf , 16 ), 16 );
430- const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2 (values, __riscv_vsrl_vx_u8mf2 (b_0_packed, 4 , 16 ), 16 );
431- // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
432- // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
433-
434427 // Accumulation loop.
435- for (int i = 0 ; i < 16 ; i++) {
428+ for (int i = 0 ; i < QK4_NL / 2 ; i++) {
429+ // Load `b_ptr`.
430+ const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2 ((const uint8_t *)&b_ptr[l].qs [i * 16 ], 16 );
431+ const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2 (values, __riscv_vand_vx_u8mf2 (b_0_packed, 0xf , 16 ), 16 );
432+ const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2 (values, __riscv_vsrl_vx_u8mf2 (b_0_packed, 4 , 16 ), 16 );
433+ // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
434+ // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
435+
436436 const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1 (b_0_lo, a_ptr[l].qs [i], 16 );
437437 const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1 (b_0_hi, a_ptr[l].qs [16 + i], 16 );
438438 sumi = __riscv_vadd_vv_i32m2 (sumi, __riscv_vwadd_vv_i32m2 (sumi_lo, sumi_hi, 16 ), 16 );
439439 }
440440
441- vfloat16m1_t b_d = __riscv_vle16_v_f16m1 ((_Float16 *)b_ptr[l].d , 16 );
442- vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d , 16 );
441+ const vfloat16m1_t b_d = __riscv_vle16_v_f16m1 ((_Float16 *)b_ptr[l].d , 16 );
442+ const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d , 16 );
443+
443444 sumf = __riscv_vfmacc_vv_f32m2 (sumf, __riscv_vfcvt_f_x_v_f32m2 (sumi, 16 ), d_0, 16 );
444445 }
445446
@@ -545,7 +546,7 @@ void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
545546 }
546547 return ;
547548#endif
548- ggml_gemm_iq4_nl_16x1_q8_0_generic (n, s, bs, vx, vy, nr, nc);
549+ ggml_gemm_iq4_nl_4x16_q8_0_generic (n, s, bs, vx, vy, nr, nc);
549550}
550551
551552void ggml_gemm_iq4_nl_16x1_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -589,15 +590,15 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
589590 vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2 (0 .0f , 16 );
590591 vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2 (0 .0f , 16 );
591592
592- // Load `b_ptr`.
593- const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2 ((const uint8_t *)b_ptr[l].qs , 16 );
594- const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2 (values, __riscv_vand_vx_u8mf2 (b_0_packed, 0xf , 16 ), 16 );
595- const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2 (values, __riscv_vsrl_vx_u8mf2 (b_0_packed, 4 , 16 ), 16 );
596- // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
597- // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
598-
599593 // Accumulation loop.
600- for (int i = 0 ; i < 16 ; i++) {
594+ for (int i = 0 ; i < QK4_NL / 2 ; i++) {
595+ // Load `b_ptr`.
596+ const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2 ((const uint8_t *)&b_ptr[l].qs [i * 16 ], 16 );
597+ const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2 (values, __riscv_vand_vx_u8mf2 (b_0_packed, 0xf , 16 ), 16 );
598+ const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2 (values, __riscv_vsrl_vx_u8mf2 (b_0_packed, 4 , 16 ), 16 );
599+ // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
600+ // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
601+
601602 const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1 (b_0_lo, a_ptr[l].qs [i * 4 ], 16 );
602603 const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1 (b_0_lo, a_ptr[l].qs [i * 4 + 1 ], 16 );
603604 const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1 (b_0_lo, a_ptr[l].qs [i * 4 + 2 ], 16 );
@@ -614,11 +615,11 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
614615 sumi_3 = __riscv_vadd_vv_i32m2 (sumi_3, __riscv_vwadd_vv_i32m2 (sumi_3_lo, sumi_3_hi, 16 ), 16 );
615616 }
616617
617- vfloat16m1_t b_d = __riscv_vle16_v_f16m1 ((_Float16 *)b_ptr[l].d , 16 );
618- vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d [0 ], 16 );
619- vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d [1 ], 16 );
620- vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d [2 ], 16 );
621- vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d [3 ], 16 );
618+ const vfloat16m1_t b_d = __riscv_vle16_v_f16m1 ((_Float16 *)b_ptr[l].d , 16 );
619+ const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d [0 ], 16 );
620+ const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d [1 ], 16 );
621+ const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d [2 ], 16 );
622+ const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2 (b_d, *(const _Float16 *)&a_ptr[l].d [3 ], 16 );
622623
623624 sumf_0 = __riscv_vfmacc_vv_f32m2 (sumf_0, __riscv_vfcvt_f_x_v_f32m2 (sumi_0, 16 ), d_0, 16 );
624625 sumf_1 = __riscv_vfmacc_vv_f32m2 (sumf_1, __riscv_vfcvt_f_x_v_f32m2 (sumi_1, 16 ), d_1, 16 );
0 commit comments