diff --git a/Makefile b/Makefile index 3f361078c..ce0ae9dd0 100644 --- a/Makefile +++ b/Makefile @@ -34,10 +34,10 @@ else endif # PHONY means these targets will always be executed -.PHONY: all train_gpt2 test_gpt2 +.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu # default target is all -all: train_gpt2 test_gpt2 +all: train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2: train_gpt2.c $(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@ @@ -45,5 +45,13 @@ train_gpt2: train_gpt2.c test_gpt2: test_gpt2.c $(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@ +# possibly may want to disable warnings? e.g. append -Xcompiler -Wno-unused-result +train_gpt2cu: train_gpt2.cu + nvcc -O3 --use_fast_math $< -lcublas -o $@ + +test_gpt2cu: test_gpt2.cu + nvcc -O3 --use_fast_math $< -lcublas -o $@ + clean: - rm -f train_gpt2 test_gpt2 + rm -f train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu + diff --git a/README.md b/README.md index 4132524c8..963cb434d 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,8 @@ The generation just gives you the token ids for now, which we have to decode bac ```python import tiktoken enc = tiktoken.get_encoding("gpt2") -print(enc.decode(list(map(int, "50256 16773 18162 21986 11 198 13681 263 23875 198 3152 262 11773 2910 198 1169 6002 6386 2583 286 262 11858 198 20424 428 3135 7596 995 3675 13 198 40 481 407 736 17903 11 329 703 6029 706 4082 198 42826 1028 1128 633 263 11 198 10594 407 198 2704 454 680 1028 262 1027 28860 286 198 3237 323".split())))) +ptok = lambda x: print(enc.decode(list(map(int, x.strip().split())))) +ptok("50256 16773 18162 21986 11 198 13681 263 23875 198 3152 262 11773 2910 198 1169 6002 6386 2583 286 262 11858 198 20424 428 3135 7596 995 3675 13 198 40 481 407 736 17903 11 329 703 6029 706 4082 198 42826 1028 1128 633 263 11 198 10594 407 198 2704 454 680 1028 262 1027 28860 286 198 3237 323") ``` which prints: @@ -99,7 +100,7 @@ I like how Netflix comes up, it's clear that the shadow of the training past is I am also attaching a simple unit test for making sure our C code agrees with the PyTorch code. Compile and run with: -``` +```bash make test_gpt2 ./test_gpt2 ``` @@ -114,6 +115,50 @@ I attached a very small tutorial here, in [doc/layernorm/layernorm.md](doc/layer CUDA port is WIP, I'm keeping the growing collection of kernels in the `dev` folder, e.g. see [dev/cuda/README.md](dev/cuda/README.md). +As of April 10, 2024 the full forward pass is now implemented in pure CUDA in one file. First we can check that all of the logits and the final loss matches the PyTorch reference: + +```bash +make test_gpt2cu +./test_gpt2cu +``` + +This prints `overall okay: 1`. Now that we are calculating all the right values, we can time our code. We can't train yet because the backward pass + update are not implemented yet, but we can run the training loop and see the timings: + +```bash +make train_gpt2cu +./train_gpt2cu +``` + +This will run GPT-2 (124M) in one file of pure CUDA (see [train_gpt2.cu](train_gpt2.cu)), using batch size 4 and sequence length 1024. This will print a bunch of hyperparameters and then the "training": + +``` +val loss 4.517294 +step 0: train loss 4.367857 (took 112.135004 ms) +step 1: train loss 4.406483 (took 112.555327 ms) +step 2: train loss 4.484838 (took 111.380248 ms) +... +``` + +The loss is changing because we are still loading real data batches from our dataset, but there is no training so they won't go down over time. In any case, on my A100 40GB PCIe GPU we are seeing about 111ms/iteration. We can compare this to PyTorch fp32 training by calling our python script like this: + +```bash +python train_gpt2.py --inference_only 1 --write_tensors 0 --sequence_length 1024 --batch_size 4 +``` + +Which shows time per iteration with the same hyperparameters (batch 4, time 1024) at 180ms/iteration. We can then enable `torch.compile` by adding the `--compile 1` flag: + +```bash +python train_gpt2.py --inference_only 1 --write_tensors 0 --sequence_length 1024 --batch_size 4 --compile 1 +``` + +And see that the first iteration now takes 20 seconds (compilation time), but all following iterations take ~86ms. And if we additionally turn on the use of fp32 tensorcores (only GPUs since Volta) with `--tensorcores 1`: + +```bash +python train_gpt2.py --inference_only 1 --write_tensors 0 --sequence_length 1024 --batch_size 4 --compile 1 --tensorcores 1 +``` + +The time drops down to 26ms/iteration. So we have a gap to close :)! At the current 111ms we are about 4.2X slower. + ## license MIT \ No newline at end of file diff --git a/dev/cuda/matmul_forward.cu b/dev/cuda/matmul_forward.cu index 4ebed6e86..1dc1e1560 100644 --- a/dev/cuda/matmul_forward.cu +++ b/dev/cuda/matmul_forward.cu @@ -8,7 +8,7 @@ nvcc -O3 --use_fast_math -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C OMP_NUM_THREADS=32 ./matmul_forward 1 -version 2 parallelizes over all of B,T,C +version 2 calls cuBLAS, very fast OMP_NUM_THREADS=32 ./matmul_forward 2 */ diff --git a/test_gpt2.c b/test_gpt2.c index 377c909d4..8cb31854d 100644 --- a/test_gpt2.c +++ b/test_gpt2.c @@ -166,6 +166,12 @@ int main(int argc, char *argv[]) { printf("overall okay: %d\n", allok); + // free everything + free(x); + free(y); + free(expected_logits); + free(expected_loss); + free(expected_grads_memory); gpt2_free(&model); return 0; } diff --git a/test_gpt2.cu b/test_gpt2.cu new file mode 100644 index 000000000..5ca7bfbcf --- /dev/null +++ b/test_gpt2.cu @@ -0,0 +1,124 @@ +#define TESTING +#include "train_gpt2.cu" + +// poor man's tensor checker +int check_tensor(float *a, float *b, int n, char* label) { + int print_upto = 5; + int ok = 1; + printf("%s\n", label); + for (int i = 0; i < n; i++) { + if (fabs(a[i] - b[i]) <= 1e-2) { + if (i < print_upto) { printf("OK "); } + } else { + if (i < print_upto) { printf("NOT OK "); } + ok = 0; + } + if (i < print_upto) { printf("%f %f\n", a[i], b[i]); } + } + // print the final result + if (ok) { + printf("TENSOR OK\n"); + } else { + printf("TENSOR NOT OK\n"); + } + return ok; +} + +int main(int argc, char *argv[]) { + + // build the GPT-2 model from a checkpoint + GPT2 model; + gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); + + int C = model.config.channels; + int V = model.config.vocab_size; + int maxT = model.config.max_seq_len; + int L = model.config.num_layers; + + // load additional information that we will use for debugging and error checking + FILE *state_file = fopen("gpt2_124M_debug_state.bin", "rb"); + if (state_file == NULL) { printf("Error opening state file\n"); exit(1); } + int state_header[256]; + fread(state_header, sizeof(int), 256, state_file); + if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(1); } + if (state_header[1] != 1) { printf("Bad version in state file"); exit(1); } + int B = state_header[2]; // batch size, e.g. 4 + int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT) + printf("[State]\n"); + printf("batch_size: %d\n", B); + printf("seq_len: %d\n", T); + + ParameterTensors expected_grads; + float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_sizes, 0); + + // inputs and expected outputs, only used for error checking + int* x = (int*) malloc(B * T * sizeof(int)); + int* y = (int*) malloc(B * T * sizeof(int)); + float* expected_logits = (float*) malloc(B * T * V * sizeof(float)); + float* expected_loss = (float*) malloc(1 * sizeof(float)); + + // read reference information from Python + fread(x, sizeof(int), B*T, state_file); + fread(y, sizeof(int), B*T, state_file); + fread(expected_logits, sizeof(float), B*T*V, state_file); + fread(expected_loss, sizeof(float), 1, state_file); + fread(expected_grads_memory, sizeof(float), model.num_parameters, state_file); + fclose(state_file); + + // overall OK signal for the test + int allok = 1; + + // let's do 10 training iterations, following the pytorch code + float losses[10]; + for (int step = 0; step < 10; step++) { + struct timespec start, end; + clock_gettime(CLOCK_MONOTONIC, &start); + gpt2_forward(&model, x, y, B, T); + clock_gettime(CLOCK_MONOTONIC, &end); + double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; + + if (step == 0) { + // error checking at step 0 for reference activations + + // at this point, target should be equal to expected_logits, let's compare + // copy logits to CPU so we can compare them + float* logits_cpu = (float*) malloc(B * T * V * sizeof(float)); + cudaMemcpy(logits_cpu, model.acts.logits, B * T * V * sizeof(float), cudaMemcpyDeviceToHost); + int logits_ok = 1; + for (int i=0; i= 1e-2) { + printf("MISMATCH AT INDEX %d: ", i); + printf("%f %f\n", expected_logits[i],logits_cpu[i]); + logits_ok = 0; + break; + } + } + if(!logits_ok) { printf("NOT "); } + printf("OK (LOGITS)\n"); + allok = allok && logits_ok; + free(logits_cpu); + + // compare the achieved loss + if (fabs(model.mean_loss - *expected_loss) >= 1e-2) { + printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss); + allok = 0; + } else { + printf("LOSS OK: %f %f\n", model.mean_loss, *expected_loss); + } + } + } + + printf("overall okay: %d\n", allok); + + // free everything + free(x); + free(y); + free(expected_logits); + free(expected_loss); + free(expected_grads_memory); + gpt2_free(&model); + return 0; +} \ No newline at end of file diff --git a/train_gpt2.c b/train_gpt2.c index ab3092109..830085cde 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -925,7 +925,7 @@ void gpt2_free(GPT2 *model) { } #ifndef TESTING -// if we are TESTING (see test.c), we'll skip the int main below +// if we are TESTING (see test_gpt2.c), we'll skip the int main below // ---------------------------------------------------------------------------- // data loader lite diff --git a/train_gpt2.cu b/train_gpt2.cu new file mode 100644 index 000000000..c7d099283 --- /dev/null +++ b/train_gpt2.cu @@ -0,0 +1,1072 @@ +/* +GPT-2 Transformer Neural Net trained in raw CUDA +*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +// ---------------------------------------------------------------------------- +// CUDA utils + +// error checking +void cudaCheck(cudaError_t error, const char *file, int line) { + if (error != cudaSuccess) { + printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, + cudaGetErrorString(error)); + exit(EXIT_FAILURE); + } +}; +#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__)) +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +// ---------------------------------------------------------------------------- +// all the kernels + +// warp-level reduction for finding the maximum value +__device__ float warpReduceMax(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); + } + return val; +} + +// warp-level reduction for summing values +__device__ float warpReduceSum(float val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xFFFFFFFF, val, offset); + } + return val; +} + +__global__ void encoder_forward_kernel2(float* out, + int* inp, float* wte, float* wpe, + int B, int T, int C) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int N = B * T * C; + + if (idx < N) { + int bt = idx / C; + int b = bt / T; + int t = bt % T; + int c = idx % C; + + int ix = inp[b * T + t]; + + float* out_btc = out + b * T * C + t * C + c; + float* wte_ix = wte + ix * C + c; + float* wpe_tc = wpe + t * C + c; + *out_btc = *wte_ix + *wpe_tc; + } +} + + +__global__ void mean_kernel(float* mean, float* inp, int N, int C, int block_size) { + extern __shared__ float shared[]; + int idx = blockIdx.x; // range [0, B*T) + int tid = threadIdx.x; // range [0, block_size) + float* x = inp + idx * C; + // thread coarsening + float sum = 0.0f; + for (int i = tid; i < C; i += block_size) { + sum += x[i]; + } + shared[tid] = sum; + __syncthreads(); + // reductions + for (int stride = block_size / 2; stride >= 1; stride /= 2) { + __syncthreads(); + if (tid < stride) { + shared[tid] += shared[tid + stride]; + } + } + // write the final result (at thread 0) to global memory + if (tid == 0) { + mean[idx] = shared[0] / C; + } +} + +__global__ void rstd_kernel(float* rstd, float* inp, float* mean, int N, int C, int block_size) { + extern __shared__ float shared[]; + int idx = blockIdx.x; // range [0, B*T) + int tid = threadIdx.x; // range [0, block_size) + float* x = inp + idx * C; + float m = mean[idx]; + // thread coarsening + float sum = 0.0f; + for (int i = tid; i < C; i += block_size) { + float diff = x[i] - m; + sum += diff * diff; + } + shared[tid] = sum; + __syncthreads(); + // reductions + for (int stride = block_size / 2; stride >= 1; stride /= 2) { + __syncthreads(); + if (tid < stride) { + shared[tid] += shared[tid + stride]; + } + } + // write the final result (at thread 0) to global memory + if (tid == 0) { + rstd[idx] = 1.0f / sqrtf(shared[0] / C + 1e-5f); + } +} + +__global__ void normalization_kernel(float* out, float* inp, float* mean, float* rstd, + float* weight, float* bias, int B, int T, int C) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + int bt = idx / C; + int c = idx % C; + + float m = mean[bt]; + float s = rstd[bt]; + float xi = inp[idx]; + float n = s * (xi - m); + float o = n * weight[c] + bias[c]; + + out[idx] = o; +} + +__global__ void add_bias(float* out, float* bias, int B, int T, int OC) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = idx; i < B*T*OC; i += stride) { + int col = i % OC; + out[i] += bias[col]; + } +} + + +__global__ void permute_kernel(float* q, float* k, float* v, + const float* inp, + int B, int N, int NH, int d) { + // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d) + // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_] + + if (idx < B * NH * N * d) { + int b = idx / (NH * N * d); + int rest = idx % (NH * N * d); + int nh_ = rest / (N * d); + rest = rest % (N * d); + int n = rest / d; + int d_ = rest % d; + + int inp_idx = \ + (b * N * 3 * NH * d) + + (n * 3 * NH * d) + + (0 * NH * d) + + (nh_ * d) + + d_; + + q[idx] = inp[inp_idx]; + k[idx] = inp[inp_idx + NH * d]; + v[idx] = inp[inp_idx + 2 * (NH * d)]; + } +} + +__global__ void unpermute_kernel(float* inp, float *out, int B, int N, int NH, int d) { + // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + // out[b][n][nh_][d_] <- inp[b][nh_][n][d_] + if (idx < B * NH * N * d) { + int b = idx / (NH * N * d); + int rest = idx % (NH * N * d); + int nh_ = rest / (N * d); + rest = rest % (N * d); + int n = rest / d; + int d_ = rest % d; + + int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; + out[other_idx] = inp[idx]; + } +} + +__global__ void scale_kernel(float* inp, float scale, int B, int NH, int T) { + // scales the pre-softmax attention scores by scale + // and sets the autoregressive locations to -INFINITY + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < B * NH * T * T) { + int rest = idx % (NH * T * T); + rest = rest % (T * T); + int t2 = rest / T; + int t = rest % T; + if (t > t2) { + inp[idx] = -INFINITY; + } else { + inp[idx] *= scale; + } + } +} + +__global__ void softmax_forward_kernel4(float* out, float* inp, int N, int C) { + // out is (N, C) just like inp. Each row of inp will get softmaxed. + // same as kernel3, but can handle any block size (multiple of 32) + // each row of C elements is handled by block_size threads + // furthermore, each block_size threads get executed in warps of 32 threads + + // special reduction operations warpReduceMax/warpReduceSum are used for intra-warp reductions + // shared memory is used for inter-warp reduction + extern __shared__ float shared[]; + int idx = blockIdx.x; + int tid = threadIdx.x; + int warpId = threadIdx.x / 32; // warp index within a block + int laneId = threadIdx.x % 32; // thread index within a warp + + // the number of warps per block. recall that blockDim.x is block_size + int warpsPerBlock = blockDim.x / 32; + + // shared[] must be allocated to have 2 * warpsPerBlock elements + // first half for max values, the second half for sum values + float* maxvals = shared; + float* sumvals = &shared[warpsPerBlock]; + + // one row of inp, i.e. inp[idx, :] of shape (C,) + float* x = inp + idx * C; + + // first, thread coarsening by directly accessing global memory in series + float maxval = -INFINITY; + for (int i = tid; i < C; i += blockDim.x) { + maxval = fmaxf(maxval, x[i]); + } + // now within-warp reductions for maxval + maxval = warpReduceMax(maxval); + + // the 0th thread of each warp writes the maxval of that warp to shared memory + if (laneId == 0) maxvals[warpId] = maxval; + __syncthreads(); + + // now the 0th thread reduces the maxvals in shared memory, i.e. across warps + if (tid == 0) { + float val = maxvals[tid]; + for (int i = 1; i < warpsPerBlock; i++) { + val = fmaxf(val, maxvals[i]); + } + // store the final max in the first position + maxvals[0] = val; + } + __syncthreads(); + // broadcast the max to all threads + float offset = maxvals[0]; + + // compute expf and write the result to global memory + for (int i = tid; i < C; i += blockDim.x) { + out[idx * C + i] = expf(x[i] - offset); + } + + // okay now we calculated exp(x - max(x)) + // step 2: sum all the values and divide by the sum + + // thread coarsening for sum + x = out + idx * C; + float sumval = 0.0f; + for (int i = tid; i < C; i += blockDim.x) { + sumval += x[i]; + } + // within-warp reduction for sumval + sumval = warpReduceSum(sumval); + + // write sumval to shared memory + if (laneId == 0) sumvals[warpId] = sumval; + __syncthreads(); + + // inter-thread reduction of sum + if (tid == 0) { + float val = sumvals[tid]; + for (int i = 1; i < warpsPerBlock; ++i) { + val += sumvals[i]; + } + sumvals[0] = val; + } + __syncthreads(); + // broadcast the sum to all threads + float sum = sumvals[0]; + + // divide the whole row by the sum + for (int i = tid; i < C; i += blockDim.x) { + out[idx * C + i] = x[i] / sum; + } +} + +__global__ void residual_forward_kernel(float* out, float* inp1, float* inp2, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + out[idx] = inp1[idx] + inp2[idx]; + } +} + +__global__ void gelu_kernel(float* out, const float* inp, int N) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + float s = sqrtf(2.0f / M_PI); + if (i < N) { + float xi = inp[i]; + float cube = 0.044715f * xi * xi * xi; + out[i] = 0.5f * xi * (1.0f + tanhf(s * (xi + cube))); + } +} + +__global__ void crossentropy_forward_kernel1(float* losses, + float* probs, int* targets, + int B, int T, int V) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < B * T) { + int b = i / T; + int t = i % T; + float* probs_bt = probs + b * T * V + t * V; + int ix = targets[b * T + t]; + losses[b * T + t] = -logf(probs_bt[ix]); + } +} + +// ---------------------------------------------------------------------------- +// kernel launchers + +void encoder_forward(float* out, + int* inp, float* wte, float* wpe, + int B, int T, int C) { + const int N = B * T * C; + const int block_size = 256; + const int grid_size = CEIL_DIV(N, block_size); + encoder_forward_kernel2<<>>(out, inp, wte, wpe, B, T, C); + cudaCheck(cudaGetLastError()); +} + +void layernorm_forward(float* out, float* mean, float* rstd, + float* inp, float* weight, float* bias, + int B, int T, int C) { + int N = B * T; + const int block_size = 128; + // in mean and rstd, threads cooperate within blocks via reductions + mean_kernel<<>>(mean, inp, N, C, block_size); + cudaCheck(cudaGetLastError()); + rstd_kernel<<>>(rstd, inp, mean, N, C, block_size); + cudaCheck(cudaGetLastError()); + // in the normalization, everything just gets flattened out + const int block_size2 = 256; + const int grid_size = CEIL_DIV(B * T * C, block_size2); + normalization_kernel<<>>(out, inp, mean, rstd, weight, bias, B, T, C); + cudaCheck(cudaGetLastError()); +} + +// kernel 1 is the most naive matmul kernel +void matmul_forward(float* out, + float* inp, float* weight, float* bias, + int B, int T, int C, int OC) { + const int sqrt_block_size = 32; + + cublasHandle_t handle; // cuBLAS context + cublasStatus_t stat = cublasCreate(&handle); // initialize CUBLAS context + const float alpha = 1.0f; + const float beta = 0.0f; + stat = cublasSgemm(handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, B*T, C, &alpha, weight, C, inp, C, &beta, out, OC); + if (stat != CUBLAS_STATUS_SUCCESS) { + printf("cublasSgemm failed\n"); + exit(1); + } + // and now we still have to add the bias... (ew) + if (bias != NULL) { + int block_size = sqrt_block_size * sqrt_block_size; + int grid_size = CEIL_DIV(OC * B * T, block_size); + add_bias<<>>(out, bias, B, T, OC); + cudaCheck(cudaGetLastError()); + } + cublasDestroy(handle); +} + +void attention_forward(float* out, float* vaccum, float* qkvr, float* preatt, float* att, + float* inp, + int B, int T, int C, int NH) { + const int block_size = 512; + int HS = C / NH; // head size + + // permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS) + float *q, *k, *v; + q = qkvr + 0 * B * T * C; + k = qkvr + 1 * B * T * C; + v = qkvr + 2 * B * T * C; + int total_threads = B * NH * T * HS; + int num_blocks = CEIL_DIV(total_threads, block_size); + permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); + + // batched matrix multiply with cuBLAS + cublasHandle_t handle; + cublasStatus_t stat = cublasCreate(&handle); + const float alpha = 1.0f; + const float beta = 0.0f; + stat = cublasSgemmStridedBatched(handle, + CUBLAS_OP_T, CUBLAS_OP_N, + T, T, HS, + &alpha, + k, HS, T * HS, + q, HS, T * HS, + &beta, + preatt, T, T * T, + B * NH); + if (stat != CUBLAS_STATUS_SUCCESS) { + printf("cublasSgemm failed\n"); + exit(1); + } + + // multiply all elements of preatt elementwise by scale + float scale = 1.0 / sqrtf(HS); + total_threads = B * NH * T * T; + num_blocks = CEIL_DIV(total_threads, block_size); + scale_kernel<<>>(preatt, scale, B, NH, T); + + // softmax. preatt is (B, NH, T, T) but we view it as (B * NH * T, T) and use the softmax kernel + int softmax_block_size = 256; + int grid_size = B * NH * T; + size_t shared_mem_size = 2 * softmax_block_size / 32 * sizeof(float); + softmax_forward_kernel4<<>>(att, preatt, B * NH * T, T); + + // new approach: first cuBLAS another batched matmul + // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) + stat = cublasSgemmStridedBatched(handle, + CUBLAS_OP_N, CUBLAS_OP_N, + HS, T, T, + &alpha, + v, HS, T * HS, + att, T, T * T, + &beta, + vaccum, HS, T * HS, + B * NH); + if (stat != CUBLAS_STATUS_SUCCESS) { + printf("cublasSgemm failed\n"); + exit(1); + } + + // now unpermute + // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + num_blocks = CEIL_DIV(B * T * C, block_size); + unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); + + // cleanups + cublasDestroy(handle); +} + +void residual_forward(float* out, float* inp1, float* inp2, int N) { + const int block_size = 256; + const int grid_size = CEIL_DIV(N, block_size); + residual_forward_kernel<<>>(out, inp1, inp2, N); + cudaCheck(cudaGetLastError()); +} + + +void gelu_forward(float* out, const float* inp, int N) { + const int block_size = 128; + const int grid_size = CEIL_DIV(N, block_size); + gelu_kernel<<>>(out, inp, N); + cudaCheck(cudaGetLastError()); +} + +void softmax_forward(float* out, float* inp, int N, int C) { + const int block_size = 256; + int grid_size = N; + size_t shared_mem_size = 2 * block_size / 32 * sizeof(float); + softmax_forward_kernel4<<>>(out, inp, N, C); +} + +void crossentropy_forward(float* losses, + float* probs, int* targets, + int B, int T, int V) { + const int block_size = 128; + const int N = B * T; + const int grid_size = CEIL_DIV(N, block_size); + crossentropy_forward_kernel1<<>>(losses, probs, targets, B, T, V); + cudaCheck(cudaGetLastError()); +} + +// ---------------------------------------------------------------------------- +// GPT-2 model definition + +// the parameters of the model +#define NUM_PARAMETER_TENSORS 16 +typedef struct { + float* wte; // (V, C) + float* wpe; // (maxT, C) + float* ln1w; // (L, C) + float* ln1b; // (L, C) + float* qkvw; // (L, 3*C, C) + float* qkvb; // (L, 3*C) + float* attprojw; // (L, C, C) + float* attprojb; // (L, C) + float* ln2w; // (L, C) + float* ln2b; // (L, C) + float* fcw; // (L, 4*C, C) + float* fcb; // (L, 4*C) + float* fcprojw; // (L, C, 4*C) + float* fcprojb; // (L, C) + float* lnfw; // (C) + float* lnfb; // (C) +} ParameterTensors; + + +// allocate memory for the parameters and point the individual tensors to the right places +float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes, int on_device) { + // on_device: 0 = CPU, 1 = GPU + // calculate the number of parameters + size_t num_parameters = 0; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters += param_sizes[i]; + } + // malloc all parameters all at once on the device + float* params_memory; + if (on_device) { + cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters * sizeof(float))); + } else { + params_memory = (float*)malloc(num_parameters * sizeof(float)); + } + // assign all the tensors their place in the array + float** ptrs[] = { + ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, + ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, + ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb + }; + float* params_memory_iterator = params_memory; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + *(ptrs[i]) = params_memory_iterator; + params_memory_iterator += param_sizes[i]; + } + return params_memory; +} + + +#define NUM_ACTIVATION_TENSORS 25 +typedef struct { + float* encoded; // (B, T, C) + float* ln1; // (L, B, T, C) + float* ln1_mean; // (L, B, T) + float* ln1_rstd; // (L, B, T) + float* qkv; // (L, B, T, 3*C) + float* atty; // (L, B, T, C) + float* preatt; // (L, B, NH, T, T) + float* att; // (L, B, NH, T, T) + float* attproj; // (L, B, T, C) + float* residual2; // (L, B, T, C) + float* ln2; // (L, B, T, C) + float* ln2_mean; // (L, B, T) + float* ln2_rstd; // (L, B, T) + float* fch; // (L, B, T, 4*C) + float* fch_gelu; // (L, B, T, 4*C) + float* fcproj; // (L, B, T, C) + float* residual3; // (L, B, T, C) + float* lnf; // (B, T, C) + float* lnf_mean; // (B, T) + float* lnf_rstd; // (B, T) + float* logits; // (B, T, V) + float* probs; // (B, T, V) + float* losses; // (B, T) + // adding these two compared to the CPU .c code, needed for attention kernel as buffers + float* qkvr; // (L, B, T, 3*C) + float* v_accum; // (L, B, T, C) +} ActivationTensors; + +float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) { + size_t num_activations = 0; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + num_activations += act_sizes[i]; + } + float* acts_memory; + cudaCheck(cudaMalloc((void**)&acts_memory, num_activations * sizeof(float))); + float** ptrs[] = { + &acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->qkv, &acts->atty, + &acts->preatt, &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean, + &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf, + &acts->lnf_mean, &acts->lnf_rstd, &acts->logits, &acts->probs, &acts->losses, + &acts->qkvr, &acts->v_accum + }; + float* acts_memory_iterator = acts_memory; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + *(ptrs[i]) = acts_memory_iterator; + acts_memory_iterator += act_sizes[i]; + } + return acts_memory; +} + +typedef struct { + int max_seq_len; // max sequence length, e.g. 1024 + int vocab_size; // vocab size, e.g. 50257 + int num_layers; // number of layers, e.g. 12 + int num_heads; // number of heads in attention, e.g. 12 + int channels; // number of channels, e.g. 768 +} GPT2Config; + +typedef struct { + GPT2Config config; + // the weights of the model, and their sizes + ParameterTensors params; + size_t param_sizes[NUM_PARAMETER_TENSORS]; + float* params_memory; + int num_parameters; + // gradients of the weights + ParameterTensors grads; + float* grads_memory; + // buffers for the AdamW optimizer + float* m_memory; + float* v_memory; + // the activations of the model, and their sizes + ActivationTensors acts; + size_t act_sizes[NUM_ACTIVATION_TENSORS]; + float* acts_memory; + int num_activations; + // gradients of the activations + ActivationTensors grads_acts; + float* grads_acts_memory; + // other run state configuration + int batch_size; // the batch size (B) of current forward pass + int seq_len; // the sequence length (T) of current forward pass + int* inputs; // the input tokens for the current forward pass + int* targets; // the target tokens for the current forward pass + float mean_loss; // after a forward pass with targets, will be populated with the mean loss +} GPT2; + + +void gpt2_build_from_checkpoint(GPT2 *model, char* checkpoint_path) { + + // read in model from a checkpoint file + FILE *model_file = fopen(checkpoint_path, "rb"); + if (model_file == NULL) { printf("Error opening model file\n"); exit(1); } + int model_header[256]; + fread(model_header, sizeof(int), 256, model_file); + if (model_header[0] != 20240326) { printf("Bad magic model file"); exit(1); } + if (model_header[1] != 1) { printf("Bad version in model file"); exit(1); } + + // read in hyperparameters + int maxT, V, L, NH, C; + model->config.max_seq_len = maxT = model_header[2]; + model->config.vocab_size = V = model_header[3]; + model->config.num_layers = L = model_header[4]; + model->config.num_heads = NH = model_header[5]; + model->config.channels = C = model_header[6]; + printf("[GPT-2]\n"); + printf("max_seq_len: %d\n", maxT); + printf("vocab_size: %d\n", V); + printf("num_layers: %d\n", L); + printf("num_heads: %d\n", NH); + printf("channels: %d\n", C); + + // allocate space for all the parameters and read them in + model->param_sizes[0] = V * C; + model->param_sizes[1] = maxT * C; + model->param_sizes[2] = L * C; + model->param_sizes[3] = L * C; + model->param_sizes[4] = L * (3 * C) * C; + model->param_sizes[5] = L * (3 * C); + model->param_sizes[6] = L * C * C; + model->param_sizes[7] = L * C; + model->param_sizes[8] = L * C; + model->param_sizes[9] = L * C; + model->param_sizes[10] = L * (4 * C) * C; + model->param_sizes[11] = L * (4 * C); + model->param_sizes[12] = L * C * (4 * C); + model->param_sizes[13] = L * C; + model->param_sizes[14] = C; + model->param_sizes[15] = C; + + // cound the number of paramaters + size_t num_parameters = 0; + for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) { + num_parameters += model->param_sizes[i]; + } + printf("num_parameters: %zu\n", num_parameters); + model->num_parameters = num_parameters; + + // create memory for model parameters on the device + model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes, 1); + + // read in all the parameters from file and copy them to device + float* params_memory_cpu = (float*)malloc(num_parameters * sizeof(float)); + fread(params_memory_cpu, sizeof(float), num_parameters, model_file); + cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, num_parameters * sizeof(float), cudaMemcpyHostToDevice)); + free(params_memory_cpu); + fclose(model_file); + + // other inits + model->acts_memory = NULL; + model->grads_memory = NULL; + model->m_memory = NULL; + model->v_memory = NULL; + model->grads_acts_memory = NULL; + model->inputs = NULL; + model->targets = NULL; + model->batch_size = 0; + model->seq_len = 0; + model->mean_loss = -1.0f; // -1.0f will designate no loss +} + +void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { + // targets are optional and could be NULL + + // ensure the model was initialized or error out + if (model->params_memory == NULL) { + printf("Error: model was not initialized properly.\n"); + exit(1); + } + + // convenience parameters + int V = model->config.vocab_size; + int L = model->config.num_layers; + int NH = model->config.num_heads; + int C = model->config.channels; + + // allocate space for all the activations if needed (done here, lazily) + if(model->acts_memory == NULL) { + // record the current B,T as well + model->batch_size = B; + model->seq_len = T; + // and now allocate the space + model->act_sizes[0] = B * T * C; + model->act_sizes[1] = L * B * T * C; + model->act_sizes[2] = L * B * T; + model->act_sizes[3] = L * B * T; + model->act_sizes[4] = L * B * T * 3*C; + model->act_sizes[5] = L * B * T * C; + model->act_sizes[6] = L * B * NH * T * T; + model->act_sizes[7] = L * B * NH * T * T; + model->act_sizes[8] = L * B * T * C; + model->act_sizes[9] = L * B * T * C; + model->act_sizes[10] = L * B * T * C; + model->act_sizes[11] = L * B * T; + model->act_sizes[12] = L * B * T; + model->act_sizes[13] = L * B * T * 4*C; + model->act_sizes[14] = L * B * T * 4*C; + model->act_sizes[15] = L * B * T * C; + model->act_sizes[16] = L * B * T * C; + model->act_sizes[17] = B * T * C; + model->act_sizes[18] = B * T; + model->act_sizes[19] = B * T; + model->act_sizes[20] = B * T * V; + model->act_sizes[21] = B * T * V; + model->act_sizes[22] = B * T; + model->act_sizes[23] = L * B * T * 3*C; // qkvr + model->act_sizes[24] = L * B * T * C; // v_accum + size_t num_activations = 0; + for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { + num_activations += model->act_sizes[i]; + } + printf("num_activations: %zu\n", num_activations); + model->num_activations = num_activations; + model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes); + // also create memory for caching inputs and targets + cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); + cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); + } else { + // validate B,T is no larger than what was previously allocated + // in principle, we could re-allocate a larger chunk of memory, for now we just error out + if (B > model->batch_size || T > model->seq_len) { + printf("Error: batch size or sequence length is inadequately large\n"); + printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, B, T); + exit(1); + } + } + + // copy inputs/targets to the model + cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); + if (targets != NULL) { + cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); + } + + // forward pass + ParameterTensors params = model->params; // for brevity + ActivationTensors acts = model->acts; + float* residual; + encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0] + + for (int l = 0; l < L; l++) { + + residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + + // get the pointers of the weights for this layer + float* l_ln1w = params.ln1w + l * C; + float* l_ln1b = params.ln1b + l * C; + float* l_qkvw = params.qkvw + l * 3*C * C; + float* l_qkvb = params.qkvb + l * 3*C; + float* l_attprojw = params.attprojw + l * C * C; + float* l_attprojb = params.attprojb + l * C; + float* l_ln2w = params.ln2w + l * C; + float* l_ln2b = params.ln2b + l * C; + float* l_fcw = params.fcw + l * 4*C * C; + float* l_fcb = params.fcb + l * 4*C; + float* l_fcprojw = params.fcprojw + l * C * 4*C; + float* l_fcprojb = params.fcprojb + l * C; + + // get the pointers of the activations for this layer + float* l_ln1 = acts.ln1 + l * B * T * C; + float* l_ln1_mean = acts.ln1_mean + l * B * T; + float* l_ln1_rstd = acts.ln1_rstd + l * B * T; + float* l_qkv = acts.qkv + l * B * T * 3*C; + float* l_qkvr = acts.qkvr + l * B * T * 3*C; + float* l_atty = acts.atty + l * B * T * C; + float* l_preatt = acts.preatt + l * B * NH * T * T; + float* l_att = acts.att + l * B * NH * T * T; + float* l_v_accum = acts.v_accum + l * B * T * C; + float* l_attproj = acts.attproj + l * B * T * C; + float* l_residual2 = acts.residual2 + l * B * T * C; + float* l_ln2 = acts.ln2 + l * B * T * C; + float* l_ln2_mean = acts.ln2_mean + l * B * T; + float* l_ln2_rstd = acts.ln2_rstd + l * B * T; + float* l_fch = acts.fch + l * B * T * 4*C; + float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C; + float* l_fcproj = acts.fcproj + l * B * T * C; + float* l_residual3 = acts.residual3 + l * B * T * C; + + // now do the forward pass + layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C); + matmul_forward(l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C); + attention_forward(l_atty, l_v_accum, l_qkvr, l_preatt, l_att, l_qkv, B, T, C, NH); + matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C); + residual_forward(l_residual2, residual, l_attproj, B*T*C); + layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C); + matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C); + gelu_forward(l_fch_gelu, l_fch, B*T*4*C); + matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C); + residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C); + } + + residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 + layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C); + matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, V); + softmax_forward(acts.probs, acts.logits, B*T, V); + + // also forward the cross-entropy loss function if we have the targets + if (targets != NULL) { + crossentropy_forward(acts.losses, acts.probs, model->targets, B, T, V); + + // for convenience also evaluate the mean loss + // move the (B,T) losses to CPU + // TODO get rid of inline mallocs + float* cpu_losses = (float*)malloc(B * T * sizeof(float)); + cudaCheck(cudaMemcpy(cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost)); + float mean_loss = 0.0f; + for (int i=0; imean_loss = mean_loss; + free(cpu_losses); + + } else { + // if we don't have targets, we don't have a loss + model->mean_loss = -1.0f; + } +} + +void gpt2_free(GPT2 *model) { + cudaCheck(cudaFree(model->params_memory)); + cudaCheck(cudaFree(model->grads_memory)); + cudaCheck(cudaFree(model->m_memory)); + cudaCheck(cudaFree(model->v_memory)); + cudaCheck(cudaFree(model->acts_memory)); + cudaCheck(cudaFree(model->grads_acts_memory)); + cudaCheck(cudaFree(model->inputs)); + cudaCheck(cudaFree(model->targets)); +} + +#ifndef TESTING +// if we are TESTING (see test_gpt2.cu), we'll skip the int main below + +// ---------------------------------------------------------------------------- +// data loader lite +// returns random batches of data from a file of integers + +typedef struct { + // hyperparameters + int B; + int T; + // input handling and its state + FILE* tokens_file; + long file_size; + long current_position; + // output memory + int* batch; + int* inputs; + int* targets; + // convenience variables + int num_batches; +} DataLoader; + +void dataloader_init(DataLoader *loader, char* filename, int B, int T) { + loader->B = B; + loader->T = T; + + // open the input file for reading + loader->tokens_file = fopen(filename, "rb"); + if (loader->tokens_file == NULL) { + printf("Error opening tokens file\n"); + exit(1); + } + + // determine the file size + fseek(loader->tokens_file, 0, SEEK_END); + loader->file_size = ftell(loader->tokens_file); + fseek(loader->tokens_file, 0, SEEK_SET); + if (loader->file_size < (B * T + 1) * sizeof(int)) { + printf("Error: file size is too small for the batch size and sequence length\n"); + exit(1); + } + loader->current_position = 0; // start at the beginning + + // allocate space for B*T + 1 integers to store the inputs and targets + loader->batch = (int*) malloc((B * T + 1) * sizeof(int)); + loader->inputs = loader->batch; + loader->targets = loader->batch + 1; // targets are shifted by one + loader->num_batches = loader->file_size / (B * T * sizeof(int)); +} + +void dataloader_reset(DataLoader *loader) { + loader->current_position = 0; +} + +void dataloader_next_batch(DataLoader *loader) { + int B = loader->B; + int T = loader->T; + // if we are at the end of the file, loop back to the beginning + if (loader->current_position + (B*T+1) * sizeof(int) > loader->file_size) { + loader->current_position = 0; + } + // read the B*T+1 integers from the file into batch + fseek(loader->tokens_file, loader->current_position, SEEK_SET); + fread(loader->batch, sizeof(int), B*T+1, loader->tokens_file); + // advance the current position by B*T integers + loader->current_position += B*T * sizeof(int); +} + +void dataloader_free(DataLoader *loader) { + fclose(loader->tokens_file); + free(loader->batch); +} + + +// ---------------------------------------------------------------------------- +// sampler + +#define GPT2_EOT 50256 + +unsigned int random_u32(unsigned long long *state) { + // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A + *state ^= *state >> 12; + *state ^= *state << 25; + *state ^= *state >> 27; + return (*state * 0x2545F4914F6CDD1Dull) >> 32; +} +float random_f32(unsigned long long *state) { // random float32 in [0,1) + return (random_u32(state) >> 8) / 16777216.0f; +} + +int sample_mult(float* probabilities, int n, float coin) { + // sample index from probabilities (they must sum to 1!) + // coin is a random number in [0, 1), usually from random_f32() + float cdf = 0.0f; + for (int i = 0; i < n; i++) { + cdf += probabilities[i]; + if (coin < cdf) { + return i; + } + } + return n - 1; // in case of rounding errors +} + +// ---------------------------------------------------------------------------- +// main training loop +int main() { + + // build the GPT-2 model from a checkpoint + GPT2 model; + gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); + + // build the DataLoaders from tokens files. for now use tiny_shakespeare if available, else tiny_stories + char* tiny_stories_train = "data/TinyStories_train.bin"; + char* tiny_stories_val = "data/TinyStories_val.bin"; + char* tiny_shakespeare_train = "data/tiny_shakespeare_train.bin"; + char* tiny_shakespeare_val = "data/tiny_shakespeare_val.bin"; + char* train_tokens = access(tiny_shakespeare_train, F_OK) != -1 ? tiny_shakespeare_train : tiny_stories_train; + char* val_tokens = access(tiny_shakespeare_val, F_OK) != -1 ? tiny_shakespeare_val : tiny_stories_val; + int B = 4; + int T = 1024; + DataLoader train_loader; + dataloader_init(&train_loader, train_tokens, B, T); + printf("train dataset num_batches: %d\n", train_loader.num_batches); + DataLoader val_loader; + dataloader_init(&val_loader, val_tokens, B, T); + printf("val dataset num_batches: %d\n", val_loader.num_batches); + int val_num_batches = 10; + printf("batch size: %d\n", B); + printf("sequence length: %d\n", T); + printf("val_num_batches: %d\n", val_num_batches); + + // some memory for generating samples from the model + unsigned long long rng_state = 1337; + const int gen_max_length = 64; + int gen_tokens[gen_max_length]; + float* cpu_probs = (float*)malloc(model.config.vocab_size * sizeof(float)); + + // train + struct timespec start, end; + for (int step = 0; step <= 40; step++) { + + // once in a while estimate the validation loss + if (step % 10 == 0) { + float val_loss = 0.0f; + dataloader_reset(&val_loader); + for (int i = 0; i < val_num_batches; i++) { + dataloader_next_batch(&val_loader); + gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T); + val_loss += model.mean_loss; + } + val_loss /= val_num_batches; + printf("val loss %f\n", val_loss); + } + + // once in a while do model inference to print generated text + if (step > 0 && step % 20 == 0) { + gen_tokens[0] = GPT2_EOT; // the GPT-2 EOT token kicks off the generation + for (int t = 1; t < gen_max_length; t++) { + // note that inference is wasteful here because + // for each t, we re-compute all activations between 0 and t + // leaving this alone because you want separate code for inference anyway + // the inference here is just for sanity checking purposes + gpt2_forward(&model, gen_tokens, NULL, 1, t); + float* probs = model.acts.probs + (t-1) * model.config.vocab_size; + float coin = random_f32(&rng_state); + // move probs back to CPU and sample + cudaCheck(cudaMemcpy(cpu_probs, probs, model.config.vocab_size * sizeof(float), cudaMemcpyDeviceToHost)); + int next_token = sample_mult(cpu_probs, model.config.vocab_size, coin); + gen_tokens[t] = next_token; + } + printf("generated: "); + for (int t = 0; t < gen_max_length; t++) { + printf("%d ", gen_tokens[t]); + } + printf("\n"); + } + + // do a training step + clock_gettime(CLOCK_MONOTONIC, &start); + dataloader_next_batch(&train_loader); + gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T); + // these are still TODO + // gpt2_zero_grad(&model); + // gpt2_backward(&model); + // gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); + cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings + clock_gettime(CLOCK_MONOTONIC, &end); + double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; + printf("step %d: train loss %f (took %f ms)\n", step, model.mean_loss, time_elapsed_s * 1000); + } + + // free + dataloader_free(&train_loader); + dataloader_free(&val_loader); + gpt2_free(&model); + free(cpu_probs); + return 0; +} +#endif \ No newline at end of file diff --git a/train_gpt2.py b/train_gpt2.py index a5d455304..79215f13a 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -292,8 +292,26 @@ def write_state(model, x, y, logits, loss, filename): if __name__ == "__main__": + import time + import argparse import tiktoken + # default settings will overfit a tiny batch of data + # and save model weights and debug state to disk on the first iteration + # if you'd like to e.g. time the forward pass only, call this script as: + # python train_gpt2.py --inference_only 1 --write_tensors 0 --sequence_length 1024 + parser = argparse.ArgumentParser() + parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk") + parser.add_argument("--inference_only", type=int, default=0, help="only run inference") + parser.add_argument("--compile", type=int, default=0, help="torch.compile the model") + parser.add_argument("--tensorcores", type=int, default=0, help="use tensorcores") + parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run") + parser.add_argument("--batch_size", type=int, default=4, help="batch size") + parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") + args = parser.parse_args() + B, T = args.batch_size, args.sequence_length + assert 1 <= T <= 1024 + # select a reasonable device to run on device = "cpu" if torch.cuda.is_available(): @@ -312,10 +330,16 @@ def write_state(model, x, y, logits, loss, filename): encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) decode = lambda l: enc.decode(l) + if args.tensorcores: + torch.set_float32_matmul_precision('high') + # load the GPT-2 model weights model = GPT.from_pretrained("gpt2") model.train() model.to(device) + if args.compile: + print("compiling the model...") + model = torch.compile(model) # load the tokens # prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories @@ -335,10 +359,6 @@ def write_state(model, x, y, logits, loss, filename): tokens = tokens.to(device) # lightweight dataloader - B = 4 # batch size - T = 64 # sequence length, up to 1024 - assert 1 <= T <= 1024 - def get_batch(): assert B*T+1 <= len(tokens), "not enough tokens" # for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping @@ -351,20 +371,24 @@ def get_batch(): if i + B*T + 1 >= len(tokens): i = 0 # in prod we'd want to randomize the start point a bit - # forward backward for 3 iterations + # forward backward for a few iterations data_iter = iter(get_batch()) x, y = next(data_iter) # we'll overfit this batch below optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) - for i in range(10): + for i in range(args.num_iterations): + t0 = time.time() logits, loss = model(x, y) - optimizer.zero_grad() - loss.backward() - # on the first iteration only, save the state dict to file for later reference - if i == 0: - write_model(model, "gpt2_124M.bin") - write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin") - optimizer.step() - print(f"iteration {i}, loss: {loss.item()}") + if not args.inference_only: + optimizer.zero_grad() + loss.backward() + # on the first iteration only, save the state dict to file for later reference + if i == 0 and args.write_tensors: + write_model(model, "gpt2_124M.bin") + write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin") + optimizer.step() + torch.cuda.synchronize() + t1 = time.time() + print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms") # before we end, let's also do one round of inference # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence