1+ /*
2+ Kernels for crossentropy forward pass.
3+
4+ Compile example:
5+ nvcc -O3 --use_fast_math crossentropy_forward.cu -o crossentropy_forward
6+
7+ version 1 is a straight-forward port from CPU code to kernel, parallel over B,T
8+ ./crossentropy_forward 1
9+ */
10+
11+ #include < stdio.h>
12+ #include < stdlib.h>
13+ #include < cuda_runtime.h>
14+
15+ // ----------------------------------------------------------------------------
16+ // CUDA utils
17+
18+ #define CEIL_DIV (M, N ) (((M) + (N)-1 ) / (N))
19+
20+ // error checking
21+ void cudaCheck (cudaError_t error, const char *file, int line) {
22+ if (error != cudaSuccess) {
23+ printf (" [CUDA ERROR] at file %s:%d:\n %s\n " , file, line,
24+ cudaGetErrorString (error));
25+ exit (EXIT_FAILURE);
26+ }
27+ };
28+ #define cudaCheck (err ) (cudaCheck(err, __FILE__, __LINE__))
29+
30+ // ----------------------------------------------------------------------------
31+ // CPU code reference
32+
33+ void crossentropy_forward_cpu (float * losses,
34+ float * probs, int * targets,
35+ int B, int T, int V) {
36+ // output: losses is (B,T) of the individual losses at each position
37+ // input: probs are (B,T,V) of the probabilities
38+ // input: targets is (B,T) of integers giving the correct index in logits
39+ for (int b = 0 ; b < B; b++) {
40+ for (int t = 0 ; t < T; t++) {
41+ // loss = -log(probs[target])
42+ float * probs_bt = probs + b * T * V + t * V;
43+ int ix = targets[b * T + t];
44+ losses[b * T + t] = -logf (probs_bt[ix]);
45+ }
46+ }
47+ }
48+
49+ // ----------------------------------------------------------------------------
50+ // GPU kernels
51+
52+ __global__ void crossentropy_forward_kernel1 (float * losses,
53+ float * probs, int * targets,
54+ int B, int T, int V) {
55+ int i = blockIdx .x * blockDim .x + threadIdx .x ;
56+ if (i < B * T) {
57+ int b = i / T;
58+ int t = i % T;
59+ float * probs_bt = probs + b * T * V + t * V;
60+ int ix = targets[b * T + t];
61+ losses[b * T + t] = -logf (probs_bt[ix]);
62+ }
63+ }
64+
65+ // ----------------------------------------------------------------------------
66+ // kernel launcher
67+
68+ void crossentropy_forward1 (float * losses,
69+ float * probs, int * targets,
70+ int B, int T, int V,
71+ const int block_size) {
72+ const int N = B * T;
73+ const int grid_size = CEIL_DIV (N, block_size);
74+ crossentropy_forward_kernel1<<<grid_size, block_size>>> (losses, probs, targets, B, T, V);
75+ cudaCheck (cudaGetLastError ());
76+ }
77+
78+ // kernel version dispatch
79+ void crossentropy_forward (int kernel_num,
80+ float * losses,
81+ float * probs, int * targets,
82+ int B, int T, int V,
83+ const int block_size) {
84+ switch (kernel_num) {
85+ case 1 :
86+ crossentropy_forward1 (losses, probs, targets, B, T, V, block_size);
87+ break ;
88+ default :
89+ printf (" Invalid kernel number\n " );
90+ exit (1 );
91+ }
92+ }
93+
94+ // ----------------------------------------------------------------------------
95+ // random utils
96+
97+ float * make_random_float (int N) {
98+ float * arr = (float *)malloc (N * sizeof (float ));
99+ for (int i = 0 ; i < N; i++) {
100+ arr[i] = ((float )rand () / RAND_MAX); // [0,1)
101+ }
102+ return arr;
103+ }
104+
105+ int * make_random_int (int N, int V) {
106+ int * arr = (int *)malloc (N * sizeof (int ));
107+ for (int i = 0 ; i < N; i++) {
108+ arr[i] = rand () % V;
109+ }
110+ return arr;
111+ }
112+
113+ // ----------------------------------------------------------------------------
114+
115+ int main (int argc, char **argv) {
116+ srand (0 );
117+
118+ int B = 8 ;
119+ int T = 1024 ;
120+ int V = 50257 ;
121+
122+ int deviceIdx = 0 ;
123+ cudaCheck (cudaSetDevice (deviceIdx));
124+
125+ // create host memory of random numbers
126+ float * out = (float *)malloc (B * T * sizeof (float ));
127+ float * probs = make_random_float (B * T * V);
128+ int * targets = make_random_int (B * T, V);
129+
130+ // move to GPU
131+ float * d_out;
132+ float * d_probs;
133+ int * d_targets;
134+ cudaCheck (cudaMalloc (&d_out, B * T * sizeof (float )));
135+ cudaCheck (cudaMalloc (&d_probs, B * T * V * sizeof (float )));
136+ cudaCheck (cudaMalloc (&d_targets, B * T * sizeof (int )));
137+ cudaCheck (cudaMemcpy (d_probs, probs, B * T * V * sizeof (float ), cudaMemcpyHostToDevice));
138+ cudaCheck (cudaMemcpy (d_targets, targets, B * T * sizeof (int ), cudaMemcpyHostToDevice));
139+
140+ // read kernel_num from command line
141+ int kernel_num = 1 ;
142+ if (argc > 1 ) {
143+ kernel_num = atoi (argv[1 ]);
144+ }
145+ printf (" Using kernel %d\n " , kernel_num);
146+
147+ // first check the correctness of the kernel
148+ crossentropy_forward_cpu (out, probs, targets, B, T, V);
149+ crossentropy_forward (kernel_num, d_out, d_probs, d_targets, B, T, V, 256 );
150+ float * out_gpu = (float *)malloc (B * T * sizeof (float ));
151+ cudaCheck (cudaMemcpy (out_gpu, d_out, B * T * sizeof (float ), cudaMemcpyDeviceToHost));
152+ for (int i = 0 ; i < B * T; i++) {
153+ // print the first few comparisons
154+ if (i < 10 ) {
155+ printf (" %f %f\n " , out[i], out_gpu[i]);
156+ }
157+ // ensure correctness for all elements
158+ if (fabs (out[i] - out_gpu[i]) > 1e-5 ) {
159+ printf (" Mismatch at %d: %f vs %f\n " , i, out[i], out_gpu[i]);
160+ exit (1 );
161+ }
162+ }
163+ printf (" Results match at block_size=256!\n " );
164+
165+ // time the kernel at different block sizes
166+ int block_sizes[] = {32 , 64 , 128 , 256 , 512 , 1024 };
167+
168+ for (int j = 0 ; j < sizeof (block_sizes) / sizeof (int ); j++) {
169+ int block_size = block_sizes[j];
170+
171+ int repeat_times = 1000 ;
172+ cudaEvent_t start, stop;
173+ cudaCheck (cudaEventCreate (&start));
174+ cudaCheck (cudaEventCreate (&stop));
175+ cudaCheck (cudaEventRecord (start, 0 ));
176+ for (int i = 0 ; i < repeat_times; i++) {
177+ crossentropy_forward (kernel_num, d_out, d_probs, d_targets, B, T, V, block_size);
178+ }
179+ cudaCheck (cudaEventRecord (stop, 0 ));
180+ cudaCheck (cudaEventSynchronize (start));
181+ cudaCheck (cudaEventSynchronize (stop));
182+ float elapsed_time;
183+ cudaCheck (cudaEventElapsedTime (&elapsed_time, start, stop));
184+
185+ printf (" block_size %4d | time %f ms\n " , block_size, elapsed_time / repeat_times);
186+ }
187+
188+ // free memory
189+ free (out);
190+ free (probs);
191+ free (targets);
192+ free (out_gpu);
193+ cudaCheck (cudaFree (d_out));
194+ cudaCheck (cudaFree (d_probs));
195+ cudaCheck (cudaFree (d_targets));
196+
197+ return 0 ;
198+ }
0 commit comments