diff --git a/Makefile b/Makefile index 2908e27aa..52d890582 100644 --- a/Makefile +++ b/Makefile @@ -85,6 +85,20 @@ else endif endif +# Precision settings, default to bf16 but ability to override +PRECISION ?= BF16 +VALID_PRECISIONS := FP32 FP16 BF16 +ifeq ($(filter $(PRECISION),$(VALID_PRECISIONS)),) + $(error Invalid precision $(PRECISION), valid precisions are $(VALID_PRECISIONS)) +endif +ifeq ($(PRECISION), FP32) + PFLAGS = -DENABLE_FP32 +else ifeq ($(PRECISION), FP16) + PFLAGS = -DENABLE_FP16 +else + PFLAGS = -DENABLE_BF16 +endif + # PHONY means these targets will always be executed .PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu @@ -108,7 +122,7 @@ test_gpt2: test_gpt2.c $(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@ train_gpt2cu: train_gpt2.cu - $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@ + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@ train_gpt2fp32cu: train_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@ diff --git a/profile_gpt2.cu b/profile_gpt2.cu index fd9c71ba3..abb1b1327 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -24,6 +24,7 @@ For example, I have NVIDIA Nsight Compute installed on my Mac, and I rsync the profile.ncu-rep from a cloud box to local to pretty view. */ +#define ENABLE_BF16 #define TESTING #include "train_gpt2.cu" @@ -51,7 +52,7 @@ int main() { // build the GPT-2 model from a checkpoint GPT2 model; - gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); + gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin"); int B = 4; int T = 1024; diff --git a/test_gpt2.cu b/test_gpt2.cu index bb35b4fda..804c55f44 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -1,13 +1,27 @@ +#define ENABLE_BF16 #define TESTING #include "train_gpt2.cu" // poor man's tensor checker int check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) { - int print_upto = 5; + // a is the calculated tensor, b is the reference tensor + int print_upto = 10; int ok = 1; + float max_diff = 0.0f; + float max_rel_error = 0.0f; + float max_a = 0.0f; + float max_b = 0.0f; printf("%s\n", label); for (int i = 0; i < n; i++) { - if (fabsf(a[i] - b[i]) <= threshold) { + float diff = fabsf(a[i] - b[i]); + if (diff > max_diff) { + max_diff = diff; + float denom = fabsf(b[i]); + max_rel_error = (denom == 0.0f) ? 0.0f : diff / denom; + max_a = a[i]; + max_b = b[i]; + } + if (diff <= threshold) { if (i < print_upto) { printf("OK "); } } else { if (i < print_upto) { printf("NOT OK "); } @@ -17,13 +31,58 @@ int check_tensor(float *a, float *b, int n, const char* label, float threshold=1 } // print the final result if (ok) { - printf("TENSOR OK\n"); + printf("TENSOR OK, max diff: %e, with rel error: %e (calculated=%f, ref=%f)\n", + max_diff, max_rel_error, max_a, max_b); } else { - printf("TENSOR NOT OK\n"); + printf("TENSOR NOT OK, max diff: %e, with rel error: %e (calculated=%f, ref=%f)\n", + max_diff, max_rel_error, max_a, max_b); } return ok; } +// the same tensors as in the train file, but in float, which are used as reference +typedef struct { + float* wte; // (V, C) + float* wpe; // (maxT, C) + float* ln1w; // (L, C) + float* ln1b; // (L, C) + float* qkvw; // (L, 3*C, C) + float* qkvb; // (L, 3*C) + float* attprojw; // (L, C, C) + float* attprojb; // (L, C) + float* ln2w; // (L, C) + float* ln2b; // (L, C) + float* fcw; // (L, 4*C, C) + float* fcb; // (L, 4*C) + float* fcprojw; // (L, C, 4*C) + float* fcprojb; // (L, C) + float* lnfw; // (C) + float* lnfb; // (C) +} FloatParameterTensors; +static_assert(sizeof(FloatParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); + +// malloc_and_point, but in float and on CPU, because we use this data to check correctness on CPU +float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size_t* param_sizes) { + // calculate the total number of parameters + size_t num_parameters = 0; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters += param_sizes[i]; + } + // everything is float so number of bytes to allocate is a simple multiplication + float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); + float** ptrs[] = { + ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, + ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb + }; + float* params_memory_iterator = params_memory; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { + *(ptrs[i]) = params_memory_iterator; + params_memory_iterator += param_sizes[i]; + } + return params_memory; +} + int main(int argc, char *argv[]) { // set up the device @@ -48,19 +107,18 @@ int main(int argc, char *argv[]) { // build the GPT-2 model from a checkpoint GPT2 model; - gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); - - // int C = model.config.channels; - int V = model.config.vocab_size; - int maxT = model.config.max_seq_len; - // int L = model.config.num_layers; + gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin"); + size_t V = model.config.vocab_size; + size_t maxT = model.config.max_seq_len; + size_t L = model.config.num_layers; + size_t C = model.config.channels; // load additional information that we will use for debugging and error checking FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb"); int state_header[256]; freadCheck(state_header, sizeof(int), 256, state_file); - if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(1); } - if (state_header[1] != 1) { printf("Bad version in state file"); exit(1); } + if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(EXIT_FAILURE); } + if (state_header[1] != 1) { printf("Bad version in state file"); exit(EXIT_FAILURE); } int B = state_header[2]; // batch size, e.g. 4 int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) assert(0 <= T && T <= maxT); @@ -68,26 +126,27 @@ int main(int argc, char *argv[]) { printf("batch_size: %d\n", B); printf("seq_len: %d\n", T); - ParameterTensors expected_grads; // will be read from file (from PyTorch) - ParameterTensors calculated_grads; // will be calculated by us - float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_elements, model.param_sizeof, 0); - float* calculated_grads_memory = malloc_and_point_parameters(&calculated_grads, model.param_elements, model.param_sizeof, 0); - float* converted_grads_memory = (float*)mallocCheck(model.num_parameters * sizeof(float)); - - // inputs and expected outputs, only used for error checking + // read reference information from the file saved from Python/PyTorch side + // 1) input x and y int* x = (int*)mallocCheck(B * T * sizeof(int)); int* y = (int*)mallocCheck(B * T * sizeof(int)); - float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float)); - float* expected_loss = (float*) mallocCheck(1 * sizeof(float)); - - // read reference information from Python freadCheck(x, sizeof(int), B*T, state_file); freadCheck(y, sizeof(int), B*T, state_file); + // 2) results of forward pass (logits and loss) + float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float)); + float* expected_loss = (float*) mallocCheck(1 * sizeof(float)); freadCheck(expected_logits, sizeof(float), B*T*V, state_file); freadCheck(expected_loss, sizeof(float), 1, state_file); + // 3) results of backward pass (parameter gradients) + FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32 + float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements); freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file); fcloseCheck(state_file); + // this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads + void* grads_memory_cpu = mallocCheck(model.num_parameters_bytes); + float* grads_memory_cpu_float = (float*)mallocCheck(model.num_parameters * sizeof(float)); + // overall OK signal for the test int allok = 1; @@ -103,25 +162,32 @@ int main(int argc, char *argv[]) { } int logits_ok = 1; - // FP16 and lower require very high tolerances unfortunately - float accuracy_threshold = 1e-2; + // FP16 and lower require very high tolerances unfortunately. TODO look into more + float logit_accuracy_threshold = 1e-2f; + float loss_diff_threshold = 0.05f; #if defined(ENABLE_BF16) || defined(ENABLE_F16) - accuracy_threshold = 23; + logit_accuracy_threshold = 15.0f; #endif + + float max_diff = 0.0f; for (int i=0; i= accuracy_threshold) { + float diff = fabsf(expected_logits[i] - logits_cpu[i]); + max_diff = fmaxf(max_diff, diff); + if (diff >= logit_accuracy_threshold) { printf("MISMATCH AT INDEX %d: ", i); printf("%f %f\n", expected_logits[i],logits_cpu[i]); logits_ok = 0; break; } } + allok = allok && logits_ok; if(!logits_ok) { printf("NOT "); } printf("OK (LOGITS)\n"); + printf("logit max diff: %f\n", max_diff); // let's do 10 training iterations, following the pytorch code float losses[10]; @@ -137,71 +203,63 @@ int main(int argc, char *argv[]) { if (step == 0) { // error checking at step 0 for reference activations - - allok = allok && logits_ok; - free(logits_cpu_raw); - free(logits_cpu); - // compare the achieved loss - if (fabsf(model.mean_loss - *expected_loss) >= accuracy_threshold) { + if (fabsf(model.mean_loss - *expected_loss) >= loss_diff_threshold) { printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss); allok = 0; } else { printf("LOSS OK: %f %f\n", model.mean_loss, *expected_loss); } - // and now compare the gradients on the parameters - // cudaMemcpy(calculated_grads.lnfw, model.grads.lnfw, C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.lnfb, model.grads.lnfb, C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.fcprojw, model.grads.fcprojw, L * C * 4*C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.fcprojb, model.grads.fcprojb, L * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.fcw, model.grads.fcw, L * 4*C * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.fcb, model.grads.fcb, L * 4*C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.ln2w, model.grads.ln2w, L * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.ln2b, model.grads.ln2b, L * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.attprojw, model.grads.attprojw, L * C * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.attprojb, model.grads.attprojb, L * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.qkvw, model.grads.qkvw, L * 3*C * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.qkvb, model.grads.qkvb, L * 3*C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.ln1w, model.grads.ln1w, L * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.ln1b, model.grads.ln1b, L * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.wte, model.grads.wte, V * C * sizeof(float), cudaMemcpyDeviceToHost); - // cudaMemcpy(calculated_grads.wpe, model.grads.wpe, maxT * C * sizeof(float), cudaMemcpyDeviceToHost); - // check_tensor(calculated_grads.lnfb, expected_grads.lnfb, C, "lnfb"); - // check_tensor(calculated_grads.lnfw, expected_grads.lnfw, C, "lnfw"); - // check_tensor(calculated_grads.fcprojw, expected_grads.fcprojw, L * C * 4*C, "fcprojw"); - // check_tensor(calculated_grads.fcprojb, expected_grads.fcprojb, L * C, "fcprojb"); - // check_tensor(calculated_grads.fcw, expected_grads.fcw, L * 4*C * C, "fcw"); - // check_tensor(calculated_grads.fcb, expected_grads.fcb, L * 4*C, "fcb"); - // check_tensor(calculated_grads.ln2w, expected_grads.ln2w, L * C, "ln2w"); - // check_tensor(calculated_grads.ln2b, expected_grads.ln2b, L * C, "ln2b"); - // check_tensor(calculated_grads.attprojw, expected_grads.attprojw, L * C * C, "attprojw"); - // check_tensor(calculated_grads.attprojb, expected_grads.attprojb, L * C, "attprojb"); - // check_tensor(calculated_grads.qkvw, expected_grads.qkvw, L * 3*C * C, "qkvw"); - // check_tensor(calculated_grads.qkvb, expected_grads.qkvb, L * 3*C, "qkvb"); - // check_tensor(calculated_grads.ln1w, expected_grads.ln1w, L * C, "ln1w"); - // check_tensor(calculated_grads.ln1b, expected_grads.ln1b, L * C, "ln1b"); - // check_tensor(calculated_grads.wte, expected_grads.wte, V * C, "wte"); - // check_tensor(calculated_grads.wpe, expected_grads.wpe, maxT * C, "wpe"); - - // get gradients from GPU and convert all non-FP32 gradients back to FP32 for check_tensor - cudaMemcpy(calculated_grads_memory, model.grads_memory, model.num_parameters * sizeof(floatX), cudaMemcpyDeviceToHost); - char* src_iterator = (char*)calculated_grads_memory; - float* dst_iterator = (float*)converted_grads_memory; - for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + // move the (mixed precision) grads from GPU to CPU + cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost); + + // convert all gradients to float on the CPU + char* src_iterator = (char*)grads_memory_cpu; // can be lower precision, so we use char* + float* dst_iterator = (float*)grads_memory_cpu_float; // float* + float* exp_iterator = expected_grads_memory; // float* of expected gradients from Python + float* tensors1[NUM_PARAMETER_TENSORS]; + float* tensors2[NUM_PARAMETER_TENSORS]; + for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { if (model.param_sizeof[i] == sizeof(float)) { + // float tensor => copy over directly memcpy(dst_iterator, src_iterator, model.param_elements[i] * sizeof(float)); } else { - assert(model.param_sizeof[i] == sizeof(floatX)); + // low-precision tensor => convert to float + assert(model.param_sizeof[i] == sizeof(floatX)); // floatX is the single non-float supported atm for (size_t j = 0; j < model.param_elements[i]; j++) { - dst_iterator[j] = ((floatX*)src_iterator)[j]; + dst_iterator[j] = ((floatX*)src_iterator)[j]; // convert to float } } + // for convenience record the position of comparison for reality vs. expectation + tensors1[i] = dst_iterator; // reality + tensors2[i] = exp_iterator; // expectation + // advance the iterators src_iterator += model.param_elements[i] * model.param_sizeof[i]; dst_iterator += model.param_elements[i]; + exp_iterator += model.param_elements[i]; } - // compare the gradients ona the parameters all at once - check_tensor(converted_grads_memory, expected_grads_memory, model.num_parameters, "grads"); + + // compare the gradients on the parameters all at once, in fp32 + // I set the tolerances manually by inspecting the gradient differences for + // a few elements of each tensor. bf16 looks ok but not amazing here. + // It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure. + allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", 6e-1f); + allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", 1e-2f); + allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", 9e-2); // hmm a bit high + allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", 3e-2f); + allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", 3e-2f); + allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", 3e-2f); + allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", 9e-2f); // hmm a bit high + allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", 9e-2f); // hmm a bit high + allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", 9e-2f); // hmm a bit high + allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", 3e-2f); + allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", 0.1f); // hmm bit higher + allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", 3e-2f); + allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", 0.1f); // hmm bit higher + allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", 3e-2f); + allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", 0.12f); // hmm bit higher + allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", 3e-2f); } gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1); @@ -227,7 +285,7 @@ int main(int argc, char *argv[]) { // compare for (int i = 0; i < 10; i++) { - if (fabsf(losses[i] - expected_losses[i]) >= accuracy_threshold) { + if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) { printf("LOSS MISMATCH AT STEP %d: %f %f\n", i, losses[i], expected_losses[i]); allok = 0; } else { @@ -241,11 +299,13 @@ int main(int argc, char *argv[]) { // free everything free(x); free(y); + free(logits_cpu_raw); + free(logits_cpu); free(expected_logits); free(expected_loss); free(expected_grads_memory); - free(calculated_grads_memory); - free(converted_grads_memory); + free(grads_memory_cpu); + free(grads_memory_cpu_float); gpt2_free(&model); cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); diff --git a/train_gpt2.cu b/train_gpt2.cu index 63515e3e8..e941277c7 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -37,12 +37,13 @@ mpirun -np 4 ./train_gpt2cu -b 8 -v 200 -s 200 -i data/TinyStories #include #include #include +// GPU / CUDA related #include #include #include #include #include - +// Multi-GPU related #ifdef MULTI_GPU #include #include @@ -51,18 +52,24 @@ mpirun -np 4 ./train_gpt2cu -b 8 -v 200 -s 200 -i data/TinyStories // ---------------------------------------------------------------------------- // CUDA precision settings -// turn on bf16 as default, done up here for now -#define ENABLE_BF16 +enum PrecisionMode { + PRECISION_FP32, + PRECISION_FP16, + PRECISION_BF16 +}; -// use bf16 (bfloat 16) -#if defined(ENABLE_BF16) -typedef __nv_bfloat16 floatX; +// fp32 +#if defined(ENABLE_FP32) +typedef float floatX; typedef float floatN; -#define CUBLAS_LOWP CUDA_R_16BF -#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F +#define CUBLAS_LOWP CUDA_R_32F +#define CUBLAS_LOWP_COMPUTE cublas_compute_type // auto-select FP32 vs TF32 +const char* load_filename = "gpt2_124M.bin"; // fp32 weights +PrecisionMode PRECISION_MODE = PRECISION_FP32; +const char* precision_mode_str = "fp32"; #ifdef MULTI_GPU -const ncclDataType_t ncclFloatX = ncclBfloat16; +const ncclDataType_t ncclFloatX = ncclFloat; const ncclDataType_t ncclFloatN = ncclFloat; #endif @@ -72,24 +79,29 @@ typedef half floatX; typedef float floatN; #define CUBLAS_LOWP CUDA_R_16F #define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F +const char* load_filename = "gpt2_124M.bin"; // fp32 weights +PrecisionMode PRECISION_MODE = PRECISION_FP16; +const char* precision_mode_str = "fp16"; #ifdef MULTI_GPU const ncclDataType_t ncclFloatX = ncclHalf; const ncclDataType_t ncclFloatN = ncclFloat; #endif -// fallback for fp32 +// bfloat16 (default!) #else -typedef float floatX; +typedef __nv_bfloat16 floatX; typedef float floatN; -#define CUBLAS_LOWP CUDA_R_32F -#define CUBLAS_LOWP_COMPUTE cublas_compute_type // auto-select FP32 vs TF32 +#define CUBLAS_LOWP CUDA_R_16BF +#define CUBLAS_LOWP_COMPUTE CUBLAS_COMPUTE_32F +const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights +PrecisionMode PRECISION_MODE = PRECISION_BF16; +const char* precision_mode_str = "bf16"; #ifdef MULTI_GPU -const ncclDataType_t ncclFloatX = ncclFloat; +const ncclDataType_t ncclFloatX = ncclBfloat16; const ncclDataType_t ncclFloatN = ncclFloat; #endif - #endif // ---------------------------------------------------------------------------- @@ -267,6 +279,7 @@ FILE *fopen_check(const char *path, const char *mode, const char *file, int line fprintf(stderr, " Line: %d\n", line); fprintf(stderr, " Path: %s\n", path); fprintf(stderr, " Mode: %s\n", mode); + fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); exit(EXIT_FAILURE); } return fp; @@ -1468,25 +1481,27 @@ typedef struct { } GPT2Config; // the parameters of the model -#define NUM_PARAMETER_TENSORS 16 +// note the layernorms are kept in higher precision (floatN) +constexpr const int NUM_PARAMETER_TENSORS = 16; typedef struct { - floatX* wte; // (V, C) - floatX* wpe; // (maxT, C) - floatN* ln1w; // (L, C) - floatN* ln1b; // (L, C) + floatX* wte; // (V, C) + floatX* wpe; // (maxT, C) + floatN* ln1w; // (L, C) + floatN* ln1b; // (L, C) floatX* qkvw; // (L, 3*C, C) floatX* qkvb; // (L, 3*C) floatX* attprojw; // (L, C, C) floatX* attprojb; // (L, C) - floatN* ln2w; // (L, C) - floatN* ln2b; // (L, C) + floatN* ln2w; // (L, C) + floatN* ln2b; // (L, C) floatX* fcw; // (L, 4*C, C) floatX* fcb; // (L, 4*C) floatX* fcprojw; // (L, C, 4*C) floatX* fcprojb; // (L, C) - floatN* lnfw; // (C) - floatN* lnfb; // (C) + floatN* lnfw; // (C) + floatN* lnfb; // (C) } ParameterTensors; +static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) { size_t V = config.vocab_size; @@ -1510,11 +1525,10 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf param_sizes[14] = C; // lnfw param_sizes[15] = C; // lnfb - // Set parameter sizes - // floatN gives us an option to keep layernorm params in FP32 if we want to + // populate the parameter sizes in bytes for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { param_sizeof[i] = sizeof(floatX); - } + } // override layernorms here below param_sizeof[2] = sizeof(floatN); // ln1w param_sizeof[3] = sizeof(floatN); // ln1b param_sizeof[8] = sizeof(floatN); // ln2w @@ -1524,8 +1538,8 @@ void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Conf } // allocate memory for the parameters and point the individual tensors to the right places -float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elements, size_t *param_sizeof, int on_device) { - // calculate the number of parameters +void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elements, size_t *param_sizeof) { + // calculate the total number of parameters and bytes across all tensors size_t num_parameters = 0; size_t num_parameters_bytes = 0; for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { @@ -1533,13 +1547,8 @@ float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_eleme num_parameters_bytes += param_elements[i] * param_sizeof[i]; } // malloc all parameters all at once on the device - // on_device: 0 = CPU, 1 = GPU - float* params_memory; - if (on_device) { - cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters_bytes)); - } else { - params_memory = (float*)mallocCheck(num_parameters * sizeof(float)); // keep FP32 here - } + void* params_memory; + cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters_bytes)); // assign all the tensors their place in the array floatX** ptrs[] = { ¶ms->wte, ¶ms->wpe, (floatX**)¶ms->ln1w, (floatX**)¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, @@ -1700,12 +1709,27 @@ typedef struct { void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { + if (PRECISION_MODE == PRECISION_FP16) { + // TODO for later perhaps, would require us dynamically converting the + // model weights from fp32 to fp16 online, here in this function, or writing + // the fp16 weights directly from Python, which we only do for fp32/bf16 atm. + fprintf(stderr, "build_from_checkpoint() does not support fp16 right now.\n"); + exit(EXIT_FAILURE); + } + // read in model from a checkpoint file FILE *model_file = fopenCheck(checkpoint_path, "rb"); int model_header[256]; freadCheck(model_header, sizeof(int), 256, model_file); - if (model_header[0] != 20240326) { printf("Bad magic model file"); exit(EXIT_FAILURE); } - if (model_header[1] != 1) { printf("Bad version in model file"); exit(EXIT_FAILURE); } + if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(EXIT_FAILURE); } + int version = model_header[1]; + if (!(version == 1 || version == 2)) { + // 1 = fp32, ordered layernorm at the end + // 2 = bf16, ordered layernorm at the end + fprintf(stderr, "Bad version in model file\n"); + fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); + exit(EXIT_FAILURE); + } // read in hyperparameters model->config.max_seq_len = model_header[2]; @@ -1723,34 +1747,14 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->num_parameters += model->param_elements[i]; model->num_parameters_bytes += model->param_elements[i] * model->param_sizeof[i]; } - size_t input_model_bytes = model->num_parameters * sizeof(float); // create memory for model parameters on the device - model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof, 1); + model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); // read in all the parameters from file and copy them to device - float* params_memory_cpu = (float*)mallocCheck(input_model_bytes); - freadCheck(params_memory_cpu, 1, input_model_bytes, model_file); - - float* params_cpu_iterator = (float*)params_memory_cpu; - char* params_gpu_iterator = (char*)model->params_memory; - - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - if (model->param_sizeof[i] == sizeof(float)) { - cudaCheck(cudaMemcpy(params_gpu_iterator, params_cpu_iterator, model->param_elements[i] * sizeof(float), cudaMemcpyHostToDevice)); - } else { - // TODO: Currently only support float or floatX (cannot mix and match FP16/BF16 etc...) - assert(model->param_sizeof[i] == sizeof(floatX)); - floatX* conversion_scratchpad = (floatX*)mallocCheck(model->param_elements[i] * sizeof(floatX)); - for (size_t j = 0; j < model->param_elements[i]; j++) { - conversion_scratchpad[j] = (floatX)params_cpu_iterator[j]; - } - cudaCheck(cudaMemcpy(params_gpu_iterator, conversion_scratchpad, model->param_elements[i] * sizeof(floatX), cudaMemcpyHostToDevice)); - free(conversion_scratchpad); - } - params_cpu_iterator += model->param_elements[i]; - params_gpu_iterator += model->param_elements[i] * model->param_sizeof[i]; - } + float* params_memory_cpu = (float*)mallocCheck(model->num_parameters_bytes); + freadCheck(params_memory_cpu, 1, model->num_parameters_bytes, model_file); + cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); free(params_memory_cpu); fcloseCheck(model_file); @@ -1922,14 +1926,12 @@ void gpt2_backward(GPT2 *model) { // lazily allocate the memory for gradients of the weights and activations, if needed if (model->grads_memory == NULL) { // allocate buffers for weight gradients - model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof, 1); + model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof); printf0("allocated %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024))); // we're going to be clever for the activations backward pass. we don't need to exactly - // mirror the forward pass acrtivations and we will save memory. + // mirror the forward pass activations and we will save memory. size_t bw_act_sizes[NUM_ACTIVATION_TENSORS]; - GPT2Config cfg = model->config; - cfg.num_layers = 1; // copy the configuration but override number of layers to 1 - fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, cfg); + fill_in_grad_act_sizes(bw_act_sizes, model->batch_size, model->seq_len, model->config); // count up and allocate the space model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes); model->num_grad_acts = 0; @@ -2087,15 +2089,13 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo float beta1_correction = 1.0f - powf(beta1, t); float beta2_correction = 1.0f - powf(beta2, t); - // Do adam per set of parameters + // Adam upadte // We need to know the parameter types (float or floatX) to process consecutive chunks - // TODO - optimise this to require fewer kernel launches and/or independent via CUDA streams char* params_mem = (char*)model->params_memory; char* grads_mem = (char*)model->grads_memory; size_t num_elements = model->param_elements[0]; size_t last_sizeof = model->param_sizeof[0]; size_t current_element = 0; - for (int i = 1; i <= NUM_PARAMETER_TENSORS; i++) { if (i == NUM_PARAMETER_TENSORS || model->param_sizeof[i] != last_sizeof) { unsigned int seed = random_u32(&model->rng_state); // seed for stochastic rounding @@ -2428,11 +2428,13 @@ int main(int argc, char *argv[]) { cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); printf0("| device | %-50s |\n", deviceProp.name); printf0("| TF32 | %-50s |\n", enable_tf32 ? "enabled" : "disabled"); + printf0("| precision | %-50s |\n", precision_mode_str); printf0("+-----------------------+----------------------------------------------------+\n"); // build the GPT-2 model from a checkpoint GPT2 model; - gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); + gpt2_build_from_checkpoint(&model, load_filename); + printf0("| load_filename | %-50s |\n", load_filename); printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len); printf0("| vocab_size V | %-50d |\n", model.config.vocab_size); printf0("| num_layers L | %-50d |\n", model.config.num_layers); diff --git a/train_gpt2.py b/train_gpt2.py index 5bd75f028..5c119c251 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -265,10 +265,14 @@ def write_tensors_fp32(model_tensors, L, file): write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, ) def write_tensors_bf16(model_tensors, L, file): - # same as fp32, but note we will re-order the tensors - # because we keep the layernorm in fp32, we place them all at the end + # same but we keep the layernorm in fp32 + # these two functions are so similar we can join them later most likely write_bf16(model_tensors["transformer.wte.weight"], file) # (V, C) write_bf16(model_tensors["transformer.wpe.weight"], file) # (T, C) + for i in range(L): # (L, C) + write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) + for i in range(L): # (L, C) + write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file) for i in range(L): # (L, 3C, C) write_bf16(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file) for i in range(L): # (L, 3C) @@ -277,6 +281,10 @@ def write_tensors_bf16(model_tensors, L, file): write_bf16(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file) for i in range(L): # (L, C) write_bf16(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file) + for i in range(L): # (L, C) + write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) + for i in range(L): # (L, C) + write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file) for i in range(L): # (L, 4C, C) write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file) for i in range(L): # (L, 4C) @@ -285,15 +293,6 @@ def write_tensors_bf16(model_tensors, L, file): write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file) for i in range(L): # (L, C) write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file) - # LayerNorms are at the end and kept in fp32 - for i in range(L): # (L, C) - write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file) - for i in range(L): # (L, C) - write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file) - for i in range(L): # (L, C) - write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file) - for i in range(L): # (L, C) - write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file) write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, ) write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, ) @@ -319,10 +318,8 @@ def write_model(model, filename, dtype): # write header file.write(header.numpy().tobytes()) # write params - if dtype == "float32": - write_tensors_fp32(params, model.config.n_layer, file) - elif dtype == "bfloat16": - write_tensors_bf16(params, model.config.n_layer, file) + write_fun = write_tensors_fp32 if dtype == "float32" else write_tensors_bf16 + write_fun(params, model.config.n_layer, file) print(f"wrote {filename}") def write_state(model, x, y, logits, loss, filename): diff --git a/train_gpt2_fp32.cu b/train_gpt2_fp32.cu index cfe5c7d90..f07e78612 100644 --- a/train_gpt2_fp32.cu +++ b/train_gpt2_fp32.cu @@ -1879,8 +1879,8 @@ void logger_free(Logger *logger) { void error_usage() { // default run = debugging run with TinyShakespeare // bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens, write to logfile - fprintf(stderr, "Usage: ./train_gpt2cu [options]\n"); - fprintf(stderr, "Example: ./train_gpt2cu -i data/TinyStories -v 100 -s 100 -g 144 -o stories.log\n"); + fprintf(stderr, "Usage: ./train_gpt2fp32cu [options]\n"); + fprintf(stderr, "Example: ./train_gpt2fp32cu -i data/TinyStories -v 100 -s 100 -g 144 -o stories.log\n"); fprintf(stderr, "Options:\n"); fprintf(stderr, " -i input dataset prefix (default = data/tiny_shakespeare)\n"); fprintf(stderr, " -o output log file (default = NULL)\n");