Skip to content

Commit b6d4517

Browse files
committed
ggml-cpu: add generic impls
1 parent fb322e1 commit b6d4517

2 files changed

Lines changed: 108 additions & 25 deletions

File tree

ggml/src/ggml-cpu/arch/riscv/repack.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -424,22 +424,23 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
424424
// 1x16 integer accumulator
425425
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);
426426

427-
// Load `b_ptr`.
428-
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16);
429-
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
430-
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
431-
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
432-
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
433-
434427
// Accumulation loop.
435-
for (int i = 0; i < 16; i++) {
428+
for (int i = 0; i < QK4_NL / 2; i++) {
429+
// Load `b_ptr`.
430+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
431+
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
432+
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
433+
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
434+
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
435+
436436
const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16);
437437
const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16);
438438
sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16);
439439
}
440440

441-
vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16);
442-
vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
441+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16);
442+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
443+
443444
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
444445
}
445446

@@ -545,7 +546,7 @@ void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
545546
}
546547
return;
547548
#endif
548-
ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
549+
ggml_gemm_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc);
549550
}
550551

551552
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -589,15 +590,15 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
589590
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);
590591
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);
591592

592-
// Load `b_ptr`.
593-
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16);
594-
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
595-
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
596-
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
597-
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
598-
599593
// Accumulation loop.
600-
for (int i = 0; i < 16; i++) {
594+
for (int i = 0; i < QK4_NL / 2; i++) {
595+
// Load `b_ptr`.
596+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
597+
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
598+
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
599+
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
600+
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
601+
601602
const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16);
602603
const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16);
603604
const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16);
@@ -614,11 +615,11 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
614615
sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16);
615616
}
616617

617-
vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16);
618-
vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
619-
vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
620-
vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
621-
vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
618+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16);
619+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
620+
const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
621+
const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
622+
const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
622623

623624
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
624625
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);

ggml/src/ggml-cpu/repack.cpp

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,44 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
855855
}
856856
}
857857

858+
void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
859+
const int qk = QK8_0;
860+
const int nb = n / qk;
861+
const int ncols_interleaved = 16;
862+
const int blocklen = 1;
863+
864+
assert(nr == 1);
865+
assert(n % qk == 0);
866+
assert(nc % ncols_interleaved == 0);
867+
868+
UNUSED(bs);
869+
UNUSED(nr);
870+
871+
float sumf[16];
872+
int sumi;
873+
874+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
875+
for (int x = 0; x < nc / ncols_interleaved; x++) {
876+
const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
877+
878+
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
879+
for (int l = 0; l < nb; l++) {
880+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
881+
for (int j = 0; j < ncols_interleaved; j++) {
882+
sumi = 0;
883+
for (int i = 0; i < blocklen; ++i) {
884+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
885+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
886+
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
887+
}
888+
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
889+
}
890+
}
891+
}
892+
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
893+
}
894+
}
895+
858896
void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
859897
float * GGML_RESTRICT s,
860898
size_t bs,
@@ -1587,6 +1625,50 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
15871625
}
15881626
}
15891627

1628+
void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1629+
const int qk = QK8_0;
1630+
const int nb = n / qk;
1631+
const int ncols_interleaved = 16;
1632+
const int blocklen = 1;
1633+
1634+
assert(n % qk == 0);
1635+
assert(nr % 4 == 0);
1636+
assert(nc % ncols_interleaved == 0);
1637+
1638+
float sumf[4][16];
1639+
int sumi;
1640+
1641+
for (int y = 0; y < nr / 4; y++) {
1642+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1643+
for (int x = 0; x < nc / ncols_interleaved; x++) {
1644+
const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
1645+
for (int m = 0; m < 4; m++) {
1646+
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1647+
}
1648+
for (int l = 0; l < nb; l++) {
1649+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1650+
for (int m = 0; m < 4; m++) {
1651+
for (int j = 0; j < ncols_interleaved; j++) {
1652+
sumi = 0;
1653+
for (int i = 0; i < blocklen; ++i) {
1654+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
1655+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
1656+
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1657+
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4]));
1658+
}
1659+
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1660+
}
1661+
}
1662+
}
1663+
}
1664+
for (int m = 0; m < 4; m++) {
1665+
for (int j = 0; j < ncols_interleaved; j++)
1666+
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1667+
}
1668+
}
1669+
}
1670+
}
1671+
15901672
void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
15911673
float * GGML_RESTRICT s,
15921674
size_t bs,
@@ -2244,7 +2326,7 @@ static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_
22442326
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
22452327
GGML_ASSERT(interleave_block == 1);
22462328

2247-
const block_iq4_nl * src = (const block_iq4_nl *)data;
2329+
const block_iq4_nl * src = (const block_iq4_nl *)data;
22482330
block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data;
22492331

22502332
block_iq4_nl dst_tmp[16];

0 commit comments

Comments
 (0)