Skip to content

Commit

Permalink
Merge branch 'ademeure-deterministic_layernorm'
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 24, 2024
2 parents bf03e7f + 25f17e6 commit 5ddb061
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 32 deletions.
199 changes: 197 additions & 2 deletions dev/cuda/layernorm_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,185 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
}
}

__global__ void layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
const floatX* dout, const floatX* inp, const floatX* weight,
const floatX* mean, const floatX* rstd,
int B, int T, int C) {
constexpr int WARP_SIZE = 32;
int BLOCK_SIZE = blockDim.x;
int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block
extern __shared__ float shared[]; // size = 2 * C + 1

int warpId = threadIdx.x / WARP_SIZE; // warp index within a block
int baseIdx = blockIdx.x * warpsInBlock + warpId;
int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp
int warpsInGrid = gridDim.x * warpsInBlock;
int C_per_iteration = WARP_SIZE * x128::size;
int iterations_C = ceil_div(C, C_per_iteration) + 2;

// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;
float* dbias_tmp_shared = shared + 2 * C;
float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE;

// init shared memory to zero
for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE);
__syncthreads();

for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) {
int b = idx / T;
int t = idx % T;

const floatX* dout_bt = dout + b * T * C + t * C;
const floatX* inp_bt = inp + b * T * C + t * C;
floatX* dinp_bt = dinp + b * T * C + t * C;
const float mean_bt = (float)mean[b * T + t];
const float rstd_bt = (float)rstd[b * T + t];

// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) {
x128 dout128_i = load128(dout_bt + i);
x128 inp128_i = load128(inp_bt + i);
x128 weight128_i = load128(weight + i);
for (int k = 0; k < x128::size; k++) {
float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
}
dnorm_mean = warpReduceSum(dnorm_mean) / C;
dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C;

// now iterate again and accumulate all the gradients
// unfortunately we cannot use the same index for x128 arrays and shared memory
// as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper)
// so this would result in an 8-way bank conflict, and kill performance
// so instead, we use a shared memory friendly index, and reorder before the final write
for (int i = 0; i < iterations_C; i++) {
int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);
int shared_index = warpThreadIdx + (i * C_per_iteration);
if (global_index >= C) {
break;
}

x128 dout128 = load128cs(dout_bt + global_index);
x128 inp128 = load128cs(inp_bt + global_index);
x128 dinp128 = load128(dinp_bt + global_index);
x128 weight128 = load128(weight + global_index);

for (int x = 0; x < x128::size; x++) {
float dout_i = (float)dout128[x];
float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt;
float dnorm_i = (float)weight128[x] * dout_i;

// sum up the gradients for bias and weight across the entire block
// this is basically a reduction (but only inter-warp, not intra-warp)
// doing it this way allows us to avoid using atomics while using many warps
if (warpId != 0) {
dbias_tmp_shared[threadIdx.x] = dout_i;
dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i;
}
__syncthreads();
if (warpId == 0) {
float dbias_tmp = dout_i;
float dweight_tmp = norm_bti * dout_i;
for (int j = 1; j < warpsInBlock; j++) {
dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE];
dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE];
}
// gradient contribution to bias (using shared memory friendly index)
dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp;
// gradient contribution to weight (using shared memory friendly index)
dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp;
}
__syncthreads();

// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp128[x] = (floatX)((float)dinp128[x] + dval);
}
// cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing
store128cg(dinp_bt + global_index, dinp128);
}
}
__syncthreads();
// Each block writes its partial sum to global memory
// The last block to finish becomes responsible for summing up all the partial sums
// This is done by atomically incrementing a flag (cleared to 0 before launching the kernel)
unsigned int* scratchFlag = (unsigned int*)(scratch);
// Increment scratch pointer by a full cacheline so that everything remains cacheline aligned
scratch += 32;
float* scratch_dbias = scratch;
float* scratch_dweight = scratch + C;
for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) {
// Write to global memory in the same "shared memory banking friendly" order
scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i];
scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i];
}
__syncthreads();
if (threadIdx.x == 0) {
*tmp_flag = atomicInc(scratchFlag, gridDim.x);
}
__syncthreads();
if (*tmp_flag == gridDim.x-1) {
// Reduction of the partial sums by the final block
// todo - there isn't enough parallelism even inside that single SM...
// ==> so could maybe split into another kernel with YET ANOTHER level of reduction?!
for(int i = threadIdx.x * f128::size; i < C; i+= BLOCK_SIZE * f128::size) {
f128 dbias_accum(make_int4(0, 0, 0, 0));
f128 dweight_accum(make_int4(0, 0, 0, 0));

for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) {
int offset = i + 2*C*read_block_idx;
f128 dbias128 = load128(scratch_dbias + offset);
f128 dweight128 = load128(scratch_dweight + offset);
for(int k = 0; k < f128::size; k++) {
dbias_accum[k] += dbias128[k];
dweight_accum[k] += dweight128[k];
}
}
store128(dbias_shared + i, dbias_accum);
store128(dweight_shared + i, dweight_accum);
}
__syncthreads();

// reorder from atomic/shared memory-friendly index to real global memory index
// and convert from float/FP32 to floatX/BF16 for the final write
// this is separate also because it cannot use as many warps as the above (f128 vs x128)
// todo - if we split this code into another kernel, we could maybe do it at the same time?
for (int i = warpId; i < iterations_C; i += warpsInBlock) {
int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration);
int shared_index = warpThreadIdx + (i * C_per_iteration);
if (global_index >= C) {
break;
}

x128 dbias128 = load128(dbias + global_index);
x128 dweight128 = load128(dweight + global_index);
for (int x = 0; x < x128::size; x++) {
float s_db = dbias_shared[shared_index + x*WARP_SIZE];
float s_dw = dweight_shared[shared_index + x*WARP_SIZE];
dbias128[x] = (floatX)(s_db + (float)dbias128[x]);
dweight128[x] = (floatX)(s_dw + (float)dweight128[x]);
}
store128(dbias + global_index, dbias128);
store128(dweight + global_index, dweight128);
}
}
}

// ----------------------------------------------------------------------------
// kernel launchers

Expand Down Expand Up @@ -947,6 +1126,18 @@ void layernorm_backward8(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* s
layernorm_backward_kernel8<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);
}

template <typename Tdinp, typename Tparams, typename Tdout, typename Trest>
void layernorm_backward9(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch,
const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd,
int B, int T, int C, int block_size) {

const int grid_size = (1024/block_size) * cuda_num_SMs; // todo - heuristics for other GPUs?
size_t shared_mem_size = (2 * C + 2 * block_size + 1) * sizeof(float);

cudaMemset(scratch, 0, 1 * sizeof(float)); // just need to memset the flag for this version
layernorm_backward_kernel9<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C);
}

// kernel version dispatch
void layernorm_backward(int kernel_num,
floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
Expand Down Expand Up @@ -982,6 +1173,9 @@ void layernorm_backward(int kernel_num,
case 8:
layernorm_backward8(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size);
break;
case 9:
layernorm_backward9(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down Expand Up @@ -1042,7 +1236,7 @@ int main(int argc, char **argv) {
cudaCheck(cudaMalloc(&d_weight, C * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_scratch, cuda_num_SMs * (2 * C + 1) * sizeof(float)));
cudaCheck(cudaMalloc(&d_scratch, (1024/32) * cuda_num_SMs * (2 * C + 1) * sizeof(float)));
// copy over the "inputs" to the backward call
cudaCheck(memcpy_convert(d_dout, dout, B * T * C));
cudaCheck(memcpy_convert(d_inp, inp, B * T * C));
Expand All @@ -1051,7 +1245,8 @@ int main(int argc, char **argv) {
cudaCheck(memcpy_convert(d_rstd, rstd, B * T));

// launch the kernel
int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024};
// removed 768 because it doesn't work for kernel9 despite being OK in train_gpt2.cu?!
int block_sizes[] = {32, 64, 128, 256, 512, /*768,*/ 1024};
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
// init the "outputs" of the backward call to zeros
Expand Down
Loading

0 comments on commit 5ddb061

Please sign in to comment.