Skip to content

Commit

Permalink
padded vocab change. touched a lot of code. very stressful and error …
Browse files Browse the repository at this point in the history
…prone, but i think it is done. had to bump versions on all .bin files, invalidating the previous files. re-run the python training script to re-export the new version files. let's not do too much of things like this in the future lol. actually, fun fact i had a chance to do the padded vocab really really early in the history of llm.c development, and chose not do it, thinking i'll just do it later. i should have done it. such is life, you make mistakes, you accumulate scar tissue, and you learn, and you become better, faster, stronger. this is the mindset one must have to lead a happy and fulfilling life. it's not important that you are perfect at any point in time, it's only important that you keep improving, every day.
  • Loading branch information
karpathy committed Apr 28, 2024
1 parent d95b8d8 commit 835060e
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 66 deletions.
57 changes: 35 additions & 22 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int check_tensor(float *a, float *b, int n, const char* label, float threshold=1

// the same tensors as in the train file, but in float, which are used as reference
typedef struct {
float* wte; // (V, C)
float* wte; // (Vp, C)
float* wpe; // (maxT, C)
float* ln1w; // (L, C)
float* ln1b; // (L, C)
Expand Down Expand Up @@ -109,6 +109,7 @@ int main(int argc, char *argv[]) {
GPT2 model;
gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin");
size_t V = model.config.vocab_size;
size_t Vp = model.config.padded_vocab_size;
size_t maxT = model.config.max_seq_len;
size_t L = model.config.num_layers;
size_t C = model.config.channels;
Expand All @@ -117,8 +118,12 @@ int main(int argc, char *argv[]) {
FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb");
int state_header[256];
freadCheck(state_header, sizeof(int), 256, state_file);
if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(EXIT_FAILURE); }
if (state_header[1] != 1) { printf("Bad version in state file"); exit(EXIT_FAILURE); }
if (state_header[0] != 20240327) { fprintf(stderr, "Bad magic state file\n"); exit(EXIT_FAILURE); }
if (state_header[1] != 2) {
fprintf(stderr, "Bad version in state file\n");
fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n");
exit(EXIT_FAILURE);
}
int B = state_header[2]; // batch size, e.g. 4
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
assert(0 <= T && T <= maxT);
Expand Down Expand Up @@ -154,13 +159,12 @@ int main(int argc, char *argv[]) {
gpt2_forward(&model, x, NULL, B, T);
// at this point, target should be equal to expected_logits, let's compare
// copy logits to CPU so we can compare them
floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * V * sizeof(floatX));
float* logits_cpu = (float*)mallocCheck(B * T * V * sizeof(float));
cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * V * sizeof(floatX), cudaMemcpyDeviceToHost);
for (int i = 0; i < B * T * V; i++) {
floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX));
float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float));
cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost);
for (int i = 0; i < B * T * Vp; i++) {
logits_cpu[i] = (float)logits_cpu_raw[i];
}
int logits_ok = 1;

