diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 8db1583c8b9..6be784ca08b 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -521,6 +521,246 @@ void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q2_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); } +template +__attribute__((optimize("no-schedule-insns"))) +static inline void ggml_gemv_q3_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + // GEMV processes 1 row against 16 columns of weights + const int N_COLS_TILE = ncols_interleaved; + + assert(nc % N_COLS_TILE == 0); + + const int num_k_blocks = n / QK_K; + + // vl = 16. Using LMUL=2 for 32-bit accumulators on VLEN=256 + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + + // Loop over output columns (16 at a time) + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + const block_q8_K * lhs_base_ptr = (const block_q8_K *) vy; + const block_q3_Kx * rhs_base_ptr = (const block_q3_Kx *) vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + // Stage 3: Persistent Float Accumulator (1 vector for 16 columns) + vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_K * lhs_current = &lhs_base_ptr[k_block]; + const block_q3_Kx * rhs_current = &rhs_base_ptr[k_block]; + + const uint8_t * rhs_qs_ptr = rhs_current->qs; + const uint8_t * rhs_hmask_ptr = rhs_current->hmask; + const uint8_t * rhs_sc_low_ptr = rhs_current->scales; + const uint8_t * rhs_sc_high_ptr = rhs_current->scales + 8*ncols_interleaved; + + // Activation pointer (linear access for GEMV) + const int8_t * lhs_qs_ptr = lhs_current->qs; + + // Stage 2: Main Integer Accumulator (1 vector) + vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl); + + for (int group = 0; group < 4; ++group) { + // High scales are needed for all 4 sub-blocks + vuint8mf2_t v_sc_h_quad = __riscv_vle8_v_u8mf2(rhs_sc_high_ptr, vl); + rhs_sc_high_ptr += ncols_interleaved; + + // --- Scope 1: Sub-blocks 1 & 2 (Pair 0) --- + { + vuint8mf2_t v_sc_l_pair0 = __riscv_vle8_v_u8mf2(rhs_sc_low_ptr, vl); + rhs_sc_low_ptr += ncols_interleaved; + + // --- Sub-block 1 --- + { + // 1. Initialize Temps + vint16m1_t v_tsum = __riscv_vmv_v_x_i16m1(0, vl); + + // 2. Heavy Dot Product Loop + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + // Mask generation as requested + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + // Masked subtraction as requested + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + + // Scalar broadcast multiply-accumulate + v_tsum = __riscv_vwmacc_vx_i16m1(v_tsum, *lhs_qs_ptr, q_val, vl); + lhs_qs_ptr++; + } + } + } + + // 3. Just-In-Time Scale Calculation + vuint8mf2_t v_sc_lo = __riscv_vand_vx_u8mf2(v_sc_l_pair0, 0x0F, vl); + vuint8mf2_t v_sc_hi = __riscv_vand_vx_u8mf2(v_sc_h_quad, 0x03, vl); + vuint8mf2_t v_sc_u8 = __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + // 4. Accumulate + v_isum = __riscv_vwmacc_vv_i32m2(v_isum, v_sc_16, v_tsum, vl); + } + + // --- Sub-block 2 --- + { + vint16m1_t v_tsum = __riscv_vmv_v_x_i16m1(0, vl); + + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + + v_tsum = __riscv_vwmacc_vx_i16m1(v_tsum, *lhs_qs_ptr, q_val, vl); + lhs_qs_ptr++; + } + } + } + + // JIT Scale Calc (Shift 4) + vuint8mf2_t v_sc_lo = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_l_pair0, 4, vl), 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 2, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum = __riscv_vwmacc_vv_i32m2(v_isum, v_sc_16, v_tsum, vl); + } + } + + // --- Scope 2: Sub-blocks 3 & 4 (Pair 1) --- + { + vuint8mf2_t v_sc_l_pair1 = __riscv_vle8_v_u8mf2(rhs_sc_low_ptr, vl); + rhs_sc_low_ptr += ncols_interleaved; + + // --- Sub-block 3 --- + { + vint16m1_t v_tsum = __riscv_vmv_v_x_i16m1(0, vl); + + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + + v_tsum = __riscv_vwmacc_vx_i16m1(v_tsum, *lhs_qs_ptr, q_val, vl); + lhs_qs_ptr++; + } + } + } + + vuint8mf2_t v_sc_lo = __riscv_vand_vx_u8mf2(v_sc_l_pair1, 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 4, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + v_isum = __riscv_vwmacc_vv_i32m2(v_isum, v_sc_16, v_tsum, vl); + } + // --- Sub-block 4 --- + { + vint16m1_t v_tsum = __riscv_vmv_v_x_i16m1(0, vl); + + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + + v_tsum = __riscv_vwmacc_vx_i16m1(v_tsum, *lhs_qs_ptr, q_val, vl); + lhs_qs_ptr++; + } + } + } + + // JIT Scale Calc (Shift 6) + vuint8mf2_t v_sc_lo = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_l_pair1, 4, vl), 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 6, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum = __riscv_vwmacc_vv_i32m2(v_isum, v_sc_16, v_tsum, vl); + } + } // End Scope 2 (Pair 1) + } // End group loop + + // --- Final Super-Block accumulation --- + vfloat32m2_t rhs_d = + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *) rhs_current->d, vl), vl); + float lhs_d = lhs_current->d; + vfloat32m2_t v_isum_f = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl); + + // v_sumf += isum * d_act (scalar) * d_weight (vector) + v_sumf = __riscv_vfmacc_vv_f32m2(v_sumf, __riscv_vfmul_vf_f32m2(v_isum_f, lhs_d, vl), rhs_d, vl); + + } // End k_block loop + + // --- Store Results --- + // GEMV outputs a vector 's' (1 row). We store 16 contiguous elements. + __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); + + } // End col_tile loop +} + +void ggml_gemv_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + template static inline void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; @@ -832,7 +1072,139 @@ void ggml_gemv_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemv_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + // 1xM Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 4 16-element sub-blocks at once. + for (int j = 0; j < QK_K / 16; j += 4) { + // Load the scales. + // + // Low bits. + vint16m4_t scales = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(&b_ptr[l].scales[j * ncols_interleaved], 4 * ncols_interleaved), 4 * ncols_interleaved); + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + for (int i = k * 8; i < k * 8 + QK8_0 / 4; i++) { + // Load the high bits. + const vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 4 + i) * ncols_interleaved], ncols_interleaved); + + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 8 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_0_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_1_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_0_hi = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x30, ncols_interleaved); + const vuint8mf2_t b_1_hi = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x30, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_0_m = __riscv_vor_vv_u8mf2(b_0_lo, b_0_hi, ncols_interleaved); + const vuint8mf2_t b_1_m = __riscv_vor_vv_u8mf2(b_1_lo, b_1_hi, ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_0_m), 32, ncols_interleaved); + const vint8mf2_t b_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_1_m), 32, ncols_interleaved); + + // Multiply and accumulate in int16. + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 16 + 0 + i], b_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 16 + 16 + i], b_1, ncols_interleaved); + } + __asm__ __volatile__("" ::: "memory"); + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 8 + 16 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_2_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_3_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_2_hi = __riscv_vand_vx_u8mf2(b_hi, 0x30, ncols_interleaved); + const vuint8mf2_t b_3_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x30, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_2_m = __riscv_vor_vv_u8mf2(b_2_lo, b_2_hi, ncols_interleaved); + const vuint8mf2_t b_3_m = __riscv_vor_vv_u8mf2(b_3_lo, b_3_hi, ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_2_m), 32, ncols_interleaved); + const vint8mf2_t b_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_3_m), 32, ncols_interleaved); + + // Multiply and accumulate in int16. + sumi_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_s_2_16, a_ptr[l].qs[j * 16 + 32 + i], b_2, ncols_interleaved); + sumi_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_s_3_16, a_ptr[l].qs[j * 16 + 48 + i], b_3, ncols_interleaved); + } + __asm__ __volatile__("" ::: "memory"); + } + + // Multiply and accumulate in int32. + sumi = __riscv_vwmacc_vv_i32m2(sumi, sumi_s_0_16, __riscv_vget_v_i16m4_i16m1(scales, 0), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, sumi_s_1_16, __riscv_vget_v_i16m4_i16m1(scales, 1), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, sumi_s_2_16, __riscv_vget_v_i16m4_i16m1(scales, 2), ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, sumi_s_3_16, __riscv_vget_v_i16m4_i16m1(scales, 3), ncols_interleaved); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + +template +static inline void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; @@ -1691,8 +2063,302 @@ void ggml_gemm_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q2_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); } + +template +__attribute__((optimize("no-schedule-insns"))) +static inline void ggml_gemm_q3_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + const int N_ROWS_TILE = 4; + const int N_COLS_TILE = ncols_interleaved; + + assert(nr % N_ROWS_TILE == 0); + assert(nc % N_COLS_TILE == 0); + + const int num_k_blocks = n / QK_K; + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + + for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) { + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + const block_q8_Kx4 * lhs_base_ptr = (const block_q8_Kx4 *) vy + (row_tile / N_ROWS_TILE) * num_k_blocks; + const block_q3_Kx * rhs_base_ptr = (const block_q3_Kx *) vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + // Stage 3: Persistent Float Accumulators (8 registers) + vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_Kx4 * lhs_current = &lhs_base_ptr[k_block]; + const block_q3_Kx * rhs_current = &rhs_base_ptr[k_block]; + + const uint8_t * rhs_qs_ptr = rhs_current->qs; + const uint8_t * rhs_hmask_ptr = rhs_current->hmask; + const uint8_t * rhs_sc_low_ptr = rhs_current->scales; + const uint8_t * rhs_sc_high_ptr = rhs_current->scales + (8 * ncols_interleaved); + const int8_t * lhs_qs_ptr = lhs_current->qs; + + // Stage 2: Main Integer Accumulators (8 registers) + vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int group = 0; group < 4; ++group) { + // High scales are needed for all 4 sub-blocks (0.5 register) + vuint8mf2_t v_sc_h_quad = __riscv_vle8_v_u8mf2(rhs_sc_high_ptr, 16); + rhs_sc_high_ptr += ncols_interleaved; + + // --- Scope 1: Sub-blocks 1 & 2 (Pair 0) --- + // By scoping this, v_sc_l_pair0 dies before we load pair1 + { + vuint8mf2_t v_sc_l_pair0 = __riscv_vle8_v_u8mf2(rhs_sc_low_ptr, 16); + rhs_sc_low_ptr += ncols_interleaved; + + // --- Sub-block 1 --- + { + // 1. Initialize Temps (4 registers) + vint16m1_t v_tsum_0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_3 = __riscv_vmv_v_x_i16m1(0, vl); + + // 2. Heavy Dot Product Loop + // Note: v_sc_16 is NOT live here, saving 0.5 - 1 register of pressure + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + v_tsum_0 = __riscv_vwmacc_vx_i16m1(v_tsum_0, (int8_t) lhs_qs_ptr[0], q_val, vl); + v_tsum_1 = __riscv_vwmacc_vx_i16m1(v_tsum_1, (int8_t) lhs_qs_ptr[1], q_val, vl); + v_tsum_2 = __riscv_vwmacc_vx_i16m1(v_tsum_2, (int8_t) lhs_qs_ptr[2], q_val, vl); + v_tsum_3 = __riscv_vwmacc_vx_i16m1(v_tsum_3, (int8_t) lhs_qs_ptr[3], q_val, vl); + lhs_qs_ptr += 4; + } + } + } + + // 3. Just-In-Time Scale Calculation + // Only now do we allocate the register for v_sc_16 + vuint8mf2_t v_sc_lo = __riscv_vand_vx_u8mf2(v_sc_l_pair0, 0x0F, vl); + vuint8mf2_t v_sc_hi = __riscv_vand_vx_u8mf2(v_sc_h_quad, 0x03, vl); + vuint8mf2_t v_sc_u8 = + __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + // 4. Accumulate + v_isum_0 = __riscv_vwmacc_vv_i32m2(v_isum_0, v_sc_16, v_tsum_0, vl); + v_isum_1 = __riscv_vwmacc_vv_i32m2(v_isum_1, v_sc_16, v_tsum_1, vl); + v_isum_2 = __riscv_vwmacc_vv_i32m2(v_isum_2, v_sc_16, v_tsum_2, vl); + v_isum_3 = __riscv_vwmacc_vv_i32m2(v_isum_3, v_sc_16, v_tsum_3, vl); + } + + // --- Sub-block 2 --- + { + vint16m1_t v_tsum_0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_3 = __riscv_vmv_v_x_i16m1(0, vl); + + // Dot Product Loop (Same as above) + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + v_tsum_0 = __riscv_vwmacc_vx_i16m1(v_tsum_0, (int8_t) lhs_qs_ptr[0], q_val, vl); + v_tsum_1 = __riscv_vwmacc_vx_i16m1(v_tsum_1, (int8_t) lhs_qs_ptr[1], q_val, vl); + v_tsum_2 = __riscv_vwmacc_vx_i16m1(v_tsum_2, (int8_t) lhs_qs_ptr[2], q_val, vl); + v_tsum_3 = __riscv_vwmacc_vx_i16m1(v_tsum_3, (int8_t) lhs_qs_ptr[3], q_val, vl); + lhs_qs_ptr += 4; + } + } + } + + // JIT Scale Calc + vuint8mf2_t v_sc_lo = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_l_pair0, 4, vl), 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 2, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = + __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum_0 = __riscv_vwmacc_vv_i32m2(v_isum_0, v_sc_16, v_tsum_0, vl); + v_isum_1 = __riscv_vwmacc_vv_i32m2(v_isum_1, v_sc_16, v_tsum_1, vl); + v_isum_2 = __riscv_vwmacc_vv_i32m2(v_isum_2, v_sc_16, v_tsum_2, vl); + v_isum_3 = __riscv_vwmacc_vv_i32m2(v_isum_3, v_sc_16, v_tsum_3, vl); + } + } // v_sc_l_pair0 dies here + + // --- Scope 2: Sub-blocks 3 & 4 (Pair 1) --- + { + vuint8mf2_t v_sc_l_pair1 = __riscv_vle8_v_u8mf2(rhs_sc_low_ptr, 16); + rhs_sc_low_ptr += ncols_interleaved; + + // --- Sub-block 3 --- + { + vint16m1_t v_tsum_0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_3 = __riscv_vmv_v_x_i16m1(0, vl); + + // Dot Product Loop (Same as above) + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + v_tsum_0 = __riscv_vwmacc_vx_i16m1(v_tsum_0, (int8_t) lhs_qs_ptr[0], q_val, vl); + v_tsum_1 = __riscv_vwmacc_vx_i16m1(v_tsum_1, (int8_t) lhs_qs_ptr[1], q_val, vl); + v_tsum_2 = __riscv_vwmacc_vx_i16m1(v_tsum_2, (int8_t) lhs_qs_ptr[2], q_val, vl); + v_tsum_3 = __riscv_vwmacc_vx_i16m1(v_tsum_3, (int8_t) lhs_qs_ptr[3], q_val, vl); + lhs_qs_ptr += 4; + } + } + } + + // JIT Scale Calc + vuint8mf2_t v_sc_lo = __riscv_vand_vx_u8mf2(v_sc_l_pair1, 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 4, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = + __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum_0 = __riscv_vwmacc_vv_i32m2(v_isum_0, v_sc_16, v_tsum_0, vl); + v_isum_1 = __riscv_vwmacc_vv_i32m2(v_isum_1, v_sc_16, v_tsum_1, vl); + v_isum_2 = __riscv_vwmacc_vv_i32m2(v_isum_2, v_sc_16, v_tsum_2, vl); + v_isum_3 = __riscv_vwmacc_vv_i32m2(v_isum_3, v_sc_16, v_tsum_3, vl); + } + + // --- Sub-block 4 --- + { + vint16m1_t v_tsum_0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_tsum_3 = __riscv_vmv_v_x_i16m1(0, vl); + + // Dot Product Loop (Same as above) + for (int i8 = 0; i8 < 2; i8++) { + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(rhs_hmask_ptr, vl); + rhs_hmask_ptr += ncols_interleaved; + uint8_t m = 1; + for (int i4 = 0; i4 < 2; i4++) { + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += ncols_interleaved; + for (int w = 0; w < 4; w++) { + vuint8mf2_t q2 = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, w * 2, vl), 0x03, vl); + vbool16_t vmask = + __riscv_vmseq_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(vqh, m, vl), 0, vl); + m <<= 1; + vint8mf2_t q_val = __riscv_vreinterpret_v_u8mf2_i8mf2(q2); + q_val = __riscv_vsub_vx_i8mf2_mu(vmask, q_val, q_val, 4, vl); + v_tsum_0 = __riscv_vwmacc_vx_i16m1(v_tsum_0, (int8_t) lhs_qs_ptr[0], q_val, vl); + v_tsum_1 = __riscv_vwmacc_vx_i16m1(v_tsum_1, (int8_t) lhs_qs_ptr[1], q_val, vl); + v_tsum_2 = __riscv_vwmacc_vx_i16m1(v_tsum_2, (int8_t) lhs_qs_ptr[2], q_val, vl); + v_tsum_3 = __riscv_vwmacc_vx_i16m1(v_tsum_3, (int8_t) lhs_qs_ptr[3], q_val, vl); + lhs_qs_ptr += 4; + } + } + } + + // JIT Scale Calc + vuint8mf2_t v_sc_lo = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_l_pair1, 4, vl), 0x0F, vl); + vuint8mf2_t v_sc_hi = + __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_sc_h_quad, 6, vl), 0x03, vl); + vuint8mf2_t v_sc_u8 = + __riscv_vor_vv_u8mf2(v_sc_lo, __riscv_vsll_vx_u8mf2(v_sc_hi, 4, vl), vl); + vint16m1_t v_sc_16 = __riscv_vsext_vf2_i16m1( + __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(v_sc_u8), 32, vl), vl); + + v_isum_0 = __riscv_vwmacc_vv_i32m2(v_isum_0, v_sc_16, v_tsum_0, vl); + v_isum_1 = __riscv_vwmacc_vv_i32m2(v_isum_1, v_sc_16, v_tsum_1, vl); + v_isum_2 = __riscv_vwmacc_vv_i32m2(v_isum_2, v_sc_16, v_tsum_2, vl); + v_isum_3 = __riscv_vwmacc_vv_i32m2(v_isum_3, v_sc_16, v_tsum_3, vl); + } + } // v_sc_l_pair1 dies here + } // End group loop + + vfloat32m2_t rhs_d = + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *) rhs_current->d, vl), vl); + const float * lhs_d_ptr = lhs_current->d; + + vfloat32m2_t v_isum_0_f = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl); + vfloat32m2_t v_isum_1_f = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl); + vfloat32m2_t v_isum_2_f = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl); + vfloat32m2_t v_isum_3_f = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl); + + v_sumf_0 = + __riscv_vfmacc_vv_f32m2(v_sumf_0, __riscv_vfmul_vf_f32m2(v_isum_0_f, lhs_d_ptr[0], vl), rhs_d, vl); + v_sumf_1 = + __riscv_vfmacc_vv_f32m2(v_sumf_1, __riscv_vfmul_vf_f32m2(v_isum_1_f, lhs_d_ptr[1], vl), rhs_d, vl); + v_sumf_2 = + __riscv_vfmacc_vv_f32m2(v_sumf_2, __riscv_vfmul_vf_f32m2(v_isum_2_f, lhs_d_ptr[2], vl), rhs_d, vl); + v_sumf_3 = + __riscv_vfmacc_vv_f32m2(v_sumf_3, __riscv_vfmul_vf_f32m2(v_isum_3_f, lhs_d_ptr[3], vl), rhs_d, vl); + + } // End k_block loop + + __riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl); + } + } +} + +void ggml_gemm_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + template -void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; @@ -2272,7 +2938,198 @@ void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static inline void ggml_gemm_q6_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + // 4xM Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 2 16-element sub-blocks at once. + for (int j = 0; j < QK_K / 16; j += 4) { + // Load the scales. + // + // Low bits. + vint16m4_t scales = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(&b_ptr[l].scales[j * ncols_interleaved], 4 * ncols_interleaved), 4 * ncols_interleaved); + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + #pragma GCC unroll 1 + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_2_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_3_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + for (int i = k * 8; i < k * 8 + QK8_0 / 4; i++) { + // Load the high bits. + vuint8mf2_t b_hi = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[(j * 4 + i) * ncols_interleaved], ncols_interleaved); + + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 8 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_0_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_1_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_0_hi = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(b_hi, 4, ncols_interleaved), 0x30, ncols_interleaved); + const vuint8mf2_t b_1_hi = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x30, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_0_m = __riscv_vor_vv_u8mf2(b_0_lo, b_0_hi, ncols_interleaved); + const vuint8mf2_t b_1_m = __riscv_vor_vv_u8mf2(b_1_lo, b_1_hi, ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_0_m), 32, ncols_interleaved); + const vint8mf2_t b_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_1_m), 32, ncols_interleaved); + + // Multiply and accumulate in int16. + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 64 + i * 4 + 0], b_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 64 + i * 4 + 1], b_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 64 + i * 4 + 2], b_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 64 + i * 4 + 3], b_0, ncols_interleaved); + // + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 64 + 64 + i * 4 + 0], b_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 64 + 64 + i * 4 + 1], b_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 64 + 64 + i * 4 + 2], b_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 64 + 64 + i * 4 + 3], b_1, ncols_interleaved); + } + asm volatile ("" ::: "memory"); + { + // Load the low bits. + const vuint8mf2_t b_lo = __riscv_vle8_v_u8mf2(&b_ptr[l].ql[(j * 8 + 16 + i) * ncols_interleaved], ncols_interleaved); + const vuint8mf2_t b_2_lo = __riscv_vand_vx_u8mf2(b_lo, 0xF, ncols_interleaved); + const vuint8mf2_t b_3_lo = __riscv_vsrl_vx_u8mf2(b_lo, 4, ncols_interleaved); + + // Unpack the high bits. + const vuint8mf2_t b_2_hi = __riscv_vand_vx_u8mf2(b_hi, 0x30, ncols_interleaved); + const vuint8mf2_t b_3_hi = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(b_hi, 2, ncols_interleaved), 0x30, ncols_interleaved); + + // Merge the low bits with the corresponding high bits. + const vuint8mf2_t b_2_m = __riscv_vor_vv_u8mf2(b_2_lo, b_2_hi, ncols_interleaved); + const vuint8mf2_t b_3_m = __riscv_vor_vv_u8mf2(b_3_lo, b_3_hi, ncols_interleaved); + + // Bias adjustment. + const vint8mf2_t b_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_2_m), 32, ncols_interleaved); + const vint8mf2_t b_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(b_3_m), 32, ncols_interleaved); + + // Multiply and accumulate in int16. + sumi_0_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_2_16, a_ptr[l].qs[j * 64 + 128 + i * 4 + 0], b_2, ncols_interleaved); + sumi_1_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_2_16, a_ptr[l].qs[j * 64 + 128 + i * 4 + 1], b_2, ncols_interleaved); + sumi_2_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_2_16, a_ptr[l].qs[j * 64 + 128 + i * 4 + 2], b_2, ncols_interleaved); + sumi_3_s_2_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_2_16, a_ptr[l].qs[j * 64 + 128 + i * 4 + 3], b_2, ncols_interleaved); + // + sumi_0_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_3_16, a_ptr[l].qs[j * 64 + 192 + i * 4 + 0], b_3, ncols_interleaved); + sumi_1_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_3_16, a_ptr[l].qs[j * 64 + 192 + i * 4 + 1], b_3, ncols_interleaved); + sumi_2_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_3_16, a_ptr[l].qs[j * 64 + 192 + i * 4 + 2], b_3, ncols_interleaved); + sumi_3_s_3_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_3_16, a_ptr[l].qs[j * 64 + 192 + i * 4 + 3], b_3, ncols_interleaved); + } + asm volatile ("" ::: "memory"); + } + + // Multiply and accumulate in int32. + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m4_i16m1(scales, 0), sumi_0_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m4_i16m1(scales, 0), sumi_1_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m4_i16m1(scales, 0), sumi_2_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m4_i16m1(scales, 0), sumi_3_s_0_16, ncols_interleaved); + // + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m4_i16m1(scales, 1), sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m4_i16m1(scales, 1), sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m4_i16m1(scales, 1), sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m4_i16m1(scales, 1), sumi_3_s_1_16, ncols_interleaved); + // + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m4_i16m1(scales, 2), sumi_0_s_2_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m4_i16m1(scales, 2), sumi_1_s_2_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m4_i16m1(scales, 2), sumi_2_s_2_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m4_i16m1(scales, 2), sumi_3_s_2_16, ncols_interleaved); + // + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, __riscv_vget_v_i16m4_i16m1(scales, 3), sumi_0_s_3_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, __riscv_vget_v_i16m4_i16m1(scales, 3), sumi_1_s_3_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, __riscv_vget_v_i16m4_i16m1(scales, 3), sumi_2_s_3_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, __riscv_vget_v_i16m4_i16m1(scales, 3), sumi_3_s_3_16, ncols_interleaved); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } +} + +void ggml_gemm_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + +template +static inline void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int blocklen = 1; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 989fd179027..ce5b5bfa65f 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -531,6 +531,87 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +void ggml_gemv_q3_K_Mx1_q8_K_generic( + int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % ncols_interleaved == 0); + (void)bs; + + const int nb = n / QK_K; + const block_q3_Kx * x = (const block_q3_Kx *) vx; + const block_q8_K * y = (const block_q8_K *) vy; + + const int scale_high_offset = 8 * ncols_interleaved; + + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { + const block_q3_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; + + float sumf[ncols_interleaved]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + const block_q3_Kx & xb = x_ptr[k_block]; + const block_q8_K & yb = y[k_block]; + + int32_t isum[ncols_interleaved] = {0}; + + for (int sb = 0; sb < 16; ++sb) { + const int s_row_lo = sb >> 1; + const int s_row_hi = sb >> 2; + const int s_shift_lo = (sb & 1) ? 4 : 0; + const int s_shift_hi = (sb & 3) * 2; + + for (int l = 0; l < 16; ++l) { + const int k = sb * 16 + l; + const int qs_row = k >> 2; + const int qs_shift = (k & 3) * 2; + const int hm_row = k >> 3; + const int hm_shift = k & 7; + + const int8_t q8 = yb.qs[k]; + + for (int col = 0; col < ncols_interleaved; ++col) { + // Inline q3k_get_scale6_packed + const uint8_t scale_lo_byte = xb.scales[s_row_lo * ncols_interleaved + col]; + const uint8_t scale_hi_byte = xb.scales[scale_high_offset + s_row_hi * ncols_interleaved + col]; + + const uint8_t s_lo = (scale_lo_byte >> s_shift_lo) & 0x0F; + const uint8_t s_hi = (scale_hi_byte >> s_shift_hi) & 0x03; + const int sc = (int)((s_hi << 4) | s_lo) - 32; + + // Inline q3k_get_val_packed + const uint8_t qs_byte = xb.qs[qs_row * ncols_interleaved + col]; + const uint8_t hm_byte = xb.hmask[hm_row * ncols_interleaved + col]; + + const int low2 = (qs_byte >> qs_shift) & 3; + const int hb = (hm_byte >> hm_shift) & 1; + const int v = low2 - (hb ? 0 : 4); + + isum[col] += (v * sc) * q8; + } + } + } + + for (int col = 0; col < ncols_interleaved; ++col) { + sumf[col] += (float) isum[col] * (GGML_FP16_TO_FP32(xb.d[col]) * yb.d); + } + } + + for (int col = 0; col < ncols_interleaved; ++col) { + s[col_tile + col] = sumf[col]; + } + } +} + template static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; @@ -687,6 +768,76 @@ static inline void ggml_gemv_q5_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +static inline void ggml_gemv_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + + int sumi0; + int sumi1; + int sumi2; + int sumi3; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + // Processing 4 sub-blocks at once. + for (int sb = 0; sb < QK_K/16; sb += 4) { + const int8_t *scales_0 = &b_ptr[l].scales[sb * ncols_interleaved]; + const int8_t *scales_1 = &b_ptr[l].scales[(sb + 1) * ncols_interleaved]; + const int8_t *scales_2 = &b_ptr[l].scales[(sb + 2) * ncols_interleaved]; + const int8_t *scales_3 = &b_ptr[l].scales[(sb + 3) * ncols_interleaved]; + const int qh_idx = sb * 4; + for (int i = 0; i < QK8_0/2; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi0 = 0; + sumi1 = 0; + sumi = 0; + const uint8_t v0 = (uint8_t) (b_ptr[l].ql[(sb * 8 + i) * ncols_interleaved + j] & 0xF); + const uint8_t v1 = (uint8_t) (b_ptr[l].ql[(sb * 8 + i) * ncols_interleaved + j] >> 4); + const uint8_t v2 = (uint8_t) (b_ptr[l].ql[(sb * 8 + 16 + i) * ncols_interleaved + j] & 0xF); + const uint8_t v3 = (uint8_t) (b_ptr[l].ql[(sb * 8 + 16 + i) * ncols_interleaved + j] >> 4); + const int8_t a0 = (int8_t)(v0 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] << 4) & 0x30)) - 32; + const int8_t a1 = (int8_t)(v1 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] << 2) & 0x30)) - 32; + const int8_t a2 = (int8_t)(v2 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j]) & 0x30)) - 32; + const int8_t a3 = (int8_t)(v3 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> 2) & 0x30)) - 32; + sumi0 = (a0 * a_ptr[l].qs[sb * 16 + i]); + sumi1 = (a1 * a_ptr[l].qs[sb * 16 + 16 + i]); + sumi2 = (a2 * a_ptr[l].qs[sb * 16 + 32 + i]); + sumi3 = (a3 * a_ptr[l].qs[sb * 16 + 48 + i]); + sumi0 = sumi0 * scales_0[j]; + sumi1 = sumi1 * scales_1[j]; + sumi2 = sumi2 * scales_2[j]; + sumi3 = sumi3 * scales_3[j]; + sumi += sumi0 + sumi1 + sumi2 + sumi3; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + template static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -961,6 +1112,114 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +void ggml_gemm_q3_K_Mx1_q8_K_generic( + int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + + assert(n % QK_K == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + const int nb = n / QK_K; + const block_q3_Kx * x = (const block_q3_Kx *) vx; + const block_q8_Kx4 * y = (const block_q8_Kx4 *) vy; + + // Offsets for the high part of the scales (8 rows of low bytes * columns) + const int scale_high_offset = 8 * ncols_interleaved; + + for (int row_tile = 0; row_tile < nr; row_tile += 4) { + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { + const block_q3_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; + const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; + + float sumf[4][ncols_interleaved]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + const block_q3_Kx & xb = x_ptr[k_block]; + const block_q8_Kx4 & yb = y_ptr[k_block]; + + int32_t isum[4][ncols_interleaved]; + memset(isum, 0, sizeof(isum)); + + for (int sb = 0; sb < 16; ++sb) { + // Pre-calc scale indices for this sub-block + const int s_row_lo = sb >> 1; + const int s_row_hi = sb >> 2; + const int s_shift_lo = (sb & 1) ? 4 : 0; + const int s_shift_hi = (sb & 3) * 2; + + for (int l = 0; l < 16; ++l) { + const int k = sb * 16 + l; + + // Pre-calc weight indices for this k + const int qs_row = k >> 2; + const int qs_shift = (k & 3) * 2; + const int hm_row = k >> 3; + const int hm_shift = k & 7; + + const int8_t q8_0 = yb.qs[k * 4 + 0]; + const int8_t q8_1 = yb.qs[k * 4 + 1]; + const int8_t q8_2 = yb.qs[k * 4 + 2]; + const int8_t q8_3 = yb.qs[k * 4 + 3]; + + for (int col = 0; col < ncols_interleaved; ++col) { + // Inline q3k_get_scale6_packed + const uint8_t scale_lo_byte = xb.scales[s_row_lo * ncols_interleaved + col]; + const uint8_t scale_hi_byte = xb.scales[scale_high_offset + s_row_hi * ncols_interleaved + col]; + + const uint8_t s_lo = (scale_lo_byte >> s_shift_lo) & 0x0F; + const uint8_t s_hi = (scale_hi_byte >> s_shift_hi) & 0x03; + + const int sc = (int)((s_hi << 4) | s_lo) - 32; + + // Inline q3k_get_val_packed + const uint8_t qs_byte = xb.qs[qs_row * ncols_interleaved + col]; + const uint8_t hm_byte = xb.hmask[hm_row * ncols_interleaved + col]; + + const int low2 = (qs_byte >> qs_shift) & 3; + const int hb = (hm_byte >> hm_shift) & 1; + + const int v = low2 - (hb ? 0 : 4); + + const int w = v * sc; + isum[0][col] += w * q8_0; + isum[1][col] += w * q8_1; + isum[2][col] += w * q8_2; + isum[3][col] += w * q8_3; + } + } + } + + for (int col = 0; col < ncols_interleaved; ++col) { + const float d_rhs = GGML_FP16_TO_FP32(xb.d[col]); + const float g0 = d_rhs * yb.d[0]; + const float g1 = d_rhs * yb.d[1]; + const float g2 = d_rhs * yb.d[2]; + const float g3 = d_rhs * yb.d[3]; + + sumf[0][col] += (float) isum[0][col] * g0; + sumf[1][col] += (float) isum[1][col] * g1; + sumf[2][col] += (float) isum[2][col] * g2; + sumf[3][col] += (float) isum[3][col] * g3; + } + } + + for (int r = 0; r < 4; ++r) { + for (int col = 0; col < ncols_interleaved; ++col) { + s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; + } + } + } + } +} + template static inline void ggml_gemm_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; @@ -1151,6 +1410,87 @@ static inline void ggml_gemm_q5_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +static inline void ggml_gemm_q6_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + + float sumf[4][ncols_interleaved]; + int sumi0; + int sumi1; + int sumi2; + int sumi3; + int sumi; + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx * b_ptr = (const block_q6_Kx *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + + for (int l = 0; l < nb; l++) { + // Processing 4 sub-blocks at once. + for (int sb = 0; sb < QK_K/16; sb += 4) { + const int8_t *scales_0 = &b_ptr[l].scales[sb * ncols_interleaved]; + const int8_t *scales_1 = &b_ptr[l].scales[(sb + 1) * ncols_interleaved]; + const int8_t *scales_2 = &b_ptr[l].scales[(sb + 2) * ncols_interleaved]; + const int8_t *scales_3 = &b_ptr[l].scales[(sb + 3) * ncols_interleaved]; + const int qh_idx = sb * 4; + for (int i = 0; i < QK8_0/2; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi0 = 0; + sumi1 = 0; + sumi = 0; + const uint8_t v0 = (uint8_t) (b_ptr[l].ql[(sb * 8 + i) * ncols_interleaved + j] & 0xF); + const uint8_t v1 = (uint8_t) (b_ptr[l].ql[(sb * 8 + i) * ncols_interleaved + j] >> 4); + const uint8_t v2 = (uint8_t) (b_ptr[l].ql[(sb * 8 + 16 + i) * ncols_interleaved + j] & 0xF); + const uint8_t v3 = (uint8_t) (b_ptr[l].ql[(sb * 8 + 16 + i) * ncols_interleaved + j] >> 4); + const int8_t a0 = (int8_t)(v0 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] << 4) & 0x30)) - 32; + const int8_t a1 = (int8_t)(v1 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] << 2) & 0x30)) - 32; + const int8_t a2 = (int8_t)(v2 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j]) & 0x30)) - 32; + const int8_t a3 = (int8_t)(v3 | ((b_ptr[l].qh[(qh_idx + i) * ncols_interleaved + j] >> 2) & 0x30)) - 32;sumi0 = (a0 * a_ptr[l].qs[sb * 64 + i * 4 + m]); + sumi0 = (a0 * a_ptr[l].qs[sb * 64 + 0 + i * 4 + m]); + sumi1 = (a1 * a_ptr[l].qs[sb * 64 + 64 + i * 4 + m]); + sumi2 = (a2 * a_ptr[l].qs[sb * 64 + 128 + i * 4 + m]); + sumi3 = (a3 * a_ptr[l].qs[sb * 64 + 192 + i * 4 + m]); + sumi0 = sumi0 * scales_0[j]; + sumi1 = sumi1 * scales_1[j]; + sumi2 = sumi2 * scales_2[j]; + sumi3 = sumi3 * scales_3[j]; + sumi += sumi0 + sumi1 + sumi2 + sumi3; + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + template static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -2295,6 +2635,20 @@ void ggml_gemv_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q2_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q3_K +void ggml_gemv_q3_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q3_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q3_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // Q4_K void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q4_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -2323,6 +2677,20 @@ void ggml_gemv_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q5_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q6_K +void ggml_gemv_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q6_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -3124,6 +3492,20 @@ void ggml_gemm_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemm_q2_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q3_K +void ggml_gemm_q3_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q3_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q3_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // Q4_K void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_q4_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -3152,6 +3534,20 @@ void ggml_gemm_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemm_q5_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q6_K +void ggml_gemm_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q6_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -4060,6 +4456,122 @@ static int repack_q2_K_to_q2_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ GGML_UNUSED(data_size); } +template +static block_q3_Kx make_block_q3_KxMx1(const block_q3_K * in) { + block_q3_Kx out; + constexpr int N_COLS = nrows_interleaved; + constexpr int scales_bytes = 12; + constexpr int hmask_bytes = 32; + constexpr int qs_bytes = 64; + for (int i = 0; i < N_COLS; i++) { + out.d[i] = in[i].d; + } + + // 2. Process each column to Linearize metadata, then Interleave + uint8_t temp_scales[scales_bytes]; + uint8_t temp_hmask[hmask_bytes]; + uint8_t temp_qs[qs_bytes]; + + for (int col = 0; col < N_COLS; ++col) { + const block_q3_K & src = in[col]; + + uint8_t scale6[16]; + for (int sb = 0; sb < 16; ++sb) { + const uint8_t lo = sb < 8 ? (src.scales[sb] & 0xF) : (src.scales[sb - 8] >> 4); + const uint8_t hi = (src.scales[8 + (sb & 3)] >> (2 * (sb >> 2))) & 0x3; + scale6[sb] = lo | (hi << 4); + } + + // Repack into linear format: + // 0-7: Low 4 bits of pairs + for (int i = 0; i < 8; ++i) { + temp_scales[i] = (scale6[2*i] & 0x0F) | ((scale6[2*i + 1] & 0x0F) << 4); + } + // 8-11: High 2 bits of quads + for (int i = 0; i < 4; ++i) { + const int base = 4*i; + temp_scales[8 + i] = + (((scale6[base + 0] >> 4) & 0x03) << 0) | + (((scale6[base + 1] >> 4) & 0x03) << 2) | + (((scale6[base + 2] >> 4) & 0x03) << 4) | + (((scale6[base + 3] >> 4) & 0x03) << 6); + } + + // --- transpose HMask --- + memset(temp_hmask, 0, sizeof(temp_hmask)); + for (int hb = 0; hb < hmask_bytes; ++hb) { + const int elem_base = hb * 8; + for (int bit = 0; bit < 8; ++bit) { + const int idx = elem_base + bit; + // We want sequential: Byte `i` contains bits for weights `8*i` to `8*i+7` + const uint8_t hi = (src.hmask[idx & 31] >> (idx >> 5)) & 0x1; + temp_hmask[hb] |= (hi << bit); + } + } + + // --- QS (De-stride) --- + memset(temp_qs, 0, sizeof(temp_qs)); + for (int qb = 0; qb < qs_bytes; ++qb) { + const int elem_base = qb * 4; + for (int lane = 0; lane < 4; ++lane) { + const int idx = elem_base + lane; + // Logic to find byte offset and shift in standard Q3_K strided layout + const int src_byte = ((idx >> 7) << 5) + (idx & 31); + const int shift = ((idx >> 5) & 0x3) << 1; + const uint8_t lo2 = (src.qs[src_byte] >> shift) & 0x3; + temp_qs[qb] |= (lo2 << (2 * lane)); + } + } + + // --- Write Interleaved to Output --- + for (int i = 0; i < scales_bytes; ++i) { + out.scales[i * N_COLS + col] = temp_scales[i]; + } + for (int i = 0; i < hmask_bytes; ++i) { + out.hmask[i * N_COLS + col] = temp_hmask[i]; + } + for (int i = 0; i < qs_bytes; ++i) { + out.qs[i * N_COLS + col] = temp_qs[i]; + } + } + + return out; +} + +template +static int repack_q3_K_to_q3_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q3_K); + + block_q3_Kx * dst = (block_q3_Kx*)t->data; + const block_q3_K * src = (const block_q3_K*) data; + + block_q3_K dst_tmp[nrows_interleaved]; + + const int nrow = ggml_nrows(t); + const int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == (size_t) nrow * nblocks * sizeof(block_q3_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + // Gather N separate blocks from N adjacent rows + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + + *dst++ = make_block_q3_KxMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + template static block_q4_Kx make_block_q4_KxMx1(block_q4_K * in) { block_q4_Kx out; @@ -4226,6 +4738,110 @@ static int repack_q5_K_to_q5_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ GGML_UNUSED(data_size); } +template +static block_q6_Kx make_block_q6_KxMx1(block_q6_K * in) { + block_q6_Kx out; + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].d; + } + + const int end_ls = QK_K / 2; + for (int i = 0; i < end_ls; i += 64) { + for (int l = 0; l < 2; l++) { + for (int k = 0; k < 16; k++) { + uint8_t temp[nrows_interleaved]; + int dst_offset = (i + k + (l % 2) * 16) * nrows_interleaved; + int src_offset = i + k + (l % 2) * 32; + for (int j = 0; j < nrows_interleaved; j++) { + int src_id = j % nrows_interleaved; + + temp[j] = (in[src_id].ql[src_offset] & 0xF) | ((in[src_id].ql[src_offset + 16] & 0xF) << 4); + } + for (int j = 0; j < nrows_interleaved; j++) { + out.ql[dst_offset + j] = temp[j]; + } + } + } + for (int l = 0; l < 2; l++) { + for (int k = 0; k < 16; k++) { + uint8_t temp[nrows_interleaved]; + int dst_offset = (i + 32 + k + (l % 2) * 16) * nrows_interleaved; + int src_offset = i + k + (l % 2) * 32; + for (int j = 0; j < nrows_interleaved; j++) { + int src_id = j % nrows_interleaved; + + temp[j] = (in[src_id].ql[src_offset] >> 4) | ((in[src_id].ql[src_offset + 16] >> 4) << 4); + } + for (int j = 0; j < nrows_interleaved; j++) { + out.ql[dst_offset + j] = temp[j]; + } + } + } + } + + const int end_hs = QK_K / 4; + for (int i = 0; i < end_hs; i += 32) { + for (int l = 0; l < 2; l++) { + for (int k = 0; k < 16; k++) { + uint8_t temp[nrows_interleaved]; + int dst_offset = (i + l * 16 + k) * nrows_interleaved; + int src_offset = i + k; + for (int j = 0; j < nrows_interleaved; j++) { + int src_id = j; + + uint8_t a = (in[src_id].qh[src_offset] >> (4*(l%2))) & 3; + uint8_t b = (in[src_id].qh[src_offset + 16] >> (4*(l%2))) & 3; + uint8_t c = (in[src_id].qh[src_offset] >> (2+4*(l%2))) & 3; + uint8_t d = (in[src_id].qh[src_offset + 16] >> (2+4*(l%2))) & 3; + + temp[j] = a | (b << 2) | (c << 4) | (d << 6); + } + for (int j = 0; j < nrows_interleaved; j++) { + out.qh[dst_offset + j] = temp[j]; + } + } + } + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int j = 0; j < 16; j++) { + out.scales[j * nrows_interleaved + i] = in[i].scales[j]; + } + } + + return out; +} + +template +static int repack_q6_K_to_q6_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + + block_q6_Kx * dst = (block_q6_Kx*)t->data; + const block_q6_K * src = (const block_q6_K*) data; + block_q6_K dst_tmp[nrows_interleaved]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q6_KxMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + template static block_iq4_nlx make_block_iq4_nlxMx1(block_iq4_nl * in) { block_iq4_nlx out; @@ -4565,6 +5181,20 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_q2_K_to_q2_K_Mx1_bl<64>(t, data, data_size); } +// Q3_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_K_to_q3_K_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_K_to_q3_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_K_to_q3_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_K_to_q3_K_Mx1_bl<64>(t, data, data_size); +} + // Q4_K template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q4_K_to_q4_K_Mx1_bl<8>(t, data, data_size); @@ -4593,6 +5223,20 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_q5_K_to_q5_K_Mx1_bl<64>(t, data, data_size); } +// Q6_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_Mx1_bl<64>(t, data, data_size); +} + // IQ4_NL template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_Mx1_bl<8>(t, data, data_size); @@ -4740,6 +5384,20 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_q2_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q3_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q3_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q3_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q3_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q3_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // Q4_K template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); @@ -4768,6 +5426,20 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_q5_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q6_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4915,6 +5587,20 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_q2_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q3_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q3_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q3_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q3_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q3_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // Q4_K template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); @@ -4943,6 +5629,20 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_q5_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q6_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -5403,6 +6103,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q2_K_32x1_q8_K; static const ggml::cpu::repack::tensor_traits q2_K_64x1_q8_K; + // Q3_K + static const ggml::cpu::repack::tensor_traits q3_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q3_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q3_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q3_K_64x1_q8_K; + // Q4_K static const ggml::cpu::repack::tensor_traits q4_K_8x1_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; @@ -5415,6 +6121,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q5_K_32x1_q8_K; static const ggml::cpu::repack::tensor_traits q5_K_64x1_q8_K; + // Q6_K + static const ggml::cpu::repack::tensor_traits q6_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q6_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q6_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q6_K_64x1_q8_K; + // IQ4_NL static const ggml::cpu::repack::tensor_traits iq4_nl_8x1_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; @@ -5504,6 +6216,19 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons } #endif } + } else if (cur->type == GGML_TYPE_Q3_K) { + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 8 == 0) { return &q3_K_8x1_q8_K; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &q3_K_16x1_q8_K; } break; } + case 512: { if (cur->ne[1] % 32 == 0) { return &q3_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q3_K_64x1_q8_K; } break; } + default: { return nullptr; } + } + #endif + } + } else if (cur->type == GGML_TYPE_Q5_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { @@ -5537,6 +6262,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q6_K_8x4_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 8 == 0) { return &q6_K_8x1_q8_K; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &q6_K_16x1_q8_K; } break; } + case 512: { if (cur->ne[1] % 32 == 0) { return &q6_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q6_K_64x1_q8_K; } break; } + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index aa0470a3083..ea0bdffe11d 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -82,6 +82,24 @@ static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, static_assert(sizeof(block_q2_Kx32) == sizeof(ggml_half) * 64 + QK_K * 2 + QK_K * 8, "wrong q2_K block size/padding"); static_assert(sizeof(block_q2_Kx64) == sizeof(ggml_half) * 128 + QK_K * 4 + QK_K * 16, "wrong q2_K block size/padding"); +template +struct block_q3_Kx { + ggml_half d[N]; // super-block scales + uint8_t scales[12 * N]; // 6-bit quantized scales (packed) + uint8_t hmask[N * QK_K / 8]; // high bit of weights (1 bit/weight) + uint8_t qs[N * QK_K / 4]; // low 2 bits of weights (2 bits/weight) +}; + +using block_q3_Kx8 = block_q3_Kx<8>; +using block_q3_Kx16 = block_q3_Kx<16>; +using block_q3_Kx32 = block_q3_Kx<32>; +using block_q3_Kx64 = block_q3_Kx<64>; + +static_assert(sizeof(block_q3_Kx8) == sizeof(ggml_half) * 8 + 12 * 8 + QK_K + QK_K * 2, "wrong q3_K block size/padding for x8"); +static_assert(sizeof(block_q3_Kx16) == sizeof(ggml_half) * 16 + 12 * 16 + QK_K * 2 + QK_K * 4, "wrong q3_K block size/padding for x16"); +static_assert(sizeof(block_q3_Kx32) == sizeof(ggml_half) * 32 + 12 * 32 + QK_K * 4 + QK_K * 8, "wrong q3_K block size/padding for x32"); +static_assert(sizeof(block_q3_Kx64) == sizeof(ggml_half) * 64 + 12 * 64 + QK_K * 8 + QK_K * 16, "wrong q3_K block size/padding for x64"); + template struct block_q5_Kx { ggml_half d[N]; // super-block scale for quantized scales ggml_half dmin[N]; // super-block scale for quantized mins @@ -100,15 +118,22 @@ static_assert(sizeof(block_q5_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 1 static_assert(sizeof(block_q5_Kx32) == sizeof(ggml_half) * 64 + K_SCALE_SIZE * 32 + QK_K * 20, "wrong q5_K block size/padding"); static_assert(sizeof(block_q5_Kx64) == sizeof(ggml_half) * 128 + K_SCALE_SIZE * 64 + QK_K * 40, "wrong q5_K block size/padding"); -struct block_q6_Kx8 { - ggml_half d[8]; - int8_t scales[QK_K / 16 * 8]; - uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2) - uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4) +template struct block_q6_Kx { + ggml_half d[N]; + int8_t scales[QK_K / 16 * N]; + uint8_t ql[QK_K / 2 * N]; // low bits of 6-bit quants (groups of 2) + uint8_t qh[QK_K / 4 * N]; // high bits of 6-bit quants (groups of 4) }; -static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, - "wrong q6_K block size/padding"); +using block_q6_Kx8 = block_q6_Kx<8>; +using block_q6_Kx16 = block_q6_Kx<16>; +using block_q6_Kx32 = block_q6_Kx<32>; +using block_q6_Kx64 = block_q6_Kx<64>; + +static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx16) == sizeof(ggml_half) * 16 + QK_K / 16 * 16 + 3 * QK_K / 4 * 16, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx32) == sizeof(ggml_half) * 32 + QK_K / 16 * 32 + 3 * QK_K / 4 * 32, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx64) == sizeof(ggml_half) * 64 + QK_K / 16 * 64 + 3 * QK_K / 4 * 64, "wrong q6_K block size/padding"); struct block_q8_Kx4 { float d[4]; // delta @@ -207,6 +232,10 @@ void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -215,6 +244,10 @@ void ggml_gemv_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -235,6 +268,10 @@ void ggml_gemm_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -243,6 +280,10 @@ void ggml_gemm_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -305,6 +346,10 @@ void ggml_gemv_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q3_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -313,6 +358,10 @@ void ggml_gemv_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -333,6 +382,10 @@ void ggml_gemm_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q3_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -341,6 +394,10 @@ void ggml_gemm_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);