diff --git a/train_gpt2.c b/train_gpt2.c index 57bdfe929..b01abf09f 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -160,32 +160,76 @@ void layernorm_backward(float* dinp, float* dweight, float* dbias, } } -void matmul_forward(float* out, - float* inp, float* weight, float* bias, - int B, int T, int C, int OC) { - // most of the running time is spent here and in matmul_backward - // OC is short for "output channels" - // inp is (B,T,C), weight is (OC, C), bias is (OC) - // out will be (B,T,OC) +void matmul_forward_naive(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC) { + // the most naive implementation of matrix multiplication + // this serves as an algorithmic reference, and as a fallback for + // unfriendly input shapes inside matmul_forward(), below. #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* out_bt = out + b * T * OC + t * OC; - float* inp_bt = inp + b * T * C + t * C; + int bt = b * T + t; for (int o = 0; o < OC; o++) { float val = (bias != NULL) ? bias[o] : 0.0f; - float* wrow = weight + o*C; for (int i = 0; i < C; i++) { - val += inp_bt[i] * wrow[i]; + val += inp[bt * C + i] * weight[o*C + i]; + } + out[bt * OC + o] = val; + } + } + } +} + +void matmul_forward(float* out, + const float* inp, const float* weight, const float* bias, + int B, int T, int C, int OC) { + // most of the running time is spent here and in matmul_backward + // therefore, the implementation below is very mildly optimized + // this function is otherwise identical to that of matmul_forward_naive() + // OC is short for "output channels" + // inp is (B,T,C), weight is (OC, C), bias is (OC) + // out will be (B,T,OC) + + // make sure the tiled loop will be correct or fallback to naive version + const int LOOP_UNROLL = 8; + if (B*T % LOOP_UNROLL != 0) { + matmul_forward_naive(out, inp, weight, bias, B, T, C, OC); + return; + } + + // collapse the B and T loops into one and turn it into a strided loop. + // then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times + #pragma omp parallel for + for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) { + for (int o = 0; o < OC; o++) { + // we'll keep LOOP_UNROLL many results in registers + float result[LOOP_UNROLL]; + // initialize the bias, if it exists + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { + result[ibt] = (bias != NULL) ? bias[o] : 0.0f; + } + // inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache + // the value of weight[i + o * C] and reuse it. + // we compile with -Ofast, so the compiler will turn the inner loop into FMAs + for (int i = 0; i < C; i++) { + float w = weight[i + o * C]; + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { + int bt = obt + ibt; + result[ibt] += inp[bt * C + i] * w; } - out_bt[o] = val; + } + // write back results to main memory + for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) { + int bt = obt + ibt; + out[bt * OC + o] = result[ibt]; } } } } void matmul_backward(float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, + const float* dout, const float* inp, const float* weight, int B, int T, int C, int OC) { // most of the running time is spent here and in matmul_forward // this backward could be done in a single "round" of loops @@ -195,10 +239,10 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, #pragma omp parallel for collapse(2) for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* dout_bt = dout + b * T * OC + t * OC; + const float* dout_bt = dout + b * T * OC + t * OC; float* dinp_bt = dinp + b * T * C + t * C; for (int o = 0; o < OC; o++) { - float* wrow = weight + o*C; + const float* wrow = weight + o*C; float d = dout_bt[o]; for (int i = 0; i < C; i++) { dinp_bt[i] += wrow[i] * d; @@ -211,8 +255,8 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, for (int o = 0; o < OC; o++) { for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - float* dout_bt = dout + b * T * OC + t * OC; - float* inp_bt = inp + b * T * C + t * C; + const float* dout_bt = dout + b * T * OC + t * OC; + const float* inp_bt = inp + b * T * C + t * C; float* dwrow = dweight + o*C; float d = dout_bt[o]; if (dbias != NULL) { dbias[o] += d; }