From b7972ff928cec444b3d104641c244cf7ceb94352 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 28 Apr 2024 19:33:21 +0000 Subject: [PATCH] make padded vocab fixes in the .c code as well, i missed it in the previous PR, should satisfy the CI now --- test_gpt2.c | 38 +++++++++++++++++++---------- train_gpt2.c | 69 ++++++++++++++++++++++++++++++++++------------------ 2 files changed, 70 insertions(+), 37 deletions(-) diff --git a/test_gpt2.c b/test_gpt2.c index c6cfebf06..e49b73fad 100644 --- a/test_gpt2.c +++ b/test_gpt2.c @@ -44,6 +44,7 @@ int main(int argc, char *argv[]) { int C = model.config.channels; int V = model.config.vocab_size; + int Vp = model.config.padded_vocab_size; int maxT = model.config.max_seq_len; int L = model.config.num_layers; @@ -52,8 +53,12 @@ int main(int argc, char *argv[]) { if (state_file == NULL) { printf("Error opening state file\n"); return 1; } int state_header[256]; fread(state_header, sizeof(int), 256, state_file); - if (state_header[0] != 20240327) { printf("Bad magic state file"); return 1; } - if (state_header[1] != 1) { printf("Bad version in state file"); return 1; } + if (state_header[0] != 20240327) { printf("Bad magic state file\n"); return 1; } + if (state_header[1] != 2) { + printf("Bad version in state file\n"); + printf("---> HINT: try to re-run `python train_gpt2.py`\n"); + return 1; + } int B = state_header[2]; // batch size, e.g. 4 int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) printf("[State]\n"); @@ -107,22 +112,29 @@ int main(int argc, char *argv[]) { if (step == 0) { // error checking at step 0 for reference activations/gradients - // at this point, target should be equal to expected_logits, let's compare int logits_ok = 1; - for (int i=0; i= 1e-2) { - printf("MISMATCH AT INDEX %d: ", i); - printf("%f %f\n", expected_logits[i],model.acts.logits[i]); - logits_ok = 0; - break; + float* calculated_logits = model.acts.logits; + float max_diff = 0.0f; + for (int bt = 0; bt < B*T; bt++) { + for (int v = 0; v < V; v++) { // note we only loop to V (ignoring padding) + int i = bt * Vp + v; // linearized index, using Vp + if (i < 10) { + printf("%f, %f\n", expected_logits[i], calculated_logits[i]); + } + float diff = fabsf(expected_logits[bt*V + v] - calculated_logits[i]); + max_diff = fmaxf(max_diff, diff); + if (diff >= 1e-2f) { + printf("MISMATCH AT INDEX %d,%d: ", bt, v); + printf("%f %f\n", expected_logits[bt*V + v], calculated_logits[i]); + logits_ok = 0; + bt = B*T; // to break out of both loops + break; + } } } if(!logits_ok) { printf("NOT "); } - printf("OK (LOGITS)\n"); + printf("OK (LOGITS), max_diff = %e\n", max_diff); allok = allok && logits_ok; // compare the achieved loss diff --git a/train_gpt2.c b/train_gpt2.c index ccf29def8..2539e1479 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -395,15 +395,17 @@ void residual_backward(float* dinp1, float* dinp2, float* dout, int N) { } } -void softmax_forward(float* probs, float* logits, int B, int T, int V) { - // output: probs are (B,T,V) of the probabilities (sums to 1.0 in each b,t position) - // input: logits is (B,T,V) of the unnormalized log probabilities +void softmax_forward(float* probs, float* logits, int B, int T, int V, int Vp) { + // output: probs are (B,T,Vp) of the probabilities (sums to 1.0 in each b,t position) + // input: logits is (B,T,Vp) of the unnormalized log probabilities + // Vp is the padded vocab size (for efficiency), V is the "real" vocab size + // example: Vp is 50304 and V is 50257 #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { // probs <- softmax(logits) - float* logits_bt = logits + b * T * V + t * V; - float* probs_bt = probs + b * T * V + t * V; + float* logits_bt = logits + b * T * Vp + t * Vp; + float* probs_bt = probs + b * T * Vp + t * Vp; // maxval is only calculated and subtracted for numerical stability float maxval = -10000.0f; // TODO something better @@ -417,23 +419,29 @@ void softmax_forward(float* probs, float* logits, int B, int T, int V) { probs_bt[i] = expf(logits_bt[i] - maxval); sum += probs_bt[i]; } + // note we only loop to V, leaving the padded dimensions for (int i = 0; i < V; i++) { probs_bt[i] /= sum; } + // for extra super safety we may wish to include this too, + // forcing the probabilities here to be zero, but it shouldn't matter + for (int i = V; i < Vp; i++) { + probs_bt[i] = 0.0f; + } } } } void crossentropy_forward(float* losses, float* probs, int* targets, - int B, int T, int V) { + int B, int T, int Vp) { // output: losses is (B,T) of the individual losses at each position - // input: probs are (B,T,V) of the probabilities + // input: probs are (B,T,Vp) of the probabilities // input: targets is (B,T) of integers giving the correct index in logits for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { // loss = -log(probs[target]) - float* probs_bt = probs + b * T * V + t * V; + float* probs_bt = probs + b * T * Vp + t * Vp; int ix = targets[b * T + t]; losses[b * T + t] = -logf(probs_bt[ix]); } @@ -442,14 +450,16 @@ void crossentropy_forward(float* losses, void crossentropy_softmax_backward(float* dlogits, float* dlosses, float* probs, int* targets, - int B, int T, int V) { + int B, int T, int V, int Vp) { // backwards through both softmax and crossentropy for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* dlogits_bt = dlogits + b * T * V + t * V; - float* probs_bt = probs + b * T * V + t * V; + float* dlogits_bt = dlogits + b * T * Vp + t * Vp; + float* probs_bt = probs + b * T * Vp + t * Vp; float dloss = dlosses[b * T + t]; int ix = targets[b * T + t]; + // note we only loop to V, leaving the padded dimensions + // of dlogits untouched, so gradient there stays at zero for (int i = 0; i < V; i++) { float p = probs_bt[i]; float indicator = i == ix ? 1.0f : 0.0f; @@ -555,6 +565,7 @@ float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) typedef struct { int max_seq_len; // max sequence length, e.g. 1024 int vocab_size; // vocab size, e.g. 50257 + int padded_vocab_size; // padded to e.g. %128==0, 50304 int num_layers; // number of layers, e.g. 12 int num_heads; // number of heads in attention, e.g. 12 int channels; // number of channels, e.g. 768 @@ -596,25 +607,31 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { if (model_file == NULL) { printf("Error opening model file\n"); exit(1); } int model_header[256]; fread(model_header, sizeof(int), 256, model_file); - if (model_header[0] != 20240326) { printf("Bad magic model file"); exit(1); } - if (model_header[1] != 1) { printf("Bad version in model file"); exit(1); } + if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(1); } + if (model_header[1] != 3) { + printf("Bad version in model file\n"); + printf("---> HINT: try to re-run `python train_gpt2.py`\n"); + exit(1); + } // read in hyperparameters - size_t maxT, V, L, NH, C; // size_t to prevent int overflow + size_t maxT, V, Vp, L, NH, C; // size_t to prevent int overflow model->config.max_seq_len = maxT = model_header[2]; model->config.vocab_size = V = model_header[3]; model->config.num_layers = L = model_header[4]; model->config.num_heads = NH = model_header[5]; model->config.channels = C = model_header[6]; + model->config.padded_vocab_size = Vp = model_header[7]; printf("[GPT-2]\n"); printf("max_seq_len: %zu\n", maxT); printf("vocab_size: %zu\n", V); + printf("padded_vocab_size: %zu\n", Vp); printf("num_layers: %zu\n", L); printf("num_heads: %zu\n", NH); printf("channels: %zu\n", C); // allocate space for all the parameters and read them in - model->param_sizes[0] = V * C; // wte + model->param_sizes[0] = Vp * C; // wte model->param_sizes[1] = maxT * C; // wpe model->param_sizes[2] = L * C; // ln1w model->param_sizes[3] = L * C; // ln1b @@ -668,6 +685,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) { // convenience parameters (size_t to help prevent int overflow) size_t V = model->config.vocab_size; + size_t Vp = model->config.padded_vocab_size; size_t L = model->config.num_layers; size_t NH = model->config.num_heads; size_t C = model->config.channels; @@ -706,8 +724,8 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) { model->act_sizes[17] = B * T * C; // lnf model->act_sizes[18] = B * T; // lnf_mean model->act_sizes[19] = B * T; // lnf_rstd - model->act_sizes[20] = B * T * V; // logits - model->act_sizes[21] = B * T * V; // probs + model->act_sizes[20] = B * T * Vp; // logits + model->act_sizes[21] = B * T * Vp; // probs model->act_sizes[22] = B * T; // losses size_t num_activations = 0; for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { @@ -789,12 +807,12 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) { } residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); - matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, V); - softmax_forward(acts.probs, acts.logits, B, T, V); + matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp); + softmax_forward(acts.probs, acts.logits, B, T, V, Vp); // also forward the cross-entropy loss function if we have the targets if (targets != NULL) { - crossentropy_forward(model->acts.losses, model->acts.probs, targets, B, T, V); + crossentropy_forward(model->acts.losses, model->acts.probs, targets, B, T, Vp); // for convenience also evaluate the mean loss float mean_loss = 0.0f; for (int i=0; iacts.losses[i]; } @@ -830,6 +848,7 @@ void gpt2_backward(GPT2 *model) { size_t B = model->batch_size; size_t T = model->seq_len; size_t V = model->config.vocab_size; + size_t Vp = model->config.padded_vocab_size; size_t L = model->config.num_layers; size_t NH = model->config.num_heads; size_t C = model->config.channels; @@ -846,8 +865,8 @@ void gpt2_backward(GPT2 *model) { float dloss_mean = 1.0f / (B*T); for (int i = 0; i < B*T; i++) { grads_acts.losses[i] = dloss_mean; } - crossentropy_softmax_backward(grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V); - matmul_backward(grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, V); + crossentropy_softmax_backward(grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp); + matmul_backward(grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, Vp); float* residual = acts.residual3 + (L-1) * B * T * C; // last layer's residual float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; // write to last layer's residual layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C); @@ -1210,9 +1229,11 @@ int main() { // furthermore, below we're only using b=0 (i.e. the first row) of all B rows // we're in principle running B "inference streams" in parallel here // but only using position 0 - // get the V-dimensional vector probs[0, t-1, :] - float* probs = model.acts.probs + (t-1) * model.config.vocab_size; + // get the Vp-dimensional vector probs[0, t-1, :] + float* probs = model.acts.probs + (t-1) * model.config.padded_vocab_size; float coin = random_f32(&rng_state); + // note we're only sampling from the first V elements, ignoring padding + // (the probabilities in the padded region should be zero anyway) int next_token = sample_mult(probs, model.config.vocab_size, coin); gen_tokens[t] = next_token; // print the generated token, either using the Tokenizer or a fallback