diff --git a/dev/cuda/layernorm_backward.cu b/dev/cuda/layernorm_backward.cu index 90dcb1674..d9502880b 100644 --- a/dev/cuda/layernorm_backward.cu +++ b/dev/cuda/layernorm_backward.cu @@ -856,6 +856,185 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) } } +__global__ void layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const floatX* mean, const floatX* rstd, + int B, int T, int C) { + constexpr int WARP_SIZE = 32; + int BLOCK_SIZE = blockDim.x; + int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block + extern __shared__ float shared[]; // size = 2 * C + 1 + + int warpId = threadIdx.x / WARP_SIZE; // warp index within a block + int baseIdx = blockIdx.x * warpsInBlock + warpId; + int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp + int warpsInGrid = gridDim.x * warpsInBlock; + int C_per_iteration = WARP_SIZE * x128::size; + int iterations_C = ceil_div(C, C_per_iteration) + 2; + + // the first half of shared memory is bias, second is weight + float* dbias_shared = shared; + float* dweight_shared = shared + C; + float* dbias_tmp_shared = shared + 2 * C; + float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE; + + // init shared memory to zero + for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){ + dbias_shared[i] = 0.0f; + dweight_shared[i] = 0.0f; + } + unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE); + __syncthreads(); + + for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { + int b = idx / T; + int t = idx % T; + + const floatX* dout_bt = dout + b * T * C + t * C; + const floatX* inp_bt = inp + b * T * C + t * C; + floatX* dinp_bt = dinp + b * T * C + t * C; + const float mean_bt = (float)mean[b * T + t]; + const float rstd_bt = (float)rstd[b * T + t]; + + // first: two reduce operations + float dnorm_mean = 0.0f; + float dnorm_norm_mean = 0.0f; + for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { + x128 dout128_i = load128(dout_bt + i); + x128 inp128_i = load128(inp_bt + i); + x128 weight128_i = load128(weight + i); + for (int k = 0; k < x128::size; k++) { + float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + dnorm_mean += dnorm_i; + dnorm_norm_mean += dnorm_i * norm_bti; + } + } + dnorm_mean = warpReduceSum(dnorm_mean) / C; + dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C; + + // now iterate again and accumulate all the gradients + // unfortunately we cannot use the same index for x128 arrays and shared memory + // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper) + // so this would result in an 8-way bank conflict, and kill performance + // so instead, we use a shared memory friendly index, and reorder before the final write + for (int i = 0; i < iterations_C; i++) { + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + + x128 dout128 = load128cs(dout_bt + global_index); + x128 inp128 = load128cs(inp_bt + global_index); + x128 dinp128 = load128(dinp_bt + global_index); + x128 weight128 = load128(weight + global_index); + + for (int x = 0; x < x128::size; x++) { + float dout_i = (float)dout128[x]; + float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128[x] * dout_i; + + // sum up the gradients for bias and weight across the entire block + // this is basically a reduction (but only inter-warp, not intra-warp) + // doing it this way allows us to avoid using atomics while using many warps + if (warpId != 0) { + dbias_tmp_shared[threadIdx.x] = dout_i; + dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i; + } + __syncthreads(); + if (warpId == 0) { + float dbias_tmp = dout_i; + float dweight_tmp = norm_bti * dout_i; + for (int j = 1; j < warpsInBlock; j++) { + dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE]; + dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE]; + } + // gradient contribution to bias (using shared memory friendly index) + dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp; + // gradient contribution to weight (using shared memory friendly index) + dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp; + } + __syncthreads(); + + // gradient contribution to input + float dval = 0.0f; + dval += dnorm_i; // term 1 + dval -= dnorm_mean; // term 2 + dval -= norm_bti * dnorm_norm_mean; // term 3 + dval *= rstd_bt; // final scale + dinp128[x] = (floatX)((float)dinp128[x] + dval); + } + // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing + store128cg(dinp_bt + global_index, dinp128); + } + } + __syncthreads(); + // Each block writes its partial sum to global memory + // The last block to finish becomes responsible for summing up all the partial sums + // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + unsigned int* scratchFlag = (unsigned int*)(scratch); + // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned + scratch += 32; + float* scratch_dbias = scratch; + float* scratch_dweight = scratch + C; + for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) { + // Write to global memory in the same "shared memory banking friendly" order + scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i]; + scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i]; + } + __syncthreads(); + if (threadIdx.x == 0) { + *tmp_flag = atomicInc(scratchFlag, gridDim.x); + } + __syncthreads(); + if (*tmp_flag == gridDim.x-1) { + // Reduction of the partial sums by the final block + // todo - there isn't enough parallelism even inside that single SM... + // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! + for(int i = threadIdx.x * f128::size; i < C; i+= BLOCK_SIZE * f128::size) { + f128 dbias_accum(make_int4(0, 0, 0, 0)); + f128 dweight_accum(make_int4(0, 0, 0, 0)); + + for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { + int offset = i + 2*C*read_block_idx; + f128 dbias128 = load128(scratch_dbias + offset); + f128 dweight128 = load128(scratch_dweight + offset); + for(int k = 0; k < f128::size; k++) { + dbias_accum[k] += dbias128[k]; + dweight_accum[k] += dweight128[k]; + } + } + store128(dbias_shared + i, dbias_accum); + store128(dweight_shared + i, dweight_accum); + } + __syncthreads(); + + // reorder from atomic/shared memory-friendly index to real global memory index + // and convert from float/FP32 to floatX/BF16 for the final write + // this is separate also because it cannot use as many warps as the above (f128 vs x128) + // todo - if we split this code into another kernel, we could maybe do it at the same time? + for (int i = warpId; i < iterations_C; i += warpsInBlock) { + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + + x128 dbias128 = load128(dbias + global_index); + x128 dweight128 = load128(dweight + global_index); + for (int x = 0; x < x128::size; x++) { + float s_db = dbias_shared[shared_index + x*WARP_SIZE]; + float s_dw = dweight_shared[shared_index + x*WARP_SIZE]; + dbias128[x] = (floatX)(s_db + (float)dbias128[x]); + dweight128[x] = (floatX)(s_dw + (float)dweight128[x]); + } + store128(dbias + global_index, dbias128); + store128(dweight + global_index, dweight128); + } + } +} + // ---------------------------------------------------------------------------- // kernel launchers @@ -947,6 +1126,18 @@ void layernorm_backward8(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* s layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } +template +void layernorm_backward9(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, + const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, + int B, int T, int C, int block_size) { + + const int grid_size = (1024/block_size) * cuda_num_SMs; // todo - heuristics for other GPUs? + size_t shared_mem_size = (2 * C + 2 * block_size + 1) * sizeof(float); + + cudaMemset(scratch, 0, 1 * sizeof(float)); // just need to memset the flag for this version + layernorm_backward_kernel9<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); +} + // kernel version dispatch void layernorm_backward(int kernel_num, floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, @@ -982,6 +1173,9 @@ void layernorm_backward(int kernel_num, case 8: layernorm_backward8(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; + case 9: + layernorm_backward9(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -1042,7 +1236,7 @@ int main(int argc, char **argv) { cudaCheck(cudaMalloc(&d_weight, C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(floatX))); cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(floatX))); - cudaCheck(cudaMalloc(&d_scratch, cuda_num_SMs * (2 * C + 1) * sizeof(float))); + cudaCheck(cudaMalloc(&d_scratch, (1024/32) * cuda_num_SMs * (2 * C + 1) * sizeof(float))); // copy over the "inputs" to the backward call cudaCheck(memcpy_convert(d_dout, dout, B * T * C)); cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); @@ -1051,7 +1245,8 @@ int main(int argc, char **argv) { cudaCheck(memcpy_convert(d_rstd, rstd, B * T)); // launch the kernel - int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024}; + // removed 768 because it doesn't work for kernel9 despite being OK in train_gpt2.cu?! + int block_sizes[] = {32, 64, 128, 256, 512, /*768,*/ 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; // init the "outputs" of the backward call to zeros diff --git a/train_gpt2.cu b/train_gpt2.cu index 77f2e6eb4..936bfa8fd 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -985,30 +985,32 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s } } -__global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with only 1024 threads? - layernorm_backward_kernel8(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, - const floatX* mean, const floatX* rstd, - int B, int T, int C) { - extern __shared__ float shared[]; // size = 2 * C + 1 - int warpId = threadIdx.x / WARP_SIZE; // warp index within a block +__global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? + layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const floatX* mean, const floatX* rstd, + int B, int T, int C) { + extern __shared__ float shared[]; // size = 2*C + 2*block_size + 1 int warpsInBlock = blockDim.x / WARP_SIZE; //number of warps in block + int warpId = threadIdx.x / WARP_SIZE; // warp index within a block int baseIdx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp int warpsInGrid = gridDim.x * warpsInBlock; int C_per_iteration = WARP_SIZE * x128::size; - int iterations_C = C / C_per_iteration; + int iterations_C = CEIL_DIV(C, C_per_iteration); // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; + float* dbias_tmp_shared = shared + 2 * C; + float* dweight_tmp_shared = shared + 2 * C + blockDim.x; // init shared memory to zero for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } - unsigned int *tmp_flag = (unsigned int*)(shared + C*2); + unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*blockDim.x); __syncthreads(); for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { @@ -1046,6 +1048,10 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with for (int i = 0; i < iterations_C; i++) { int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + x128 dout128 = load128cs(dout_bt + global_index); x128 inp128 = load128cs(inp_bt + global_index); x128 dinp128 = load128(dinp_bt + global_index); @@ -1055,10 +1061,29 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with float dout_i = (float)dout128[x]; float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; float dnorm_i = (float)weight128[x] * dout_i; - // gradient contribution to bias (using shared memory friendly index) - atomicAdd(&dbias_shared[shared_index + x*WARP_SIZE], dout_i); - // gradient contribution to weight (using shared memory friendly index) - atomicAdd(&dweight_shared[shared_index + x*WARP_SIZE], norm_bti * dout_i); + + // sum up the gradients for bias and weight across the entire block + // this is basically a reduction (but only inter-warp, not intra-warp) + // doing it this way allows us to avoid using atomics while using many warps + if (warpId != 0) { + dbias_tmp_shared[threadIdx.x] = dout_i; + dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i; + } + __syncthreads(); + if (warpId == 0) { + float dbias_tmp = dout_i; + float dweight_tmp = norm_bti * dout_i; + for (int j = 1; j < warpsInBlock; j++) { + dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE]; + dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE]; + } + // gradient contribution to bias (using shared memory friendly index) + dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp; + // gradient contribution to weight (using shared memory friendly index) + dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp; + } + __syncthreads(); + // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 @@ -1071,35 +1096,64 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with store128cg(dinp_bt + global_index, dinp128); } } - // Accumulate into a FP32 scratchpad - // BF16 atomics are potentially much slower... and this is more precise! - // todo - could potentially avoid the extra copy if floatX is FP32, fairly negligible though __syncthreads(); + // Each block writes its partial sum to global memory + // The last block to finish becomes responsible for summing up all the partial sums + // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + unsigned int* scratchFlag = (unsigned int*)(scratch); + // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned + scratch += 32; float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; - unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C)); for(int i = threadIdx.x; i < C; i+= blockDim.x) { - // global atomics in the same "shared memory banking friendly" order - atomicAdd(&scratch_dbias[i], dbias_shared[i]); - atomicAdd(&scratch_dweight[i], dweight_shared[i]); + // Write to global memory in the same "shared memory banking friendly" order + scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i]; + scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i]; } + + // todo - everything below could become a separate kernel for better performance with maybe less code + // not enough parallelism even inside that single SM... do we need another level of reduction?! __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicInc(scratchFlag, gridDim.x); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { + // Reduction of the partial sums by the final block + for(int i = threadIdx.x * f128::size; i < C; i+= blockDim.x * f128::size) { + f128 dbias_accum(make_int4(0, 0, 0, 0)); + f128 dweight_accum(make_int4(0, 0, 0, 0)); + + for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { + int offset = i + 2*C*read_block_idx; + f128 dbias128 = load128(scratch_dbias + offset); + f128 dweight128 = load128(scratch_dweight + offset); + for(int k = 0; k < f128::size; k++) { + dbias_accum[k] += dbias128[k]; + dweight_accum[k] += dweight128[k]; + } + } + store128(dbias_shared + i, dbias_accum); + store128(dweight_shared + i, dweight_accum); + } + __syncthreads(); + + // reorder from atomic/shared memory-friendly index to real global memory index + // and convert from float/FP32 to floatX/BF16 for the final write + // this is separate also because it cannot use as many warps as the above (f128 vs x128) + // todo - if we split this code into another kernel, we could maybe do it at the same time? for (int i = warpId; i < iterations_C; i += warpsInBlock) { - // reorder from atomic/shared memory-friendly index to real global memory index - // and convert from float/FP32 to floatX/BF16 for the final write int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } x128 dbias128 = load128(dbias + global_index); x128 dweight128 = load128(dweight + global_index); for (int x = 0; x < x128::size; x++) { - float s_db = scratch_dbias[shared_index + x*WARP_SIZE]; - float s_dw = scratch_dweight[shared_index + x*WARP_SIZE]; + float s_db = dbias_shared[shared_index + x*WARP_SIZE]; + float s_dw = dweight_shared[shared_index + x*WARP_SIZE]; dbias128[x] = (floatX)(s_db + (float)dbias128[x]); dweight128[x] = (floatX)(s_dw + (float)dweight128[x]); } @@ -1611,15 +1665,13 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { NVTX_RANGE_FN(); - // todo - forcing 3 x 512 threads per SM maximum is a bit hacky, but more than that results in - // cache thrashing and lower performance on A100... is there a better way? const int block_size = 512; - const int blocks_per_sm = min(3, (deviceProp.maxThreadsPerMultiProcessor / 1024)); + const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3 const int grid_size = blocks_per_sm * deviceProp.multiProcessorCount; - size_t shared_mem_size = (2 * C + 1) * sizeof(float); + size_t shared_mem_size = (2*C + 2*block_size + 1) * sizeof(float); // see kernel - cudaMemset(scratch, 0, (2 * C + 1) * sizeof(float)); - layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + cudaMemset(scratch, 0, 1 * sizeof(float)); // only need to reset the flag to 0 + layernorm_backward_kernel9<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); cudaCheck(cudaGetLastError()); }