Skip to content

Commit 090e5e6

Browse files
committed
ggml-cpu: add generic impl for iq4_nl gemm/gemv
1 parent 103e71c commit 090e5e6

2 files changed

Lines changed: 108 additions & 25 deletions

File tree

ggml/src/ggml-cpu/arch/riscv/repack.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -424,22 +424,23 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
424424
// 1x16 integer accumulator
425425
vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16);
426426

427-
// Load `b_ptr`.
428-
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16);
429-
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
430-
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
431-
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
432-
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
433-
434427
// Accumulation loop.
435-
for (int i = 0; i < 16; i++) {
428+
for (int i = 0; i < QK4_NL / 2; i++) {
429+
// Load `b_ptr`.
430+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
431+
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
432+
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
433+
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
434+
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
435+
436436
const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16);
437437
const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16);
438438
sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16);
439439
}
440440

441-
vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16);
442-
vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
441+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16);
442+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16);
443+
443444
sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16);
444445
}
445446

@@ -545,7 +546,7 @@ void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
545546
}
546547
return;
547548
#endif
548-
ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc);
549+
ggml_gemm_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc);
549550
}
550551

551552
void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -589,15 +590,15 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
589590
vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16);
590591
vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16);
591592

592-
// Load `b_ptr`.
593-
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16);
594-
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
595-
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
596-
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
597-
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
598-
599593
// Accumulation loop.
600-
for (int i = 0; i < 16; i++) {
594+
for (int i = 0; i < QK4_NL / 2; i++) {
595+
// Load `b_ptr`.
596+
const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16);
597+
const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16);
598+
const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16);
599+
// const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16);
600+
// const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16);
601+
601602
const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16);
602603
const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16);
603604
const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16);
@@ -614,11 +615,11 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
614615
sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16);
615616
}
616617

617-
vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16);
618-
vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
619-
vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
620-
vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
621-
vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
618+
const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16);
619+
const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16);
620+
const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16);
621+
const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16);
622+
const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16);
622623

623624
sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16);
624625
sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16);

ggml/src/ggml-cpu/repack.cpp

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,44 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
942942
}
943943
}
944944

945+
void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
946+
const int qk = QK8_0;
947+
const int nb = n / qk;
948+
const int ncols_interleaved = 16;
949+
const int blocklen = 1;
950+
951+
assert(nr == 1);
952+
assert(n % qk == 0);
953+
assert(nc % ncols_interleaved == 0);
954+
955+
UNUSED(bs);
956+
UNUSED(nr);
957+
958+
float sumf[16];
959+
int sumi;
960+
961+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
962+
for (int x = 0; x < nc / ncols_interleaved; x++) {
963+
const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
964+
965+
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
966+
for (int l = 0; l < nb; l++) {
967+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
968+
for (int j = 0; j < ncols_interleaved; j++) {
969+
sumi = 0;
970+
for (int i = 0; i < blocklen; ++i) {
971+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
972+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
973+
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
974+
}
975+
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
976+
}
977+
}
978+
}
979+
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
980+
}
981+
}
982+
945983
void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
946984
float * GGML_RESTRICT s,
947985
size_t bs,
@@ -1777,6 +1815,50 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
17771815
}
17781816
}
17791817

1818+
void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1819+
const int qk = QK8_0;
1820+
const int nb = n / qk;
1821+
const int ncols_interleaved = 16;
1822+
const int blocklen = 1;
1823+
1824+
assert(n % qk == 0);
1825+
assert(nr % 4 == 0);
1826+
assert(nc % ncols_interleaved == 0);
1827+
1828+
float sumf[4][16];
1829+
int sumi;
1830+
1831+
for (int y = 0; y < nr / 4; y++) {
1832+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
1833+
for (int x = 0; x < nc / ncols_interleaved; x++) {
1834+
const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb);
1835+
for (int m = 0; m < 4; m++) {
1836+
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
1837+
}
1838+
for (int l = 0; l < nb; l++) {
1839+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1840+
for (int m = 0; m < 4; m++) {
1841+
for (int j = 0; j < ncols_interleaved; j++) {
1842+
sumi = 0;
1843+
for (int i = 0; i < blocklen; ++i) {
1844+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
1845+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
1846+
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
1847+
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4]));
1848+
}
1849+
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
1850+
}
1851+
}
1852+
}
1853+
}
1854+
for (int m = 0; m < 4; m++) {
1855+
for (int j = 0; j < ncols_interleaved; j++)
1856+
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
1857+
}
1858+
}
1859+
}
1860+
}
1861+
17801862
void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
17811863
float * GGML_RESTRICT s,
17821864
size_t bs,
@@ -2554,7 +2636,7 @@ static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_
25542636
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
25552637
GGML_ASSERT(interleave_block == 1);
25562638

2557-
const block_iq4_nl * src = (const block_iq4_nl *)data;
2639+
const block_iq4_nl * src = (const block_iq4_nl *)data;
25582640
block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data;
25592641

25602642
block_iq4_nl dst_tmp[16];

0 commit comments

Comments
 (0)