// FP16 and lower require very high tolerances unfortunately. TODO look into more
float logit_accuracy_threshold = 1e-2f;
Expand All @@ -169,19 +173,25 @@ int main(int argc, char *argv[]) {
logit_accuracy_threshold = 15.0f;
#endif


// compare the output logits from the forward pass
// also careful that we don't access and compare the padded columns of logits
int logits_ok = 1;
float max_diff = 0.0f;
for (int i=0; i<B*T*V; i++) {
if(i < 10) {
printf("%f %f\n", expected_logits[i], logits_cpu[i]);
}
float diff = fabsf(expected_logits[i] - logits_cpu[i]);
max_diff = fmaxf(max_diff, diff);
if (diff >= logit_accuracy_threshold) {
printf("MISMATCH AT INDEX %d: ", i);
printf("%f %f\n", expected_logits[i],logits_cpu[i]);
logits_ok = 0;
break;
for (int bt = 0; bt < B*T; bt++) {
for (int v = 0; v < V; v++) {
int i = bt * Vp + v; // linearized index
if (i < 10) {
printf("%f, %f\n", expected_logits[i], logits_cpu[i]);
}
float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]);
max_diff = fmaxf(max_diff, diff);
if (diff >= logit_accuracy_threshold) {
printf("MISMATCH AT INDEX %d,%d: ", bt, v);
printf("%f %f\n", expected_logits[bt*V + v], logits_cpu[i]);
logits_ok = 0;
bt = B*T; // to break out of both loops
break;
}
}
}
allok = allok && logits_ok;
Expand Down Expand Up @@ -244,10 +254,13 @@ int main(int argc, char *argv[]) {
// I set the tolerances manually by inspecting the gradient differences for
// a few elements of each tensor. bf16 looks ok but not amazing here.
// It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure.
// Also, if code changes and some of these get tripped, it could be ok if it's not by too much,
// because our use of stochastic rounding is adding some non-determinism "pepper noise".
// In that case it's ok to extend the tolerance by a bit, after a manual review.
allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", 6e-1f);
allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", 1e-2f);
allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", 9e-2); // hmm a bit high
allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", 3e-2f);
allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", 1.1e-1); // hmm a bit high
allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", 4e-2f);
allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", 3e-2f);
allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", 3e-2f);
allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", 9e-2f); // hmm a bit high
Expand Down
45 changes: 29 additions & 16 deletions test_gpt2_fp32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,20 @@ int main(int argc, char *argv[]) {

// int C = model.config.channels;
int V = model.config.vocab_size;
int Vp = model.config.padded_vocab_size;
int maxT = model.config.max_seq_len;
// int L = model.config.num_layers;

// load additional information that we will use for debugging and error checking
FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb");
int state_header[256];
freadCheck(state_header, sizeof(int), 256, state_file);
if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(1); }
if (state_header[1] != 1) { printf("Bad version in state file"); exit(1); }
if (state_header[0] != 20240327) { printf("Bad magic state file\n"); exit(EXIT_FAILURE); }
if (state_header[1] != 2) {
fprintf(stderr, "Bad version in state file\n");
fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n");
exit(EXIT_FAILURE);
}
int B = state_header[2]; // batch size, e.g. 4
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
assert(0 <= T && T <= maxT);
Expand Down Expand Up @@ -94,20 +99,31 @@ int main(int argc, char *argv[]) {
gpt2_forward(&model, x, NULL, B, T);
// at this point, target should be equal to expected_logits, let's compare
// copy logits to CPU so we can compare them
float* logits_cpu = (float*)mallocCheck(B * T * V * sizeof(float));
cudaMemcpy(logits_cpu, model.acts.output, B * T * V * sizeof(float), cudaMemcpyDeviceToHost);
float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float));
cudaMemcpy(logits_cpu, model.acts.output, B * T * Vp * sizeof(float), cudaMemcpyDeviceToHost);

// compare the output logits from the forward pass
// also careful that we don't access and compare the padded columns of logits
int logits_ok = 1;
for (int i=0; i<B*T*V; i++) {
if(i < 3) {
printf("%f %f\n", expected_logits[i], logits_cpu[i]);
}
if (fabsf(expected_logits[i] - logits_cpu[i]) >= 1e-2) {
printf("MISMATCH AT INDEX %d: ", i);
printf("%f %f\n", expected_logits[i],logits_cpu[i]);
logits_ok = 0;
break;
float max_diff = 0.0f;
for (int bt = 0; bt < B*T; bt++) {
for (int v = 0; v < V; v++) {
int i = bt * Vp + v; // linearized index
if (i < 10) {
printf("%f, %f\n", expected_logits[i], logits_cpu[i]);
}
float diff = fabsf(expected_logits[bt*V + v] - logits_cpu[i]);
max_diff = fmaxf(max_diff, diff);
if (diff >= 1e-2f) {
printf("MISMATCH AT INDEX %d,%d: ", bt, v);
printf("%f %f\n", expected_logits[bt*V + v], logits_cpu[i]);
logits_ok = 0;
bt = B*T; // to break out of both loops
break;
}
}
}
allok = allok && logits_ok;
if(!logits_ok) { printf("NOT "); }
printf("OK (LOGITS)\n");

Expand All @@ -124,9 +140,6 @@ int main(int argc, char *argv[]) {

if (step == 0) {
// error checking at step 0 for reference activations


allok = allok && logits_ok;
free(logits_cpu);

// compare the achieved loss
Expand Down
34 changes: 20 additions & 14 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1108,10 +1108,12 @@ __global__ void fused_classifier_kernel3(Type* logits, Type* losses, Type* probs
}

// very sensible default for dlosses is 1/(B*T), which is the uniform loss
float dloss = dlosses != NULL ? (float)dlosses[idx] : 1.0f / (B*T);
float dloss = (dlosses != NULL) ? (float)dlosses[idx] : 1.0f / (B*T);
// calculate the gradients directly, saves bandwidth from probs during training
// but also supports writing probs for inference-only and debugging
const Type* logits_vec = logits + idx * P;
// note that we use the padded dimension P to access data, but we only ever
// modify the elements up to V, ignoring the padded dimensions and leaving them at 0
for (int i = threadIdx.x; i < V; i += blockDim.x) {
// this is the 2nd read of logits after the one in prepare_softmax2
// this data will never be needed again, so we reduce cache persistence
Expand Down Expand Up @@ -1475,6 +1477,7 @@ void fused_classifier3(Type* logits, Type* losses,
typedef struct {
int max_seq_len; // max sequence length, e.g. 1024
int vocab_size; // vocab size, e.g. 50257
int padded_vocab_size; // padded to e.g. %128==0, 50304
int num_layers; // number of layers, e.g. 12
int num_heads; // number of heads in attention, e.g. 12
int channels; // number of channels, e.g. 768
Expand Down Expand Up @@ -1504,11 +1507,11 @@ typedef struct {
static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!");

void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) {
size_t V = config.vocab_size;
size_t Vp = config.padded_vocab_size;
size_t C = config.channels;
size_t maxT = config.max_seq_len;
size_t L = config.num_layers;
param_sizes[0] = V * C; // wte
param_sizes[0] = Vp * C; // wte
param_sizes[1] = maxT * C; // wpe
param_sizes[2] = L * C; // ln1w
param_sizes[3] = L * C; // ln1b
Expand Down Expand Up @@ -1595,7 +1598,7 @@ typedef struct {
} ActivationTensors;

void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config config) {
size_t V = config.vocab_size;
size_t Vp = config.padded_vocab_size;
size_t L = config.num_layers;
size_t NH = config.num_heads;
size_t C = config.channels;
Expand All @@ -1619,7 +1622,7 @@ void fill_in_activation_sizes(size_t* act_sizes, size_t B, size_t T, GPT2Config
act_sizes[17] = B * T; // lnf_rstd
act_sizes[18] = B * T; // losses
act_sizes[19] = L * B * T * 3*C; // qkvr
act_sizes[20] = B * T * max(3*C, max(NH*T, V)); // output / scratch
act_sizes[20] = B * T * max(3*C, max(NH*T, Vp)); // output / scratch
}

// Backward pass is conceptually quite different from forward, because we can discard
Expand Down Expand Up @@ -1723,9 +1726,9 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
freadCheck(model_header, sizeof(int), 256, model_file);
if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(EXIT_FAILURE); }
int version = model_header[1];
if (!(version == 1 || version == 2)) {
// 1 = fp32, ordered layernorm at the end
// 2 = bf16, ordered layernorm at the end
if (!(version == 3 || version == 4)) {
// 3 = fp32, padded vocab
// 4 = bf16, padded vocab
fprintf(stderr, "Bad version in model file\n");
fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n");
exit(EXIT_FAILURE);
Expand All @@ -1737,6 +1740,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->config.num_layers = model_header[4];
model->config.num_heads = model_header[5];
model->config.channels = model_header[6];
model->config.padded_vocab_size = model_header[7];

// allocate space for all the parameters and read them in
fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config);
Expand Down Expand Up @@ -1786,6 +1790,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {

// convenience parameters
size_t V = model->config.vocab_size;
size_t Vp = model->config.padded_vocab_size;
size_t L = model->config.num_layers;
size_t NH = model->config.num_heads;
size_t C = model->config.channels;
Expand Down Expand Up @@ -1890,13 +1895,13 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {

residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);
matmul_forward_cublas(acts.output, acts.lnf, params.wte, NULL, B, T, C, V);
matmul_forward_cublas(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp);

// also forward the cross-entropy loss function if we have the targets
if (targets != NULL) {
// fused classifier: does the forward pass and first part of the backward pass
// we're passing dlosses = NULL, which will default them to 1.0f/(B*T), i.e. uniform loss
fused_classifier3(acts.output, acts.losses, (floatX*)NULL, model->targets, B, T, V, V);
fused_classifier3(acts.output, acts.losses, (floatX*)NULL, model->targets, B, T, V, Vp);
// for convenience also evaluate the mean loss (TODO re-think this compute+sync point)
// move the (B,T) losses to CPU
cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(floatX), cudaMemcpyDeviceToHost));
Expand Down Expand Up @@ -1946,7 +1951,7 @@ void gpt2_backward(GPT2 *model) {
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
size_t B = model->batch_size;
size_t T = model->seq_len;
size_t V = model->config.vocab_size;
size_t Vp = model->config.padded_vocab_size;
size_t L = model->config.num_layers;
size_t NH = model->config.num_heads;
size_t C = model->config.channels;
Expand All @@ -1962,7 +1967,7 @@ void gpt2_backward(GPT2 *model) {
// technically that is a small, inline backward() pass of calculating
// total, final loss as the mean over all losses over all (B,T) positions in the batch
// next: backward the classifier matmul
matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, V);
matmul_backward(grads_acts.bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, B, T, C, Vp);
// backward the final layernorm
floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
floatX* dresidual = (floatX*)grads_acts.residual3; // the main buffer holding the gradient in the backward pass
Expand Down Expand Up @@ -2437,6 +2442,7 @@ int main(int argc, char *argv[]) {
printf0("| load_filename | %-50s |\n", load_filename);
printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len);
printf0("| vocab_size V | %-50d |\n", model.config.vocab_size);
printf0("| padded_vocab_size Vp | %-50d |\n", model.config.padded_vocab_size);
printf0("| num_layers L | %-50d |\n", model.config.num_layers);
printf0("| num_heads NH | %-50d |\n", model.config.num_heads);
printf0("| channels C | %-50d |\n", model.config.channels);
Expand Down Expand Up @@ -2520,8 +2526,8 @@ int main(int argc, char *argv[]) {
// we're in principle running B "inference streams" in parallel here
// only using position 0 because it's a bit faster (copy less probs from GPU -> CPU)
// get the V-dimensional vector probs[0, t-1, :]
floatX* logits = model.acts.output + (t - 1) * model.config.vocab_size;
// move probs back to CPU and sample
floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size;
// move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding)
cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost));
// convert to FP32 into cpu_logits (this does nothing useful if floatX == float)
for (int i = 0; i < model.config.vocab_size; i++) {
Expand Down
Loading

0 comments on commit 835060e

Please sign in to comment.