From 3b5933ecfb9dca85e5663effdb662092aac11a7f Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 14 May 2024 22:32:24 +0300 Subject: [PATCH 01/17] considerably speed up CPU matmul while still keeping it relatively readable --- train_gpt2.c | 70 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/train_gpt2.c b/train_gpt2.c index 9706a2c0b..06cdfbb54 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -158,32 +158,76 @@ void layernorm_backward(float* dinp, float* dweight, float* dbias, } } +void matmul_forward_slow(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC) { + // basic implementation of matrix multiplication. This serves as a fallback + // for bad input shapes, and as an illustration for the most basic version + // of the algorithm. +#pragma omp parallel for collapse(2) + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + int bt = b * T + t; + for (int o = 0; o < OC; o++) { + float val = (bias != NULL) ? bias[o] : 0.0f; + for (int i = 0; i < C; i++) { + val += inp[bt * C + i] * weight[o*C + i]; + } + out[bt * OC + o] = val; + } + } + } +} + void matmul_forward(float* out, - float* inp, float* weight, float* bias, + const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC) { // most of the running time is spent here and in matmul_backward // OC is short for "output channels" // inp is (B,T,C), weight is (OC, C), bias is (OC) // out will be (B,T,OC) - #pragma omp parallel for collapse(2) - for (int b = 0; b < B; b++) { - for (int t = 0; t < T; t++) { - float* out_bt = out + b * T * OC + t * OC; - float* inp_bt = inp + b * T * C + t * C; - for (int o = 0; o < OC; o++) { - float val = (bias != NULL) ? bias[o] : 0.0f; - float* wrow = weight + o*C; - for (int i = 0; i < C; i++) { - val += inp_bt[i] * wrow[i]; + + // make sure the tiled loop will be correct, otherwise, fallback to slow version + const int LOOP_UNROLL = 8; + if (B*T % LOOP_UNROLL != 0) { + matmul_forward_slow(out, inp, weight, bias, B, T, C, OC); + return; + } + + // collapse the B and T loops into one and turn it into a strided loop. + // then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times + // for significant speed-ups. + #pragma omp parallel for + for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) { + for (int o = 0; o < OC; o++) { + // keep LOOP_UNROLL many results in register, initialized by the bias term. + float result[LOOP_UNROLL]; + for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { + result[ibt] = (bias != NULL) ? bias[o] : 0.0f; + } + + // inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache + // the value of weight[i + o * C] and reuse it. + // we compile with -Ofast, so the compiler will turn the inner loop into a bunch of FMAs + for (int i = 0; i < C; i++) { + float w = weight[i + o * C]; + for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { + int bt = obt + ibt; + result[ibt] += inp[bt * C + i] * w; } - out_bt[o] = val; + } + + // write back results to main memory + for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { + int bt = obt + ibt; + out[bt * OC + o] = result[ibt]; } } } } void matmul_backward(float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, + const float* dout, const float* inp, const float* weight, int B, int T, int C, int OC) { // most of the running time is spent here and in matmul_forward // this backward could be done in a single "round" of loops From b2a5508b84a0db561e371ef0092050e33c245a29 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 14 May 2024 22:48:59 +0300 Subject: [PATCH 02/17] constness fixes --- train_gpt2.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_gpt2.c b/train_gpt2.c index 06cdfbb54..0c5583e5e 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -237,10 +237,10 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* dout_bt = dout + b * T * OC + t * OC; + const float* dout_bt = dout + b * T * OC + t * OC; float* dinp_bt = dinp + b * T * C + t * C; for (int o = 0; o < OC; o++) { - float* wrow = weight + o*C; + const float* wrow = weight + o*C; float d = dout_bt[o]; for (int i = 0; i < C; i++) { dinp_bt[i] += wrow[i] * d; @@ -253,8 +253,8 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, for (int o = 0; o < OC; o++) { for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* dout_bt = dout + b * T * OC + t * OC; - float* inp_bt = inp + b * T * C + t * C; + const float* dout_bt = dout + b * T * OC + t * OC; + const float* inp_bt = inp + b * T * C + t * C; float* dwrow = dweight + o*C; float d = dout_bt[o]; if (dbias != NULL) { dbias[o] += d; } From 6348d4196d6857244d7833988c405e44afe578d7 Mon Sep 17 00:00:00 2001 From: lancer Date: Sun, 19 May 2024 17:39:25 -0700 Subject: [PATCH 03/17] fix the unsupported block_size --- dev/cuda/matmul_backward_bias.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index 12b167083..52d793ac7 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -421,6 +421,9 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s // version1: simple cuBLAS calls void matmul_backward_bias1(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { + if (block_size == 768) { + block_size = 1024; // block_size needs to be power of 2 due to the reduction + } dim3 block_dim(block_size); dim3 grid_dim(OC); size_t shared_mem_size = block_size * sizeof(float); From 2b0667aee15151622797d6bc209eec8f4742f3a7 Mon Sep 17 00:00:00 2001 From: lancer Date: Mon, 20 May 2024 08:00:39 -0700 Subject: [PATCH 04/17] update the utils function and assert --- dev/cuda/matmul_backward_bias.cu | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index 52d793ac7..16172bcf2 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -27,6 +27,26 @@ sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1 #define ENABLE_BF16 #include "common.h" + +// ---------------------------------------------------------------------------- +// utility functions +__host__ __device__ bool isPowerOfTwo(int n) { + return (n > 0) && ((n & (n - 1)) == 0); +} + +__host__ __device__ int largestPowerOfTwoLessOrEqual(int n) { + // Return the largest power of 2 less than or equal to n + if (n < 1) { + return 0; + } + + while ((n & (n - 1)) > 0) { + n = n & (n - 1); + } + + return n; +} + // ---------------------------------------------------------------------------- // CPU code reference @@ -421,9 +441,8 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s // version1: simple cuBLAS calls void matmul_backward_bias1(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { - if (block_size == 768) { - block_size = 1024; // block_size needs to be power of 2 due to the reduction - } + block_size = largestPowerOfTwoLessOrEqual(block_size); + assert(isPowerOfTwo(block_size)); // block_size needs to be power of 2 due to the reduction dim3 block_dim(block_size); dim3 grid_dim(OC); size_t shared_mem_size = block_size * sizeof(float); From b5e75dde8e8d20b13177b060f3ed364bbe50eb12 Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 21 May 2024 15:57:07 +0100 Subject: [PATCH 05/17] Fully deterministic encoder backward kernels for train_gpt2.cu --- profile_gpt2.cu | 2 +- test_gpt2.cu | 2 +- train_gpt2.cu | 245 ++++++++++++++++++++++++++++++++++++------------ 3 files changed, 189 insertions(+), 60 deletions(-) diff --git a/profile_gpt2.cu b/profile_gpt2.cu index f2ac0e84c..f79e9ada4 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -54,7 +54,7 @@ int main(int argc, char *argv[]) { // do a training step gpt2_forward(&model, x, y, B, T); gpt2_zero_grad(&model); - gpt2_backward(&model); + gpt2_backward(&model, x); gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings diff --git a/test_gpt2.cu b/test_gpt2.cu index 50a291f18..d06734507 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -203,7 +203,7 @@ int main(int argc, char *argv[]) { clock_gettime(CLOCK_MONOTONIC, &start); gpt2_forward(&model, x, y, B, T); gpt2_zero_grad(&model); - gpt2_backward(&model); + gpt2_backward(&model, x); clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; diff --git a/train_gpt2.cu b/train_gpt2.cu index 1e8b54be2..899293f75 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -38,6 +38,9 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200), #include #include #include +#include +#include +#include // GPU / CUDA related #include #include @@ -532,50 +535,108 @@ __global__ void encoder_forward_kernel3(floatX* out, store128(out_btc, packed_out); } -template -__device__ void atomicStochasticAdd(T* address, float val0, float val1, unsigned int seed) { - static_assert(sizeof(T) == 2, "Only 16-bit atomicStochasticAdd supported."); - float2 val = make_float2(val0, val1); - unsigned int* address_as_uint = (unsigned int*)address; - unsigned int old = *address_as_uint, assumed; - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); - do { - assumed = old; - float2 new_fp32 = make_float2((float)(reinterpret_cast(&old)[0]) + val.x, - (float)(reinterpret_cast(&old)[1]) + val.y); - T new_rounded[2]; - stochastic_rounding(new_fp32.x, &new_rounded[0], random); - stochastic_rounding(new_fp32.y, &new_rounded[1], random >> 16); - old = atomicCAS(address_as_uint, assumed, *(unsigned int*)&new_rounded); - } while (assumed != old); -} -__device__ void atomicStochasticAdd(float* address, float val0, float val1, unsigned int seed) { - atomicAdd(address, val0); - atomicAdd(address + 1, val1); -} - -__global__ void encoder_backward_kernel(floatX* dwte, floatX* dwpe, - const floatX* dout, const int* inp, - int B, int T, int C, unsigned int seed) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int N = B * T * C; - idx *= 2; // 2 elements per thread - if (idx >= N) { return; } +template +__global__ void wte_backward_kernel(floatX* dwte, + const int4* bucket_info, const int* workload_indices, const floatX* dout, const int* inp, + unsigned int seed, int B, int T, int C) { + // In order to be deterministic, we preprocess the inputs on the cpu into "buckets" + // Each bucket corresponds to (WARP_SIZE * x128::size) channels for a single vocabulary token + // Each thread handles x128::size channels, e.g. 256 per warp for BF16 + // Each block handles (BLOCK_SIZE / WARP_SIZE) elements in a single bucket in parallel + // If a bucket has less than 8 elements, some warps will return immediately + // If a bucket has more than 8 elements, we will loop over all of them + // The buckets are sorted on the CPU so the largest buckets start 1st + int bucket = blockIdx.x; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int c_per_warp = WARP_SIZE * x128::size; + + int bucket_start_idx = bucket_info[bucket].x; + int bucket_size = bucket_info[bucket].y; + int bucket_ix = bucket_info[bucket].z; + int c = bucket_info[bucket].w * c_per_warp + (lane_id * x128::size); + + // Each thread handles "x128::size" channels, so at fp8, each warp would handle 512 channels + // If C is not a multiple of this (e.g. 768), some buckets/c_groups cannot use the entire warp + if (c >= C) { return; } + // Exit early if this is a small bucket and this warp doesn't have any items to process + if (warp_id >= bucket_size) { return; } + + float accum[x128::size] = {0.0f}; + __shared__ float accum_shared[x128::size * BLOCK_SIZE]; + + for(int item = warp_id; item < bucket_size; item += BLOCK_SIZE/WARP_SIZE) { + int bt = workload_indices[bucket_start_idx + item]; + int b = bt / T; + int t = bt % T; + + const floatX* dout_btc = dout + b * T * C + t * C + c; + x128 packed_inp1 = load128cs(dout_btc); + for (int k = 0; k < packed_inp1.size; k++) { + accum[k] += (float)packed_inp1[k]; + } + } - int bt = idx / C; - int b = bt / T; - int t = bt % T; - int c = idx % C; + if (warp_id != 0) { + // we accumulate into warp 0, so only the other warps need to write to shared memory + for (int k = 0; k < x128::size; k++) { + accum_shared[threadIdx.x + k * BLOCK_SIZE] = accum[k]; + } + return; // only warp 0 is needed after writing to shared memory + } - int ix = inp[b * T + t]; + // Read dwte for warp 0 even if other warps are not finished yet to maximise latency tolerance + floatX* dwte_ix = dwte + bucket_ix * C + c; + x128 packed_in_out = load128(dwte_ix); - const floatX* dout_btc = dout + b * T * C + t * C + c; - floatX* dwte_ix = dwte + ix * C + c; - floatX* dwpe_tc = dwpe + t * C + c; + // note: threads which have returned are considered synchronised by CUDA so no risk of deadlock + __syncthreads(); - float2 dout_data = make_float2(dout_btc[0], dout_btc[1]); - atomicStochasticAdd(dwte_ix, dout_data.x, dout_data.y, seed); - atomicStochasticAdd(dwpe_tc, dout_data.x, dout_data.y, seed ^ 0xFFFFFFFF); + // Accumulate into warp 0's registers by reading the values of the other warps in shared memory + for (int i = threadIdx.x+WARP_SIZE; i < min(BLOCK_SIZE, bucket_size*WARP_SIZE); i += WARP_SIZE) { + for (int k = 0; k < x128::size; k++) { + accum[k] += accum_shared[i + k * BLOCK_SIZE]; + } + } + + // Add the result to dwte and write back to global memory (read-modify-write) + for (unsigned int k = 0; k < x128::size; k++) { + // We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic + stochastic_rounding(accum[k] + (float)packed_in_out[k], &packed_in_out[k], seed + k); + } + store128(dwte_ix, packed_in_out); +} + +__global__ void wpe_backward_kernel(floatX* dwpe, + const floatX* dout, const int* inp, + int B, int T, int C, unsigned int seed) { + // Each thread handles x128::size "channel positions", e.g. 256 per warp for BF16 + // For gpt2-124M BF16, C=768 and T=1024, so 3 warps per channel and 3072 warps in total + // For each "channel position" we sum the gradients for every batch at that C/T element + // This way each dwte element is only updated once, and the kernel is fully deterministic! + // The previous kernel was not deterministic, as batches were aggregated with atomicAdd + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + if (idx >= T * C) { return; } + + // if C is not a multiple of WARP_SIZE*x128::size, it's OK for some warps to handle multiple t + int t = idx / C; + int c = idx % C; + float accum[x128::size] = {0.0f}; + + for (int b = 0; b < B; b++) { + x128 packed_dout = load128cs(dout + (b * T * C) + (t * C) + c); // will never be read again + for (int k = 0; k < x128::size; k++) { + accum[k] += (float)packed_dout[k]; + } + } + + floatX* dwpe_tc = dwpe + (t * C) + c; + x128 packed_dwpe = load128(dwpe_tc); + for (unsigned int k = 0; k < x128::size; k++) { + // We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic + stochastic_rounding(accum[k] + (float)packed_dwpe[k], &packed_dwpe[k], seed + k); + } + store128(dwpe_tc, packed_dwpe); } __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, @@ -783,10 +844,9 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons // directly autoregressive, so we only compute the lower triangular part // uses the online softmax algorithm assert(T % 4 == 0); - const int warp_size = 32; - int lane_id = threadIdx.x % warp_size; - int warp_id = threadIdx.x / warp_size; - int num_warps = blockDim.x / warp_size; + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = blockDim.x / WARP_SIZE; // micro-optimization: we iterate backwards so that // after the softmax backward operation completes, the cache retains the @@ -809,7 +869,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons float sumval = 0.0f; const floatX* x_aligned = reinterpret_cast(__builtin_assume_aligned(x, 16)); - for (int i = lane_id; i < pos_by_4; i += warp_size) { + for (int i = lane_id; i < pos_by_4; i += WARP_SIZE) { float regarray[4]; for (int k = 0; k < 4; ++k) { regarray[k] = (float)x_aligned[4*i + k]; @@ -838,7 +898,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons float norm = 1.f / sum; // divide the whole row by the sum - for (int i = lane_id; i <= own_pos; i += warp_size) { + for (int i = lane_id; i <= own_pos; i += WARP_SIZE) { // recalculation is faster than doing the round-trip through memory. float ev = expf(inv_temperature * ((float)__ldcs(x + i) - global_maxval)); __stcs(out + idx * T + i, (floatX)(ev * norm)); @@ -1354,14 +1414,70 @@ void encoder_forward(floatX* out, cudaCheck(cudaGetLastError()); } -void encoder_backward(floatX* dwte, floatX* dwpe, - const floatX* dout, const int* inp, - int B, int T, int C, unsigned int seed) { +// Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details) +void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch + int* workload_indices, int4* bucket_info, // cpu scratch buffers + const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs + int B, int T, int C, unsigned int seed) { NVTX_RANGE_FN(); - const int N = B * T * C; + + // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) const int block_size = 256; - const int grid_size = CEIL_DIV(N, block_size * 2); // each thread handles 2 elements - encoder_backward_kernel<<>>(dwte, dwpe, dout, inp, B, T, C, seed); + const int N = T * C / x128::size; + const int grid_size = CEIL_DIV(N, block_size); + wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); + + // check the GPU scratch buffer is large enough to hold the bucket info and workload indices + // todo - this is trivially true given hardcoded scratch buffer size here, is this useful? + int num_c_groups = CEIL_DIV(C, x128::size * WARP_SIZE); + assert(B*T*num_c_groups * (sizeof(int4)+sizeof(int)) <= B*T*3*C * sizeof(floatX)); + + // Step 1: Sort inputs into buckets + int total_items = 0; + std::unordered_map> buckets; + for (uint64_t bt = 0; bt < B * T; bt++) { + for (uint64_t c_group = 0; c_group < num_c_groups; c_group++) { + // todo - passing c_group/inputs_cpu[bt] in data to avoid a second hash lookup is a bit hacky + uint64_t data = bt + (c_group<<32ULL) + ((uint64_t)inputs_cpu[bt]<<42ULL); + buckets[c_group + num_c_groups * inputs_cpu[bt]].push_back(data); + total_items++; + } + } + + // Step 2: Sort buckets by size in descending order + // this is so the largest buckets are processed first by the GPU + // otherwise, if they started late, they would still be running with the rest of the GPU idle + std::vector>> sortedBuckets(buckets.begin(), buckets.end()); + std::sort(sortedBuckets.begin(), sortedBuckets.end(), // ugly because we don't have a typedef for the std::pair + [](const std::pair>& a, const std::pair>& b) { + return a.second.size() > b.second.size(); + }); + + int num_buckets = buckets.size(); + int bucket_index = 0; + int workload_index = 0; + for (const auto& bucket : sortedBuckets) { + bucket_info[bucket_index].x = workload_index; // bucket start + bucket_info[bucket_index].y = bucket.second.size(); // bucket size + bucket_info[bucket_index].z = (bucket.second[0] >> 42ULL) & ((1ULL<<20ULL)-1); // bucket ix + bucket_info[bucket_index].w = (bucket.second[0] >> 32ULL) & ((1ULL<<10ULL)-1); // bucket c + + for (uint64_t idx : bucket.second) { + workload_indices[workload_index++] = (int)(idx & ((1ULL<<31ULL)-1ULL)); + } + bucket_index++; + } + + // Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice) + // todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely + int4* d_bucket_info = (int4*)scratch; + int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * sizeof(int4)); + cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice); + cudaMemcpy(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice); + + // Launch wte kernel + // todo - profile block sizes on more content (depends on number of buckets and on GPU?) + wte_backward_kernel<256><<>>(dwte, d_bucket_info, d_workload_indices, dout, inp, seed, B, T, C); cudaCheck(cudaGetLastError()); } @@ -1947,6 +2063,9 @@ typedef struct { unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. int use_master_weights; int recompute; + // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? + int* workload_indices; // encoder_backward, B*T*num_c_groups (int) + int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case } GPT2; void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { @@ -2022,6 +2141,8 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->inputs = NULL; model->targets = NULL; model->cpu_losses = NULL; + model->workload_indices = NULL; + model->bucket_info = NULL; model->batch_size = 0; model->seq_len = 0; model->mean_loss = -1.0f; // -1.0f will designate no loss @@ -2195,7 +2316,7 @@ void gpt2_zero_grad(GPT2 *model) { } } -void gpt2_backward(GPT2 *model) { +void gpt2_backward(GPT2 *model, int* inputs) { NVTX_RANGE_FN(); // double check we forwarded previously, with targets if (model->mean_loss == -1.0f) { @@ -2221,6 +2342,11 @@ void gpt2_backward(GPT2 *model) { model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes); // init gradients of parameters and activations to zero gpt2_zero_grad(model); + // initialise cpu scratch buffers for encoder backward + size_t num_c_groups = model->config.channels / (WARP_SIZE * x128::size); + assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?) + model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups); + model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups); } // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow @@ -2241,7 +2367,8 @@ void gpt2_backward(GPT2 *model) { cudaCheck(cudaMemset(model->grads_acts.residual3, 0, B * T * C * sizeof(floatX))); // re-use the output buffer of the forward pass as a scratchpad during backward pass - float* scratchF = (float*)acts.output; + float* scratchF = (float*)acts.output; + floatX* scratchX = (floatX*)acts.output; // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) // this was done in the fused classifier kernel as last step of forward pass @@ -2323,7 +2450,6 @@ void gpt2_backward(GPT2 *model) { floatX* buffer_a = l_atty; floatX* buffer_b = l_fch; // this is B x T x 4C, so even larger than what we need floatX* dl_preatt = (floatX*)grads_acts.preatt; // dedicated scratchpad allocation - floatX* scratchX = (floatX*)acts.output; attention_backward(dl_bt4c, buffer_b, dl_preatt, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH); #endif @@ -2332,7 +2458,8 @@ void gpt2_backward(GPT2 *model) { // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); } - encoder_backward(grads.wte, grads.wpe, dresidual, model->inputs, B, T, C, random_u32(&model->rng_state)); + encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, + dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state)); } // Compute a mean of a single CPU value across all GPU processes. No-op when multi-GPU is disabled. @@ -2448,6 +2575,8 @@ void gpt2_free(GPT2 *model) { cudaCheck(cudaFree(model->inputs)); cudaCheck(cudaFree(model->targets)); cudaFreeHost(model->cpu_losses); + free(model->workload_indices); + free(model->bucket_info); } // ---------------------------------------------------------------------------- @@ -2477,7 +2606,7 @@ void common_free(GPT2 &model) { cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); cublasCheck(cublasLtDestroy(cublaslt_handle)); - create_cudnn(); + destroy_cudnn(); } #ifndef TESTING @@ -2880,7 +3009,7 @@ int main(int argc, char *argv[]) { gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, grad_accum_steps); lossf += model.mean_loss; // the mean_loss was normalized by grad_accum_steps inside gpt2_forward // backward pass. all model params accumulate gradients with += inside this inner loop - gpt2_backward(&model); + gpt2_backward(&model, train_loader.inputs); } // override the mean loss, accounting for the gradient accumulation loop // this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced From a3801f01efae3434d6e4cdbef3dd455fcc10404f Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 21 May 2024 16:53:11 +0100 Subject: [PATCH 06/17] added algorithm header for std::sort on windows (not sure about compile time impact...) --- train_gpt2.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/train_gpt2.cu b/train_gpt2.cu index 899293f75..16f8a4216 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -39,6 +39,7 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200), #include #include #include +#include #include #include // GPU / CUDA related From 7d0891f6ddebebeefc8a9a5c3f319484aa31f1d5 Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 21 May 2024 22:37:09 +0100 Subject: [PATCH 07/17] Fully deterministic layernorm (slight perf loss) --- train_gpt2.cu | 110 +++++++++++++++++++++++++++++++++++++------------- 1 file changed, 82 insertions(+), 28 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 1e8b54be2..6c60b8a74 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -980,30 +980,34 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s } } -__global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with only 1024 threads? - layernorm_backward_kernel8(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, +__global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? + layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { + constexpr int BLOCK_SIZE = 512; + constexpr int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block extern __shared__ float shared[]; // size = 2 * C + 1 + int warpId = threadIdx.x / WARP_SIZE; // warp index within a block - int warpsInBlock = blockDim.x / WARP_SIZE; //number of warps in block int baseIdx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp int warpsInGrid = gridDim.x * warpsInBlock; int C_per_iteration = WARP_SIZE * x128::size; - int iterations_C = C / C_per_iteration; + int iterations_C = CEIL_DIV(C, C_per_iteration); // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; + float* dbias_tmp_shared = shared + 2 * C; + float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE; // init shared memory to zero - for(int i = threadIdx.x; i < C; i+= blockDim.x){ + for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } - unsigned int *tmp_flag = (unsigned int*)(shared + C*2); + unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE); __syncthreads(); for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { @@ -1041,6 +1045,10 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with for (int i = 0; i < iterations_C; i++) { int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + x128 dout128 = load128cs(dout_bt + global_index); x128 inp128 = load128cs(inp_bt + global_index); x128 dinp128 = load128(dinp_bt + global_index); @@ -1050,10 +1058,29 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with float dout_i = (float)dout128[x]; float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; float dnorm_i = (float)weight128[x] * dout_i; - // gradient contribution to bias (using shared memory friendly index) - atomicAdd(&dbias_shared[shared_index + x*WARP_SIZE], dout_i); - // gradient contribution to weight (using shared memory friendly index) - atomicAdd(&dweight_shared[shared_index + x*WARP_SIZE], norm_bti * dout_i); + + // sum up the gradients for bias and weight across the entire block + // this is basically a reduction (but only inter-warp, not intra-warp) + // doing it this way allows us to avoid using atomics while using many warps + if (warpId != 0) { + dbias_tmp_shared[threadIdx.x] = dout_i; + dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i; + } + __syncthreads(); + if (warpId == 0) { + float dbias_tmp = dout_i; + float dweight_tmp = norm_bti * dout_i; + for (int j = 1; j < warpsInBlock; j++) { + dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE]; + dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE]; + } + // gradient contribution to bias (using shared memory friendly index) + dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp; + // gradient contribution to weight (using shared memory friendly index) + dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp; + } + __syncthreads(); + // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 @@ -1066,35 +1093,64 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with store128cg(dinp_bt + global_index, dinp128); } } - // Accumulate into a FP32 scratchpad - // BF16 atomics are potentially much slower... and this is more precise! - // todo - could potentially avoid the extra copy if floatX is FP32, fairly negligible though __syncthreads(); + // Each block writes its partial sum to global memory + // The last block to finish becomes responsible for summing up all the partial sums + // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + unsigned int* scratchFlag = (unsigned int*)(scratch); + // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned + scratch += 32; float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; - unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C)); - for(int i = threadIdx.x; i < C; i+= blockDim.x) { - // global atomics in the same "shared memory banking friendly" order - atomicAdd(&scratch_dbias[i], dbias_shared[i]); - atomicAdd(&scratch_dweight[i], dweight_shared[i]); + for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) { + // Write to global memory in the same "shared memory banking friendly" order + scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i]; + scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i]; } + __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicInc(scratchFlag, gridDim.x); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { + // Reduction of the partial sums by the final block + // todo - there isn't enough parallelism even inside that single SM... + // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! + for(int i = threadIdx.x * f128::size; i < C; i+= BLOCK_SIZE * f128::size) { + f128 dbias_accum(make_int4(0, 0, 0, 0)); + f128 dweight_accum(make_int4(0, 0, 0, 0)); + + for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { + int offset = i + 2*C*read_block_idx; + f128 dbias128 = load128(scratch_dbias + offset); + f128 dweight128 = load128(scratch_dweight + offset); + for(int k = 0; k < f128::size; k++) { + dbias_accum[k] += dbias128[k]; + dweight_accum[k] += dweight128[k]; + } + } + store128(dbias_shared + i, dbias_accum); + store128(dweight_shared + i, dweight_accum); + } + __syncthreads(); + + // reorder from atomic/shared memory-friendly index to real global memory index + // and convert from float/FP32 to floatX/BF16 for the final write + // this is separate also because it cannot use as many warps as the above (f128 vs x128) + // todo - if we split this code into another kernel, we could maybe do it at the same time? for (int i = warpId; i < iterations_C; i += warpsInBlock) { - // reorder from atomic/shared memory-friendly index to real global memory index - // and convert from float/FP32 to floatX/BF16 for the final write int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } x128 dbias128 = load128(dbias + global_index); x128 dweight128 = load128(dweight + global_index); for (int x = 0; x < x128::size; x++) { - float s_db = scratch_dbias[shared_index + x*WARP_SIZE]; - float s_dw = scratch_dweight[shared_index + x*WARP_SIZE]; + float s_db = dbias_shared[shared_index + x*WARP_SIZE]; + float s_dw = dweight_shared[shared_index + x*WARP_SIZE]; dbias128[x] = (floatX)(s_db + (float)dbias128[x]); dweight128[x] = (floatX)(s_dw + (float)dweight128[x]); } @@ -1603,15 +1659,13 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { NVTX_RANGE_FN(); - // todo - forcing 3 x 512 threads per SM maximum is a bit hacky, but more than that results in - // cache thrashing and lower performance on A100... is there a better way? const int block_size = 512; - const int blocks_per_sm = min(3, (deviceProp.maxThreadsPerMultiProcessor / 1024)); + const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3 const int grid_size = blocks_per_sm * deviceProp.multiProcessorCount; - size_t shared_mem_size = (2 * C + 1) * sizeof(float); + size_t shared_mem_size = (2*C + 2*block_size + 1) * sizeof(float); // see kernel - cudaMemset(scratch, 0, (2 * C + 1) * sizeof(float)); - layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + cudaMemset(scratch, 0, 1 * sizeof(float)); // only need to reset the flag to 0 + layernorm_backward_kernel9<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); cudaCheck(cudaGetLastError()); } From 7cbeefc7f371412bbaca2990abbf5873bb8547ae Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 21 May 2024 23:26:54 +0100 Subject: [PATCH 08/17] added new layernorm backward to /dev/cuda/ --- dev/cuda/layernorm_backward.cu | 199 ++++++++++++++++++++++++++++++++- train_gpt2.cu | 20 ++-- 2 files changed, 206 insertions(+), 13 deletions(-) diff --git a/dev/cuda/layernorm_backward.cu b/dev/cuda/layernorm_backward.cu index 90dcb1674..d9502880b 100644 --- a/dev/cuda/layernorm_backward.cu +++ b/dev/cuda/layernorm_backward.cu @@ -856,6 +856,185 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) } } +__global__ void layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const floatX* mean, const floatX* rstd, + int B, int T, int C) { + constexpr int WARP_SIZE = 32; + int BLOCK_SIZE = blockDim.x; + int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block + extern __shared__ float shared[]; // size = 2 * C + 1 + + int warpId = threadIdx.x / WARP_SIZE; // warp index within a block + int baseIdx = blockIdx.x * warpsInBlock + warpId; + int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp + int warpsInGrid = gridDim.x * warpsInBlock; + int C_per_iteration = WARP_SIZE * x128::size; + int iterations_C = ceil_div(C, C_per_iteration) + 2; + + // the first half of shared memory is bias, second is weight + float* dbias_shared = shared; + float* dweight_shared = shared + C; + float* dbias_tmp_shared = shared + 2 * C; + float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE; + + // init shared memory to zero + for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){ + dbias_shared[i] = 0.0f; + dweight_shared[i] = 0.0f; + } + unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE); + __syncthreads(); + + for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { + int b = idx / T; + int t = idx % T; + + const floatX* dout_bt = dout + b * T * C + t * C; + const floatX* inp_bt = inp + b * T * C + t * C; + floatX* dinp_bt = dinp + b * T * C + t * C; + const float mean_bt = (float)mean[b * T + t]; + const float rstd_bt = (float)rstd[b * T + t]; + + // first: two reduce operations + float dnorm_mean = 0.0f; + float dnorm_norm_mean = 0.0f; + for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { + x128 dout128_i = load128(dout_bt + i); + x128 inp128_i = load128(inp_bt + i); + x128 weight128_i = load128(weight + i); + for (int k = 0; k < x128::size; k++) { + float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + dnorm_mean += dnorm_i; + dnorm_norm_mean += dnorm_i * norm_bti; + } + } + dnorm_mean = warpReduceSum(dnorm_mean) / C; + dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C; + + // now iterate again and accumulate all the gradients + // unfortunately we cannot use the same index for x128 arrays and shared memory + // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper) + // so this would result in an 8-way bank conflict, and kill performance + // so instead, we use a shared memory friendly index, and reorder before the final write + for (int i = 0; i < iterations_C; i++) { + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + + x128 dout128 = load128cs(dout_bt + global_index); + x128 inp128 = load128cs(inp_bt + global_index); + x128 dinp128 = load128(dinp_bt + global_index); + x128 weight128 = load128(weight + global_index); + + for (int x = 0; x < x128::size; x++) { + float dout_i = (float)dout128[x]; + float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128[x] * dout_i; + + // sum up the gradients for bias and weight across the entire block + // this is basically a reduction (but only inter-warp, not intra-warp) + // doing it this way allows us to avoid using atomics while using many warps + if (warpId != 0) { + dbias_tmp_shared[threadIdx.x] = dout_i; + dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i; + } + __syncthreads(); + if (warpId == 0) { + float dbias_tmp = dout_i; + float dweight_tmp = norm_bti * dout_i; + for (int j = 1; j < warpsInBlock; j++) { + dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE]; + dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE]; + } + // gradient contribution to bias (using shared memory friendly index) + dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp; + // gradient contribution to weight (using shared memory friendly index) + dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp; + } + __syncthreads(); + + // gradient contribution to input + float dval = 0.0f; + dval += dnorm_i; // term 1 + dval -= dnorm_mean; // term 2 + dval -= norm_bti * dnorm_norm_mean; // term 3 + dval *= rstd_bt; // final scale + dinp128[x] = (floatX)((float)dinp128[x] + dval); + } + // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing + store128cg(dinp_bt + global_index, dinp128); + } + } + __syncthreads(); + // Each block writes its partial sum to global memory + // The last block to finish becomes responsible for summing up all the partial sums + // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + unsigned int* scratchFlag = (unsigned int*)(scratch); + // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned + scratch += 32; + float* scratch_dbias = scratch; + float* scratch_dweight = scratch + C; + for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) { + // Write to global memory in the same "shared memory banking friendly" order + scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i]; + scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i]; + } + __syncthreads(); + if (threadIdx.x == 0) { + *tmp_flag = atomicInc(scratchFlag, gridDim.x); + } + __syncthreads(); + if (*tmp_flag == gridDim.x-1) { + // Reduction of the partial sums by the final block + // todo - there isn't enough parallelism even inside that single SM... + // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! + for(int i = threadIdx.x * f128::size; i < C; i+= BLOCK_SIZE * f128::size) { + f128 dbias_accum(make_int4(0, 0, 0, 0)); + f128 dweight_accum(make_int4(0, 0, 0, 0)); + + for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { + int offset = i + 2*C*read_block_idx; + f128 dbias128 = load128(scratch_dbias + offset); + f128 dweight128 = load128(scratch_dweight + offset); + for(int k = 0; k < f128::size; k++) { + dbias_accum[k] += dbias128[k]; + dweight_accum[k] += dweight128[k]; + } + } + store128(dbias_shared + i, dbias_accum); + store128(dweight_shared + i, dweight_accum); + } + __syncthreads(); + + // reorder from atomic/shared memory-friendly index to real global memory index + // and convert from float/FP32 to floatX/BF16 for the final write + // this is separate also because it cannot use as many warps as the above (f128 vs x128) + // todo - if we split this code into another kernel, we could maybe do it at the same time? + for (int i = warpId; i < iterations_C; i += warpsInBlock) { + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + + x128 dbias128 = load128(dbias + global_index); + x128 dweight128 = load128(dweight + global_index); + for (int x = 0; x < x128::size; x++) { + float s_db = dbias_shared[shared_index + x*WARP_SIZE]; + float s_dw = dweight_shared[shared_index + x*WARP_SIZE]; + dbias128[x] = (floatX)(s_db + (float)dbias128[x]); + dweight128[x] = (floatX)(s_dw + (float)dweight128[x]); + } + store128(dbias + global_index, dbias128); + store128(dweight + global_index, dweight128); + } + } +} + // ---------------------------------------------------------------------------- // kernel launchers @@ -947,6 +1126,18 @@ void layernorm_backward8(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* s layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } +template +void layernorm_backward9(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, + const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, + int B, int T, int C, int block_size) { + + const int grid_size = (1024/block_size) * cuda_num_SMs; // todo - heuristics for other GPUs? + size_t shared_mem_size = (2 * C + 2 * block_size + 1) * sizeof(float); + + cudaMemset(scratch, 0, 1 * sizeof(float)); // just need to memset the flag for this version + layernorm_backward_kernel9<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); +} + // kernel version dispatch void layernorm_backward(int kernel_num, floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, @@ -982,6 +1173,9 @@ void layernorm_backward(int kernel_num, case 8: layernorm_backward8(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; + case 9: + layernorm_backward9(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -1042,7 +1236,7 @@ int main(int argc, char **argv) { cudaCheck(cudaMalloc(&d_weight, C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(floatX))); cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(floatX))); - cudaCheck(cudaMalloc(&d_scratch, cuda_num_SMs * (2 * C + 1) * sizeof(float))); + cudaCheck(cudaMalloc(&d_scratch, (1024/32) * cuda_num_SMs * (2 * C + 1) * sizeof(float))); // copy over the "inputs" to the backward call cudaCheck(memcpy_convert(d_dout, dout, B * T * C)); cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); @@ -1051,7 +1245,8 @@ int main(int argc, char **argv) { cudaCheck(memcpy_convert(d_rstd, rstd, B * T)); // launch the kernel - int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024}; + // removed 768 because it doesn't work for kernel9 despite being OK in train_gpt2.cu?! + int block_sizes[] = {32, 64, 128, 256, 512, /*768,*/ 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; // init the "outputs" of the backward call to zeros diff --git a/train_gpt2.cu b/train_gpt2.cu index 6c60b8a74..31b6db2b7 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -985,10 +985,8 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { - constexpr int BLOCK_SIZE = 512; - constexpr int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block - extern __shared__ float shared[]; // size = 2 * C + 1 - + extern __shared__ float shared[]; // size = 2*C + 2*block_size + 1 + int warpsInBlock = blockDim.x / WARP_SIZE; //number of warps in block int warpId = threadIdx.x / WARP_SIZE; // warp index within a block int baseIdx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp @@ -1000,14 +998,14 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with float* dbias_shared = shared; float* dweight_shared = shared + C; float* dbias_tmp_shared = shared + 2 * C; - float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE; + float* dweight_tmp_shared = shared + 2 * C + blockDim.x; // init shared memory to zero - for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){ + for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } - unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE); + unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*blockDim.x); __syncthreads(); for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { @@ -1102,12 +1100,14 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with scratch += 32; float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; - for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) { + for(int i = threadIdx.x; i < C; i+= blockDim.x) { // Write to global memory in the same "shared memory banking friendly" order scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i]; scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i]; } + // todo - everything below could become a separate kernel for better performance with maybe less code + // not enough parallelism even inside that single SM... do we need another level of reduction?! __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicInc(scratchFlag, gridDim.x); @@ -1115,9 +1115,7 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with __syncthreads(); if (*tmp_flag == gridDim.x-1) { // Reduction of the partial sums by the final block - // todo - there isn't enough parallelism even inside that single SM... - // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! - for(int i = threadIdx.x * f128::size; i < C; i+= BLOCK_SIZE * f128::size) { + for(int i = threadIdx.x * f128::size; i < C; i+= blockDim.x * f128::size) { f128 dbias_accum(make_int4(0, 0, 0, 0)); f128 dweight_accum(make_int4(0, 0, 0, 0)); From d3f26951ef0a612c917217f992b49404fe8ae847 Mon Sep 17 00:00:00 2001 From: Ross Wheeler Date: Thu, 23 May 2024 00:06:18 -0700 Subject: [PATCH 09/17] Add glob() for windows Tested with tinyshakespeare and fineweb --- dataloader.h | 91 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 16 deletions(-) diff --git a/dataloader.h b/dataloader.h index cd38fe343..cd4caf47f 100644 --- a/dataloader.h +++ b/dataloader.h @@ -1,6 +1,8 @@ /* Implements a medium simple DataLoader for a distributed training setup. */ +#ifndef DATALOADER_H +#define DATALOADER_H #include #include @@ -13,30 +15,85 @@ Implements a medium simple DataLoader for a distributed training setup. #include "utils.h" // ---------------------------------------------------------------------------- -// we need glob to list files matching a pattern -// windows does not have glob, so we fall back on a very simple implementation -// this implementation doesn't actually do a glob, it assumes that the "pattern" -// is exactly the single file of interest +// implementation of glob for Windows #ifndef _WIN32 #include #else typedef struct glob_t { - size_t gl_pathc; - char **gl_pathv; + size_t gl_pathc; // Count of matched pathnames + char **gl_pathv; // List of matched pathnames } glob_t; -int glob(const char *pattern, int flags, void *unused, glob_t *pglob) { - assert(strstr(pattern, "*") == NULL); // we don't support * here - pglob->gl_pathc = 1; - pglob->gl_pathv = (char **)malloc(sizeof(char *)); - if (pglob->gl_pathv == NULL) { exit(EXIT_FAILURE); } // ??? oom? - pglob->gl_pathv[0] = (char *)pattern; - return 0; +void replace_forward_slashes(char* str) { + while (*str) { + if (*str == '/') { + *str = '\\'; + } + str++; + } } -void globfree(glob_t* pglob) { - free(pglob->gl_pathv); +void globfree(glob_t *pglob) { + for (size_t i = 0; i < pglob->gl_pathc; ++i) { + free(pglob->gl_pathv[i]); // Free the allocated memory for each filename + } + free(pglob->gl_pathv); // Free the allocated memory for the list of filenames +} + +int glob(const char* pattern, int ignored_flags, int (*ignored_errfunc)(const char* epath, int eerrno), glob_t* pglob){ + struct _finddata_t find_file_data; + char full_path[576]; // stored in pglob->gl_pathv[n] + char directory_path[512] = {0}; // Store the directory path from the pattern + char pattern_copy[512]; // Copy of the pattern to modify + + strncpy_s(pattern_copy, sizeof(pattern_copy) - 1, pattern, sizeof(pattern_copy) - 1); + + replace_forward_slashes (pattern_copy); // Replace forward slashes with backslashes + + if (strchr(pattern_copy, '\\') != NULL) { + strncpy_s(directory_path, sizeof(directory_path) - 1, pattern_copy, strrchr(pattern_copy, '\\') - pattern_copy + 1); + directory_path[strrchr(pattern_copy, '\\') - pattern_copy + 1] = '\0'; + } + + // find the first file matching the pattern in the directory + intptr_t find_handle = _findfirst(pattern_copy, &find_file_data); + + if (find_handle == -1) { + return 1; // No files found + } + + size_t file_count = 0; + size_t max_files = 64000; // hard-coded limit for the number of files + + pglob->gl_pathv = (char **) malloc(max_files * sizeof(char*)); // freed in globfree + + if (pglob->gl_pathv == NULL) { + _findclose(find_handle); + return 2; // Memory allocation failed + } + + do { + if (file_count >= max_files) { + _findclose(find_handle); + return 2; // Too many files found + } + + snprintf(full_path, sizeof(full_path), "%s%s", directory_path, find_file_data.name); + + pglob->gl_pathv[file_count] = _strdup(full_path); // freed in globfree + + if (pglob->gl_pathv[file_count] == NULL) { + _findclose(find_handle); + return 2; // Memory allocation for filename failed + } + file_count++; + } while (_findnext(find_handle, &find_file_data) == 0); + + _findclose(find_handle); + + pglob->gl_pathc = file_count; + return 0; } #endif @@ -460,4 +517,6 @@ void evalloader_free(EvalLoader *loader) { free(loader->mask); free(loader->label); fcloseCheck(loader->eval_file); -} \ No newline at end of file +} + +#endif // DATALOADER_H \ No newline at end of file From 2a736cb9e280b911fdc468704a526fb43e740a27 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 24 May 2024 19:45:41 +0300 Subject: [PATCH 10/17] fix for large batch sizes --- train_gpt2.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 81f82f7b4..c7a9ef39f 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1264,7 +1264,10 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs, const float dloss, const int* targets, int B, int T, int V, int P) { - int idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data + // note: idx is small enough that it easily fits into 32 bit; + // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P) + // are done is 64 bit + long idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -2044,7 +2047,8 @@ void gpt2_build_from_random(GPT2 *model, int depth) { model->config.num_layers = depth; // follows GPT-2 sizes int channels, num_heads; - if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M) + if (depth == 6) { channels = 384; num_heads = 6; } // gpt2-tiny (30M) + else if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M) else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M) else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M) else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M) From df2e0dadd2e54394e8f95a4d31ab4bef875fb3d5 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 24 May 2024 19:59:53 +0300 Subject: [PATCH 11/17] int64_t --- train_gpt2.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index c7a9ef39f..77f2e6eb4 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1210,7 +1210,7 @@ struct SoftmaxParams { float Offset; }; -__device__ SoftmaxParams prepare_softmax_blockwide3(int idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) @@ -1267,7 +1267,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // note: idx is small enough that it easily fits into 32 bit; // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P) // are done is 64 bit - long idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data + int64_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) From 1b98637960581d0179188f57e645b9e8737b266d Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Fri, 24 May 2024 20:11:34 +0300 Subject: [PATCH 12/17] int -> int64_t --- dev/cuda/classifier_fused.cu | 62 +++++++++++++++++------------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/dev/cuda/classifier_fused.cu b/dev/cuda/classifier_fused.cu index 092de5955..2125b874d 100644 --- a/dev/cuda/classifier_fused.cu +++ b/dev/cuda/classifier_fused.cu @@ -38,7 +38,7 @@ typedef Packed128 x128; void softmax_forward_cpu(float* out, const float* inp, int N, int C) { // inp is (N, C) // out is (N, C), each row of inp will get softmaxed - for (int i = 0; i < N; i++) { + for (int64_t i = 0; i < N; i++) { const float* inp_row = inp + i * C; float* out_row = out + i * C; @@ -66,13 +66,11 @@ void crossentropy_forward_cpu(float* losses, // output: losses is (B,T) of the individual losses at each position // input: probs are (B,T,V) of the probabilities // input: targets is (B,T) of integers giving the correct index in logits - for (int b = 0; b < B; b++) { - for (int t = 0; t < T; t++) { - // loss = -log(probs[target]) - const float* probs_bt = probs + b * T * V + t * V; - int ix = targets[b * T + t]; - losses[b * T + t] = -logf(probs_bt[ix]); - } + for (int64_t bt = 0; bt < B * T; bt++) { + // loss = -log(probs[target]) + const float* probs_bt = probs + bt * V; + int ix = targets[bt]; + losses[bt] = -logf(probs_bt[ix]); } } @@ -80,17 +78,15 @@ void crossentropy_softmax_backward_cpu(float* dlogits, const float* dlosses, const float* probs, const int* targets, int B, int T, int V) { // backwards through both softmax and crossentropy - for (int b = 0; b < B; b++) { - for (int t = 0; t < T; t++) { - float* dlogits_bt = dlogits + b * T * V + t * V; - const float* probs_bt = probs + b * T * V + t * V; - float dloss = dlosses[b * T + t]; - int ix = targets[b * T + t]; - for (int i = 0; i < V; i++) { - float p = probs_bt[i]; - float indicator = i == ix ? 1.0f : 0.0f; - dlogits_bt[i] = (p - indicator) * dloss; - } + for (int64_t bt = 0; bt < B * T; bt++) { + float* dlogits_bt = dlogits + bt * V; + const float* probs_bt = probs + bt * V; + float dloss = dlosses[bt]; + int ix = targets[bt]; + for (int i = 0; i < V; i++) { + float p = probs_bt[i]; + float indicator = i == ix ? 1.0f : 0.0f; + dlogits_bt[i] = (p - indicator) * dloss; } } } @@ -115,7 +111,7 @@ struct SoftmaxParams { }; namespace cg = cooperative_groups; __device__ SoftmaxParams prepare_softmax(cg::thread_block_tile<32>& warp, - int idx, const float* inp, int V, int P) { + int64_t idx, const float* inp, int V, int P) { // this warp (of 32) threads processes one row of inp, i.e. inp[idx, :] of shape (V,) // note that inp is actually (B * T, P) but we only use the first V elements // this function tehen calculates: @@ -155,7 +151,7 @@ __global__ void fused_classifier_kernel1(float* dlogits, float* losses, // each block of 4 warps is in charge of 4 rows of the input, one warp per row // meta_group_size is the number of warps per block (e.g. 4) // meta_group_rank is the index of the warp in the block (e.g. 0, 1, 2, 3) - int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + int64_t idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if (idx >= B * T) { // there are B * T rows in the input return; } @@ -192,7 +188,7 @@ __device__ float vec_at(const float4& vec, int index) { } __device__ SoftmaxParams prepare_softmax_blockwide(cg::thread_block_tile<32>& warp, - int idx, const float* inp, int V, int P) { + int64_t idx, const float* inp, int V, int P) { // one row of inp, i.e. inp[idx, :] of shape (V,) // float4 to get 128-bit loads and memory level parallelism const float4* x_vec4 = reinterpret_cast(inp + idx * P); @@ -256,7 +252,7 @@ __global__ void fused_classifier_kernel2(float* dlogits, float* losses, float* p namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); - int idx = blockIdx.x; + int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -297,7 +293,7 @@ __global__ void fused_classifier_kernel2(float* dlogits, float* losses, float* p } __device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp, - int idx, const float* inp, int V, int P) { + int64_t idx, const float* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) @@ -353,7 +349,7 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); - int idx = blockIdx.x; + int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -385,7 +381,7 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p } } -__device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide2(int64_t idx, const floatX* inp, int V, int P) { // one row of inp, i.e. inp[idx, :] of shape (V,) const floatX* x = inp + idx * P; @@ -443,7 +439,7 @@ __device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp, __global__ void fused_classifier_kernel4(floatX* dlogits, floatX* losses, floatX* probs, const floatX* logits, const floatX* dlosses, const int* targets, int B, int T, int V, int P) { - int idx = blockIdx.x; + int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -512,7 +508,7 @@ __device__ float blockReduce(float val, bool final_sync=false, float out_of_boun return block_val; } -__device__ SoftmaxParams prepare_softmax_blockwide3(int idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) @@ -566,7 +562,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) fused_classifier_kernel5(floatX* dlogits, floatX* losses, floatX* probs, const floatX* logits, const floatX* dlosses, const int* targets, int B, int T, int V, int P) { - int idx = blockIdx.x; + int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -702,10 +698,10 @@ void fused_classifier(int kernel_num, float* dlogits, float* losses, int main(int argc, char **argv) { srand(0); - int B = 8; // batch size - int T = 1024; // sequence length - int V = 50257; // vocab size - int P = (V + 63) & ~63; // padded vocab size, up to nearest multiple of 64 + int64_t B = 8; // batch size + int64_t T = 1024; // sequence length + int64_t V = 50257; // vocab size + int64_t P = (V + 63) & ~63; // padded vocab size, up to nearest multiple of 64 int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); From 25f17e6748d806b0de48ee2aadcb6e47a4be0449 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 24 May 2024 21:53:00 +0000 Subject: [PATCH 13/17] small formatting fix before merge --- train_gpt2.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 3530edf8a..936bfa8fd 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -986,10 +986,10 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s } __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? - layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, - const floatX* mean, const floatX* rstd, - int B, int T, int C) { + layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const floatX* mean, const floatX* rstd, + int B, int T, int C) { extern __shared__ float shared[]; // size = 2*C + 2*block_size + 1 int warpsInBlock = blockDim.x / WARP_SIZE; //number of warps in block int warpId = threadIdx.x / WARP_SIZE; // warp index within a block From 3221e4b2d29e90aa78a67f457dcb57143009b94d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 24 May 2024 23:10:17 +0000 Subject: [PATCH 14/17] small cosmetic changes --- train_gpt2.c | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/train_gpt2.c b/train_gpt2.c index 775b2b8d8..b01abf09f 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -160,13 +160,13 @@ void layernorm_backward(float* dinp, float* dweight, float* dbias, } } -void matmul_forward_slow(float* out, +void matmul_forward_naive(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC) { - // basic implementation of matrix multiplication. This serves as a fallback - // for bad input shapes, and as an illustration for the most basic version - // of the algorithm. -#pragma omp parallel for collapse(2) + // the most naive implementation of matrix multiplication + // this serves as an algorithmic reference, and as a fallback for + // unfriendly input shapes inside matmul_forward(), below. + #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { int bt = b * T + t; @@ -185,42 +185,42 @@ void matmul_forward(float* out, const float* inp, const float* weight, const float* bias, int B, int T, int C, int OC) { // most of the running time is spent here and in matmul_backward + // therefore, the implementation below is very mildly optimized + // this function is otherwise identical to that of matmul_forward_naive() // OC is short for "output channels" // inp is (B,T,C), weight is (OC, C), bias is (OC) // out will be (B,T,OC) - // make sure the tiled loop will be correct, otherwise, fallback to slow version + // make sure the tiled loop will be correct or fallback to naive version const int LOOP_UNROLL = 8; if (B*T % LOOP_UNROLL != 0) { - matmul_forward_slow(out, inp, weight, bias, B, T, C, OC); + matmul_forward_naive(out, inp, weight, bias, B, T, C, OC); return; } // collapse the B and T loops into one and turn it into a strided loop. // then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times - // for significant speed-ups. #pragma omp parallel for for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) { for (int o = 0; o < OC; o++) { - // keep LOOP_UNROLL many results in register, initialized by the bias term. + // we'll keep LOOP_UNROLL many results in registers float result[LOOP_UNROLL]; - for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { + // initialize the bias, if it exists + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { result[ibt] = (bias != NULL) ? bias[o] : 0.0f; } - // inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache // the value of weight[i + o * C] and reuse it. - // we compile with -Ofast, so the compiler will turn the inner loop into a bunch of FMAs + // we compile with -Ofast, so the compiler will turn the inner loop into FMAs for (int i = 0; i < C; i++) { float w = weight[i + o * C]; - for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { int bt = obt + ibt; result[ibt] += inp[bt * C + i] * w; } } - // write back results to main memory - for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) { + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { int bt = obt + ibt; out[bt * OC + o] = result[ibt]; } From e5083be900997bc13d9d3c6894583f8159d1c65e Mon Sep 17 00:00:00 2001 From: Ross Wheeler Date: Fri, 24 May 2024 16:56:52 -0700 Subject: [PATCH 15/17] Moved windows glob() over to dev/unistd.h Added header guard and changed long->int64_t in dataloader.h --- dataloader.h | 106 +++++++-------------------------------------------- dev/unistd.h | 82 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 95 insertions(+), 93 deletions(-) diff --git a/dataloader.h b/dataloader.h index cd4caf47f..6b63c34a1 100644 --- a/dataloader.h +++ b/dataloader.h @@ -15,88 +15,10 @@ Implements a medium simple DataLoader for a distributed training setup. #include "utils.h" // ---------------------------------------------------------------------------- -// implementation of glob for Windows +// implementation of glob for Windows is in dev/unistd.h #ifndef _WIN32 #include -#else - -typedef struct glob_t { - size_t gl_pathc; // Count of matched pathnames - char **gl_pathv; // List of matched pathnames -} glob_t; - -void replace_forward_slashes(char* str) { - while (*str) { - if (*str == '/') { - *str = '\\'; - } - str++; - } -} - -void globfree(glob_t *pglob) { - for (size_t i = 0; i < pglob->gl_pathc; ++i) { - free(pglob->gl_pathv[i]); // Free the allocated memory for each filename - } - free(pglob->gl_pathv); // Free the allocated memory for the list of filenames -} - -int glob(const char* pattern, int ignored_flags, int (*ignored_errfunc)(const char* epath, int eerrno), glob_t* pglob){ - struct _finddata_t find_file_data; - char full_path[576]; // stored in pglob->gl_pathv[n] - char directory_path[512] = {0}; // Store the directory path from the pattern - char pattern_copy[512]; // Copy of the pattern to modify - - strncpy_s(pattern_copy, sizeof(pattern_copy) - 1, pattern, sizeof(pattern_copy) - 1); - - replace_forward_slashes (pattern_copy); // Replace forward slashes with backslashes - - if (strchr(pattern_copy, '\\') != NULL) { - strncpy_s(directory_path, sizeof(directory_path) - 1, pattern_copy, strrchr(pattern_copy, '\\') - pattern_copy + 1); - directory_path[strrchr(pattern_copy, '\\') - pattern_copy + 1] = '\0'; - } - - // find the first file matching the pattern in the directory - intptr_t find_handle = _findfirst(pattern_copy, &find_file_data); - - if (find_handle == -1) { - return 1; // No files found - } - - size_t file_count = 0; - size_t max_files = 64000; // hard-coded limit for the number of files - - pglob->gl_pathv = (char **) malloc(max_files * sizeof(char*)); // freed in globfree - - if (pglob->gl_pathv == NULL) { - _findclose(find_handle); - return 2; // Memory allocation failed - } - - do { - if (file_count >= max_files) { - _findclose(find_handle); - return 2; // Too many files found - } - - snprintf(full_path, sizeof(full_path), "%s%s", directory_path, find_file_data.name); - - pglob->gl_pathv[file_count] = _strdup(full_path); // freed in globfree - - if (pglob->gl_pathv[file_count] == NULL) { - _findclose(find_handle); - return 2; // Memory allocation for filename failed - } - file_count++; - } while (_findnext(find_handle, &find_file_data) == 0); - - _findclose(find_handle); - - pglob->gl_pathc = file_count; - return 0; -} #endif - // ---------------------------------------------------------------------------- // Distributed Data Loader #define HEADER_SIZE 256 @@ -113,8 +35,8 @@ typedef struct { glob_t glob_result; // stores the result of glob, for all shards we want to iterate int current_shard; // the current shard we are reading from FILE* tokens_file; - long file_size; - long current_position; + int64_t file_size; + int64_t current_position; uint16_t* buffer; // we fread data from file into this buffer // public variables that could be accessed from outside size_t num_batches; @@ -122,7 +44,7 @@ typedef struct { int* targets; // target tokens for the transformer } DataLoader; -long dataloader_load_shard_(DataLoader *loader, int shard_index) { +int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) { // use the first glob match as the filename for now const char* filename = loader->glob_result.gl_pathv[shard_index]; // open the input file for reading. also only a single file can be opened at a time @@ -140,14 +62,14 @@ long dataloader_load_shard_(DataLoader *loader, int shard_index) { exit(EXIT_FAILURE); } if (header[1] != 1) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); } - long ntok = header[2]; // number of tokens in the file + int64_t ntok = header[2]; // number of tokens in the file assert(ntok > 0); // we expect some tokens in the file. this should never trip, right? // determine the file size and make sure it is consistent with the number of tokens fseekCheck(loader->tokens_file, 0, SEEK_END); // seek to end of file loader->file_size = ftell(loader->tokens_file); // read the offset, i.e. file size fseekCheck(loader->tokens_file, 0, SEEK_SET); // seek back to the beginning // we expect ntok in the file to be consistent with filesize, assert that is the case - long expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t); + int64_t expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t); if (loader->file_size != expected_file_size) { printf("Error: file size is not as expected\n"); exit(EXIT_FAILURE); @@ -158,8 +80,8 @@ long dataloader_load_shard_(DataLoader *loader, int shard_index) { void dataloader_reset(DataLoader *loader) { // fully resets the DataLoader object to init configuration // each process starts at a different offset in the file - long header_bytes = HEADER_SIZE * sizeof(int); - long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); + int64_t header_bytes = HEADER_SIZE * sizeof(int); + int64_t token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); loader->current_shard = 0; loader->current_position = header_bytes + token_bytes_offset; dataloader_load_shard_(loader, loader->current_shard); @@ -172,8 +94,8 @@ void dataloader_advance_(DataLoader *loader) { loader->current_shard = (loader->current_shard + 1) % loader->glob_result.gl_pathc; dataloader_load_shard_(loader, loader->current_shard); } - long header_bytes = HEADER_SIZE * sizeof(int); - long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); + int64_t header_bytes = HEADER_SIZE * sizeof(int); + int64_t token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); loader->current_position = header_bytes + token_bytes_offset; } @@ -202,9 +124,9 @@ void dataloader_init(DataLoader *loader, // inspect and validate all shards so we don't get any runtime errors later // if too slow / too many shards, may wish to revisit later - long ntok_total = 0; + int64_t ntok_total = 0; for (int shard_index = 0; shard_index < loader->glob_result.gl_pathc; shard_index++) { - long shard_ntok = dataloader_load_shard_(loader, shard_index); + int64_t shard_ntok = dataloader_load_shard_(loader, shard_index); // we need at least one batch/shard, the way things are written right now. // can be relaxed a lot later. assert(shard_ntok >= num_processes * B * T + 1); @@ -286,7 +208,7 @@ typedef struct { size_t T; // maximum context length of the model // input handling and its state FILE* eval_file; - long file_size; + int64_t file_size; uint16_t* buffer; // we fread data from file into this buffer // public variables that could be accessed from outside int num_examples; // in total across all processes @@ -318,7 +240,7 @@ void evalloader_reset(EvalLoader *loader) { } // now seek through the file to the start of that example // utilize for efficiency - long header_bytes = HEADER_SIZE * sizeof(int); + int64_t header_bytes = HEADER_SIZE * sizeof(int); fseekCheck(loader->eval_file, header_bytes, SEEK_SET); for (int i = 0; i < loader->start_example_index; i++) { uint16_t example_header[3]; diff --git a/dev/unistd.h b/dev/unistd.h index 18efc2206..348bbae0a 100644 --- a/dev/unistd.h +++ b/dev/unistd.h @@ -5,12 +5,14 @@ #define _CRT_SECURE_NO_WARNINGS #define _USE_MATH_DEFINES +#include + #include //#define gen_max_length 64 // compile as C++ to skip this VLA issue #include #define CLOCK_MONOTONIC 0 -int clock_gettime(int ignore_variable, struct timespec* tv) +static inline int clock_gettime(int ignore_variable, struct timespec* tv) { return timespec_get(tv, TIME_UTC); // TODO: not sure this is the best solution. Need to review. } @@ -23,4 +25,82 @@ int clock_gettime(int ignore_variable, struct timespec* tv) #define TURN_OFF_FP_FAST __pragma(float_control( precise, on, push )) // Save current setting and turn on /fp:precise #define TURN_ON_FP_FAST __pragma(float_control(pop)) // Restore file's default settings +#define _mkdir mkdir // add mkdir into namespace for windows + +typedef struct glob_t { + size_t gl_pathc; // Count of matched pathnames + char **gl_pathv; // List of matched pathnames +} glob_t; + +static inline void replace_forward_slashes(char* str) { + while (*str) { + if (*str == '/') { + *str = '\\'; + } + str++; + } +} + +static inline void globfree(glob_t *pglob) { + for (size_t i = 0; i < pglob->gl_pathc; ++i) { + free(pglob->gl_pathv[i]); // Free the allocated memory for each filename + } + free(pglob->gl_pathv); // Free the allocated memory for the list of filenames +} + +static inline int glob(const char* pattern, int ignored_flags, int (*ignored_errfunc)(const char* epath, int eerrno), glob_t* pglob){ + struct _finddata_t find_file_data; + char full_path[576]; // stored in pglob->gl_pathv[n] + char directory_path[512] = {0}; // Store the directory path from the pattern + char pattern_copy[512]; // Copy of the pattern to modify + + strncpy_s(pattern_copy, sizeof(pattern_copy) - 1, pattern, sizeof(pattern_copy) - 1); + + replace_forward_slashes (pattern_copy); // Replace forward slashes with backslashes + + if (strchr(pattern_copy, '\\') != NULL) { + strncpy_s(directory_path, sizeof(directory_path) - 1, pattern_copy, strrchr(pattern_copy, '\\') - pattern_copy + 1); + directory_path[strrchr(pattern_copy, '\\') - pattern_copy + 1] = '\0'; + } + + // find the first file matching the pattern in the directory + intptr_t find_handle = _findfirst(pattern_copy, &find_file_data); + + if (find_handle == -1) { + return 1; // No files found + } + + size_t file_count = 0; + size_t max_files = 64000; // hard-coded limit for the number of files + + pglob->gl_pathv = (char **) malloc(max_files * sizeof(char*)); // freed in globfree + + if (pglob->gl_pathv == NULL) { + _findclose(find_handle); + return 2; // Memory allocation failed + } + + do { + if (file_count >= max_files) { + _findclose(find_handle); + return 2; // Too many files found + } + + snprintf(full_path, sizeof(full_path), "%s%s", directory_path, find_file_data.name); + + pglob->gl_pathv[file_count] = _strdup(full_path); // freed in globfree + + if (pglob->gl_pathv[file_count] == NULL) { + _findclose(find_handle); + return 2; // Memory allocation for filename failed + } + file_count++; + } while (_findnext(find_handle, &find_file_data) == 0); + + _findclose(find_handle); + + pglob->gl_pathc = file_count; + return 0; +} + #endif From 79738d2ca4786eef178a82eec62249d976c3b86b Mon Sep 17 00:00:00 2001 From: Ross Wheeler Date: Fri, 24 May 2024 17:05:59 -0700 Subject: [PATCH 16/17] fixed mkdir change --- dev/unistd.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/unistd.h b/dev/unistd.h index 348bbae0a..861041d46 100644 --- a/dev/unistd.h +++ b/dev/unistd.h @@ -25,7 +25,7 @@ static inline int clock_gettime(int ignore_variable, struct timespec* tv) #define TURN_OFF_FP_FAST __pragma(float_control( precise, on, push )) // Save current setting and turn on /fp:precise #define TURN_ON_FP_FAST __pragma(float_control(pop)) // Restore file's default settings -#define _mkdir mkdir // add mkdir into namespace for windows +#define mkdir _mkdir // add mkdir into namespace for windows typedef struct glob_t { size_t gl_pathc; // Count of matched pathnames From 9f08882051426ac5d0fe30db283b195c20dad34b Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 25 May 2024 00:14:10 +0000 Subject: [PATCH 17/17] add weight decay, but only for 2D tensors, as done in GPT series and in general too. this forces us to break up our adamw kernel again into one call per tensor, so there is a small throughput hit, of about 0.5% for me. but we have to break up this kernel in near future anyway --- train_gpt2.cu | 30 +++++++++++++++++++++++++----- train_gpt2.py | 35 ++++++++++++++++++++++++++++++++--- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 26d29927b..3e8964690 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -2711,14 +2711,34 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl // AdamW update int block_size = 512; - int num_blocks = CEIL_DIV(num_parameters, block_size); float beta1_correction = 1.0f - powf(beta1, t); float beta2_correction = 1.0f - powf(beta2, t); unsigned int seed = random_u32(&model->rng_state); - adamw_kernel3<<>>(params_memory, model->master_weights, grads_memory, - model->m_memory, model->v_memory, num_parameters, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, - grad_scale, seed); + + // individually call the adamw_kernel3 on all parameter tensors separately + floatX* params_memory_iter = params_memory; + float* master_weights_iter = model->master_weights; + floatX* grads_memory_iter = grads_memory; + float* m_memory_iter = (float*)model->m_memory; + float* v_memory_iter = (float*)model->v_memory; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + size_t num_parameters = model->param_elements[i]; + int num_blocks = CEIL_DIV(num_parameters, block_size); + // we only want to weight decay the 2D tensors and leave all 1D tensors alone + // in particular this also decays the embedding weights, but this is ok: + // - the token embeddings are weight shared and participate in the final projection to logits + // - the position embeddings actively participate at every forward/backward pass + float wd = (i == 0 || i == 1 || i == 4 || i == 6 || i == 10 || i == 12) ? weight_decay : 0.0f; + adamw_kernel3<<>>(params_memory_iter, master_weights_iter, grads_memory_iter, + m_memory_iter, v_memory_iter, num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, wd, + grad_scale, seed); + params_memory_iter += num_parameters; + if (master_weights_iter != NULL) { master_weights_iter += num_parameters; } + grads_memory_iter += num_parameters; + m_memory_iter += num_parameters; + v_memory_iter += num_parameters; + } cudaCheck(cudaGetLastError()); return grad_norm_cpu; } diff --git a/train_gpt2.py b/train_gpt2.py index d32ab5042..f2b0ae302 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -19,6 +19,7 @@ import os import math import struct +import inspect from contextlib import nullcontext from dataclasses import dataclass @@ -228,6 +229,31 @@ def from_pretrained(cls, model_type): return model + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {'params': decay_params, 'weight_decay': weight_decay}, + {'params': nodecay_params, 'weight_decay': 0.0} + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == 'cuda' + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + print(f"using fused AdamW: {use_fused}") + return optimizer + @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ @@ -429,6 +455,7 @@ def print0(*args, **kwargs): parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") parser.add_argument("--total_batch_size", type=int, default=256, help="total desired batch size, in units of #tokens") parser.add_argument("--grad_clip", type=float, default=1.0, help="maximum gradient magnitude") + parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay") parser.add_argument("--overfit_single_batch", type=int, default=1, help="overfit just one batch of data") args = parser.parse_args() B, T = args.batch_size, args.sequence_length @@ -467,6 +494,7 @@ def print0(*args, **kwargs): elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" print(f"using device: {device}") + device_type = 'cuda' if 'cuda' in device else 'cpu' # calculate gradient accumulation from the desired total batch size and the current run configuration tokens_per_fwdbwd = B * T * ddp_world_size @@ -478,7 +506,7 @@ def print0(*args, **kwargs): # set up a context manager following the desired dtype and device ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] - ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) if device == "cuda" else nullcontext() + ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() # seed the random number generators (in DDP we want different processes to use different offsets) # in the code below we don't actually use random numbers because there is no active dataloader @@ -604,8 +632,9 @@ def get_batch(): raw_model = model.module if ddp else model # always contains the "raw" unwrapped model # init the optimizer - adam_use_fused = device == "cuda" # only works on CUDA (?) - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95), weight_decay=0.0, fused=adam_use_fused) + optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, + learning_rate=1e-4, betas=(0.9, 0.95), + device_type=device) if device == "cuda": torch.cuda.reset_peak_memory_stats()