@@ -278,6 +278,72 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
278
278
#endif
279
279
}
280
280
281
+ void ggml_vec_dot_mxfp4_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
282
+ assert (nrc == 1 );
283
+ UNUSED (nrc );
284
+ UNUSED (bx );
285
+ UNUSED (by );
286
+ UNUSED (bs );
287
+ assert (n % QK_MXFP4 == 0 );
288
+ static_assert (QK_MXFP4 == QK8_0 , "QK_MXFP4 and QK8_0 must be the same" );
289
+
290
+ const block_mxfp4 * GGML_RESTRICT x = vx ;
291
+ const block_q8_0 * GGML_RESTRICT y = vy ;
292
+
293
+ const int nb = n / QK_MXFP4 ;
294
+
295
+ int ib = 0 ;
296
+ float sumf = 0 ;
297
+
298
+ #if defined(__POWER9_VECTOR__ )
299
+ const vector signed char lowMask = vec_splats ((signed char )0xF );
300
+ const vector unsigned char vshift4 = vec_splats ((unsigned char )4 );
301
+ vector float vsumf0 = vec_splats (0.0f );
302
+
303
+ vector signed char kv = vec_xl (0 , (const signed char * )kvalues_mxfp4 );
304
+
305
+ #pragma GCC unroll 8
306
+ for (; ib < nb ; ++ ib ) {
307
+ __builtin_prefetch (x [ib ].qs , 0 , 1 );
308
+ __builtin_prefetch (y [ib ].qs , 0 , 1 );
309
+
310
+ vector float vyd = vec_splats (GGML_CPU_FP16_TO_FP32 (y [ib ].d ) *
311
+ GGML_E8M0_TO_FP32_HALF (x [ib ].e ));
312
+
313
+ vector signed char q8y0 = vec_xl ( 0 , y [ib ].qs );
314
+ vector signed char q8y1 = vec_xl (16 , y [ib ].qs );
315
+
316
+ vector signed char qxs = (vector signed char )vec_xl (0 , x [ib ].qs );
317
+
318
+ vector unsigned char lo_nibbles = (vector unsigned char )vec_and (qxs , lowMask );
319
+ vector unsigned char hi_nibbles = (vector unsigned char )vec_sr (qxs , vshift4 );
320
+
321
+ vector signed char q4x0 = vec_perm (kv , kv , lo_nibbles );
322
+ vector signed char q4x1 = vec_perm (kv , kv , hi_nibbles );
323
+
324
+ vector signed short qv0 = vec_add (vec_mule (q4x0 , q8y0 ), vec_mulo (q4x0 , q8y0 ));
325
+ vector signed short qv1 = vec_add (vec_mule (q4x1 , q8y1 ), vec_mulo (q4x1 , q8y1 ));
326
+
327
+ vector signed int vsumi0 = vec_splats ((int32_t )0 );
328
+ vsumi0 = vec_sum4s (qv0 , vsumi0 );
329
+ vsumi0 = vec_sum4s (qv1 , vsumi0 );
330
+
331
+ vsumf0 = vec_madd (vec_ctf (vsumi0 , 0 ), vyd , vsumf0 );
332
+ }
333
+
334
+ vsumf0 = vec_add (vsumf0 , vec_sld (vsumf0 , vsumf0 , 4 ));
335
+ vsumf0 = vec_add (vsumf0 , vec_sld (vsumf0 , vsumf0 , 8 ));
336
+ sumf = vec_extract (vsumf0 , 0 );
337
+ * s = sumf ;
338
+ #else
339
+ UNUSED (x );
340
+ UNUSED (y );
341
+ UNUSED (ib );
342
+ UNUSED (sumf );
343
+ ggml_vec_dot_mxfp4_q8_0_generic (n , s , bs , vx , bx , vy , by , nrc );
344
+ #endif
345
+ }
346
+
281
347
void ggml_vec_dot_q5_0_q8_0 (int n , float * GGML_RESTRICT s , size_t bs , const void * GGML_RESTRICT vx , size_t bx , const void * GGML_RESTRICT vy , size_t by , int nrc ) {
282
348
const int qk = QK8_0 ;
283
349
const int nb = n / qk ;
0 commit comments