Skip to content

Commit

Permalink
Merge branch 'karpathy:master' into dataloader_win_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rosslwheeler authored May 24, 2024
2 parents 5e0fa45 + e60c484 commit 1ec081e
Showing 1 changed file with 61 additions and 17 deletions.
78 changes: 61 additions & 17 deletions train_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -160,32 +160,76 @@ void layernorm_backward(float* dinp, float* dweight, float* dbias,
}
}

void matmul_forward(float* out,
float* inp, float* weight, float* bias,
int B, int T, int C, int OC) {
// most of the running time is spent here and in matmul_backward
// OC is short for "output channels"
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// out will be (B,T,OC)
void matmul_forward_naive(float* out,
const float* inp, const float* weight, const float* bias,
int B, int T, int C, int OC) {
// the most naive implementation of matrix multiplication
// this serves as an algorithmic reference, and as a fallback for
// unfriendly input shapes inside matmul_forward(), below.
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* out_bt = out + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
int bt = b * T + t;
for (int o = 0; o < OC; o++) {
float val = (bias != NULL) ? bias[o] : 0.0f;
float* wrow = weight + o*C;
for (int i = 0; i < C; i++) {
val += inp_bt[i] * wrow[i];
val += inp[bt * C + i] * weight[o*C + i];
}
out[bt * OC + o] = val;
}
}
}
}

void matmul_forward(float* out,
const float* inp, const float* weight, const float* bias,
int B, int T, int C, int OC) {
// most of the running time is spent here and in matmul_backward
// therefore, the implementation below is very mildly optimized
// this function is otherwise identical to that of matmul_forward_naive()
// OC is short for "output channels"
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// out will be (B,T,OC)

// make sure the tiled loop will be correct or fallback to naive version
const int LOOP_UNROLL = 8;
if (B*T % LOOP_UNROLL != 0) {
matmul_forward_naive(out, inp, weight, bias, B, T, C, OC);
return;
}

// collapse the B and T loops into one and turn it into a strided loop.
// then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times
#pragma omp parallel for
for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) {
for (int o = 0; o < OC; o++) {
// we'll keep LOOP_UNROLL many results in registers
float result[LOOP_UNROLL];
// initialize the bias, if it exists
for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
result[ibt] = (bias != NULL) ? bias[o] : 0.0f;
}
// inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache
// the value of weight[i + o * C] and reuse it.
// we compile with -Ofast, so the compiler will turn the inner loop into FMAs
for (int i = 0; i < C; i++) {
float w = weight[i + o * C];
for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
int bt = obt + ibt;
result[ibt] += inp[bt * C + i] * w;
}
out_bt[o] = val;
}
// write back results to main memory
for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
int bt = obt + ibt;
out[bt * OC + o] = result[ibt];
}
}
}
}

void matmul_backward(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight,
const float* dout, const float* inp, const float* weight,
int B, int T, int C, int OC) {
// most of the running time is spent here and in matmul_forward
// this backward could be done in a single "round" of loops
Expand All @@ -195,10 +239,10 @@ void matmul_backward(float* dinp, float* dweight, float* dbias,
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
const float* dout_bt = dout + b * T * OC + t * OC;
float* dinp_bt = dinp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float* wrow = weight + o*C;
const float* wrow = weight + o*C;
float d = dout_bt[o];
for (int i = 0; i < C; i++) {
dinp_bt[i] += wrow[i] * d;
Expand All @@ -211,8 +255,8 @@ void matmul_backward(float* dinp, float* dweight, float* dbias,
for (int o = 0; o < OC; o++) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
const float* dout_bt = dout + b * T * OC + t * OC;
const float* inp_bt = inp + b * T * C + t * C;
float* dwrow = dweight + o*C;
float d = dout_bt[o];
if (dbias != NULL) { dbias[o] += d; }
Expand Down

0 comments on commit 1ec081e

Please sign in to comment.