From 835060e18d1ceeb6c19ad1c214b9f03cc13daf16 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 28 Apr 2024 18:47:03 +0000 Subject: [PATCH] padded vocab change. touched a lot of code. very stressful and error prone, but i think it is done. had to bump versions on all .bin files, invalidating the previous files. re-run the python training script to re-export the new version files. let's not do too much of things like this in the future lol. actually, fun fact i had a chance to do the padded vocab really really early in the history of llm.c development, and chose not do it, thinking i'll just do it later. i should have done it. such is life, you make mistakes, you accumulate scar tissue, and you learn, and you become better, faster, stronger. this is the mindset one must have to lead a happy and fulfilling life. it's not important that you are perfect at any point in time, it's only important that you keep improving, every day. --- test_gpt2.cu | 57 ++++++++++++++++++++++++++++------------------ test_gpt2_fp32.cu | 45 +++++++++++++++++++++++------------- train_gpt2.cu | 34 +++++++++++++++------------ train_gpt2.py | 39 ++++++++++++++++++++++++++++--- train_gpt2_fp32.cu | 32 +++++++++++++++++--------- 5 files changed, 141 insertions(+), 66 deletions(-) diff --git a/test_gpt2.cu b/test_gpt2.cu index 804c55f44..657d20768 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -42,7 +42,7 @@ int check_tensor(float *a, float *b, int n, const char* label, float threshold=1 // the same tensors as in the train file, but in float, which are used as reference typedef struct { - float* wte; // (V, C) + float* wte; // (Vp, C) float* wpe; // (maxT, C) float* ln1w; // (L, C) float* ln1b; // (L, C) @@ -109,6 +109,7 @@ int main(int argc, char *argv[]) { GPT2 model; gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin"); size_t V = model.config.vocab_size; + size_t Vp = model.config.padded_vocab_size; size_t maxT = model.config.max_seq_len; size_t L = model.config.num_layers; size_t C = model.config.channels; @@ -117,8 +118,12 @@ int main(int argc, char *argv[]) { FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb"); int state_header[256]; freadCheck(state_header, sizeof(int), 256, state_file); - if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(EXIT_FAILURE); } - if (state_header[1] != 1) { printf("Bad version in state file"); exit(EXIT_FAILURE); } + if (state_header[0] != 20240327) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); } + if (state_header[1] != 2) { + fprintf(stderr, "Bad version in state file\n"); + fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); + exit(EXIT_FAILURE); + } int B = state_header[2]; // batch size, e.g. 4 int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) assert(0 <= T && T <= maxT); @@ -154,13 +159,12 @@ int main(int argc, char *argv[]) { gpt2_forward(&model, x, NULL, B, T); // at this point, target should be equal to expected_logits, let's compare // copy logits to CPU so we can compare them - floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * V * sizeof(floatX)); - float* logits_cpu = (float*)mallocCheck(B * T * V * sizeof(float)); - cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * V * sizeof(floatX), cudaMemcpyDeviceToHost); - for (int i = 0; i < B * T * V; i++) { + floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX)); + float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float)); + cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost); + for (int i = 0; i < B * T * Vp; i++) { logits_cpu[i] = (float)logits_cpu_raw[i]; } - int logits_ok = 1; // FP16 and lower require very high tolerances unfortunately. TODO look into more float logit_accuracy_threshold = 1e-2f; @@ -169,19 +173,25 @@ int main(int argc, char *argv[]) { logit_accuracy_threshold = 15.0f; #endif - + // compare the output logits from the forward pass + // also careful that we don't access and compare the padded columns of logits + int logits_ok = 1; float max_diff = 0.0f; - for (int i=0; i= logit_accuracy_threshold) { - printf("MISMATCH AT INDEX %d: ", i); - printf("%f %f\n", expected_logits[i],logits_cpu[i]); - logits_ok = 0; - break; + for (int bt = 0; bt < B*T; bt++) { + for (int v = 0; v < V; v++) { + int i = bt * Vp + v; // linearized index + if (i < 10) { + printf("%f, %f\n", expected_logits[i], logits_cpu[i]); + } + float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]); + max_diff = fmaxf(max_diff, diff); + if (diff >= logit_accuracy_threshold) { + printf("MISMATCH AT INDEX %d,%d: ", bt, v); + printf("%f %f\n", expected_logits[bt*V + v], logits_cpu[i]); + logits_ok = 0; + bt = B*T; // to break out of both loops + break; + } } } allok = allok && logits_ok; @@ -244,10 +254,13 @@ int main(int argc, char *argv[]) { // I set the tolerances manually by inspecting the gradient differences for // a few elements of each tensor. bf16 looks ok but not amazing here. // It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure. + // Also, if code changes and some of these get tripped, it could be ok if it's not by too much, + // because our use of stochastic rounding is adding some non-determinism "pepper noise". + // In that case it's ok to extend the tolerance by a bit, after a manual review. allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", 6e-1f); allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", 1e-2f); - allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", 9e-2); // hmm a bit high - allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", 3e-2f); + allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", 1.1e-1); // hmm a bit high + allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", 4e-2f); allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", 3e-2f); allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", 3e-2f); allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", 9e-2f); // hmm a bit high diff --git a/test_gpt2_fp32.cu b/test_gpt2_fp32.cu index cf6816572..01440072a 100644 --- a/test_gpt2_fp32.cu +++ b/test_gpt2_fp32.cu @@ -52,6 +52,7 @@ int main(int argc, char *argv[]) { // int C = model.config.channels; int V = model.config.vocab_size; + int Vp = model.config.padded_vocab_size; int maxT = model.config.max_seq_len; // int L = model.config.num_layers; @@ -59,8 +60,12 @@ int main(int argc, char *argv[]) { FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb"); int state_header[256]; freadCheck(state_header, sizeof(int), 256, state_file); - if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(1); } - if (state_header[1] != 1) { printf("Bad version in state file"); exit(1); } + if (state_header[0] != 20240327) { printf("Bad magic state file\n"); exit(EXIT_FAILURE); } + if (state_header[1] != 2) { + fprintf(stderr, "Bad version in state file\n"); + fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); + exit(EXIT_FAILURE); + } int B = state_header[2]; // batch size, e.g. 4 int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) assert(0 <= T && T <= maxT); @@ -94,20 +99,31 @@ int main(int argc, char *argv[]) { gpt2_forward(&model, x, NULL, B, T); // at this point, target should be equal to expected_logits, let's compare // copy logits to CPU so we can compare them - float* logits_cpu = (float*)mallocCheck(B * T * V * sizeof(float)); - cudaMemcpy(logits_cpu, model.acts.output, B * T * V * sizeof(float), cudaMemcpyDeviceToHost); + float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float)); + cudaMemcpy(logits_cpu, model.acts.output, B * T * Vp * sizeof(float), cudaMemcpyDeviceToHost); + + // compare the output logits from the forward pass + // also careful that we don't access and compare the padded columns of logits int logits_ok = 1; - for (int i=0; i= 1e-2) { - printf("MISMATCH AT INDEX %d: ", i); - printf("%f %f\n", expected_logits[i],logits_cpu[i]); - logits_ok = 0; - break; + float max_diff = 0.0f; + for (int bt = 0; bt < B*T; bt++) { + for (int v = 0; v < V; v++) { + int i = bt * Vp + v; // linearized index + if (i < 10) { + printf("%f, %f\n", expected_logits[i], logits_cpu[i]); + } + float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]); + max_diff = fmaxf(max_diff, diff); + if (diff >= 1e-2f) { + printf("MISMATCH AT INDEX %d,%d: ", bt, v); + printf("%f %f\n", expected_logits[bt*V + v], logits_cpu[i]); + logits_ok = 0; + bt = B*T; // to break out of both loops + break; + } } } + allok = allok && logits_ok; if(!logits_ok) { printf("NOT "); } printf("OK (LOGITS)\n"); @@ -124,9 +140,6 @@ int main(int argc, char *argv[]) { if (step == 0) { // error checking at step 0 for reference activations - - - allok = allok && logits_ok; free(logits_cpu); // compare the achieved loss diff --git a/train_gpt2.cu b/train_gpt2.cu index e941277c7..288f73d2c 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1108,10 +1108,12 @@ __global__ void fused_classifier_kernel3(Type* logits, Type* losses, Type* probs } // very sensible default for dlosses is 1/(B*T), which is the uniform loss - float dloss = dlosses != NULL ? (float)dlosses[idx] : 1.0f / (B*T); + float dloss = (dlosses != NULL) ? (float)dlosses[idx] : 1.0f / (B*T); // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging const Type* logits_vec = logits + idx * P; + // note that we use the padded dimension P to access data, but we only ever + // modify the elements up to V, ignoring the padded dimensions and leaving them at 0 for (int i = threadIdx.x; i < V; i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // this data will never be needed again, so we reduce cache persistence @@ -1475,6 +1477,7 @@ void fused_classifier3(Type* logits, Type* losses, typedef struct { int max_seq_len; // max sequence length, e.g. 1024 int vocab_size; // vocab size, e.g. 50257 + int padded_vocab_size; // padded to e.g. %128==0, 50304 int num_layers; // number of layers, e.g. 12 int num_heads; // number of heads in attention, e.g. 12 int channels; // number of channels, e.g. 768 @@ -1504,11 +1507,11 @@ typedef struct { static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) { - size_t V = config.vocab_size; + size_t Vp = config.padded_vocab_size; size_t C = config.channels; size_t maxT = config.max_seq_len; size_t L = config.num_layers; - param_sizes[0] = V * C; // wte + param_sizes[0] = Vp * C; // wte param_sizes[1] = maxT * C; // wpe param_sizes[2] = L * C; // ln1w param_sizes[3] = L * C; // ln1b @@ -1595,7 +1598,7 @@ typedef struct { } ActivationTensors; void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config config) { - size_t V = config.vocab_size; + size_t Vp = config.padded_vocab_size; size_t L = config.num_layers; size_t NH = config.num_heads; size_t C = config.channels; @@ -1619,7 +1622,7 @@ void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config act_sizes[17] = B * T; // lnf_rstd act_sizes[18] = B * T; // losses act_sizes[19] = L * B * T * 3*C; // qkvr - act_sizes[20] = B * T * max(3*C, max(NH*T, V)); // output / scratch + act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch } // Backward pass is conceptually quite different from forward, because we can discard @@ -1723,9 +1726,9 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { freadCheck(model_header, sizeof(int), 256, model_file); if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(EXIT_FAILURE); } int version = model_header[1]; - if (!(version == 1 || version == 2)) { - // 1 = fp32, ordered layernorm at the end - // 2 = bf16, ordered layernorm at the end + if (!(version == 3 || version == 4)) { + // 3 = fp32, padded vocab + // 4 = bf16, padded vocab fprintf(stderr, "Bad version in model file\n"); fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); exit(EXIT_FAILURE); @@ -1737,6 +1740,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->config.num_layers = model_header[4]; model->config.num_heads = model_header[5]; model->config.channels = model_header[6]; + model->config.padded_vocab_size = model_header[7]; // allocate space for all the parameters and read them in fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config); @@ -1786,6 +1790,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) { // convenience parameters size_t V = model->config.vocab_size; + size_t Vp = model->config.padded_vocab_size; size_t L = model->config.num_layers; size_t NH = model->config.num_heads; size_t C = model->config.channels; @@ -1890,13 +1895,13 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) { residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); - matmul_forward_cublas(acts.output, acts.lnf, params.wte, NULL, B, T, C, V); + matmul_forward_cublas(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp); // also forward the cross-entropy loss function if we have the targets if (targets != NULL) { // fused classifier: does the forward pass and first part of the backward pass // we're passing dlosses = NULL, which will default them to 1.0f/(B*T), i.e. uniform loss - fused_classifier3(acts.output, acts.losses, (floatX*)NULL, model->targets, B, T, V, V); + fused_classifier3(acts.output, acts.losses, (floatX*)NULL, model->targets, B, T, V, Vp); // for convenience also evaluate the mean loss (TODO re-think this compute+sync point) // move the (B,T) losses to CPU cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(floatX), cudaMemcpyDeviceToHost)); @@ -1946,7 +1951,7 @@ void gpt2_backward(GPT2 *model) { // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow size_t B = model->batch_size; size_t T = model->seq_len; - size_t V = model->config.vocab_size; + size_t Vp = model->config.padded_vocab_size; size_t L = model->config.num_layers; size_t NH = model->config.num_heads; size_t C = model->config.channels; @@ -1962,7 +1967,7 @@ void gpt2_backward(GPT2 *model) { // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, V); + matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, Vp); // backward the final layernorm floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 floatX* dresidual = (floatX*)grads_acts.residual3; // the main buffer holding the gradient in the backward pass @@ -2437,6 +2442,7 @@ int main(int argc, char *argv[]) { printf0("| load_filename | %-50s |\n", load_filename); printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len); printf0("| vocab_size V | %-50d |\n", model.config.vocab_size); + printf0("| padded_vocab_size Vp | %-50d |\n", model.config.padded_vocab_size); printf0("| num_layers L | %-50d |\n", model.config.num_layers); printf0("| num_heads NH | %-50d |\n", model.config.num_heads); printf0("| channels C | %-50d |\n", model.config.channels); @@ -2520,8 +2526,8 @@ int main(int argc, char *argv[]) { // we're in principle running B "inference streams" in parallel here // only using position 0 because it's a bit faster (copy less probs from GPU -> CPU) // get the V-dimensional vector probs[0, t-1, :] - floatX* logits = model.acts.output + (t - 1) * model.config.vocab_size; - // move probs back to CPU and sample + floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size; + // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost)); // convert to FP32 into cpu_logits (this does nothing useful if floatX == float) for (int i = 0; i < model.config.vocab_size; i++) { diff --git a/train_gpt2.py b/train_gpt2.py index 5c119c251..42c955d34 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -296,13 +296,34 @@ def write_tensors_bf16(model_tensors, L, file): write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, ) write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, ) +@torch.no_grad() +def pad_vocab(tensor, multiple=128, value=0): + """ + The dimension of the vocab size in GPT-2 is 50,257 + which is unfortunately a very unfriendly number for a lot of + matrix operations on the GPU. So we pad it to the nearest + friendlier multiple, e.g. 50,304 if multiple=128 when we + export the weights into C land. This is a NOOP algorithmically + and is only done to make the tensor operations more efficient. + """ + assert tensor.ndim == 2 + V, C = tensor.shape + assert V == 50257, "just being defensive here" + # calculate padded vocab size by rounding up to nearest multiple + Vp = ((V + multiple - 1) // multiple) * multiple + # pad the tensor + pad_rows = Vp - V + padded = tensor if pad_rows == 0 else F.pad(tensor, (0, 0, 0, pad_rows), value=value) + assert padded.shape == (Vp, C) + return padded + def write_model(model, filename, dtype): # everything we need to instantiate the model # 1) header is: version int, GPTConfig ints, padding to 1024 bytes assert dtype in {"float32", "bfloat16"} # float16 todo maybe later version = { - "float32": 1, - "bfloat16": 2, + "float32": 3, + "bfloat16": 4, }[dtype] header = torch.zeros(256, dtype=torch.int32) header[0] = 20240326 # magic @@ -314,6 +335,13 @@ def write_model(model, filename, dtype): header[6] = model.config.n_embd # 2) the parameters follow the header params = {name: param.cpu() for name, param in model.named_parameters()} + # pad the vocab to a multiple of 128 here at export, for efficiency in C + wte = params["transformer.wte.weight"] # (V, C) + wte_padded = pad_vocab(wte) # (Vp, C) + params["transformer.wte.weight"] = wte_padded # (Vp, C) + print(f"padded vocab size from {wte.size(0)} to {wte_padded.size(0)}") + header[7] = wte_padded.size(0) # padded vocab size store in header + # now write to file with open(filename, "wb") as file: # write header file.write(header.numpy().tobytes()) @@ -328,10 +356,15 @@ def write_state(model, x, y, logits, loss, filename): # this can be used for checking the computation correctness in C header = torch.zeros(256, dtype=torch.int32) header[0] = 20240327 # magic - header[1] = 1 # run state version = 1 + header[1] = 2 # run state version = 2 (1 -> 2 for padded vocab changes) header[2] = x.size(0) # batch size of the batch, B header[3] = x.size(1) # temporal extent of the batch, T grads = {name: param.grad.cpu() for name, param in model.named_parameters()} + # pad the vocab grads here as well, to mirror write_model + wte_grad = grads["transformer.wte.weight"] # (V, C) + wte_grad_padded = pad_vocab(wte_grad, value=0) # (Vp, C) # TODO later maybe pad with nan? + grads["transformer.wte.weight"] = wte_grad_padded # (Vp, C) + print(f"padded vocab size in reference grads from {wte_grad.size(0)} to {wte_grad_padded.size(0)}") with open(filename, "wb") as file: # header file.write(header.numpy().tobytes()) diff --git a/train_gpt2_fp32.cu b/train_gpt2_fp32.cu index f07e78612..19bc574dc 100644 --- a/train_gpt2_fp32.cu +++ b/train_gpt2_fp32.cu @@ -1085,6 +1085,7 @@ void fused_classifier3(float* logits, float* losses, typedef struct { int max_seq_len; // max sequence length, e.g. 1024 int vocab_size; // vocab size, e.g. 50257 + int padded_vocab_size; // padded to e.g. %128==0, 50304 int num_layers; // number of layers, e.g. 12 int num_heads; // number of heads in attention, e.g. 12 int channels; // number of channels, e.g. 768 @@ -1112,11 +1113,12 @@ typedef struct { } ParameterTensors; void fill_in_parameter_sizes(size_t* param_sizes, GPT2Config config) { + int Vp = config.padded_vocab_size; int V = config.vocab_size; int C = config.channels; int maxT = config.max_seq_len; int L = config.num_layers; - param_sizes[0] = V * C; // wte + param_sizes[0] = Vp * C; // wte param_sizes[1] = maxT * C; // wpe param_sizes[2] = L * C; // ln1w param_sizes[3] = L * C; // ln1b @@ -1196,7 +1198,7 @@ typedef struct { } ActivationTensors; void fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config) { - size_t V = config.vocab_size; + size_t Vp = config.padded_vocab_size; size_t L = config.num_layers; size_t NH = config.num_heads; size_t C = config.channels; @@ -1220,7 +1222,7 @@ void fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config act_sizes[17] = B * T; // lnf_rstd act_sizes[18] = B * T; // losses act_sizes[19] = L * B * T * 3*C; // qkvr - act_sizes[20] = B * T * max(3*C, max(NH*T, V)); // output / scratch + act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch } // Backward pass is conceptually quite different from forward, because we can discard @@ -1312,8 +1314,13 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { FILE *model_file = fopenCheck(checkpoint_path, "rb"); int model_header[256]; freadCheck(model_header, sizeof(int), 256, model_file); - if (model_header[0] != 20240326) { printf("Bad magic model file"); exit(EXIT_FAILURE); } - if (model_header[1] != 1) { printf("Bad version in model file"); exit(EXIT_FAILURE); } + if (model_header[0] != 20240326) { fprintf(stderr, "Bad magic model file\n"); exit(EXIT_FAILURE); } + if (model_header[1] != 3) { + // was bumped from 1 -> 3 to incorporate the padded vocab size + fprintf(stderr, "Bad version in model file\n"); + fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n"); + exit(EXIT_FAILURE); + } // read in hyperparameters model->config.max_seq_len = model_header[2]; @@ -1321,6 +1328,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->config.num_layers = model_header[4]; model->config.num_heads = model_header[5]; model->config.channels = model_header[6]; + model->config.padded_vocab_size = model_header[7]; // allocate space for all the parameters and read them in fill_in_parameter_sizes(model->param_sizes, model->config); @@ -1367,6 +1375,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { // convenience parameters int V = model->config.vocab_size; + int Vp = model->config.padded_vocab_size; int L = model->config.num_layers; int NH = model->config.num_heads; int C = model->config.channels; @@ -1471,13 +1480,13 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); - matmul_forward_cublas(acts.output, acts.lnf, params.wte, NULL, B, T, C, V); + matmul_forward_cublas(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp); // also forward the cross-entropy loss function if we have the targets if (targets != NULL) { // fused classifier: does the forward pass and first part of the backward pass // we're passing dlosses = NULL, which will default them to 1.0f/(B*T), i.e. uniform loss - fused_classifier3(acts.output, acts.losses, NULL, model->targets, B, T, V, V); + fused_classifier3(acts.output, acts.losses, NULL, model->targets, B, T, V, Vp); // for convenience also evaluate the mean loss (TODO re-think this compute+sync point) // move the (B,T) losses to CPU cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost)); @@ -1530,7 +1539,7 @@ void gpt2_backward(GPT2 *model) { // convenience shortcuts int B = model->batch_size; int T = model->seq_len; - int V = model->config.vocab_size; + int Vp = model->config.padded_vocab_size; int L = model->config.num_layers; int NH = model->config.num_heads; int C = model->config.channels; @@ -1546,7 +1555,7 @@ void gpt2_backward(GPT2 *model) { // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, V); + matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, Vp); // backward the final layernorm float* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 float* dresidual = grads_acts.residual3; // the main buffer holding the gradient in the backward pass @@ -1961,6 +1970,7 @@ int main(int argc, char *argv[]) { gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); printf("| max_sequence_length T | %-50d |\n", model.config.max_seq_len); printf("| vocab_size V | %-50d |\n", model.config.vocab_size); + printf("| padded_vocab_size Vp | %-50d |\n", model.config.padded_vocab_size); printf("| num_layers L | %-50d |\n", model.config.num_layers); printf("| num_heads NH | %-50d |\n", model.config.num_heads); printf("| channels C | %-50d |\n", model.config.channels); @@ -2037,8 +2047,8 @@ int main(int argc, char *argv[]) { // we're in principle running B "inference streams" in parallel here // only using position 0 because it's a bit faster (copy less probs from GPU -> CPU) // get the V-dimensional vector probs[0, t-1, :] - float* logits = model.acts.output + (t - 1) * model.config.vocab_size; - // move probs back to CPU and sample + float* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size; + // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) cudaCheck(cudaMemcpy(cpu_logits, logits, model.config.vocab_size * sizeof(float), cudaMemcpyDeviceToHost)); float coin = random_f32(&rng_state); int next_token = sample_softmax(cpu_logits, model.config.vocab_size, coin);