Skip to content

Commit

Permalink
small cosmetic changes
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 24, 2024
1 parent 6e4296f commit 3221e4b
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions train_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ void layernorm_backward(float* dinp, float* dweight, float* dbias,
}
}

void matmul_forward_slow(float* out,
void matmul_forward_naive(float* out,
const float* inp, const float* weight, const float* bias,
int B, int T, int C, int OC) {
// basic implementation of matrix multiplication. This serves as a fallback
// for bad input shapes, and as an illustration for the most basic version
// of the algorithm.
#pragma omp parallel for collapse(2)
// the most naive implementation of matrix multiplication
// this serves as an algorithmic reference, and as a fallback for
// unfriendly input shapes inside matmul_forward(), below.
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
int bt = b * T + t;
Expand All @@ -185,42 +185,42 @@ void matmul_forward(float* out,
const float* inp, const float* weight, const float* bias,
int B, int T, int C, int OC) {
// most of the running time is spent here and in matmul_backward
// therefore, the implementation below is very mildly optimized
// this function is otherwise identical to that of matmul_forward_naive()
// OC is short for "output channels"
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// out will be (B,T,OC)

// make sure the tiled loop will be correct, otherwise, fallback to slow version
// make sure the tiled loop will be correct or fallback to naive version
const int LOOP_UNROLL = 8;
if (B*T % LOOP_UNROLL != 0) {
matmul_forward_slow(out, inp, weight, bias, B, T, C, OC);
matmul_forward_naive(out, inp, weight, bias, B, T, C, OC);
return;
}

// collapse the B and T loops into one and turn it into a strided loop.
// then we can tile the inner loop, and reuse the loaded weight LOOP_UNROLL many times
// for significant speed-ups.
#pragma omp parallel for
for (int obt = 0; obt < B * T; obt += LOOP_UNROLL) {
for (int o = 0; o < OC; o++) {
// keep LOOP_UNROLL many results in register, initialized by the bias term.
// we'll keep LOOP_UNROLL many results in registers
float result[LOOP_UNROLL];
for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) {
// initialize the bias, if it exists
for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
result[ibt] = (bias != NULL) ? bias[o] : 0.0f;
}

// inner loops. Because we do LOOP_UNROLL steps of inner bt, we can cache
// the value of weight[i + o * C] and reuse it.
// we compile with -Ofast, so the compiler will turn the inner loop into a bunch of FMAs
// we compile with -Ofast, so the compiler will turn the inner loop into FMAs
for (int i = 0; i < C; i++) {
float w = weight[i + o * C];
for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) {
for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
int bt = obt + ibt;
result[ibt] += inp[bt * C + i] * w;
}
}

// write back results to main memory
for (int ibt = 0; ibt < LOOP_UNROLL; ++ibt) {
for (int ibt = 0; ibt < LOOP_UNROLL; ibt++) {
int bt = obt + ibt;
out[bt * OC + o] = result[ibt];
}
Expand Down

0 comments on commit 3221e4b

Please sign in to comment.