diff --git a/dataloader.h b/dataloader.h index cd38fe343..6b63c34a1 100644 --- a/dataloader.h +++ b/dataloader.h @@ -1,6 +1,8 @@ /* Implements a medium simple DataLoader for a distributed training setup. */ +#ifndef DATALOADER_H +#define DATALOADER_H #include #include @@ -13,33 +15,10 @@ Implements a medium simple DataLoader for a distributed training setup. #include "utils.h" // ---------------------------------------------------------------------------- -// we need glob to list files matching a pattern -// windows does not have glob, so we fall back on a very simple implementation -// this implementation doesn't actually do a glob, it assumes that the "pattern" -// is exactly the single file of interest +// implementation of glob for Windows is in dev/unistd.h #ifndef _WIN32 #include -#else - -typedef struct glob_t { - size_t gl_pathc; - char **gl_pathv; -} glob_t; - -int glob(const char *pattern, int flags, void *unused, glob_t *pglob) { - assert(strstr(pattern, "*") == NULL); // we don't support * here - pglob->gl_pathc = 1; - pglob->gl_pathv = (char **)malloc(sizeof(char *)); - if (pglob->gl_pathv == NULL) { exit(EXIT_FAILURE); } // ??? oom? - pglob->gl_pathv[0] = (char *)pattern; - return 0; -} - -void globfree(glob_t* pglob) { - free(pglob->gl_pathv); -} #endif - // ---------------------------------------------------------------------------- // Distributed Data Loader #define HEADER_SIZE 256 @@ -56,8 +35,8 @@ typedef struct { glob_t glob_result; // stores the result of glob, for all shards we want to iterate int current_shard; // the current shard we are reading from FILE* tokens_file; - long file_size; - long current_position; + int64_t file_size; + int64_t current_position; uint16_t* buffer; // we fread data from file into this buffer // public variables that could be accessed from outside size_t num_batches; @@ -65,7 +44,7 @@ typedef struct { int* targets; // target tokens for the transformer } DataLoader; -long dataloader_load_shard_(DataLoader *loader, int shard_index) { +int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) { // use the first glob match as the filename for now const char* filename = loader->glob_result.gl_pathv[shard_index]; // open the input file for reading. also only a single file can be opened at a time @@ -83,14 +62,14 @@ long dataloader_load_shard_(DataLoader *loader, int shard_index) { exit(EXIT_FAILURE); } if (header[1] != 1) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); } - long ntok = header[2]; // number of tokens in the file + int64_t ntok = header[2]; // number of tokens in the file assert(ntok > 0); // we expect some tokens in the file. this should never trip, right? // determine the file size and make sure it is consistent with the number of tokens fseekCheck(loader->tokens_file, 0, SEEK_END); // seek to end of file loader->file_size = ftell(loader->tokens_file); // read the offset, i.e. file size fseekCheck(loader->tokens_file, 0, SEEK_SET); // seek back to the beginning // we expect ntok in the file to be consistent with filesize, assert that is the case - long expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t); + int64_t expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t); if (loader->file_size != expected_file_size) { printf("Error: file size is not as expected\n"); exit(EXIT_FAILURE); @@ -101,8 +80,8 @@ long dataloader_load_shard_(DataLoader *loader, int shard_index) { void dataloader_reset(DataLoader *loader) { // fully resets the DataLoader object to init configuration // each process starts at a different offset in the file - long header_bytes = HEADER_SIZE * sizeof(int); - long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); + int64_t header_bytes = HEADER_SIZE * sizeof(int); + int64_t token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); loader->current_shard = 0; loader->current_position = header_bytes + token_bytes_offset; dataloader_load_shard_(loader, loader->current_shard); @@ -115,8 +94,8 @@ void dataloader_advance_(DataLoader *loader) { loader->current_shard = (loader->current_shard + 1) % loader->glob_result.gl_pathc; dataloader_load_shard_(loader, loader->current_shard); } - long header_bytes = HEADER_SIZE * sizeof(int); - long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); + int64_t header_bytes = HEADER_SIZE * sizeof(int); + int64_t token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t); loader->current_position = header_bytes + token_bytes_offset; } @@ -145,9 +124,9 @@ void dataloader_init(DataLoader *loader, // inspect and validate all shards so we don't get any runtime errors later // if too slow / too many shards, may wish to revisit later - long ntok_total = 0; + int64_t ntok_total = 0; for (int shard_index = 0; shard_index < loader->glob_result.gl_pathc; shard_index++) { - long shard_ntok = dataloader_load_shard_(loader, shard_index); + int64_t shard_ntok = dataloader_load_shard_(loader, shard_index); // we need at least one batch/shard, the way things are written right now. // can be relaxed a lot later. assert(shard_ntok >= num_processes * B * T + 1); @@ -229,7 +208,7 @@ typedef struct { size_t T; // maximum context length of the model // input handling and its state FILE* eval_file; - long file_size; + int64_t file_size; uint16_t* buffer; // we fread data from file into this buffer // public variables that could be accessed from outside int num_examples; // in total across all processes @@ -261,7 +240,7 @@ void evalloader_reset(EvalLoader *loader) { } // now seek through the file to the start of that example // utilize for efficiency - long header_bytes = HEADER_SIZE * sizeof(int); + int64_t header_bytes = HEADER_SIZE * sizeof(int); fseekCheck(loader->eval_file, header_bytes, SEEK_SET); for (int i = 0; i < loader->start_example_index; i++) { uint16_t example_header[3]; @@ -460,4 +439,6 @@ void evalloader_free(EvalLoader *loader) { free(loader->mask); free(loader->label); fcloseCheck(loader->eval_file); -} \ No newline at end of file +} + +#endif // DATALOADER_H \ No newline at end of file diff --git a/dev/cuda/classifier_fused.cu b/dev/cuda/classifier_fused.cu index 092de5955..2125b874d 100644 --- a/dev/cuda/classifier_fused.cu +++ b/dev/cuda/classifier_fused.cu @@ -38,7 +38,7 @@ typedef Packed128 x128; void softmax_forward_cpu(float* out, const float* inp, int N, int C) { // inp is (N, C) // out is (N, C), each row of inp will get softmaxed - for (int i = 0; i < N; i++) { + for (int64_t i = 0; i < N; i++) { const float* inp_row = inp + i * C; float* out_row = out + i * C; @@ -66,13 +66,11 @@ void crossentropy_forward_cpu(float* losses, // output: losses is (B,T) of the individual losses at each position // input: probs are (B,T,V) of the probabilities // input: targets is (B,T) of integers giving the correct index in logits - for (int b = 0; b < B; b++) { - for (int t = 0; t < T; t++) { - // loss = -log(probs[target]) - const float* probs_bt = probs + b * T * V + t * V; - int ix = targets[b * T + t]; - losses[b * T + t] = -logf(probs_bt[ix]); - } + for (int64_t bt = 0; bt < B * T; bt++) { + // loss = -log(probs[target]) + const float* probs_bt = probs + bt * V; + int ix = targets[bt]; + losses[bt] = -logf(probs_bt[ix]); } } @@ -80,17 +78,15 @@ void crossentropy_softmax_backward_cpu(float* dlogits, const float* dlosses, const float* probs, const int* targets, int B, int T, int V) { // backwards through both softmax and crossentropy - for (int b = 0; b < B; b++) { - for (int t = 0; t < T; t++) { - float* dlogits_bt = dlogits + b * T * V + t * V; - const float* probs_bt = probs + b * T * V + t * V; - float dloss = dlosses[b * T + t]; - int ix = targets[b * T + t]; - for (int i = 0; i < V; i++) { - float p = probs_bt[i]; - float indicator = i == ix ? 1.0f : 0.0f; - dlogits_bt[i] = (p - indicator) * dloss; - } + for (int64_t bt = 0; bt < B * T; bt++) { + float* dlogits_bt = dlogits + bt * V; + const float* probs_bt = probs + bt * V; + float dloss = dlosses[bt]; + int ix = targets[bt]; + for (int i = 0; i < V; i++) { + float p = probs_bt[i]; + float indicator = i == ix ? 1.0f : 0.0f; + dlogits_bt[i] = (p - indicator) * dloss; } } } @@ -115,7 +111,7 @@ struct SoftmaxParams { }; namespace cg = cooperative_groups; __device__ SoftmaxParams prepare_softmax(cg::thread_block_tile<32>& warp, - int idx, const float* inp, int V, int P) { + int64_t idx, const float* inp, int V, int P) { // this warp (of 32) threads processes one row of inp, i.e. inp[idx, :] of shape (V,) // note that inp is actually (B * T, P) but we only use the first V elements // this function tehen calculates: @@ -155,7 +151,7 @@ __global__ void fused_classifier_kernel1(float* dlogits, float* losses, // each block of 4 warps is in charge of 4 rows of the input, one warp per row // meta_group_size is the number of warps per block (e.g. 4) // meta_group_rank is the index of the warp in the block (e.g. 0, 1, 2, 3) - int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + int64_t idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); if (idx >= B * T) { // there are B * T rows in the input return; } @@ -192,7 +188,7 @@ __device__ float vec_at(const float4& vec, int index) { } __device__ SoftmaxParams prepare_softmax_blockwide(cg::thread_block_tile<32>& warp, - int idx, const float* inp, int V, int P) { + int64_t idx, const float* inp, int V, int P) { // one row of inp, i.e. inp[idx, :] of shape (V,) // float4 to get 128-bit loads and memory level parallelism const float4* x_vec4 = reinterpret_cast(inp + idx * P); @@ -256,7 +252,7 @@ __global__ void fused_classifier_kernel2(float* dlogits, float* losses, float* p namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); - int idx = blockIdx.x; + int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -297,7 +293,7 @@ __global__ void fused_classifier_kernel2(float* dlogits, float* losses, float* p } __device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp, - int idx, const float* inp, int V, int P) { + int64_t idx, const float* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) @@ -353,7 +349,7 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); - int idx = blockIdx.x; + int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -385,7 +381,7 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p } } -__device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide2(int64_t idx, const floatX* inp, int V, int P) { // one row of inp, i.e. inp[idx, :] of shape (V,) const floatX* x = inp + idx * P; @@ -443,7 +439,7 @@ __device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp, __global__ void fused_classifier_kernel4(floatX* dlogits, floatX* losses, floatX* probs, const floatX* logits, const floatX* dlosses, const int* targets, int B, int T, int V, int P) { - int idx = blockIdx.x; + int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -512,7 +508,7 @@ __device__ float blockReduce(float val, bool final_sync=false, float out_of_boun return block_val; } -__device__ SoftmaxParams prepare_softmax_blockwide3(int idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) @@ -566,7 +562,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) fused_classifier_kernel5(floatX* dlogits, floatX* losses, floatX* probs, const floatX* logits, const floatX* dlosses, const int* targets, int B, int T, int V, int P) { - int idx = blockIdx.x; + int64_t idx = blockIdx.x; int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -702,10 +698,10 @@ void fused_classifier(int kernel_num, float* dlogits, float* losses, int main(int argc, char **argv) { srand(0); - int B = 8; // batch size - int T = 1024; // sequence length - int V = 50257; // vocab size - int P = (V + 63) & ~63; // padded vocab size, up to nearest multiple of 64 + int64_t B = 8; // batch size + int64_t T = 1024; // sequence length + int64_t V = 50257; // vocab size + int64_t P = (V + 63) & ~63; // padded vocab size, up to nearest multiple of 64 int deviceIdx = 0; cudaCheck(cudaSetDevice(deviceIdx)); diff --git a/dev/cuda/layernorm_backward.cu b/dev/cuda/layernorm_backward.cu index 90dcb1674..d9502880b 100644 --- a/dev/cuda/layernorm_backward.cu +++ b/dev/cuda/layernorm_backward.cu @@ -856,6 +856,185 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) } } +__global__ void layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const floatX* mean, const floatX* rstd, + int B, int T, int C) { + constexpr int WARP_SIZE = 32; + int BLOCK_SIZE = blockDim.x; + int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block + extern __shared__ float shared[]; // size = 2 * C + 1 + + int warpId = threadIdx.x / WARP_SIZE; // warp index within a block + int baseIdx = blockIdx.x * warpsInBlock + warpId; + int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp + int warpsInGrid = gridDim.x * warpsInBlock; + int C_per_iteration = WARP_SIZE * x128::size; + int iterations_C = ceil_div(C, C_per_iteration) + 2; + + // the first half of shared memory is bias, second is weight + float* dbias_shared = shared; + float* dweight_shared = shared + C; + float* dbias_tmp_shared = shared + 2 * C; + float* dweight_tmp_shared = shared + 2 * C + BLOCK_SIZE; + + // init shared memory to zero + for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE){ + dbias_shared[i] = 0.0f; + dweight_shared[i] = 0.0f; + } + unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*BLOCK_SIZE); + __syncthreads(); + + for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { + int b = idx / T; + int t = idx % T; + + const floatX* dout_bt = dout + b * T * C + t * C; + const floatX* inp_bt = inp + b * T * C + t * C; + floatX* dinp_bt = dinp + b * T * C + t * C; + const float mean_bt = (float)mean[b * T + t]; + const float rstd_bt = (float)rstd[b * T + t]; + + // first: two reduce operations + float dnorm_mean = 0.0f; + float dnorm_norm_mean = 0.0f; + for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { + x128 dout128_i = load128(dout_bt + i); + x128 inp128_i = load128(inp_bt + i); + x128 weight128_i = load128(weight + i); + for (int k = 0; k < x128::size; k++) { + float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + dnorm_mean += dnorm_i; + dnorm_norm_mean += dnorm_i * norm_bti; + } + } + dnorm_mean = warpReduceSum(dnorm_mean) / C; + dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C; + + // now iterate again and accumulate all the gradients + // unfortunately we cannot use the same index for x128 arrays and shared memory + // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper) + // so this would result in an 8-way bank conflict, and kill performance + // so instead, we use a shared memory friendly index, and reorder before the final write + for (int i = 0; i < iterations_C; i++) { + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + + x128 dout128 = load128cs(dout_bt + global_index); + x128 inp128 = load128cs(inp_bt + global_index); + x128 dinp128 = load128(dinp_bt + global_index); + x128 weight128 = load128(weight + global_index); + + for (int x = 0; x < x128::size; x++) { + float dout_i = (float)dout128[x]; + float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128[x] * dout_i; + + // sum up the gradients for bias and weight across the entire block + // this is basically a reduction (but only inter-warp, not intra-warp) + // doing it this way allows us to avoid using atomics while using many warps + if (warpId != 0) { + dbias_tmp_shared[threadIdx.x] = dout_i; + dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i; + } + __syncthreads(); + if (warpId == 0) { + float dbias_tmp = dout_i; + float dweight_tmp = norm_bti * dout_i; + for (int j = 1; j < warpsInBlock; j++) { + dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE]; + dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE]; + } + // gradient contribution to bias (using shared memory friendly index) + dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp; + // gradient contribution to weight (using shared memory friendly index) + dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp; + } + __syncthreads(); + + // gradient contribution to input + float dval = 0.0f; + dval += dnorm_i; // term 1 + dval -= dnorm_mean; // term 2 + dval -= norm_bti * dnorm_norm_mean; // term 3 + dval *= rstd_bt; // final scale + dinp128[x] = (floatX)((float)dinp128[x] + dval); + } + // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing + store128cg(dinp_bt + global_index, dinp128); + } + } + __syncthreads(); + // Each block writes its partial sum to global memory + // The last block to finish becomes responsible for summing up all the partial sums + // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + unsigned int* scratchFlag = (unsigned int*)(scratch); + // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned + scratch += 32; + float* scratch_dbias = scratch; + float* scratch_dweight = scratch + C; + for(int i = threadIdx.x; i < C; i+= BLOCK_SIZE) { + // Write to global memory in the same "shared memory banking friendly" order + scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i]; + scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i]; + } + __syncthreads(); + if (threadIdx.x == 0) { + *tmp_flag = atomicInc(scratchFlag, gridDim.x); + } + __syncthreads(); + if (*tmp_flag == gridDim.x-1) { + // Reduction of the partial sums by the final block + // todo - there isn't enough parallelism even inside that single SM... + // ==> so could maybe split into another kernel with YET ANOTHER level of reduction?! + for(int i = threadIdx.x * f128::size; i < C; i+= BLOCK_SIZE * f128::size) { + f128 dbias_accum(make_int4(0, 0, 0, 0)); + f128 dweight_accum(make_int4(0, 0, 0, 0)); + + for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { + int offset = i + 2*C*read_block_idx; + f128 dbias128 = load128(scratch_dbias + offset); + f128 dweight128 = load128(scratch_dweight + offset); + for(int k = 0; k < f128::size; k++) { + dbias_accum[k] += dbias128[k]; + dweight_accum[k] += dweight128[k]; + } + } + store128(dbias_shared + i, dbias_accum); + store128(dweight_shared + i, dweight_accum); + } + __syncthreads(); + + // reorder from atomic/shared memory-friendly index to real global memory index + // and convert from float/FP32 to floatX/BF16 for the final write + // this is separate also because it cannot use as many warps as the above (f128 vs x128) + // todo - if we split this code into another kernel, we could maybe do it at the same time? + for (int i = warpId; i < iterations_C; i += warpsInBlock) { + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + + x128 dbias128 = load128(dbias + global_index); + x128 dweight128 = load128(dweight + global_index); + for (int x = 0; x < x128::size; x++) { + float s_db = dbias_shared[shared_index + x*WARP_SIZE]; + float s_dw = dweight_shared[shared_index + x*WARP_SIZE]; + dbias128[x] = (floatX)(s_db + (float)dbias128[x]); + dweight128[x] = (floatX)(s_dw + (float)dweight128[x]); + } + store128(dbias + global_index, dbias128); + store128(dweight + global_index, dweight128); + } + } +} + // ---------------------------------------------------------------------------- // kernel launchers @@ -947,6 +1126,18 @@ void layernorm_backward8(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* s layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } +template +void layernorm_backward9(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, + const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, + int B, int T, int C, int block_size) { + + const int grid_size = (1024/block_size) * cuda_num_SMs; // todo - heuristics for other GPUs? + size_t shared_mem_size = (2 * C + 2 * block_size + 1) * sizeof(float); + + cudaMemset(scratch, 0, 1 * sizeof(float)); // just need to memset the flag for this version + layernorm_backward_kernel9<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); +} + // kernel version dispatch void layernorm_backward(int kernel_num, floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, @@ -982,6 +1173,9 @@ void layernorm_backward(int kernel_num, case 8: layernorm_backward8(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; + case 9: + layernorm_backward9(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -1042,7 +1236,7 @@ int main(int argc, char **argv) { cudaCheck(cudaMalloc(&d_weight, C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_mean, B * T * sizeof(floatX))); cudaCheck(cudaMalloc(&d_rstd, B * T * sizeof(floatX))); - cudaCheck(cudaMalloc(&d_scratch, cuda_num_SMs * (2 * C + 1) * sizeof(float))); + cudaCheck(cudaMalloc(&d_scratch, (1024/32) * cuda_num_SMs * (2 * C + 1) * sizeof(float))); // copy over the "inputs" to the backward call cudaCheck(memcpy_convert(d_dout, dout, B * T * C)); cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); @@ -1051,7 +1245,8 @@ int main(int argc, char **argv) { cudaCheck(memcpy_convert(d_rstd, rstd, B * T)); // launch the kernel - int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024}; + // removed 768 because it doesn't work for kernel9 despite being OK in train_gpt2.cu?! + int block_sizes[] = {32, 64, 128, 256, 512, /*768,*/ 1024}; for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; // init the "outputs" of the backward call to zeros diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index 12b167083..16172bcf2 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -27,6 +27,26 @@ sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1 #define ENABLE_BF16 #include "common.h" + +// ---------------------------------------------------------------------------- +// utility functions +__host__ __device__ bool isPowerOfTwo(int n) { + return (n > 0) && ((n & (n - 1)) == 0); +} + +__host__ __device__ int largestPowerOfTwoLessOrEqual(int n) { + // Return the largest power of 2 less than or equal to n + if (n < 1) { + return 0; + } + + while ((n & (n - 1)) > 0) { + n = n & (n - 1); + } + + return n; +} + // ---------------------------------------------------------------------------- // CPU code reference @@ -421,6 +441,8 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s // version1: simple cuBLAS calls void matmul_backward_bias1(floatX* dbias, const floatX* dout, int B, int T, int OC, int block_size) { + block_size = largestPowerOfTwoLessOrEqual(block_size); + assert(isPowerOfTwo(block_size)); // block_size needs to be power of 2 due to the reduction dim3 block_dim(block_size); dim3 grid_dim(OC); size_t shared_mem_size = block_size * sizeof(float); diff --git a/dev/unistd.h b/dev/unistd.h index 18efc2206..861041d46 100644 --- a/dev/unistd.h +++ b/dev/unistd.h @@ -5,12 +5,14 @@ #define _CRT_SECURE_NO_WARNINGS #define _USE_MATH_DEFINES +#include + #include //#define gen_max_length 64 // compile as C++ to skip this VLA issue #include #define CLOCK_MONOTONIC 0 -int clock_gettime(int ignore_variable, struct timespec* tv) +static inline int clock_gettime(int ignore_variable, struct timespec* tv) { return timespec_get(tv, TIME_UTC); // TODO: not sure this is the best solution. Need to review. } @@ -23,4 +25,82 @@ int clock_gettime(int ignore_variable, struct timespec* tv) #define TURN_OFF_FP_FAST __pragma(float_control( precise, on, push )) // Save current setting and turn on /fp:precise #define TURN_ON_FP_FAST __pragma(float_control(pop)) // Restore file's default settings +#define mkdir _mkdir // add mkdir into namespace for windows + +typedef struct glob_t { + size_t gl_pathc; // Count of matched pathnames + char **gl_pathv; // List of matched pathnames +} glob_t; + +static inline void replace_forward_slashes(char* str) { + while (*str) { + if (*str == '/') { + *str = '\\'; + } + str++; + } +} + +static inline void globfree(glob_t *pglob) { + for (size_t i = 0; i < pglob->gl_pathc; ++i) { + free(pglob->gl_pathv[i]); // Free the allocated memory for each filename + } + free(pglob->gl_pathv); // Free the allocated memory for the list of filenames +} + +static inline int glob(const char* pattern, int ignored_flags, int (*ignored_errfunc)(const char* epath, int eerrno), glob_t* pglob){ + struct _finddata_t find_file_data; + char full_path[576]; // stored in pglob->gl_pathv[n] + char directory_path[512] = {0}; // Store the directory path from the pattern + char pattern_copy[512]; // Copy of the pattern to modify + + strncpy_s(pattern_copy, sizeof(pattern_copy) - 1, pattern, sizeof(pattern_copy) - 1); + + replace_forward_slashes (pattern_copy); // Replace forward slashes with backslashes + + if (strchr(pattern_copy, '\\') != NULL) { + strncpy_s(directory_path, sizeof(directory_path) - 1, pattern_copy, strrchr(pattern_copy, '\\') - pattern_copy + 1); + directory_path[strrchr(pattern_copy, '\\') - pattern_copy + 1] = '\0'; + } + + // find the first file matching the pattern in the directory + intptr_t find_handle = _findfirst(pattern_copy, &find_file_data); + + if (find_handle == -1) { + return 1; // No files found + } + + size_t file_count = 0; + size_t max_files = 64000; // hard-coded limit for the number of files + + pglob->gl_pathv = (char **) malloc(max_files * sizeof(char*)); // freed in globfree + + if (pglob->gl_pathv == NULL) { + _findclose(find_handle); + return 2; // Memory allocation failed + } + + do { + if (file_count >= max_files) { + _findclose(find_handle); + return 2; // Too many files found + } + + snprintf(full_path, sizeof(full_path), "%s%s", directory_path, find_file_data.name); + + pglob->gl_pathv[file_count] = _strdup(full_path); // freed in globfree + + if (pglob->gl_pathv[file_count] == NULL) { + _findclose(find_handle); + return 2; // Memory allocation for filename failed + } + file_count++; + } while (_findnext(find_handle, &find_file_data) == 0); + + _findclose(find_handle); + + pglob->gl_pathc = file_count; + return 0; +} + #endif diff --git a/profile_gpt2.cu b/profile_gpt2.cu index f2ac0e84c..f79e9ada4 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -54,7 +54,7 @@ int main(int argc, char *argv[]) { // do a training step gpt2_forward(&model, x, y, B, T); gpt2_zero_grad(&model); - gpt2_backward(&model); + gpt2_backward(&model, x); gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings diff --git a/test_gpt2.cu b/test_gpt2.cu index 247cd322c..bde357f32 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -203,7 +203,7 @@ int main(int argc, char *argv[]) { clock_gettime(CLOCK_MONOTONIC, &start); gpt2_forward(&model, x, y, B, T); gpt2_zero_grad(&model); - gpt2_backward(&model); + gpt2_backward(&model, x); clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; diff --git a/train_gpt2.c b/train_gpt2.c index 57bdfe929..b01abf09f 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -160,32 +160,76 @@ void layernorm_backward(float* dinp, float* dweight, float* dbias, } } -void matmul_forward(float* out, - float* inp, float* weight, float* bias, - int B, int T, int C, int OC) { - // most of the running time is spent here and in matmul_backward - // OC is short for "output channels" - // inp is (B,T,C), weight is (OC, C), bias is (OC) - // out will be (B,T,OC) +void matmul_forward_naive(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC) { + // the most naive implementation of matrix multiplication + // this serves as an algorithmic reference, and as a fallback for + // unfriendly input shapes inside matmul_forward(), below. #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* out_bt = out + b * T * OC + t * OC; - float* inp_bt = inp + b * T * C + t * C; + int bt = b * T + t; for (int o = 0; o < OC; o++) { float val = (bias != NULL) ? bias[o] : 0.0f; - float* wrow = weight + o*C; for (int i = 0; i < C; i++) { - val += inp_bt[i] * wrow[i]; + val += inp[bt * C + i] * weight[o*C + i]; + } + out[bt * OC + o] = val; + } + } + } +} + +void matmul_forward(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC) { + // most of the running time is spent here and in matmul_backward + // therefore, the implementation below is very mildly optimized + // this function is otherwise identical to that of matmul_forward_naive() + // OC is short for "output channels" + // inp is (B,T,C), weight is (OC, C), bias is (OC) + // out will be (B,T,OC) + + // make sure the tiled loop will be correct or fallback to naive version + const int LOOP_UNROLL = 8; + if (B*T % LOOP_UNROLL != 0) { + matmul_forward_naive(out, inp, weight, bias, B, T, C, OC); + return; + } + + // collapse the B and T loops into one and turn it into a strided loop. + // then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times + #pragma omp parallel for + for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) { + for (int o = 0; o < OC; o++) { + // we'll keep LOOP_UNROLL many results in registers + float result[LOOP_UNROLL]; + // initialize the bias, if it exists + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { + result[ibt] = (bias != NULL) ? bias[o] : 0.0f; + } + // inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache + // the value of weight[i + o * C] and reuse it. + // we compile with -Ofast, so the compiler will turn the inner loop into FMAs + for (int i = 0; i < C; i++) { + float w = weight[i + o * C]; + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { + int bt = obt + ibt; + result[ibt] += inp[bt * C + i] * w; } - out_bt[o] = val; + } + // write back results to main memory + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { + int bt = obt + ibt; + out[bt * OC + o] = result[ibt]; } } } } void matmul_backward(float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, + const float* dout, const float* inp, const float* weight, int B, int T, int C, int OC) { // most of the running time is spent here and in matmul_forward // this backward could be done in a single "round" of loops @@ -195,10 +239,10 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* dout_bt = dout + b * T * OC + t * OC; + const float* dout_bt = dout + b * T * OC + t * OC; float* dinp_bt = dinp + b * T * C + t * C; for (int o = 0; o < OC; o++) { - float* wrow = weight + o*C; + const float* wrow = weight + o*C; float d = dout_bt[o]; for (int i = 0; i < C; i++) { dinp_bt[i] += wrow[i] * d; @@ -211,8 +255,8 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, for (int o = 0; o < OC; o++) { for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* dout_bt = dout + b * T * OC + t * OC; - float* inp_bt = inp + b * T * C + t * C; + const float* dout_bt = dout + b * T * OC + t * OC; + const float* inp_bt = inp + b * T * C + t * C; float* dwrow = dweight + o*C; float d = dout_bt[o]; if (dbias != NULL) { dbias[o] += d; } diff --git a/train_gpt2.cu b/train_gpt2.cu index f704d1064..2488e0f27 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -41,6 +41,10 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200), #include #include #include +#include +#include +#include +#include // GPU / CUDA related #include #include @@ -540,50 +544,108 @@ __global__ void encoder_forward_kernel3(floatX* out, store128(out_btc, packed_out); } -template -__device__ void atomicStochasticAdd(T* address, float val0, float val1, unsigned int seed) { - static_assert(sizeof(T) == 2, "Only 16-bit atomicStochasticAdd supported."); - float2 val = make_float2(val0, val1); - unsigned int* address_as_uint = (unsigned int*)address; - unsigned int old = *address_as_uint, assumed; - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); - do { - assumed = old; - float2 new_fp32 = make_float2((float)(reinterpret_cast(&old)[0]) + val.x, - (float)(reinterpret_cast(&old)[1]) + val.y); - T new_rounded[2]; - stochastic_rounding(new_fp32.x, &new_rounded[0], random); - stochastic_rounding(new_fp32.y, &new_rounded[1], random >> 16); - old = atomicCAS(address_as_uint, assumed, *(unsigned int*)&new_rounded); - } while (assumed != old); -} -__device__ void atomicStochasticAdd(float* address, float val0, float val1, unsigned int seed) { - atomicAdd(address, val0); - atomicAdd(address + 1, val1); -} - -__global__ void encoder_backward_kernel(floatX* dwte, floatX* dwpe, - const floatX* dout, const int* inp, - int B, int T, int C, unsigned int seed) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int N = B * T * C; - idx *= 2; // 2 elements per thread - if (idx >= N) { return; } +template +__global__ void wte_backward_kernel(floatX* dwte, + const int4* bucket_info, const int* workload_indices, const floatX* dout, const int* inp, + unsigned int seed, int B, int T, int C) { + // In order to be deterministic, we preprocess the inputs on the cpu into "buckets" + // Each bucket corresponds to (WARP_SIZE * x128::size) channels for a single vocabulary token + // Each thread handles x128::size channels, e.g. 256 per warp for BF16 + // Each block handles (BLOCK_SIZE / WARP_SIZE) elements in a single bucket in parallel + // If a bucket has less than 8 elements, some warps will return immediately + // If a bucket has more than 8 elements, we will loop over all of them + // The buckets are sorted on the CPU so the largest buckets start 1st + int bucket = blockIdx.x; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int c_per_warp = WARP_SIZE * x128::size; + + int bucket_start_idx = bucket_info[bucket].x; + int bucket_size = bucket_info[bucket].y; + int bucket_ix = bucket_info[bucket].z; + int c = bucket_info[bucket].w * c_per_warp + (lane_id * x128::size); + + // Each thread handles "x128::size" channels, so at fp8, each warp would handle 512 channels + // If C is not a multiple of this (e.g. 768), some buckets/c_groups cannot use the entire warp + if (c >= C) { return; } + // Exit early if this is a small bucket and this warp doesn't have any items to process + if (warp_id >= bucket_size) { return; } + + float accum[x128::size] = {0.0f}; + __shared__ float accum_shared[x128::size * BLOCK_SIZE]; + + for(int item = warp_id; item < bucket_size; item += BLOCK_SIZE/WARP_SIZE) { + int bt = workload_indices[bucket_start_idx + item]; + int b = bt / T; + int t = bt % T; + + const floatX* dout_btc = dout + b * T * C + t * C + c; + x128 packed_inp1 = load128cs(dout_btc); + for (int k = 0; k < packed_inp1.size; k++) { + accum[k] += (float)packed_inp1[k]; + } + } - int bt = idx / C; - int b = bt / T; - int t = bt % T; - int c = idx % C; + if (warp_id != 0) { + // we accumulate into warp 0, so only the other warps need to write to shared memory + for (int k = 0; k < x128::size; k++) { + accum_shared[threadIdx.x + k * BLOCK_SIZE] = accum[k]; + } + return; // only warp 0 is needed after writing to shared memory + } - int ix = inp[b * T + t]; + // Read dwte for warp 0 even if other warps are not finished yet to maximise latency tolerance + floatX* dwte_ix = dwte + bucket_ix * C + c; + x128 packed_in_out = load128(dwte_ix); - const floatX* dout_btc = dout + b * T * C + t * C + c; - floatX* dwte_ix = dwte + ix * C + c; - floatX* dwpe_tc = dwpe + t * C + c; + // note: threads which have returned are considered synchronised by CUDA so no risk of deadlock + __syncthreads(); + + // Accumulate into warp 0's registers by reading the values of the other warps in shared memory + for (int i = threadIdx.x+WARP_SIZE; i < min(BLOCK_SIZE, bucket_size*WARP_SIZE); i += WARP_SIZE) { + for (int k = 0; k < x128::size; k++) { + accum[k] += accum_shared[i + k * BLOCK_SIZE]; + } + } - float2 dout_data = make_float2(dout_btc[0], dout_btc[1]); - atomicStochasticAdd(dwte_ix, dout_data.x, dout_data.y, seed); - atomicStochasticAdd(dwpe_tc, dout_data.x, dout_data.y, seed ^ 0xFFFFFFFF); + // Add the result to dwte and write back to global memory (read-modify-write) + for (unsigned int k = 0; k < x128::size; k++) { + // We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic + stochastic_rounding(accum[k] + (float)packed_in_out[k], &packed_in_out[k], seed + k); + } + store128(dwte_ix, packed_in_out); +} + +__global__ void wpe_backward_kernel(floatX* dwpe, + const floatX* dout, const int* inp, + int B, int T, int C, unsigned int seed) { + // Each thread handles x128::size "channel positions", e.g. 256 per warp for BF16 + // For gpt2-124M BF16, C=768 and T=1024, so 3 warps per channel and 3072 warps in total + // For each "channel position" we sum the gradients for every batch at that C/T element + // This way each dwte element is only updated once, and the kernel is fully deterministic! + // The previous kernel was not deterministic, as batches were aggregated with atomicAdd + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; + if (idx >= T * C) { return; } + + // if C is not a multiple of WARP_SIZE*x128::size, it's OK for some warps to handle multiple t + int t = idx / C; + int c = idx % C; + float accum[x128::size] = {0.0f}; + + for (int b = 0; b < B; b++) { + x128 packed_dout = load128cs(dout + (b * T * C) + (t * C) + c); // will never be read again + for (int k = 0; k < x128::size; k++) { + accum[k] += (float)packed_dout[k]; + } + } + + floatX* dwpe_tc = dwpe + (t * C) + c; + x128 packed_dwpe = load128(dwpe_tc); + for (unsigned int k = 0; k < x128::size; k++) { + // We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic + stochastic_rounding(accum[k] + (float)packed_dwpe[k], &packed_dwpe[k], seed + k); + } + store128(dwpe_tc, packed_dwpe); } __global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd, @@ -791,10 +853,9 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons // directly autoregressive, so we only compute the lower triangular part // uses the online softmax algorithm assert(T % 4 == 0); - const int warp_size = 32; - int lane_id = threadIdx.x % warp_size; - int warp_id = threadIdx.x / warp_size; - int num_warps = blockDim.x / warp_size; + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int num_warps = blockDim.x / WARP_SIZE; // micro-optimization: we iterate backwards so that // after the softmax backward operation completes, the cache retains the @@ -817,7 +878,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons float sumval = 0.0f; const floatX* x_aligned = reinterpret_cast(__builtin_assume_aligned(x, 16)); - for (int i = lane_id; i < pos_by_4; i += warp_size) { + for (int i = lane_id; i < pos_by_4; i += WARP_SIZE) { float regarray[4]; for (int k = 0; k < 4; ++k) { regarray[k] = (float)x_aligned[4*i + k]; @@ -846,7 +907,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons float norm = 1.f / sum; // divide the whole row by the sum - for (int i = lane_id; i <= own_pos; i += warp_size) { + for (int i = lane_id; i <= own_pos; i += WARP_SIZE) { // recalculation is faster than doing the round-trip through memory. float ev = expf(inv_temperature * ((float)__ldcs(x + i) - global_maxval)); __stcs(out + idx * T + i, (floatX)(ev * norm)); @@ -988,30 +1049,32 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s } } -__global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with only 1024 threads? - layernorm_backward_kernel8(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, - const floatX* mean, const floatX* rstd, - int B, int T, int C) { - extern __shared__ float shared[]; // size = 2 * C + 1 - int warpId = threadIdx.x / WARP_SIZE; // warp index within a block +__global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? + layernorm_backward_kernel9(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const floatX* mean, const floatX* rstd, + int B, int T, int C) { + extern __shared__ float shared[]; // size = 2*C + 2*block_size + 1 int warpsInBlock = blockDim.x / WARP_SIZE; //number of warps in block + int warpId = threadIdx.x / WARP_SIZE; // warp index within a block int baseIdx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % WARP_SIZE; // Thread index within the warp int warpsInGrid = gridDim.x * warpsInBlock; int C_per_iteration = WARP_SIZE * x128::size; - int iterations_C = C / C_per_iteration; + int iterations_C = CEIL_DIV(C, C_per_iteration); // the first half of shared memory is bias, second is weight float* dbias_shared = shared; float* dweight_shared = shared + C; + float* dbias_tmp_shared = shared + 2 * C; + float* dweight_tmp_shared = shared + 2 * C + blockDim.x; // init shared memory to zero for(int i = threadIdx.x; i < C; i+= blockDim.x){ dbias_shared[i] = 0.0f; dweight_shared[i] = 0.0f; } - unsigned int *tmp_flag = (unsigned int*)(shared + C*2); + unsigned int *tmp_flag = (unsigned int*)(shared + 2*C + 2*blockDim.x); __syncthreads(); for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { @@ -1049,6 +1112,10 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with for (int i = 0; i < iterations_C; i++) { int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } + x128 dout128 = load128cs(dout_bt + global_index); x128 inp128 = load128cs(inp_bt + global_index); x128 dinp128 = load128(dinp_bt + global_index); @@ -1058,10 +1125,29 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with float dout_i = (float)dout128[x]; float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; float dnorm_i = (float)weight128[x] * dout_i; - // gradient contribution to bias (using shared memory friendly index) - atomicAdd(&dbias_shared[shared_index + x*WARP_SIZE], dout_i); - // gradient contribution to weight (using shared memory friendly index) - atomicAdd(&dweight_shared[shared_index + x*WARP_SIZE], norm_bti * dout_i); + + // sum up the gradients for bias and weight across the entire block + // this is basically a reduction (but only inter-warp, not intra-warp) + // doing it this way allows us to avoid using atomics while using many warps + if (warpId != 0) { + dbias_tmp_shared[threadIdx.x] = dout_i; + dweight_tmp_shared[threadIdx.x] = norm_bti * dout_i; + } + __syncthreads(); + if (warpId == 0) { + float dbias_tmp = dout_i; + float dweight_tmp = norm_bti * dout_i; + for (int j = 1; j < warpsInBlock; j++) { + dbias_tmp += dbias_tmp_shared[threadIdx.x + j * WARP_SIZE]; + dweight_tmp += dweight_tmp_shared[threadIdx.x + j * WARP_SIZE]; + } + // gradient contribution to bias (using shared memory friendly index) + dbias_shared[shared_index + x*WARP_SIZE] += dbias_tmp; + // gradient contribution to weight (using shared memory friendly index) + dweight_shared[shared_index + x*WARP_SIZE] += dweight_tmp; + } + __syncthreads(); + // gradient contribution to input float dval = 0.0f; dval += dnorm_i; // term 1 @@ -1074,35 +1160,64 @@ __global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with store128cg(dinp_bt + global_index, dinp128); } } - // Accumulate into a FP32 scratchpad - // BF16 atomics are potentially much slower... and this is more precise! - // todo - could potentially avoid the extra copy if floatX is FP32, fairly negligible though __syncthreads(); + // Each block writes its partial sum to global memory + // The last block to finish becomes responsible for summing up all the partial sums + // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + unsigned int* scratchFlag = (unsigned int*)(scratch); + // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned + scratch += 32; float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; - unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C)); for(int i = threadIdx.x; i < C; i+= blockDim.x) { - // global atomics in the same "shared memory banking friendly" order - atomicAdd(&scratch_dbias[i], dbias_shared[i]); - atomicAdd(&scratch_dweight[i], dweight_shared[i]); + // Write to global memory in the same "shared memory banking friendly" order + scratch_dbias[i + 2*C*blockIdx.x] = dbias_shared[i]; + scratch_dweight[i + 2*C*blockIdx.x] = dweight_shared[i]; } + + // todo - everything below could become a separate kernel for better performance with maybe less code + // not enough parallelism even inside that single SM... do we need another level of reduction?! __syncthreads(); if (threadIdx.x == 0) { *tmp_flag = atomicInc(scratchFlag, gridDim.x); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { + // Reduction of the partial sums by the final block + for(int i = threadIdx.x * f128::size; i < C; i+= blockDim.x * f128::size) { + f128 dbias_accum(make_int4(0, 0, 0, 0)); + f128 dweight_accum(make_int4(0, 0, 0, 0)); + + for (int read_block_idx = 0; read_block_idx < gridDim.x; read_block_idx++) { + int offset = i + 2*C*read_block_idx; + f128 dbias128 = load128(scratch_dbias + offset); + f128 dweight128 = load128(scratch_dweight + offset); + for(int k = 0; k < f128::size; k++) { + dbias_accum[k] += dbias128[k]; + dweight_accum[k] += dweight128[k]; + } + } + store128(dbias_shared + i, dbias_accum); + store128(dweight_shared + i, dweight_accum); + } + __syncthreads(); + + // reorder from atomic/shared memory-friendly index to real global memory index + // and convert from float/FP32 to floatX/BF16 for the final write + // this is separate also because it cannot use as many warps as the above (f128 vs x128) + // todo - if we split this code into another kernel, we could maybe do it at the same time? for (int i = warpId; i < iterations_C; i += warpsInBlock) { - // reorder from atomic/shared memory-friendly index to real global memory index - // and convert from float/FP32 to floatX/BF16 for the final write int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); int shared_index = warpThreadIdx + (i * C_per_iteration); + if (global_index >= C) { + break; + } x128 dbias128 = load128(dbias + global_index); x128 dweight128 = load128(dweight + global_index); for (int x = 0; x < x128::size; x++) { - float s_db = scratch_dbias[shared_index + x*WARP_SIZE]; - float s_dw = scratch_dweight[shared_index + x*WARP_SIZE]; + float s_db = dbias_shared[shared_index + x*WARP_SIZE]; + float s_dw = dweight_shared[shared_index + x*WARP_SIZE]; dbias128[x] = (floatX)(s_db + (float)dbias128[x]); dweight128[x] = (floatX)(s_dw + (float)dweight128[x]); } @@ -1213,7 +1328,7 @@ struct SoftmaxParams { float Offset; }; -__device__ SoftmaxParams prepare_softmax_blockwide3(int idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) @@ -1267,7 +1382,10 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs, const float dloss, const int* targets, int B, int T, int V, int P) { - int idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data + // note: idx is small enough that it easily fits into 32 bit; + // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P) + // are done is 64 bit + int64_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -1362,14 +1480,70 @@ void encoder_forward(floatX* out, cudaCheck(cudaGetLastError()); } -void encoder_backward(floatX* dwte, floatX* dwpe, - const floatX* dout, const int* inp, - int B, int T, int C, unsigned int seed) { +// Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details) +void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch + int* workload_indices, int4* bucket_info, // cpu scratch buffers + const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs + int B, int T, int C, unsigned int seed) { NVTX_RANGE_FN(); - const int N = B * T * C; + + // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) const int block_size = 256; - const int grid_size = CEIL_DIV(N, block_size * 2); // each thread handles 2 elements - encoder_backward_kernel<<>>(dwte, dwpe, dout, inp, B, T, C, seed); + const int N = T * C / x128::size; + const int grid_size = CEIL_DIV(N, block_size); + wpe_backward_kernel<<>>(dwpe, dout, inp, B, T, C, seed); + + // check the GPU scratch buffer is large enough to hold the bucket info and workload indices + // todo - this is trivially true given hardcoded scratch buffer size here, is this useful? + int num_c_groups = CEIL_DIV(C, x128::size * WARP_SIZE); + assert(B*T*num_c_groups * (sizeof(int4)+sizeof(int)) <= B*T*3*C * sizeof(floatX)); + + // Step 1: Sort inputs into buckets + int total_items = 0; + std::unordered_map> buckets; + for (uint64_t bt = 0; bt < B * T; bt++) { + for (uint64_t c_group = 0; c_group < num_c_groups; c_group++) { + // todo - passing c_group/inputs_cpu[bt] in data to avoid a second hash lookup is a bit hacky + uint64_t data = bt + (c_group<<32ULL) + ((uint64_t)inputs_cpu[bt]<<42ULL); + buckets[c_group + num_c_groups * inputs_cpu[bt]].push_back(data); + total_items++; + } + } + + // Step 2: Sort buckets by size in descending order + // this is so the largest buckets are processed first by the GPU + // otherwise, if they started late, they would still be running with the rest of the GPU idle + std::vector>> sortedBuckets(buckets.begin(), buckets.end()); + std::sort(sortedBuckets.begin(), sortedBuckets.end(), // ugly because we don't have a typedef for the std::pair + [](const std::pair>& a, const std::pair>& b) { + return a.second.size() > b.second.size(); + }); + + int num_buckets = buckets.size(); + int bucket_index = 0; + int workload_index = 0; + for (const auto& bucket : sortedBuckets) { + bucket_info[bucket_index].x = workload_index; // bucket start + bucket_info[bucket_index].y = bucket.second.size(); // bucket size + bucket_info[bucket_index].z = (bucket.second[0] >> 42ULL) & ((1ULL<<20ULL)-1); // bucket ix + bucket_info[bucket_index].w = (bucket.second[0] >> 32ULL) & ((1ULL<<10ULL)-1); // bucket c + + for (uint64_t idx : bucket.second) { + workload_indices[workload_index++] = (int)(idx & ((1ULL<<31ULL)-1ULL)); + } + bucket_index++; + } + + // Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice) + // todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely + int4* d_bucket_info = (int4*)scratch; + int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * sizeof(int4)); + cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice); + cudaMemcpy(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice); + + // Launch wte kernel + // todo - profile block sizes on more content (depends on number of buckets and on GPU?) + wte_backward_kernel<256><<>>(dwte, d_bucket_info, d_workload_indices, dout, inp, seed, B, T, C); cudaCheck(cudaGetLastError()); } @@ -1611,15 +1785,13 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { NVTX_RANGE_FN(); - // todo - forcing 3 x 512 threads per SM maximum is a bit hacky, but more than that results in - // cache thrashing and lower performance on A100... is there a better way? const int block_size = 512; - const int blocks_per_sm = min(3, (deviceProp.maxThreadsPerMultiProcessor / 1024)); + const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3 const int grid_size = blocks_per_sm * deviceProp.multiProcessorCount; - size_t shared_mem_size = (2 * C + 1) * sizeof(float); + size_t shared_mem_size = (2*C + 2*block_size + 1) * sizeof(float); // see kernel - cudaMemset(scratch, 0, (2 * C + 1) * sizeof(float)); - layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + cudaMemset(scratch, 0, 1 * sizeof(float)); // only need to reset the flag to 0 + layernorm_backward_kernel9<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); cudaCheck(cudaGetLastError()); } @@ -1956,6 +2128,9 @@ typedef struct { unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. int use_master_weights; int recompute; + // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? + int* workload_indices; // encoder_backward, B*T*num_c_groups (int) + int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case } GPT2; void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { @@ -2057,6 +2232,8 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->targets = NULL; model->cpu_losses = NULL; model->cpu_losses_fp32 = NULL; + model->workload_indices = NULL; + model->bucket_info = NULL; model->batch_size = 0; model->seq_len = 0; model->mean_loss = -1.0f; // -1.0f will designate no loss @@ -2072,7 +2249,8 @@ void gpt2_build_from_random(GPT2 *model, int depth) { model->config.num_layers = depth; // follows GPT-2 sizes int channels, num_heads; - if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M) + if (depth == 6) { channels = 384; num_heads = 6; } // gpt2-tiny (30M) + else if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M) else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M) else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M) else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M) @@ -2339,7 +2517,7 @@ void gpt2_zero_grad(GPT2 *model) { } } -void gpt2_backward(GPT2 *model) { +void gpt2_backward(GPT2 *model, int* inputs) { NVTX_RANGE_FN(); // double check we forwarded previously, with targets if (model->mean_loss == -1.0f) { @@ -2365,6 +2543,11 @@ void gpt2_backward(GPT2 *model) { model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes); // init gradients of parameters and activations to zero gpt2_zero_grad(model); + // initialise cpu scratch buffers for encoder backward + size_t num_c_groups = model->config.channels / (WARP_SIZE * x128::size); + assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?) + model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups); + model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups); } // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow @@ -2385,7 +2568,8 @@ void gpt2_backward(GPT2 *model) { cudaCheck(cudaMemset(model->grads_acts.residual3, 0, B * T * C * sizeof(floatX))); // re-use the output buffer of the forward pass as a scratchpad during backward pass - float* scratchF = (float*)acts.output; + float* scratchF = (float*)acts.output; + floatX* scratchX = (floatX*)acts.output; // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) // this was done in the fused classifier kernel as last step of forward pass @@ -2467,7 +2651,6 @@ void gpt2_backward(GPT2 *model) { floatX* buffer_a = l_atty; floatX* buffer_b = l_fch; // this is B x T x 4C, so even larger than what we need floatX* dl_preatt = (floatX*)grads_acts.preatt; // dedicated scratchpad allocation - floatX* scratchX = (floatX*)acts.output; attention_backward(dl_bt4c, buffer_b, dl_preatt, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH); #endif @@ -2476,7 +2659,8 @@ void gpt2_backward(GPT2 *model) { // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); } - encoder_backward(grads.wte, grads.wpe, dresidual, model->inputs, B, T, C, random_u32(&model->rng_state)); + encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, + dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state)); } // Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled. @@ -2555,14 +2739,34 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl // AdamW update int block_size = 512; - int num_blocks = CEIL_DIV(num_parameters, block_size); float beta1_correction = 1.0f - powf(beta1, t); float beta2_correction = 1.0f - powf(beta2, t); unsigned int seed = random_u32(&model->rng_state); - adamw_kernel3<<>>(params_memory, model->master_weights, grads_memory, - model->m_memory, model->v_memory, num_parameters, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, - grad_scale, seed); + + // individually call the adamw_kernel3 on all parameter tensors separately + floatX* params_memory_iter = params_memory; + float* master_weights_iter = model->master_weights; + floatX* grads_memory_iter = grads_memory; + float* m_memory_iter = (float*)model->m_memory; + float* v_memory_iter = (float*)model->v_memory; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + size_t num_parameters = model->param_elements[i]; + int num_blocks = CEIL_DIV(num_parameters, block_size); + // we only want to weight decay the 2D tensors and leave all 1D tensors alone + // in particular this also decays the embedding weights, but this is ok: + // - the token embeddings are weight shared and participate in the final projection to logits + // - the position embeddings actively participate at every forward/backward pass + float wd = (i == 0 || i == 1 || i == 4 || i == 6 || i == 10 || i == 12) ? weight_decay : 0.0f; + adamw_kernel3<<>>(params_memory_iter, master_weights_iter, grads_memory_iter, + m_memory_iter, v_memory_iter, num_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, wd, + grad_scale, seed); + params_memory_iter += num_parameters; + if (master_weights_iter != NULL) { master_weights_iter += num_parameters; } + grads_memory_iter += num_parameters; + m_memory_iter += num_parameters; + v_memory_iter += num_parameters; + } cudaCheck(cudaGetLastError()); return grad_norm_cpu; } @@ -2593,6 +2797,8 @@ void gpt2_free(GPT2 *model) { cudaCheck(cudaFree(model->targets)); cudaFreeHost(model->cpu_losses); cudaFreeHost(model->cpu_losses_fp32); + free(model->workload_indices); + free(model->bucket_info); } // ---------------------------------------------------------------------------- @@ -2622,7 +2828,7 @@ void common_free(GPT2 &model) { cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); cublasCheck(cublasLtDestroy(cublaslt_handle)); - create_cudnn(); + destroy_cudnn(); } #ifndef TESTING @@ -3041,7 +3247,7 @@ int main(int argc, char *argv[]) { gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, grad_accum_steps); lossf += model.mean_loss; // the mean_loss was normalized by grad_accum_steps inside gpt2_forward // backward pass. all model params accumulate gradients with += inside this inner loop - gpt2_backward(&model); + gpt2_backward(&model, train_loader.inputs); } // override the mean loss, accounting for the gradient accumulation loop // this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced diff --git a/train_gpt2.py b/train_gpt2.py index d32ab5042..f2b0ae302 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -19,6 +19,7 @@ import os import math import struct +import inspect from contextlib import nullcontext from dataclasses import dataclass @@ -228,6 +229,31 @@ def from_pretrained(cls, model_type): return model + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {'params': decay_params, 'weight_decay': weight_decay}, + {'params': nodecay_params, 'weight_decay': 0.0} + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == 'cuda' + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + print(f"using fused AdamW: {use_fused}") + return optimizer + @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ @@ -429,6 +455,7 @@ def print0(*args, **kwargs): parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") parser.add_argument("--total_batch_size", type=int, default=256, help="total desired batch size, in units of #tokens") parser.add_argument("--grad_clip", type=float, default=1.0, help="maximum gradient magnitude") + parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay") parser.add_argument("--overfit_single_batch", type=int, default=1, help="overfit just one batch of data") args = parser.parse_args() B, T = args.batch_size, args.sequence_length @@ -467,6 +494,7 @@ def print0(*args, **kwargs): elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = "mps" print(f"using device: {device}") + device_type = 'cuda' if 'cuda' in device else 'cpu' # calculate gradient accumulation from the desired total batch size and the current run configuration tokens_per_fwdbwd = B * T * ddp_world_size @@ -478,7 +506,7 @@ def print0(*args, **kwargs): # set up a context manager following the desired dtype and device ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] - ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) if device == "cuda" else nullcontext() + ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() # seed the random number generators (in DDP we want different processes to use different offsets) # in the code below we don't actually use random numbers because there is no active dataloader @@ -604,8 +632,9 @@ def get_batch(): raw_model = model.module if ddp else model # always contains the "raw" unwrapped model # init the optimizer - adam_use_fused = device == "cuda" # only works on CUDA (?) - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95), weight_decay=0.0, fused=adam_use_fused) + optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, + learning_rate=1e-4, betas=(0.9, 0.95), + device_type=device) if device == "cuda": torch.cuda.reset_peak_memory_stats()