Skip to content

Commit

Permalink
make padded vocab fixes in the .c code as well, i missed it in the pr…
Browse files Browse the repository at this point in the history
…evious PR, should satisfy the CI now
  • Loading branch information
karpathy committed Apr 28, 2024
1 parent 835060e commit b7972ff
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 37 deletions.
38 changes: 25 additions & 13 deletions test_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ int main(int argc, char *argv[]) {

int C = model.config.channels;
int V = model.config.vocab_size;
int Vp = model.config.padded_vocab_size;
int maxT = model.config.max_seq_len;
int L = model.config.num_layers;

Expand All @@ -52,8 +53,12 @@ int main(int argc, char *argv[]) {
if (state_file == NULL) { printf("Error opening state file\n"); return 1; }
int state_header[256];
fread(state_header, sizeof(int), 256, state_file);
if (state_header[0] != 20240327) { printf("Bad magic state file"); return 1; }
if (state_header[1] != 1) { printf("Bad version in state file"); return 1; }
if (state_header[0] != 20240327) { printf("Bad magic state file\n"); return 1; }
if (state_header[1] != 2) {
printf("Bad version in state file\n");
printf("---> HINT: try to re-run `python train_gpt2.py`\n");
return 1;
}
int B = state_header[2]; // batch size, e.g. 4
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
printf("[State]\n");
Expand Down Expand Up @@ -107,22 +112,29 @@ int main(int argc, char *argv[]) {

if (step == 0) {
// error checking at step 0 for reference activations/gradients

// at this point, target should be equal to expected_logits, let's compare
int logits_ok = 1;
for (int i=0; i<B*T*V; i++) {
if(i < 3) {
printf("%f %f\n", expected_logits[i], model.acts.logits[i]);
}
if (fabsf(expected_logits[i] - model.acts.logits[i]) >= 1e-2) {
printf("MISMATCH AT INDEX %d: ", i);
printf("%f %f\n", expected_logits[i],model.acts.logits[i]);
logits_ok = 0;
break;
float* calculated_logits = model.acts.logits;
float max_diff = 0.0f;
for (int bt = 0; bt < B*T; bt++) {
for (int v = 0; v < V; v++) { // note we only loop to V (ignoring padding)
int i = bt * Vp + v; // linearized index, using Vp
if (i < 10) {
printf("%f, %f\n", expected_logits[i], calculated_logits[i]);
}
float diff = fabsf(expected_logits[bt*V + v] - calculated_logits[i]);
max_diff = fmaxf(max_diff, diff);
if (diff >= 1e-2f) {
printf("MISMATCH AT INDEX %d,%d: ", bt, v);
printf("%f %f\n", expected_logits[bt*V + v], calculated_logits[i]);
logits_ok = 0;
bt = B*T; // to break out of both loops
break;
}
}
}
if(!logits_ok) { printf("NOT "); }
printf("OK (LOGITS)\n");
printf("OK (LOGITS), max_diff = %e\n", max_diff);
allok = allok && logits_ok;

// compare the achieved loss
Expand Down
69 changes: 45 additions & 24 deletions train_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -395,15 +395,17 @@ void residual_backward(float* dinp1, float* dinp2, float* dout, int N) {
}
}

void softmax_forward(float* probs, float* logits, int B, int T, int V) {
// output: probs are (B,T,V) of the probabilities (sums to 1.0 in each b,t position)
// input: logits is (B,T,V) of the unnormalized log probabilities
void softmax_forward(float* probs, float* logits, int B, int T, int V, int Vp) {
// output: probs are (B,T,Vp) of the probabilities (sums to 1.0 in each b,t position)
// input: logits is (B,T,Vp) of the unnormalized log probabilities
// Vp is the padded vocab size (for efficiency), V is the "real" vocab size
// example: Vp is 50304 and V is 50257
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// probs <- softmax(logits)
float* logits_bt = logits + b * T * V + t * V;
float* probs_bt = probs + b * T * V + t * V;
float* logits_bt = logits + b * T * Vp + t * Vp;
float* probs_bt = probs + b * T * Vp + t * Vp;

// maxval is only calculated and subtracted for numerical stability
float maxval = -10000.0f; // TODO something better
Expand All @@ -417,23 +419,29 @@ void softmax_forward(float* probs, float* logits, int B, int T, int V) {
probs_bt[i] = expf(logits_bt[i] - maxval);
sum += probs_bt[i];
}
// note we only loop to V, leaving the padded dimensions
for (int i = 0; i < V; i++) {
probs_bt[i] /= sum;
}
// for extra super safety we may wish to include this too,
// forcing the probabilities here to be zero, but it shouldn't matter
for (int i = V; i < Vp; i++) {
probs_bt[i] = 0.0f;
}
}
}
}

void crossentropy_forward(float* losses,
float* probs, int* targets,
int B, int T, int V) {
int B, int T, int Vp) {
// output: losses is (B,T) of the individual losses at each position
// input: probs are (B,T,V) of the probabilities
// input: probs are (B,T,Vp) of the probabilities
// input: targets is (B,T) of integers giving the correct index in logits
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// loss = -log(probs[target])
float* probs_bt = probs + b * T * V + t * V;
float* probs_bt = probs + b * T * Vp + t * Vp;
int ix = targets[b * T + t];
losses[b * T + t] = -logf(probs_bt[ix]);
}
Expand All @@ -442,14 +450,16 @@ void crossentropy_forward(float* losses,

void crossentropy_softmax_backward(float* dlogits,
float* dlosses, float* probs, int* targets,
int B, int T, int V) {
int B, int T, int V, int Vp) {
// backwards through both softmax and crossentropy
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dlogits_bt = dlogits + b * T * V + t * V;
float* probs_bt = probs + b * T * V + t * V;
float* dlogits_bt = dlogits + b * T * Vp + t * Vp;
float* probs_bt = probs + b * T * Vp + t * Vp;
float dloss = dlosses[b * T + t];
int ix = targets[b * T + t];
// note we only loop to V, leaving the padded dimensions
// of dlogits untouched, so gradient there stays at zero
for (int i = 0; i < V; i++) {
float p = probs_bt[i];
float indicator = i == ix ? 1.0f : 0.0f;
Expand Down Expand Up @@ -555,6 +565,7 @@ float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes)
typedef struct {
int max_seq_len; // max sequence length, e.g. 1024
int vocab_size; // vocab size, e.g. 50257
int padded_vocab_size; // padded to e.g. %128==0, 50304
int num_layers; // number of layers, e.g. 12
int num_heads; // number of heads in attention, e.g. 12
int channels; // number of channels, e.g. 768
Expand Down Expand Up @@ -596,25 +607,31 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
if (model_file == NULL) { printf("Error opening model file\n"); exit(1); }
int model_header[256];
fread(model_header, sizeof(int), 256, model_file);
if (model_header[0] != 20240326) { printf("Bad magic model file"); exit(1); }
if (model_header[1] != 1) { printf("Bad version in model file"); exit(1); }
if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(1); }
if (model_header[1] != 3) {
printf("Bad version in model file\n");
printf("---> HINT: try to re-run `python train_gpt2.py`\n");
exit(1);
}

// read in hyperparameters
size_t maxT, V, L, NH, C; // size_t to prevent int overflow
size_t maxT, V, Vp, L, NH, C; // size_t to prevent int overflow
model->config.max_seq_len = maxT = model_header[2];
model->config.vocab_size = V = model_header[3];
model->config.num_layers = L = model_header[4];
model->config.num_heads = NH = model_header[5];
model->config.channels = C = model_header[6];
model->config.padded_vocab_size = Vp = model_header[7];
printf("[GPT-2]\n");
printf("max_seq_len: %zu\n", maxT);
printf("vocab_size: %zu\n", V);
printf("padded_vocab_size: %zu\n", Vp);
printf("num_layers: %zu\n", L);
printf("num_heads: %zu\n", NH);
printf("channels: %zu\n", C);

// allocate space for all the parameters and read them in
model->param_sizes[0] = V * C; // wte
model->param_sizes[0] = Vp * C; // wte
model->param_sizes[1] = maxT * C; // wpe
model->param_sizes[2] = L * C; // ln1w
model->param_sizes[3] = L * C; // ln1b
Expand Down Expand Up @@ -668,6 +685,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {

// convenience parameters (size_t to help prevent int overflow)
size_t V = model->config.vocab_size;
size_t Vp = model->config.padded_vocab_size;
size_t L = model->config.num_layers;
size_t NH = model->config.num_heads;
size_t C = model->config.channels;
Expand Down Expand Up @@ -706,8 +724,8 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
model->act_sizes[17] = B * T * C; // lnf
model->act_sizes[18] = B * T; // lnf_mean
model->act_sizes[19] = B * T; // lnf_rstd
model->act_sizes[20] = B * T * V; // logits
model->act_sizes[21] = B * T * V; // probs
model->act_sizes[20] = B * T * Vp; // logits
model->act_sizes[21] = B * T * Vp; // probs
model->act_sizes[22] = B * T; // losses
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
Expand Down Expand Up @@ -789,12 +807,12 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
}
residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);
matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, V);
softmax_forward(acts.probs, acts.logits, B, T, V);
matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp);
softmax_forward(acts.probs, acts.logits, B, T, V, Vp);

// also forward the cross-entropy loss function if we have the targets
if (targets != NULL) {
crossentropy_forward(model->acts.losses, model->acts.probs, targets, B, T, V);
crossentropy_forward(model->acts.losses, model->acts.probs, targets, B, T, Vp);
// for convenience also evaluate the mean loss
float mean_loss = 0.0f;
for (int i=0; i<B*T; i++) { mean_loss += model->acts.losses[i]; }
Expand Down Expand Up @@ -830,6 +848,7 @@ void gpt2_backward(GPT2 *model) {
size_t B = model->batch_size;
size_t T = model->seq_len;
size_t V = model->config.vocab_size;
size_t Vp = model->config.padded_vocab_size;
size_t L = model->config.num_layers;
size_t NH = model->config.num_heads;
size_t C = model->config.channels;
Expand All @@ -846,8 +865,8 @@ void gpt2_backward(GPT2 *model) {
float dloss_mean = 1.0f / (B*T);
for (int i = 0; i < B*T; i++) { grads_acts.losses[i] = dloss_mean; }

crossentropy_softmax_backward(grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V);
matmul_backward(grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, V);
crossentropy_softmax_backward(grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V, Vp);
matmul_backward(grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, Vp);
float* residual = acts.residual3 + (L-1) * B * T * C; // last layer's residual
float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; // write to last layer's residual
layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);
Expand Down Expand Up @@ -1210,9 +1229,11 @@ int main() {
// furthermore, below we're only using b=0 (i.e. the first row) of all B rows
// we're in principle running B "inference streams" in parallel here
// but only using position 0
// get the V-dimensional vector probs[0, t-1, :]
float* probs = model.acts.probs + (t-1) * model.config.vocab_size;
// get the Vp-dimensional vector probs[0, t-1, :]
float* probs = model.acts.probs + (t-1) * model.config.padded_vocab_size;
float coin = random_f32(&rng_state);
// note we're only sampling from the first V elements, ignoring padding
// (the probabilities in the padded region should be zero anyway)
int next_token = sample_mult(probs, model.config.vocab_size, coin);
gen_tokens[t] = next_token;
// print the generated token, either using the Tokenizer or a fallback
Expand Down

0 comments on commit b7972ff

Please sign in to comment.