@@ -1201,34 +1201,23 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
1201
1201
}
1202
1202
1203
1203
template <class T >
1204
- __global__ void norm_kernel (float * out, const T* data, size_t count) {
1204
+ __global__ void global_norm_kernel (float * out, const T* data, size_t count) {
1205
1205
// we want as few atomics as possible, so each block tries to do
1206
1206
// the maximum amount of work (so no fixed chunk, but instead iterating
1207
1207
// until we run out of data), and then we reduce inside the block
1208
1208
// and finally have just one atomic per block.
1209
- namespace cg = cooperative_groups;
1210
- cg::thread_block block = cg::this_thread_block ();
1211
- cg::thread_block_tile<32 > warp = cg::tiled_partition<32 >(block);
1212
-
1213
- __shared__ float block_result[32 ];
1214
-
1215
- // out will be updated atomically from all thread blocks
1209
+ // out will be updated atomically from all thread blocks. It is a float, so the
1210
+ // atomic op is unproblematic
1216
1211
size_t index = threadIdx .x + blockDim .x * blockIdx .x ;
1217
1212
size_t grid_width = blockDim .x * gridDim .x ;
1218
1213
float accumulator = 0 .f ;
1219
1214
for (size_t i = index; i < count; i += grid_width) {
1220
1215
accumulator += (float )data[i] * (float )data[i];
1221
1216
}
1222
1217
// warp-level reduce
1223
- float warp_result = cg::reduce (warp, accumulator, cg::plus<float >{});
1224
- block_result[warp.meta_group_rank ()] = warp_result;
1225
- block.sync ();
1226
- if (warp.meta_group_rank () == 0 ) {
1227
- float gather = warp.thread_rank () < warp.meta_group_size () ? block_result[warp.thread_rank ()] : 0 .f ;
1228
- float block_sum = cg::reduce (warp, gather, cg::plus<float >{});
1229
- if (warp.thread_rank () == 0 ) {
1230
- atomicAdd (out, block_sum);
1231
- }
1218
+ float block_sum = blockReduce<warpReduceSum>(accumulator);
1219
+ if (threadIdx .x == 0 ) {
1220
+ atomicAdd (out, block_sum);
1232
1221
}
1233
1222
}
1234
1223
@@ -1716,9 +1705,9 @@ void global_norm(float* out, const T* values, size_t count) {
1716
1705
// one block too many is catastrophic, since it only can start once all the other
1717
1706
// blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512
1718
1707
// on all gpus, so the division really is going to be exact.
1719
- const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size;
1708
+ const int grid_size = deviceProp. maxThreadsPerMultiProcessor * deviceProp. multiProcessorCount / block_size;
1720
1709
assert (grid_size > 0 ); // gives a better error than letting the call below fail
1721
- norm_kernel <<<grid_size, block_size>>> (out, values, count);
1710
+ global_norm_kernel <<<grid_size, block_size>>> (out, values, count);
1722
1711
cudaCheck (cudaGetLastError ());
1723
1712
}
1724
1713
0 commit comments