@@ -809,16 +809,18 @@ __global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, i
809
809
shared[idx] = 0 .0f ;
810
810
}
811
811
__syncthreads ();
812
- for (int idx = blockIdx .y *block_size_y + threadIdx .y ; idx < B * T; idx += gridDim .y *block_size_y) {
813
- x128 packed_dout = load128 (dout + global_oc + idx*OC);
814
- for (int k = 0 ; k < x128::size; k++) {
815
- accumulators[k] += (float )packed_dout[k];
816
- }
817
- }
818
- // we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance
819
- // so we accumulate in a conflict-free order, then reorder to match the global memory order
820
- for (int k = 0 ; k < x128::size; k++) {
821
- atomicAdd (shared + threadIdx .x + (k * block_size_x), accumulators[k]);
812
+ if (global_oc < OC) {
813
+ for (int idx = blockIdx .y *block_size_y + threadIdx .y ; idx < B * T; idx += gridDim .y *block_size_y) {
814
+ x128 packed_dout = load128 (dout + global_oc + idx*OC);
815
+ for (int k = 0 ; k < x128::size; k++) {
816
+ accumulators[k] += (float )packed_dout[k];
817
+ }
818
+ }
819
+ // we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance
820
+ // so we accumulate in a conflict-free order, then reorder to match the global memory order
821
+ for (int k = 0 ; k < x128::size; k++) {
822
+ atomicAdd (shared + threadIdx .x + (k * block_size_x), accumulators[k]);
823
+ }
822
824
}
823
825
if (threadIdx .y >= x128::size) { return ; } // only need this many warps to reorder the data
824
826
__syncthreads ();
@@ -831,7 +833,9 @@ __global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, i
831
833
shared[local_oc + threadIdx .y ] = tmp;
832
834
__syncthreads ();
833
835
// now we do a perfectly coalesced atomic add to global memory (1x 128-byte cacheline per warp)
834
- atomicAdd (dbias + i + blockIdx .x *OC_per_warp, shared[i]);
836
+ if (i + blockIdx .x *OC_per_warp < OC) {
837
+ atomicAdd (dbias + i + blockIdx .x *OC_per_warp, shared[i]);
838
+ }
835
839
}
836
840
837
841
__global__ void __launch_bounds__ (512 , 3 ) // todo - any warnings on Turing with only 1024 threads?
@@ -1363,11 +1367,10 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,
1363
1367
const int OC_per_warp = warp_size * x128::size; // 256 at BF16
1364
1368
const int block_size_x = 32 ;
1365
1369
const int block_size_y = block_size / block_size_x; // 16
1366
- const int grid_size_x = OC / OC_per_warp; // e.g. 3 horizontal blocks for 768 OCs at BF16
1370
+ const int grid_size_x = CEIL_DIV (OC, OC_per_warp) ; // e.g. 3 horizontal blocks for 768 OCs at BF16
1367
1371
const int grid_size_y = max (1 , deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount
1368
1372
/ (block_size * grid_size_x)); // full GPU!
1369
1373
1370
- assert ((OC % OC_per_warp) == 0 ); // there is no bounds checking in the kernel to maximise performance
1371
1374
assert (block_size_y >= x128::size); // part of the kernel assumes this is large enough to avoid loops
1372
1375
1373
1376
cudaMemsetAsync (dbias_buffer, 0 , OC * sizeof (float ), main_stream);
0 commit comments