Skip to content

Commit d6fdaf4

Browse files
committed
ggml-cpu: add repack GEMM and GEMV for floating-point (#4)
1 parent cf1e478 commit d6fdaf4

3 files changed

Lines changed: 56 additions & 97 deletions

File tree

ggml/src/ggml-cpu/arch/riscv/repack.cpp

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
343343

344344
template<int ncols_interleaved>
345345
static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
346+
GGML_UNUSED(bs);
347+
346348
const int nb = n / 1;
347349

348350
assert (nr == 1);
@@ -369,39 +371,41 @@ static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t
369371
}
370372

371373
void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
372-
#if defined __riscv_v_intrinsic
374+
#if defined __riscv_zvfh
373375
ggml_gemv_f16_1xM_f16<16>(n, s, bs, vx, vy, nr, nc);
374-
return;
375-
#endif
376+
#else
376377
ggml_gemv_f16_1x16_f16_generic(n, s, bs, vx, vy, nr, nc);
378+
#endif
377379
}
378380

379381
void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
380-
#if defined __riscv_v_intrinsic
382+
#if defined __riscv_zvfh
381383
ggml_gemv_f16_1xM_f16<32>(n, s, bs, vx, vy, nr, nc);
382-
return;
383-
#endif
384+
#else
384385
ggml_gemv_f16_1x32_f16_generic(n, s, bs, vx, vy, nr, nc);
386+
#endif
385387
}
386388

387389
void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
388-
#if defined __riscv_v_intrinsic
390+
#if defined __riscv_zvfh
389391
ggml_gemv_f16_1xM_f16<64>(n, s, bs, vx, vy, nr, nc);
390-
return;
391-
#endif
392+
#else
392393
ggml_gemv_f16_1x64_f16_generic(n, s, bs, vx, vy, nr, nc);
394+
#endif
393395
}
394396

395397
void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
396-
#if defined __riscv_v_intrinsic
398+
#if defined __riscv_zvfh
397399
ggml_gemv_f16_1xM_f16<128>(n, s, bs, vx, vy, nr, nc);
398-
return;
399-
#endif
400+
#else
400401
ggml_gemv_f16_1x128_f16_generic(n, s, bs, vx, vy, nr, nc);
402+
#endif
401403
}
402404

403405
template<int ncols_interleaved>
404406
static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
407+
GGML_UNUSED(bs);
408+
405409
const int nb = n / 1;
406410

407411
assert (nr == 1);
@@ -428,35 +432,35 @@ static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t
428432
}
429433

430434
void ggml_gemv_f32_1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
431-
#if defined __riscv_v_intrinsic
435+
#if defined __riscv_zvfh
432436
ggml_gemv_f32_1xM_f32<16>(n, s, bs, vx, vy, nr, nc);
433-
return;
434-
#endif
437+
#else
435438
ggml_gemv_f32_1x16_f32_generic(n, s, bs, vx, vy, nr, nc);
439+
#endif
436440
}
437441

438442
void ggml_gemv_f32_1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
439-
#if defined __riscv_v_intrinsic
443+
#if defined __riscv_zvfh
440444
ggml_gemv_f32_1xM_f32<32>(n, s, bs, vx, vy, nr, nc);
441-
return;
442-
#endif
445+
#else
443446
ggml_gemv_f32_1x32_f32_generic(n, s, bs, vx, vy, nr, nc);
447+
#endif
444448
}
445449

446450
void ggml_gemv_f32_1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
447-
#if defined __riscv_v_intrinsic
451+
#if defined __riscv_zvfh
448452
ggml_gemv_f32_1xM_f32<64>(n, s, bs, vx, vy, nr, nc);
449-
return;
450-
#endif
453+
#else
451454
ggml_gemv_f32_1x64_f32_generic(n, s, bs, vx, vy, nr, nc);
455+
#endif
452456
}
453457

454458
void ggml_gemv_f32_1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
455-
#if defined __riscv_v_intrinsic
459+
#if defined __riscv_zvfh
456460
ggml_gemv_f32_1xM_f32<128>(n, s, bs, vx, vy, nr, nc);
457-
return;
458-
#endif
461+
#else
459462
ggml_gemv_f32_1x128_f32_generic(n, s, bs, vx, vy, nr, nc);
463+
#endif
460464
}
461465

462466
template<int ncols_interleaved>
@@ -506,35 +510,35 @@ static inline void ggml_gemm_f16_7x1xM_f16(int n, float * GGML_RESTRICT s, size_
506510
}
507511

508512
void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
509-
#if defined __riscv_v_intrinsic
513+
#if defined __riscv_zvfh
510514
ggml_gemm_f16_7x1xM_f16<16>(n, s, bs, vx, vy, nr, nc);
511-
return;
512-
#endif
515+
#else
513516
ggml_gemm_f16_7x1x16_f16_generic(n, s, bs, vx, vy, nr, nc);
517+
#endif
514518
}
515519

516520
void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
517-
#if defined __riscv_v_intrinsic
521+
#if defined __riscv_zvfh
518522
ggml_gemm_f16_7x1xM_f16<32>(n, s, bs, vx, vy, nr, nc);
519-
return;
520-
#endif
523+
#else
521524
ggml_gemm_f16_7x1x32_f16_generic(n, s, bs, vx, vy, nr, nc);
525+
#endif
522526
}
523527

524528
void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
525-
#if defined __riscv_v_intrinsic
529+
#if defined __riscv_zvfh
526530
ggml_gemm_f16_7x1xM_f16<64>(n, s, bs, vx, vy, nr, nc);
527-
return;
528-
#endif
531+
#else
529532
ggml_gemm_f16_7x1x64_f16_generic(n, s, bs, vx, vy, nr, nc);
533+
#endif
530534
}
531535

