@@ -1663,16 +1663,10 @@ kernel void kernel_ssm_conv_f32(
16631663        device const   void  * src0,
16641664        device const   void  * src1,
16651665        device       float  * dst,
1666-         threadgroup  float  * shared [[threadgroup(0 )]],
16671666        constant ggml_metal_kargs_ssm_conv & args,
1668-         uint3  tgpig[[threadgroup_position_in_grid]],
1669-         uint3  tpitg[[thread_position_in_threadgroup]],
1670-         ushort sgitg[[simdgroup_index_in_threadgroup]],
1671-         ushort tiisg[[thread_index_in_simdgroup]],
1672-         ushort sgptg[[simdgroups_per_threadgroup]],
1673-         uint3   tgpg[[threadgroups_per_grid]]) {
1674- 
1675-     const  int64_t  i0 = tpitg.x ;
1667+         uint3 tgpig[[threadgroup_position_in_grid]],
1668+         uint3 tpitg[[thread_position_in_threadgroup]],
1669+         uint3   ntg[[threads_per_threadgroup]]) {
16761670    const  int64_t  ir = tgpig.x ;
16771671    const  int64_t  i2 = tgpig.y ;
16781672    const  int64_t  i3 = tgpig.z ;
@@ -1687,31 +1681,13 @@ kernel void kernel_ssm_conv_f32(
16871681    device const  float  * c = (device const  float  *) ((device const  char  *) src1 + ir*args.nb11 );
16881682    device       float  * x = (device       float  *) ((device       char  *) dst  + ir*args.nb0   + i2*args.nb1   + i3*args.nb2 );
16891683
1690-     float  sumf = s[i0] * c[i0];
1691- 
1692-     //  Parallel sum: first sum over threads in simd group, then sum over simd
1693-     //  group sums
1694-     sumf = simd_sum (sumf);
1684+     float  sumf = 0 .0f ;
16951685
1696-     //  If multiple simd groups per threadgroup, sum over simd group sums
1697-     if  (sgptg > 1 ) {
1698-         if  (tiisg == 0 ) {
1699-             shared[sgitg] = sumf;
1700-         }
1701-         threadgroup_barrier (mem_flags::mem_threadgroup);
1702-         sumf = 0 .0f ;
1703-         if  (sgitg == 0 ) {
1704-             if  (tiisg < sgptg) {
1705-                 sumf = shared[tiisg];
1706-             }
1707-             sumf = simd_sum (sumf);
1708-             if  (tiisg == 0 ) {
1709-                 x[0 ] = sumf;
1710-             }
1711-         }
1712-     } else  if  (tiisg == 0 ) {
1713-         x[0 ] = sumf;
1686+     for  (int64_t  i0 = 0 ; i0 < nc; ++i0) {
1687+         sumf += s[i0] * c[i0];
17141688    }
1689+ 
1690+     x[0 ] = sumf;
17151691}
17161692
17171693//  ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
0 commit comments