Skip to content

Commit

Permalink
sorry there should be no mallocs inside these functions, have to pass…
Browse files Browse the repository at this point in the history
… in buffer memory from outside. getting tired
  • Loading branch information
karpathy committed Apr 10, 2024
1 parent 919ae1f commit 46ee0a3
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,10 @@ void attention_forward2(float* out,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
// TODO there should be no mallocs inside any of these functions!
// not fixing this because we don't intend to use attention_forward2,
// it seems to be way too slow as is

// these are hardcoded to 32 for now
const int Bc = 32;
const int Br = 32;
Expand Down Expand Up @@ -601,7 +605,7 @@ void attention_forward2(float* out,
cudaCheck(cudaFree(v));
}

void attention_forward3(float* out, float* qkvr, float* preatt, float* att,
void attention_forward3(float* out, float* vaccum, float* qkvr, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
Expand Down Expand Up @@ -652,8 +656,6 @@ void attention_forward3(float* out, float* qkvr, float* preatt, float* att,

// new approach: first cuBLAS another batched matmul
// y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
float* vaccum;
cudaCheck(cudaMalloc(&vaccum, B * NH * T * HS * sizeof(float)));
stat = cublasSgemmStridedBatched(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
HS, T, T,
Expand All @@ -679,7 +681,7 @@ void attention_forward3(float* out, float* qkvr, float* preatt, float* att,

// kernel version dispatch
void attention_forward(int kernel_num,
float* out, float* qkvr, float* preatt, float* att,
float* out, float* vaccum, float* qkvr, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
Expand All @@ -691,7 +693,7 @@ void attention_forward(int kernel_num,
attention_forward2(out, inp, B, T, C, NH, block_size);
break;
case 3:
attention_forward3(out, qkvr, preatt, att, inp, B, T, C, NH, block_size);
attention_forward3(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size);
break;
default:
printf("Invalid kernel number\n");
Expand Down Expand Up @@ -731,11 +733,13 @@ int main(int argc, char **argv) {

// move to GPU
float* d_out;
float* d_vaccum;
float* d_qkvr;
float* d_preatt;
float* d_att;
float* d_inp;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_vaccum, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float)));
cudaCheck(cudaMalloc(&d_att, B * NH * T * T * sizeof(float)));
Expand All @@ -751,7 +755,7 @@ int main(int argc, char **argv) {

// first check the correctness of the kernel
attention_forward_cpu(out, preatt, att, inp, B, T, C, NH);
attention_forward(kernel_num, d_out, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, 256);
attention_forward(kernel_num, d_out, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, 256);

// compare the output
float* out_gpu = (float*)malloc(B * T * C * sizeof(float));
Expand Down Expand Up @@ -781,7 +785,7 @@ int main(int argc, char **argv) {
cudaCheck(cudaEventCreate(&stop));
cudaCheck(cudaEventRecord(start, 0));
for (int i = 0; i < repeat_times; i++) {
attention_forward(kernel_num, d_out, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size);
attention_forward(kernel_num, d_out, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size);
}
cudaCheck(cudaEventRecord(stop, 0));
cudaCheck(cudaEventSynchronize(start));
Expand All @@ -799,6 +803,7 @@ int main(int argc, char **argv) {
free(inp);
free(out_gpu);
cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_vaccum));
cudaCheck(cudaFree(d_qkvr));
cudaCheck(cudaFree(d_preatt));
cudaCheck(cudaFree(d_att));
Expand Down

0 comments on commit 46ee0a3

Please sign in to comment.