Skip to content

Commit

Permalink
ok crossentropy is fairly ez, just have to be careful with logs of ne…
Browse files Browse the repository at this point in the history
…gative numbers in debug tests. so this should be all the layers now
  • Loading branch information
karpathy committed Apr 10, 2024
1 parent 46ee0a3 commit d8e2a36
Showing 1 changed file with 198 additions and 0 deletions.
198 changes: 198 additions & 0 deletions dev/cuda/crossentropy_forward.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
Kernels for crossentropy forward pass.
Compile example:
nvcc -O3 --use_fast_math crossentropy_forward.cu -o crossentropy_forward
version 1 is a straight-forward port from CPU code to kernel, parallel over B,T
./crossentropy_forward 1
*/

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>

// ----------------------------------------------------------------------------
// CUDA utils

#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

// error checking
void cudaCheck(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line,
cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))

// ----------------------------------------------------------------------------
// CPU code reference

void crossentropy_forward_cpu(float* losses,
float* probs, int* targets,
int B, int T, int V) {
// output: losses is (B,T) of the individual losses at each position
// input: probs are (B,T,V) of the probabilities
// input: targets is (B,T) of integers giving the correct index in logits
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// loss = -log(probs[target])
float* probs_bt = probs + b * T * V + t * V;
int ix = targets[b * T + t];
losses[b * T + t] = -logf(probs_bt[ix]);
}
}
}

// ----------------------------------------------------------------------------
// GPU kernels

__global__ void crossentropy_forward_kernel1(float* losses,
float* probs, int* targets,
int B, int T, int V) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < B * T) {
int b = i / T;
int t = i % T;
float* probs_bt = probs + b * T * V + t * V;
int ix = targets[b * T + t];
losses[b * T + t] = -logf(probs_bt[ix]);
}
}

// ----------------------------------------------------------------------------
// kernel launcher

void crossentropy_forward1(float* losses,
float* probs, int* targets,
int B, int T, int V,
const int block_size) {
const int N = B * T;
const int grid_size = CEIL_DIV(N, block_size);
crossentropy_forward_kernel1<<<grid_size, block_size>>>(losses, probs, targets, B, T, V);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void crossentropy_forward(int kernel_num,
float* losses,
float* probs, int* targets,
int B, int T, int V,
const int block_size) {
switch (kernel_num) {
case 1:
crossentropy_forward1(losses, probs, targets, B, T, V, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
}
}

// ----------------------------------------------------------------------------
// random utils

float* make_random_float(int N) {
float* arr = (float*)malloc(N * sizeof(float));
for (int i = 0; i < N; i++) {
arr[i] = ((float)rand() / RAND_MAX); // [0,1)
}
return arr;
}

int* make_random_int(int N, int V) {
int* arr = (int*)malloc(N * sizeof(int));
for (int i = 0; i < N; i++) {
arr[i] = rand() % V;
}
return arr;
}

// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);

int B = 8;
int T = 1024;
int V = 50257;

int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));

// create host memory of random numbers
float* out = (float*)malloc(B * T * sizeof(float));
float* probs = make_random_float(B * T * V);
int* targets = make_random_int(B * T, V);

// move to GPU
float* d_out;
float* d_probs;
int* d_targets;
cudaCheck(cudaMalloc(&d_out, B * T * sizeof(float)));
cudaCheck(cudaMalloc(&d_probs, B * T * V * sizeof(float)));
cudaCheck(cudaMalloc(&d_targets, B * T * sizeof(int)));
cudaCheck(cudaMemcpy(d_probs, probs, B * T * V * sizeof(float), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpy(d_targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));

// read kernel_num from command line
int kernel_num = 1;
if (argc > 1) {
kernel_num = atoi(argv[1]);
}
printf("Using kernel %d\n", kernel_num);

// first check the correctness of the kernel
crossentropy_forward_cpu(out, probs, targets, B, T, V);
crossentropy_forward(kernel_num, d_out, d_probs, d_targets, B, T, V, 256);
float* out_gpu = (float*)malloc(B * T * sizeof(float));
cudaCheck(cudaMemcpy(out_gpu, d_out, B * T * sizeof(float), cudaMemcpyDeviceToHost));
for (int i = 0; i < B * T; i++) {
// print the first few comparisons
if (i < 10) {
printf("%f %f\n", out[i], out_gpu[i]);
}
// ensure correctness for all elements
if (fabs(out[i] - out_gpu[i]) > 1e-5) {
printf("Mismatch at %d: %f vs %f\n", i, out[i], out_gpu[i]);
exit(1);
}
}
printf("Results match at block_size=256!\n");

// time the kernel at different block sizes
int block_sizes[] = {32, 64, 128, 256, 512, 1024};

for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];

int repeat_times = 1000;
cudaEvent_t start, stop;
cudaCheck(cudaEventCreate(&start));
cudaCheck(cudaEventCreate(&stop));
cudaCheck(cudaEventRecord(start, 0));
for (int i = 0; i < repeat_times; i++) {
crossentropy_forward(kernel_num, d_out, d_probs, d_targets, B, T, V, block_size);
}
cudaCheck(cudaEventRecord(stop, 0));
cudaCheck(cudaEventSynchronize(start));
cudaCheck(cudaEventSynchronize(stop));
float elapsed_time;
cudaCheck(cudaEventElapsedTime(&elapsed_time, start, stop));

printf("block_size %4d | time %f ms\n", block_size, elapsed_time / repeat_times);
}

// free memory
free(out);
free(probs);
free(targets);
free(out_gpu);
cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_probs));
cudaCheck(cudaFree(d_targets));

return 0;
}

0 comments on commit d8e2a36

Please sign in to comment.