From b5e75dde8e8d20b13177b060f3ed364bbe50eb12 Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 21 May 2024 15:57:07 +0100 Subject: [PATCH] Fully deterministic encoder backward kernels for train_gpt2.cu --- profile_gpt2.cu | 2 +- test_gpt2.cu | 2 +- train_gpt2.cu | 245 ++++++++++++++++++++++++++++++++++++------------ 3 files changed, 189 insertions(+), 60 deletions(-) diff --git a/profile_gpt2.cu b/profile_gpt2.cu index f2ac0e84c..f79e9ada4 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -54,7 +54,7 @@ int main(int argc, char *argv[]) { // do a training step gpt2_forward(&model, x, y, B, T); gpt2_zero_grad(&model); - gpt2_backward(&model); + gpt2_backward(&model, x); gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings diff --git a/test_gpt2.cu b/test_gpt2.cu index 50a291f18..d06734507 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -203,7 +203,7 @@ int main(int argc, char *argv[]) { clock_gettime(CLOCK_MONOTONIC, &start); gpt2_forward(&model, x, y, B, T); gpt2_zero_grad(&model); - gpt2_backward(&model); + gpt2_backward(&model, x); clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; diff --git a/train_gpt2.cu b/train_gpt2.cu index 1e8b54be2..899293f75 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -38,6 +38,9 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200), #include #include #include +#include +#include +#include // GPU / CUDA related #include #include @@ -532,50 +535,108 @@ __global__ void encoder_forward_kernel3(floatX* out, store128(out_btc, packed_out); } -template -__device__ void atomicStochasticAdd(T* address, float val0, float val1, unsigned int seed) { - static_assert(sizeof(T) == 2, "Only 16-bit atomicStochasticAdd supported."); - float2 val = make_float2(val0, val1); - unsigned int* address_as_uint = (unsigned int*)address; - unsigned int old = *address_as_uint, assumed; - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); - do { - assumed = old; - float2 new_fp32 = make_float2((float)(reinterpret_cast(&old)[0]) + val.x, - (float)(reinterpret_cast(&old)[1]) + val.y); - T new_rounded[2]; - stochastic_rounding(new_fp32.x, &new_rounded[0], random); - stochastic_rounding(new_fp32.y, &new_rounded[1], random >> 16); - old = atomicCAS(address_as_uint, assumed, *(unsigned int*)&new_rounded); - } while (assumed != old); -} -__device__ void atomicStochasticAdd(float* address, float val0, float val1, unsigned int seed) { - atomicAdd(address, val0); - atomicAdd(address + 1, val1); -} - -__global__ void encoder_backward_kernel(floatX* dwte, floatX* dwpe, - const floatX* dout, const int* inp, - int B, int T, int C, unsigned int seed) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int N = B * T * C; - idx *= 2; // 2 elements per thread - if (idx >= N) { return; } +template +__global__ void wte_backward_kernel(floatX* dwte, + const int4* bucket_info, const int* workload_indices, const floatX* dout, const int* inp, + unsigned int seed, int B, int T, int C) { + // In order to be deterministic, we preprocess the inputs on the cpu into "buckets" + // Each bucket corresponds to (WARP_SIZE * x128::size) channels for a single vocabulary token + // Each thread handles x128::size channels, e.g. 256 per warp for BF16 + // Each block handles (BLOCK_SIZE / WARP_SIZE) elements in a single bucket in parallel + // If a bucket has less than 8 elements, some warps will return immediately + // If a bucket has more than 8 elements, we will loop over all of them + // The buckets are sorted on the CPU so the largest buckets start 1st + int bucket = blockIdx.x; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int c_per_warp = WARP_SIZE * x128::size; + + int bucket_start_idx = bucket_info[bucket].x; + int bucket_size = bucket_info[bucket].y; + int bucket_ix = bucket_info[bucket].z; + int c = bucket_info[bucket].w * c_per_warp + (lane_id * x128::size); + + // Each thread handles "x128::size" channels, so at fp8, each warp would handle 512 channels + // If C is not a multiple of this (e.g. 768), some buckets/c_groups cannot use the entire warp + if (c >= C) { return; } + // Exit early if this is a small bucket and this warp doesn't have any items to process + if (warp_id >= bucket_size) { return; } + + float accum[x128::size] = {0.0f}; + __shared__ float accum_shared[x128::size * BLOCK_SIZE]; + + for(int item = warp_id; item < bucket_size; item += BLOCK_SIZE/WARP_SIZE) { + int bt = workload_indices[bucket_start_idx + item]; + int b = bt / T; + int t = bt % T; + + const floatX* dout_btc = dout + b * T * C + t * C + c; + x128 packed_inp1 = load128cs(dout_btc); + for (int k = 0; k < packed_inp1.size; k++) { + accum[k] += (float)packed_inp1[k]; + } + } - int bt = idx / C; - int b = bt / T; - int t = bt % T; - int c = idx % C; + if (warp_id != 0) { + // we accumulate into warp 0, so only the other warps need to write to shared memory + for (int k = 0; k < x128::size; k++) { + accum_shared[threadIdx.x + k * BLOCK_SIZE] = accum[k]; + } + return; // only warp 0 is needed after writing to shared memory + } - int ix = inp[b * T + t]; + // Read dwte for warp 0 even if other warps are not finished yet to maximise latency tolerance + floatX* dwte_ix = dwte + bucket_ix * C + c; + x128 packed_in_out = load128(dwte_ix); - const floatX* dout_btc = dout + b * T * C + t * C + c; - floatX* dwte_ix = dwte + ix * C + c; - floatX* dwpe_tc = dwpe + t * C + c; + // note: threads which have returned are considered synchronised by CUDA so no risk of deadlock + __syncthreads(); - float2 dout_data = make_float2(dout_btc[0], dout_btc[1]); - atomicStochasticAdd(dwte_ix, dout_data.x, dout_data.y, seed); - atomicStochasticAdd(dwpe_tc, dout_data.x, dout_data.y, seed ^ 0xFFFFFFFF); + // Accumulate into warp 0's registers by reading the values of the other warps in shared memory + for (int i = threadIdx.x+WARP_SIZE; i < min(BLOCK_SIZE, bucket_size*WARP_SIZE); i += WARP_SIZE) { + for (int k = 0; k < x128::size; k++) { + accum[k] += accum_shared[i + k * BLOCK_SIZE]; + } + } + + // Add the result to dwte and write back to global memory (read-modify-write) + for (unsigned int k = 0; k < x128::size; k++) { + // We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic + stochastic_rounding(accum[k] + (float)packed_in_out[k], &packed_in_out[k], seed + k); + } + store128(dwte_ix, packed_in_out); +} + +__global__ void wpe_backward_kernel(floatX* dwpe, + const floatX* dout, const int* inp, + int B, int T, int C, unsigned int seed) { + // Each thread handles x128::size "channel positions", e.g. 256 per warp for BF16 + // For gpt2-124M BF16, C=768 and T=1024, so 3 warps per channel and 3072 warps in total + // For each "channel position" we sum the gradients for every batch at that C/T element + // This way each dwte element is only updated once, and the kernel is fully deterministic! + // The previous kernel was not deterministic, as batches were aggregated with atomicAdd + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + if (idx >= T * C) { return; } + + // if C is not a multiple of WARP_SIZE*x128::size, it's OK for some warps to handle multiple t + int t = idx / C; + int c = idx % C; + float accum[x128::size] = {0.0f}; + + for (int b = 0; b < B; b++) { + x128 packed_dout = load128cs(dout + (b * T * C) + (t * C) + c); // will never be read again + for (int k = 0; k < x128::size; k++) { + accum[k] += (float)packed_dout[k]; + } + } + + floatX* dwpe_tc = dwpe + (t * C) + c; + x128 packed_dwpe = load128(dwpe_tc); + for (unsigned int k = 0; k < x128::size; k++) { + // We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic + stochastic_rounding(accum[k] + (float)packed_dwpe[k], &packed_dwpe[k], seed + k); + } + store128(dwpe_tc, packed_dwpe); } __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, @@ -783,10 +844,9 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons // directly autoregressive, so we only compute the lower triangular part // uses the online softmax algorithm assert(T % 4 == 0); - const int warp_size = 32; - int lane_id = threadIdx.x % warp_size; - int warp_id = threadIdx.x / warp_size; - int num_warps = blockDim.x / warp_size; + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = blockDim.x / WARP_SIZE; // micro-optimization: we iterate backwards so that // after the softmax backward operation completes, the cache retains the @@ -809,7 +869,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons float sumval = 0.0f; const floatX* x_aligned = reinterpret_cast(__builtin_assume_aligned(x, 16)); - for (int i = lane_id; i < pos_by_4; i += warp_size) { + for (int i = lane_id; i < pos_by_4; i += WARP_SIZE) { float regarray[4]; for (int k = 0; k < 4; ++k) { regarray[k] = (float)x_aligned[4*i + k]; @@ -838,7 +898,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons float norm = 1.f / sum; // divide the whole row by the sum - for (int i = lane_id; i <= own_pos; i += warp_size) { + for (int i = lane_id; i <= own_pos; i += WARP_SIZE) { // recalculation is faster than doing the round-trip through memory. float ev = expf(inv_temperature * ((float)__ldcs(x + i) - global_maxval)); __stcs(out + idx * T + i, (floatX)(ev * norm)); @@ -1354,14 +1414,70 @@ void encoder_forward(floatX* out, cudaCheck(cudaGetLastError()); } -void encoder_backward(floatX* dwte, floatX* dwpe, - const floatX* dout, const int* inp, - int B, int T, int C, unsigned int seed) { +// Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details) +void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch + int* workload_indices, int4* bucket_info, // cpu scratch buffers + const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs + int B, int T, int C, unsigned int seed) { NVTX_RANGE_FN(); - const int N = B * T * C; + + // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) const int block_size = 256; - const int grid_size = CEIL_DIV(N, block_size * 2); // each thread handles 2 elements - encoder_backward_kernel<<>>(dwte, dwpe, dout, inp, B, T, C, seed); + const int N = T * C / x128::size; + const int grid_size = CEIL_DIV(N, block_size); + wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); + + // check the GPU scratch buffer is large enough to hold the bucket info and workload indices + // todo - this is trivially true given hardcoded scratch buffer size here, is this useful? + int num_c_groups = CEIL_DIV(C, x128::size * WARP_SIZE); + assert(B*T*num_c_groups * (sizeof(int4)+sizeof(int)) <= B*T*3*C * sizeof(floatX)); + + // Step 1: Sort inputs into buckets + int total_items = 0; + std::unordered_map> buckets; + for (uint64_t bt = 0; bt < B * T; bt++) { + for (uint64_t c_group = 0; c_group < num_c_groups; c_group++) { + // todo - passing c_group/inputs_cpu[bt] in data to avoid a second hash lookup is a bit hacky + uint64_t data = bt + (c_group<<32ULL) + ((uint64_t)inputs_cpu[bt]<<42ULL); + buckets[c_group + num_c_groups * inputs_cpu[bt]].push_back(data); + total_items++; + } + } + + // Step 2: Sort buckets by size in descending order + // this is so the largest buckets are processed first by the GPU + // otherwise, if they started late, they would still be running with the rest of the GPU idle + std::vector>> sortedBuckets(buckets.begin(), buckets.end()); + std::sort(sortedBuckets.begin(), sortedBuckets.end(), // ugly because we don't have a typedef for the std::pair + [](const std::pair>& a, const std::pair>& b) { + return a.second.size() > b.second.size(); + }); + + int num_buckets = buckets.size(); + int bucket_index = 0; + int workload_index = 0; + for (const auto& bucket : sortedBuckets) { + bucket_info[bucket_index].x = workload_index; // bucket start + bucket_info[bucket_index].y = bucket.second.size(); // bucket size + bucket_info[bucket_index].z = (bucket.second[0] >> 42ULL) & ((1ULL<<20ULL)-1); // bucket ix + bucket_info[bucket_index].w = (bucket.second[0] >> 32ULL) & ((1ULL<<10ULL)-1); // bucket c + + for (uint64_t idx : bucket.second) { + workload_indices[workload_index++] = (int)(idx & ((1ULL<<31ULL)-1ULL)); + } + bucket_index++; + } + + // Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice) + // todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely + int4* d_bucket_info = (int4*)scratch; + int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * sizeof(int4)); + cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice); + cudaMemcpy(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice); + + // Launch wte kernel + // todo - profile block sizes on more content (depends on number of buckets and on GPU?) + wte_backward_kernel<256><<>>(dwte, d_bucket_info, d_workload_indices, dout, inp, seed, B, T, C); cudaCheck(cudaGetLastError()); } @@ -1947,6 +2063,9 @@ typedef struct { unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. int use_master_weights; int recompute; + // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? + int* workload_indices; // encoder_backward, B*T*num_c_groups (int) + int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case } GPT2; void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { @@ -2022,6 +2141,8 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->inputs = NULL; model->targets = NULL; model->cpu_losses = NULL; + model->workload_indices = NULL; + model->bucket_info = NULL; model->batch_size = 0; model->seq_len = 0; model->mean_loss = -1.0f; // -1.0f will designate no loss @@ -2195,7 +2316,7 @@ void gpt2_zero_grad(GPT2 *model) { } } -void gpt2_backward(GPT2 *model) { +void gpt2_backward(GPT2 *model, int* inputs) { NVTX_RANGE_FN(); // double check we forwarded previously, with targets if (model->mean_loss == -1.0f) { @@ -2221,6 +2342,11 @@ void gpt2_backward(GPT2 *model) { model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes); // init gradients of parameters and activations to zero gpt2_zero_grad(model); + // initialise cpu scratch buffers for encoder backward + size_t num_c_groups = model->config.channels / (WARP_SIZE * x128::size); + assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?) + model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups); + model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups); } // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow @@ -2241,7 +2367,8 @@ void gpt2_backward(GPT2 *model) { cudaCheck(cudaMemset(model->grads_acts.residual3, 0, B * T * C * sizeof(floatX))); // re-use the output buffer of the forward pass as a scratchpad during backward pass - float* scratchF = (float*)acts.output; + float* scratchF = (float*)acts.output; + floatX* scratchX = (floatX*)acts.output; // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) // this was done in the fused classifier kernel as last step of forward pass @@ -2323,7 +2450,6 @@ void gpt2_backward(GPT2 *model) { floatX* buffer_a = l_atty; floatX* buffer_b = l_fch; // this is B x T x 4C, so even larger than what we need floatX* dl_preatt = (floatX*)grads_acts.preatt; // dedicated scratchpad allocation - floatX* scratchX = (floatX*)acts.output; attention_backward(dl_bt4c, buffer_b, dl_preatt, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH); #endif @@ -2332,7 +2458,8 @@ void gpt2_backward(GPT2 *model) { // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); } - encoder_backward(grads.wte, grads.wpe, dresidual, model->inputs, B, T, C, random_u32(&model->rng_state)); + encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, + dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state)); } // Compute a mean of a single CPU value across all GPU processes. No-op when multi-GPU is disabled. @@ -2448,6 +2575,8 @@ void gpt2_free(GPT2 *model) { cudaCheck(cudaFree(model->inputs)); cudaCheck(cudaFree(model->targets)); cudaFreeHost(model->cpu_losses); + free(model->workload_indices); + free(model->bucket_info); } // ---------------------------------------------------------------------------- @@ -2477,7 +2606,7 @@ void common_free(GPT2 &model) { cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); cublasCheck(cublasLtDestroy(cublaslt_handle)); - create_cudnn(); + destroy_cudnn(); } #ifndef TESTING @@ -2880,7 +3009,7 @@ int main(int argc, char *argv[]) { gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, grad_accum_steps); lossf += model.mean_loss; // the mean_loss was normalized by grad_accum_steps inside gpt2_forward // backward pass. all model params accumulate gradients with += inside this inner loop - gpt2_backward(&model); + gpt2_backward(&model, train_loader.inputs); } // override the mean loss, accounting for the gradient accumulation loop // this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced