Skip to content

Commit cf1e478

Browse files
committed
ggml-cpu: add repack GEMM and GEMV for floating-point
1 parent 137435f commit cf1e478

5 files changed

Lines changed: 6607 additions & 50 deletions

File tree

ggml/src/ggml-cpu/arch-fallback.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
3737
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
3838
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
39+
#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1
40+
#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1
3941
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
4042
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
4143
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
@@ -75,13 +77,31 @@
7577
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
7678
#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
7779
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
80+
#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16
81+
#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16
82+
#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16
83+
#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16
84+
#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32
85+
#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32
86+
#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32
87+
#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32
7888
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
7989
#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
8090
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
91+
#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16
92+
#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16
93+
#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16
94+
#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16
95+
#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32
96+
#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32
97+
#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32
98+
#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32
8199
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
82100
// repack.cpp
83101
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
84102
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
103+
#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1
104+
#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1
85105
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
86106
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
87107
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
@@ -116,6 +136,8 @@
116136
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
117137
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
118138
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
139+
#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1
140+
#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1
119141
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
120142
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
121143
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
@@ -160,6 +182,8 @@
160182
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
161183
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
162184
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
185+
#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1
186+
#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1
163187
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
164188
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
165189
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
@@ -206,6 +230,8 @@
206230
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
207231
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
208232
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
233+
#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1
234+
#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1
209235
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
210236
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
211237
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
@@ -254,6 +280,8 @@
254280
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
255281
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
256282
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
283+
#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1
284+
#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1
257285
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
258286
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
259287
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
@@ -306,6 +334,8 @@
306334
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
307335
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
308336
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
337+
#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1
338+
#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1
309339
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
310340
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
311341
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0

