@@ -1663,10 +1663,16 @@ kernel void kernel_ssm_conv_f32(
16631663 device const void * src0,
16641664 device const void * src1,
16651665 device float * dst,
1666+ threadgroup float * shared [[threadgroup(0 )]],
16661667 constant ggml_metal_kargs_ssm_conv & args,
1667- uint3 tgpig[[threadgroup_position_in_grid]],
1668- uint3 tpitg[[thread_position_in_threadgroup]],
1669- uint3 ntg[[threads_per_threadgroup]]) {
1668+ uint3 tgpig[[threadgroup_position_in_grid]],
1669+ uint3 tpitg[[thread_position_in_threadgroup]],
1670+ ushort sgitg[[simdgroup_index_in_threadgroup]],
1671+ ushort tiisg[[thread_index_in_simdgroup]],
1672+ ushort sgptg[[simdgroups_per_threadgroup]],
1673+ uint3 tgpg[[threadgroups_per_grid]]) {
1674+
1675+ const int64_t i0 = tpitg.x ;
16701676 const int64_t ir = tgpig.x ;
16711677 const int64_t i2 = tgpig.y ;
16721678 const int64_t i3 = tgpig.z ;
@@ -1681,13 +1687,31 @@ kernel void kernel_ssm_conv_f32(
16811687 device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11 );
16821688 device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2 );
16831689
1684- float sumf = 0 . 0f ;
1690+ float sumf = s[i0] * c[i0] ;
16851691
1686- for ( int64_t i0 = 0 ; i0 < nc; ++i0) {
1687- sumf += s[i0] * c[i0];
1688- }
1692+ // Parallel sum: first sum over threads in simd group, then sum over simd
1693+ // group sums
1694+ sumf = simd_sum (sumf);
16891695
1690- x[0 ] = sumf;
1696+ // If multiple simd groups per threadgroup, sum over simd group sums
1697+ if (sgptg > 1 ) {
1698+ if (tiisg == 0 ) {
1699+ shared[sgitg] = sumf;
1700+ }
1701+ threadgroup_barrier (mem_flags::mem_threadgroup);
1702+ sumf = 0 .0f ;
1703+ if (sgitg == 0 ) {
1704+ if (tiisg < sgptg) {
1705+ sumf = shared[tiisg];
1706+ }
1707+ sumf = simd_sum (sumf);
1708+ if (tiisg == 0 ) {
1709+ x[0 ] = sumf;
1710+ }
1711+ }
1712+ } else if (tiisg == 0 ) {
1713+ x[0 ] = sumf;
1714+ }
16911715}
16921716
16931717// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
0 commit comments