@@ -96,6 +96,37 @@ fp16_vec_L2sqr_sve(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
96
96
return svaddv_f32 (pg_32, total_sum);
97
97
}
98
98
99
+ float
100
+ fp16_vec_inner_product_sve (const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
101
+ svfloat32_t sum1 = svdup_f32 (0 .0f );
102
+ svfloat32_t sum2 = svdup_f32 (0 .0f );
103
+ size_t i = 0 ;
104
+
105
+ svbool_t pg_16 = svptrue_b16 ();
106
+ svbool_t pg_32 = svptrue_b32 ();
107
+
108
+ while (i < d) {
109
+ if (d - i < svcnth ())
110
+ pg_16 = svwhilelt_b16 (i, d);
111
+
112
+ svfloat16_t a_fp16 = svld1_f16 (pg_16, reinterpret_cast <const __fp16*>(x + i));
113
+ svfloat16_t b_fp16 = svld1_f16 (pg_16, reinterpret_cast <const __fp16*>(y + i));
114
+
115
+ svfloat32_t a_fp32_low = svcvt_f32_f16_z (pg_32, svtrn1_f16 (a_fp16, a_fp16));
116
+ svfloat32_t a_fp32_high = svcvt_f32_f16_z (pg_32, svtrn2_f16 (a_fp16, a_fp16));
117
+ svfloat32_t b_fp32_low = svcvt_f32_f16_z (pg_32, svtrn1_f16 (b_fp16, b_fp16));
118
+ svfloat32_t b_fp32_high = svcvt_f32_f16_z (pg_32, svtrn2_f16 (b_fp16, b_fp16));
119
+
120
+ sum1 = svmla_f32_m (pg_32, sum1, a_fp32_low, b_fp32_low);
121
+ sum2 = svmla_f32_m (pg_32, sum2, a_fp32_high, b_fp32_high);
122
+
123
+ i += svcnth ();
124
+ }
125
+
126
+ svfloat32_t total_sum = svadd_f32_m (pg_32, sum1, sum2);
127
+ return svaddv_f32 (pg_32, total_sum);
128
+ }
129
+
99
130
float
100
131
fvec_L1_sve (const float * x, const float * y, size_t d) {
101
132
svfloat32_t sum = svdup_f32 (0 .0f );
@@ -308,6 +339,14 @@ fvec_L2sqr_ny_sve(float* dis, const float* x, const float* y, size_t d, size_t n
308
339
}
309
340
}
310
341
342
+ void
343
+ fvec_inner_products_ny_sve (float * ip, const float * x, const float * y, size_t d, size_t ny) {
344
+ for (size_t i = 0 ; i < ny; ++i) {
345
+ ip[i] = fvec_inner_product_sve (x, y, d);
346
+ y += d;
347
+ }
348
+ }
349
+
311
350
} // namespace faiss
312
351
313
352
#endif
0 commit comments