forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
the full forward pass of GPT-2 in one file of pure CUDA
- Loading branch information
Showing
8 changed files
with
1,300 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#define TESTING | ||
#include "train_gpt2.cu" | ||
|
||
// poor man's tensor checker | ||
int check_tensor(float *a, float *b, int n, char* label) { | ||
int print_upto = 5; | ||
int ok = 1; | ||
printf("%s\n", label); | ||
for (int i = 0; i < n; i++) { | ||
if (fabs(a[i] - b[i]) <= 1e-2) { | ||
if (i < print_upto) { printf("OK "); } | ||
} else { | ||
if (i < print_upto) { printf("NOT OK "); } | ||
ok = 0; | ||
} | ||
if (i < print_upto) { printf("%f %f\n", a[i], b[i]); } | ||
} | ||
// print the final result | ||
if (ok) { | ||
printf("TENSOR OK\n"); | ||
} else { | ||
printf("TENSOR NOT OK\n"); | ||
} | ||
return ok; | ||
} | ||
|
||
int main(int argc, char *argv[]) { | ||
|
||
// build the GPT-2 model from a checkpoint | ||
GPT2 model; | ||
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); | ||
|
||
int C = model.config.channels; | ||
int V = model.config.vocab_size; | ||
int maxT = model.config.max_seq_len; | ||
int L = model.config.num_layers; | ||
|
||
// load additional information that we will use for debugging and error checking | ||
FILE *state_file = fopen("gpt2_124M_debug_state.bin", "rb"); | ||
if (state_file == NULL) { printf("Error opening state file\n"); exit(1); } | ||
int state_header[256]; | ||
fread(state_header, sizeof(int), 256, state_file); | ||
if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(1); } | ||
if (state_header[1] != 1) { printf("Bad version in state file"); exit(1); } | ||
int B = state_header[2]; // batch size, e.g. 4 | ||
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) | ||
printf("[State]\n"); | ||
printf("batch_size: %d\n", B); | ||
printf("seq_len: %d\n", T); | ||
|
||
ParameterTensors expected_grads; | ||
float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_sizes, 0); | ||
|
||
// inputs and expected outputs, only used for error checking | ||
int* x = (int*) malloc(B * T * sizeof(int)); | ||
int* y = (int*) malloc(B * T * sizeof(int)); | ||
float* expected_logits = (float*) malloc(B * T * V * sizeof(float)); | ||
float* expected_loss = (float*) malloc(1 * sizeof(float)); | ||
|
||
// read reference information from Python | ||
fread(x, sizeof(int), B*T, state_file); | ||
fread(y, sizeof(int), B*T, state_file); | ||
fread(expected_logits, sizeof(float), B*T*V, state_file); | ||
fread(expected_loss, sizeof(float), 1, state_file); | ||
fread(expected_grads_memory, sizeof(float), model.num_parameters, state_file); | ||
fclose(state_file); | ||
|
||
// overall OK signal for the test | ||
int allok = 1; | ||
|
||
// let's do 10 training iterations, following the pytorch code | ||
float losses[10]; | ||
for (int step = 0; step < 10; step++) { | ||
struct timespec start, end; | ||
clock_gettime(CLOCK_MONOTONIC, &start); | ||
gpt2_forward(&model, x, y, B, T); | ||
clock_gettime(CLOCK_MONOTONIC, &end); | ||
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; | ||
|
||
if (step == 0) { | ||
// error checking at step 0 for reference activations | ||
|
||
// at this point, target should be equal to expected_logits, let's compare | ||
// copy logits to CPU so we can compare them | ||
float* logits_cpu = (float*) malloc(B * T * V * sizeof(float)); | ||
cudaMemcpy(logits_cpu, model.acts.logits, B * T * V * sizeof(float), cudaMemcpyDeviceToHost); | ||
int logits_ok = 1; | ||
for (int i=0; i<B*T*V; i++) { | ||
if(i < 3) { | ||
printf("%f %f\n", expected_logits[i], logits_cpu[i]); | ||
} | ||
if (fabs(expected_logits[i] - logits_cpu[i]) >= 1e-2) { | ||
printf("MISMATCH AT INDEX %d: ", i); | ||
printf("%f %f\n", expected_logits[i],logits_cpu[i]); | ||
logits_ok = 0; | ||
break; | ||
} | ||
} | ||
if(!logits_ok) { printf("NOT "); } | ||
printf("OK (LOGITS)\n"); | ||
allok = allok && logits_ok; | ||
free(logits_cpu); | ||
|
||
// compare the achieved loss | ||
if (fabs(model.mean_loss - *expected_loss) >= 1e-2) { | ||
printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss); | ||
allok = 0; | ||
} else { | ||
printf("LOSS OK: %f %f\n", model.mean_loss, *expected_loss); | ||
} | ||
} | ||
} | ||
|
||
printf("overall okay: %d\n", allok); | ||
|
||
// free everything | ||
free(x); | ||
free(y); | ||
free(expected_logits); | ||
free(expected_loss); | ||
free(expected_grads_memory); | ||
gpt2_free(&model); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.