532536
void ggml_gemm_f16_7x1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
533-
#if defined __riscv_v_intrinsic
537+
#if defined __riscv_zvfh
534538
ggml_gemm_f16_7x1xM_f16<128>(n, s, bs, vx, vy, nr, nc);
535-
return;
536-
#endif
539+
#else
537540
ggml_gemm_f16_7x1x128_f16_generic(n, s, bs, vx, vy, nr, nc);
541+
#endif
538542
}
539543

540544
template<int ncols_interleaved>
@@ -584,33 +588,33 @@ static inline void ggml_gemm_f32_7x1xM_f32(int n, float * GGML_RESTRICT s, size_
584588
}
585589

586590
void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
587-
#if defined __riscv_v_intrinsic
591+
#if defined __riscv_zvfh
588592
ggml_gemm_f32_7x1xM_f32<16>(n, s, bs, vx, vy, nr, nc);
589-
return;
590-
#endif
593+
#else
591594
ggml_gemm_f32_7x1x16_f32_generic(n, s, bs, vx, vy, nr, nc);
595+
#endif
592596
}
593597

594598
void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
595-
#if defined __riscv_v_intrinsic
599+
#if defined __riscv_zvfh
596600
ggml_gemm_f32_7x1xM_f32<32>(n, s, bs, vx, vy, nr, nc);
597-
return;
598-
#endif
601+
#else
599602
ggml_gemm_f32_7x1x32_f32_generic(n, s, bs, vx, vy, nr, nc);
603+
#endif
600604
}
601605

602606
void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
603-
#if defined __riscv_v_intrinsic
607+
#if defined __riscv_zvfh
604608
ggml_gemm_f32_7x1xM_f32<64>(n, s, bs, vx, vy, nr, nc);
605-
return;
606-
#endif
609+
#else
607610
ggml_gemm_f32_7x1x64_f32_generic(n, s, bs, vx, vy, nr, nc);
611+
#endif
608612
}
609613

610614
void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
611-
#if defined __riscv_v_intrinsic
615+
#if defined __riscv_zvfh
612616
ggml_gemm_f32_7x1xM_f32<128>(n, s, bs, vx, vy, nr, nc);
613-
return;
614-
#endif
617+
#else
615618
ggml_gemm_f32_7x1x128_f32_generic(n, s, bs, vx, vy, nr, nc);
619+
#endif
616620
}

ggml/src/ggml-cpu/repack.cpp

Lines changed: 2 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,51 +1413,8 @@ void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
14131413
}
14141414
}
14151415

1416-
void ggml_gemv_q8_0_4x8_q8_0_generic(int n,
1417-
float * GGML_RESTRICT s,
1418-
size_t bs,
1419-
const void * GGML_RESTRICT vx,
1420-
const void * GGML_RESTRICT vy,
1421-
int nr,
1422-
int nc) {
1423-
const int qk = QK8_0;
1424-
const int nb = n / qk;
1425-
const int ncols_interleaved = 4;
1426-
const int blocklen = 8;
1427-
1428-
assert(nr == 1);
1429-
assert(n % qk == 0);
1430-
assert(nc % ncols_interleaved == 0);
1431-
1432-
UNUSED(bs);
1433-
UNUSED(nr);
1434-
1435-
float sumf[4];
1436-
int sumi;
1437-
1438-
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1439-
for (int x = 0; x < nc / ncols_interleaved; x++) {
1440-
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
1441-
1442-
for (int j = 0; j < ncols_interleaved; j++) {
1443-
sumf[j] = 0.0;
1444-
}
1445-
for (int l = 0; l < nb; l++) {
1446-
for (int k = 0; k < (qk / blocklen); k++) {
1447-
for (int j = 0; j < ncols_interleaved; j++) {
1448-
sumi = 0;
1449-
for (int i = 0; i < blocklen; ++i) {
1450-
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
1451-
sumi += v0 * a_ptr[l].qs[k * blocklen + i];
1452-
}
1453-
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
1454-
}
1455-
}
1456-
}
1457-
for (int j = 0; j < ncols_interleaved; j++) {
1458-
s[x * ncols_interleaved + j] = sumf[j];
1459-
}
1460-
}
1416+
void ggml_gemv_f32_1x16_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
1417+
ggml_gemv_f32_KxM_f32_generic<1, 16>(n, s, bs, vx, vy, nr, nc);
14611418
}
14621419

14631420
#if defined __riscv_zvfh

ggml/src/ggml-cpu/repack.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,25 +205,22 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
205205
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
206206
void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
207207

208+
#ifdef __riscv_zvfh
208209
// FP16
209-
void ggml_repack_mat_f16_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
210210
void ggml_repack_mat_f16_7x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
211211
void ggml_gemv_f16_1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
212212
void ggml_gemv_f16_1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
213213
void ggml_gemv_f16_1x64_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
214214
void ggml_gemv_f16_1x128_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
215-
void ggml_gemm_f16_4x1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
216215
void ggml_gemm_f16_7x1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
217216
void ggml_gemm_f16_7x1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
218217
void ggml_gemm_f16_7x1x64_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
219218
void ggml_gemm_f16_7x1x128_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
220-
void ggml_repack_mat_f16_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
221219
void ggml_repack_mat_f16_7x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
222220
void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
223221
void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
224222
void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
225223
void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
226-
void ggml_gemm_f16_4x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
227224
void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
228225
void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
229226
void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -248,6 +245,7 @@ void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const v
248245
void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
249246
void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
250247
void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
248+
#endif
251249

252250
#if defined(__cplusplus)
253251
} // extern "C"

0 commit comments

Comments
 (0)