ggml/src/ggml-cpu/arch/riscv/repack.cpp

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,277 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
340340
#endif
341341
ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
342342
}
343+
344+
template<int ncols_interleaved>
345+
static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
346+
const int nb = n / 1;
347+
348+
assert (nr == 1);
349+
assert(n % 1 == 0);
350+
assert(nc % ncols_interleaved == 0);
351+
352+
const _Float16 * a_ptr = (const _Float16 *) vy;
353+
for (int x = 0; x < nc / ncols_interleaved; x++) {
354+
const block_f16<ncols_interleaved, 1> * b_ptr = (const block_f16<ncols_interleaved, 1> *) vx + (x * nb);
355+
356+
// Accumulators
357+
vfloat32m4_t sumf_0 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
358+
359+
for (int l = 0; l < nb; l++) {
360+
vfloat16m2_t b_0 = __riscv_vle16_v_f16m2((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved);
361+
362+
sumf_0 = __riscv_vfwmacc_vf_f32m4(sumf_0, *(const _Float16*)(&a_ptr[l]), b_0, ncols_interleaved);
363+
}
364+
365+
__riscv_vse32_v_f32m4(&s[x * ncols_interleaved], sumf_0, ncols_interleaved);
366+
}
367+
368+
return;
369+
}
370+
371+
void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
372+
#if defined __riscv_v_intrinsic
373+
ggml_gemv_f16_1xM_f16<16>(n, s, bs, vx, vy, nr, nc);
374+
return;
375+
#endif
376+
ggml_gemv_f16_1x16_f16_generic(n, s, bs, vx, vy, nr, nc);
377+
}
378+
379+
void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
380+
#if defined __riscv_v_intrinsic
381+
ggml_gemv_f16_1xM_f16<32>(n, s, bs, vx, vy, nr, nc);
382+
return;
383+
#endif
384+
ggml_gemv_f16_1x32_f16_generic(n, s, bs, vx, vy, nr, nc);
385+
}
386+
387+
void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
388+
#if defined __riscv_v_intrinsic
389+
ggml_gemv_f16_1xM_f16<64>(n, s, bs, vx, vy, nr, nc);
390+
return;
391+
#endif
392+
ggml_gemv_f16_1x64_f16_generic(n, s, bs, vx, vy, nr, nc);
393+
}
394+
395+
void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
396+
#if defined __riscv_v_intrinsic
397+
ggml_gemv_f16_1xM_f16<128>(n, s, bs, vx, vy, nr, nc);
398+
return;
399+
#endif
400+
ggml_gemv_f16_1x128_f16_generic(n, s, bs, vx, vy, nr, nc);
401+
}
402+
403+
template<int ncols_interleaved>
404+
static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
405+
const int nb = n / 1;
406+
407+
assert (nr == 1);
408+
assert(n % 1 == 0);
409+
assert(nc % ncols_interleaved == 0);
410+
411+
const float * a_ptr = (const float *) vy;
412+
for (int x = 0; x < nc / ncols_interleaved; x++) {
413+
const block_f32<ncols_interleaved, 1> * b_ptr = (const block_f32<ncols_interleaved, 1> *) vx + (x * nb);
414+
415+
// Accumulators
416+
vfloat32m4_t sumf_0 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
417+
418+
for (int l = 0; l < nb; l++) {
419+
vfloat32m4_t b_0 = __riscv_vle32_v_f32m4((const float *)&b_ptr[l].d[0], ncols_interleaved);
420+
421+
sumf_0 = __riscv_vfmacc_vf_f32m4(sumf_0, *(const float*)(&a_ptr[l]), b_0, ncols_interleaved);
422+
}
423+
424+
__riscv_vse32_v_f32m4(&s[x * ncols_interleaved], sumf_0, ncols_interleaved);
425+
}
426+
427+
return;
428+
}
429+
430+
void ggml_gemv_f32_1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
431+
#if defined __riscv_v_intrinsic
432+
ggml_gemv_f32_1xM_f32<16>(n, s, bs, vx, vy, nr, nc);
433+
return;
434+
#endif
435+
ggml_gemv_f32_1x16_f32_generic(n, s, bs, vx, vy, nr, nc);
436+
}
437+
438+
void ggml_gemv_f32_1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
439+
#if defined __riscv_v_intrinsic
440+
ggml_gemv_f32_1xM_f32<32>(n, s, bs, vx, vy, nr, nc);
441+
return;
442+
#endif
443+
ggml_gemv_f32_1x32_f32_generic(n, s, bs, vx, vy, nr, nc);
444+
}
445+
446+
void ggml_gemv_f32_1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
447+
#if defined __riscv_v_intrinsic
448+
ggml_gemv_f32_1xM_f32<64>(n, s, bs, vx, vy, nr, nc);
449+
return;
450+
#endif
451+
ggml_gemv_f32_1x64_f32_generic(n, s, bs, vx, vy, nr, nc);
452+
}
453+
454+
void ggml_gemv_f32_1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
455+
#if defined __riscv_v_intrinsic
456+
ggml_gemv_f32_1xM_f32<128>(n, s, bs, vx, vy, nr, nc);
457+
return;
458+
#endif
459+
ggml_gemv_f32_1x128_f32_generic(n, s, bs, vx, vy, nr, nc);
460+
}
461+
462+
template<int ncols_interleaved>
463+
static inline void ggml_gemm_f16_7x1xM_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
464+
const int nb = n / 1;
465+
466+
assert (nr % 7 == 0);
467+
assert(n % 1 == 0);
468+
assert(nc % ncols_interleaved == 0);
469+
470+
for (int y = 0; y < nr / 7; y++) {
471+
const block_f16_7x1 * a_ptr = (const block_f16_7x1*) vy + (y * nb);
472+
for (int x = 0; x < nc / ncols_interleaved; x++) {
473+
const block_f16<ncols_interleaved, 1> * b_ptr = (const block_f16<ncols_interleaved, 1> *) vx + (x * nb);
474+
475+
// Accumulators
476+
vfloat32m4_t sumf_0 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
477+
vfloat32m4_t sumf_1 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
478+
vfloat32m4_t sumf_2 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
479+
vfloat32m4_t sumf_3 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
480+
vfloat32m4_t sumf_4 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
481+
vfloat32m4_t sumf_5 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
482+
vfloat32m4_t sumf_6 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
483+
484+
for (int l = 0; l < nb; l++) {
485+
vfloat16m2_t b_0 = __riscv_vle16_v_f16m2((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved);
486+
487+
sumf_0 = __riscv_vfwmacc_vf_f32m4(sumf_0, *(const _Float16*)&a_ptr[l].d[0], b_0, ncols_interleaved);
488+
sumf_1 = __riscv_vfwmacc_vf_f32m4(sumf_1, *(const _Float16*)&a_ptr[l].d[1], b_0, ncols_interleaved);
489+
sumf_2 = __riscv_vfwmacc_vf_f32m4(sumf_2, *(const _Float16*)&a_ptr[l].d[2], b_0, ncols_interleaved);
490+
sumf_3 = __riscv_vfwmacc_vf_f32m4(sumf_3, *(const _Float16*)&a_ptr[l].d[3], b_0, ncols_interleaved);
491+
sumf_4 = __riscv_vfwmacc_vf_f32m4(sumf_4, *(const _Float16*)&a_ptr[l].d[4], b_0, ncols_interleaved);
492+
sumf_5 = __riscv_vfwmacc_vf_f32m4(sumf_5, *(const _Float16*)&a_ptr[l].d[5], b_0, ncols_interleaved);
493+
sumf_6 = __riscv_vfwmacc_vf_f32m4(sumf_6, *(const _Float16*)&a_ptr[l].d[6], b_0, ncols_interleaved);
494+
}
495+
496+
__riscv_vse32_v_f32m4(&s[(y * 7 + 0) * bs + x * ncols_interleaved], sumf_0, ncols_interleaved);
497+
__riscv_vse32_v_f32m4(&s[(y * 7 + 1) * bs + x * ncols_interleaved], sumf_1, ncols_interleaved);
498+
__riscv_vse32_v_f32m4(&s[(y * 7 + 2) * bs + x * ncols_interleaved], sumf_2, ncols_interleaved);
499+
__riscv_vse32_v_f32m4(&s[(y * 7 + 3) * bs + x * ncols_interleaved], sumf_3, ncols_interleaved);
500+
__riscv_vse32_v_f32m4(&s[(y * 7 + 4) * bs + x * ncols_interleaved], sumf_4, ncols_interleaved);
501+
__riscv_vse32_v_f32m4(&s[(y * 7 + 5) * bs + x * ncols_interleaved], sumf_5, ncols_interleaved);
502+
__riscv_vse32_v_f32m4(&s[(y * 7 + 6) * bs + x * ncols_interleaved], sumf_6, ncols_interleaved);
503+
}
504+
}
505+
return;
506+
}
507+
508+
void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
509+
#if defined __riscv_v_intrinsic
510+
ggml_gemm_f16_7x1xM_f16<16>(n, s, bs, vx, vy, nr, nc);
511+
return;
512+
#endif
513+
ggml_gemm_f16_7x1x16_f16_generic(n, s, bs, vx, vy, nr, nc);
514+
}
515+
516+
void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
517+
#if defined __riscv_v_intrinsic
518+
ggml_gemm_f16_7x1xM_f16<32>(n, s, bs, vx, vy, nr, nc);
519+
return;
520+
#endif
521+
ggml_gemm_f16_7x1x32_f16_generic(n, s, bs, vx, vy, nr, nc);
522+
}
523+
524+
void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
525+
#if defined __riscv_v_intrinsic
526+
ggml_gemm_f16_7x1xM_f16<64>(n, s, bs, vx, vy, nr, nc);
527+
return;
528+
#endif
529+
ggml_gemm_f16_7x1x64_f16_generic(n, s, bs, vx, vy, nr, nc);
530+
}
531+
532+
void ggml_gemm_f16_7x1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
533+
#if defined __riscv_v_intrinsic
534+
ggml_gemm_f16_7x1xM_f16<128>(n, s, bs, vx, vy, nr, nc);
535+
return;
536+
#endif
537+
ggml_gemm_f16_7x1x128_f16_generic(n, s, bs, vx, vy, nr, nc);
538+
}
539+
540+
template<int ncols_interleaved>
541+
static inline void ggml_gemm_f32_7x1xM_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
542+
const int nb = n / 1;
543+
544+
assert (nr % 7 == 0);
545+
assert(n % 1 == 0);
546+
assert(nc % ncols_interleaved == 0);
547+
548+
for (int y = 0; y < nr / 7; y++) {
549+
const block_f32_7x1 * a_ptr = (const block_f32_7x1*) vy + (y * nb);
550+
for (int x = 0; x < nc / ncols_interleaved; x++) {
551+
const block_f32<ncols_interleaved, 1> * b_ptr = (const block_f32<ncols_interleaved, 1> *) vx + (x * nb);
552+
553+
// Accumulators
554+
vfloat32m4_t sumf_0 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
555+
vfloat32m4_t sumf_1 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
556+
vfloat32m4_t sumf_2 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
557+
vfloat32m4_t sumf_3 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
558+
vfloat32m4_t sumf_4 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
559+
vfloat32m4_t sumf_5 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
560+
vfloat32m4_t sumf_6 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved);
561+
562+
for (int l = 0; l < nb; l++) {
563+
vfloat32m4_t b_0 = __riscv_vle32_v_f32m4((const float*)&b_ptr[l].d[0], ncols_interleaved);
564+
565+
sumf_0 = __riscv_vfmacc_vf_f32m4(sumf_0, *(const float*)&a_ptr[l].d[0], b_0, ncols_interleaved);
566+
sumf_1 = __riscv_vfmacc_vf_f32m4(sumf_1, *(const float*)&a_ptr[l].d[1], b_0, ncols_interleaved);
567+
sumf_2 = __riscv_vfmacc_vf_f32m4(sumf_2, *(const float*)&a_ptr[l].d[2], b_0, ncols_interleaved);
568+
sumf_3 = __riscv_vfmacc_vf_f32m4(sumf_3, *(const float*)&a_ptr[l].d[3], b_0, ncols_interleaved);
569+
sumf_4 = __riscv_vfmacc_vf_f32m4(sumf_4, *(const float*)&a_ptr[l].d[4], b_0, ncols_interleaved);
570+
sumf_5 = __riscv_vfmacc_vf_f32m4(sumf_5, *(const float*)&a_ptr[l].d[5], b_0, ncols_interleaved);
571+
sumf_6 = __riscv_vfmacc_vf_f32m4(sumf_6, *(const float*)&a_ptr[l].d[6], b_0, ncols_interleaved);
572+
}
573+
574+
__riscv_vse32_v_f32m4(&s[(y * 7 + 0) * bs + x * ncols_interleaved], sumf_0, ncols_interleaved);
575+
__riscv_vse32_v_f32m4(&s[(y * 7 + 1) * bs + x * ncols_interleaved], sumf_1, ncols_interleaved);
576+
__riscv_vse32_v_f32m4(&s[(y * 7 + 2) * bs + x * ncols_interleaved], sumf_2, ncols_interleaved);
577+
__riscv_vse32_v_f32m4(&s[(y * 7 + 3) * bs + x * ncols_interleaved], sumf_3, ncols_interleaved);
578+
__riscv_vse32_v_f32m4(&s[(y * 7 + 4) * bs + x * ncols_interleaved], sumf_4, ncols_interleaved);
579+
__riscv_vse32_v_f32m4(&s[(y * 7 + 5) * bs + x * ncols_interleaved], sumf_5, ncols_interleaved);
580+
__riscv_vse32_v_f32m4(&s[(y * 7 + 6) * bs + x * ncols_interleaved], sumf_6, ncols_interleaved);
581+
}
582+
}
583+
return;
584+
}
585+
586+
void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
587+
#if defined __riscv_v_intrinsic
588+
ggml_gemm_f32_7x1xM_f32<16>(n, s, bs, vx, vy, nr, nc);
589+
return;
590+
#endif
591+
ggml_gemm_f32_7x1x16_f32_generic(n, s, bs, vx, vy, nr, nc);
592+
}
593+
594+
void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
595+
#if defined __riscv_v_intrinsic
596+
ggml_gemm_f32_7x1xM_f32<32>(n, s, bs, vx, vy, nr, nc);
597+
return;
598+
#endif
599+
ggml_gemm_f32_7x1x32_f32_generic(n, s, bs, vx, vy, nr, nc);
600+
}
601+
602+
void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
603+
#if defined __riscv_v_intrinsic
604+
ggml_gemm_f32_7x1xM_f32<64>(n, s, bs, vx, vy, nr, nc);
605+
return;
606+
#endif
607+
ggml_gemm_f32_7x1x64_f32_generic(n, s, bs, vx, vy, nr, nc);
608+
}
609+
610+
void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
611+
#if defined __riscv_v_intrinsic
612+
ggml_gemm_f32_7x1xM_f32<128>(n, s, bs, vx, vy, nr, nc);
613+
return;
614+
#endif
615+
ggml_gemm_f32_7x1x128_f32_generic(n, s, bs, vx, vy, nr, nc);
616+
}

0 commit comments

Comments
 (0)