Skip to content

Commit 68509c8

Browse files
committed
Merge branch 'pjj-fix-gpt2-xl'
2 parents 26dbbc7 + 0c73ba6 commit 68509c8

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

train_gpt2.cu

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -809,16 +809,18 @@ __global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, i
809809
shared[idx] = 0.0f;
810810
}
811811
__syncthreads();
812-
for (int idx = blockIdx.y*block_size_y + threadIdx.y; idx < B * T; idx += gridDim.y*block_size_y) {
813-
x128 packed_dout = load128(dout + global_oc + idx*OC);
814-
for (int k = 0; k < x128::size; k++) {
815-
accumulators[k] += (float)packed_dout[k];
816-
}
817-
}
818-
// we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance
819-
// so we accumulate in a conflict-free order, then reorder to match the global memory order
820-
for (int k = 0; k < x128::size; k++) {
821-
atomicAdd(shared + threadIdx.x + (k * block_size_x), accumulators[k]);
812+
if(global_oc < OC) {
813+
for (int idx = blockIdx.y*block_size_y + threadIdx.y; idx < B * T; idx += gridDim.y*block_size_y) {
814+
x128 packed_dout = load128(dout + global_oc + idx*OC);
815+
for (int k = 0; k < x128::size; k++) {
816+
accumulators[k] += (float)packed_dout[k];
817+
}
818+
}
819+
// we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance
820+
// so we accumulate in a conflict-free order, then reorder to match the global memory order
821+
for (int k = 0; k < x128::size; k++) {
822+
atomicAdd(shared + threadIdx.x + (k * block_size_x), accumulators[k]);
823+
}
822824
}
823825
if (threadIdx.y >= x128::size) { return; } // only need this many warps to reorder the data
824826
__syncthreads();
@@ -831,7 +833,9 @@ __global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, i
831833
shared[local_oc + threadIdx.y] = tmp;
832834
__syncthreads();
833835
// now we do a perfectly coalesced atomic add to global memory (1x 128-byte cacheline per warp)
834-
atomicAdd(dbias + i + blockIdx.x*OC_per_warp, shared[i]);
836+
if (i + blockIdx.x*OC_per_warp < OC) {
837+
atomicAdd(dbias + i + blockIdx.x*OC_per_warp, shared[i]);
838+
}
835839
}
836840

837841
__global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with only 1024 threads?
@@ -1363,11 +1367,10 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,
13631367
const int OC_per_warp = warp_size * x128::size; // 256 at BF16
13641368
const int block_size_x = 32;
13651369
const int block_size_y = block_size / block_size_x; // 16
1366-
const int grid_size_x = OC / OC_per_warp; // e.g. 3 horizontal blocks for 768 OCs at BF16
1370+
const int grid_size_x = CEIL_DIV(OC, OC_per_warp); // e.g. 3 horizontal blocks for 768 OCs at BF16
13671371
const int grid_size_y = max(1, deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount
13681372
/ (block_size * grid_size_x)); // full GPU!
13691373

1370-
assert((OC % OC_per_warp) == 0); // there is no bounds checking in the kernel to maximise performance
13711374
assert(block_size_y >= x128::size); // part of the kernel assumes this is large enough to avoid loops
13721375

13731376
cudaMemsetAsync(dbias_buffer, 0, OC * sizeof(float), main_stream);

0 commit comments

Comments
 (0)