Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 52 additions & 48 deletions ggml/src/ggml-cpu/arch/riscv/repack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo

template<int ncols_interleaved>
static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
GGML_UNUSED(bs);

const int nb = n / 1;

assert (nr == 1);
Expand All @@ -369,39 +371,41 @@ static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t
}

void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemv_f16_1xM_f16<16>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemv_f16_1x16_f16_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemv_f16_1xM_f16<32>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemv_f16_1x32_f16_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemv_f16_1xM_f16<64>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemv_f16_1x64_f16_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemv_f16_1xM_f16<128>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemv_f16_1x128_f16_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

template<int ncols_interleaved>
static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
GGML_UNUSED(bs);

const int nb = n / 1;

assert (nr == 1);
Expand All @@ -428,35 +432,35 @@ static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t
}

void ggml_gemv_f32_1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemv_f32_1xM_f32<16>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemv_f32_1x16_f32_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemv_f32_1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemv_f32_1xM_f32<32>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemv_f32_1x32_f32_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemv_f32_1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemv_f32_1xM_f32<64>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemv_f32_1x64_f32_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemv_f32_1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemv_f32_1xM_f32<128>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemv_f32_1x128_f32_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

template<int ncols_interleaved>
Expand Down Expand Up @@ -506,35 +510,35 @@ static inline void ggml_gemm_f16_7x1xM_f16(int n, float * GGML_RESTRICT s, size_
}

void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemm_f16_7x1xM_f16<16>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemm_f16_7x1x16_f16_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemm_f16_7x1xM_f16<32>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemm_f16_7x1x32_f16_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemm_f16_7x1xM_f16<64>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemm_f16_7x1x64_f16_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemm_f16_7x1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemm_f16_7x1xM_f16<128>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemm_f16_7x1x128_f16_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

template<int ncols_interleaved>
Expand Down Expand Up @@ -584,33 +588,33 @@ static inline void ggml_gemm_f32_7x1xM_f32(int n, float * GGML_RESTRICT s, size_
}

void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemm_f32_7x1xM_f32<16>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemm_f32_7x1x16_f32_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemm_f32_7x1xM_f32<32>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemm_f32_7x1x32_f32_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemm_f32_7x1xM_f32<64>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemm_f32_7x1x64_f32_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
#if defined __riscv_v_intrinsic
#if defined __riscv_zvfh
ggml_gemm_f32_7x1xM_f32<128>(n, s, bs, vx, vy, nr, nc);
return;
#endif
#else
ggml_gemm_f32_7x1x128_f32_generic(n, s, bs, vx, vy, nr, nc);
#endif
}
Loading
Loading