From 06980b6b2f3f541e5bd2c1f3a60744e0a3a77616 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Wed, 1 May 2024 22:34:49 +0300 Subject: [PATCH 1/3] mixed precision utilities for dev/cuda --- dev/cuda/Makefile | 12 ++++--- dev/cuda/adamw.cu | 10 +----- dev/cuda/attention_backward.cu | 24 ++++---------- dev/cuda/attention_forward.cu | 3 +- dev/cuda/common.h | 60 ++++++++++++++++++++++++++++++++++ dev/cuda/gelu_forward.cu | 10 +----- dev/cuda/layernorm_backward.cu | 12 +------ dev/cuda/matmul_backward.cu | 5 --- dev/cuda/matmul_forward.cu | 10 ------ dev/cuda/trimat_forward.cu | 7 +--- 10 files changed, 79 insertions(+), 74 deletions(-) diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index b789dd316..af4d7d554 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -10,14 +10,15 @@ endif # Compiler flags CFLAGS = -O3 --use_fast_math +NVCCFLAGS = -lcublas -lcublasLt MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ # Default rule for our CUDA files %: %.cu - $(NVCC) $(CFLAGS) $< -o $@ -lcublas + $(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@ # Build all targets -TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward +TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward global_norm all: $(TARGETS) # Individual targets: forward pass @@ -32,7 +33,7 @@ softmax_forward: softmax_forward.cu trimat_forward: trimat_forward.cu # matmul fwd/bwd also uses OpenMP (optionally) and cuBLASLt libs matmul_forward: matmul_forward.cu - $(NVCC) $(CFLAGS) -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward -lcublas -lcublasLt + $(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward # Individual targets: backward pass attention_backward: attention_backward.cu @@ -41,14 +42,15 @@ encoder_backward: encoder_backward.cu layernorm_backward: layernorm_backward.cu matmul_backward_bias: matmul_backward_bias.cu matmul_backward: matmul_backward.cu - $(NVCC) $(CFLAGS) -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward -lcublas + $(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward # Update kernels adamw: adamw.cu +global_norm: global_norm.cu # NCCL communication kernels nccl_all_reduce: nccl_all_reduce.cu - $(NVCC) -lmpi -lnccl $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce + $(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce # Run all targets run_all: all diff --git a/dev/cuda/adamw.cu b/dev/cuda/adamw.cu index 15e9048cd..23770b2c3 100644 --- a/dev/cuda/adamw.cu +++ b/dev/cuda/adamw.cu @@ -145,7 +145,7 @@ void adamw(int kernel_num, // ---------------------------------------------------------------------------- int main(int argc, char **argv) { - srand(0); + setup_main(); const long num_parameters = 1048576; const int t = 10; @@ -156,14 +156,6 @@ int main(int argc, char **argv) { const float eps = 1e-8f; const float weight_decay = 0.0f; - - // set up the device - int deviceIdx = 0; - cudaCheck(cudaSetDevice(deviceIdx)); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, deviceIdx); - printf("Device %d: %s\n", deviceIdx, deviceProp.name); - // create random data on host (to be used for the CPU reference implementation) float* params_memory = make_random_float(num_parameters); float* grads_memory = make_random_float(num_parameters); diff --git a/dev/cuda/attention_backward.cu b/dev/cuda/attention_backward.cu index 88ce4ae71..8e673d79f 100644 --- a/dev/cuda/attention_backward.cu +++ b/dev/cuda/attention_backward.cu @@ -31,11 +31,6 @@ OMP_NUM_THREADS=32 ./attention_backward 5 #include #include "common.h" -// ---------------------------------------------------------------------------- -// CUDA setup - -static cublasHandle_t cublas_handle; - // ---------------------------------------------------------------------------- // CPU code reference @@ -984,7 +979,7 @@ void attention_backward(int kernel_num, // ---------------------------------------------------------------------------- int main(int argc, char **argv) { - srand(0); // reproducibility + setup_main(); // hyperparameters int B = 4; @@ -992,11 +987,6 @@ int main(int argc, char **argv) { int C = 768; int NH = 12; - // set up CUDA / cuBLAS - int deviceIdx = 0; - cudaCheck(cudaSetDevice(deviceIdx)); - cublasCreate(&cublas_handle); - // read kernel_num from command line int kernel_num = 1; if (argc > 1) { @@ -1032,9 +1022,9 @@ int main(int argc, char **argv) { // check that preatt, att, and out match between the CPU and GPU versions printf("Checking the forward pass CPU <-> GPU...\n"); - printf("[preatt]\n"); validate_result(d_preatt, preatt, "preatt", B * T * C, 1e-4f); - printf("[att]\n"); validate_result(d_att, att, "att", B * T * C, 1e-4f); - printf("[out]\n"); validate_result(d_out, out, "out", B * T * C, 1e-4f); + printf("[preatt]\n"); validate_result(d_preatt, preatt, "preatt", B * T * C, 5e-3f); + printf("[att]\n"); validate_result(d_att, att, "att", B * T * C, 1e-3f); + printf("[out]\n"); validate_result(d_out, out, "out", B * T * C, 1e-3f); // set up the memory for the backward pass float* dout = make_random_float(B * T * C); // the gradients on the output @@ -1072,9 +1062,9 @@ int main(int argc, char **argv) { // the gradients at qkvr and vaccum will remain unchecked, but are // assumed to be correct if the other gradients are correct printf("Checking the backward pass CPU <-> GPU...\n"); - printf("[datt]\n"); validate_result(d_datt, datt, "datt", B * NH * T * T, 1e-4f); - printf("[dpreatt]\n"); validate_result(d_dpreatt, dpreatt, "dpreatt", B * NH * T * T, 1e-4f); - printf("[dinp]\n"); validate_result(d_dinp, dinp, "dinp", B * T * 3 * C, 1e-4f); + printf("[datt]\n"); validate_result(d_datt, datt, "datt", B * NH * T * T, 5e-3f); + printf("[dpreatt]\n"); validate_result(d_dpreatt, dpreatt, "dpreatt", B * NH * T * T, 1e-3f); + printf("[dinp]\n"); validate_result(d_dinp, dinp, "dinp", B * T * 3 * C, 1e-3f); // also let's manually step through the gradients here float* h_dinp = (float*)malloc(B * T * 3 * C * sizeof(float)); diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index b75230f1e..a9325f085 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -64,7 +64,6 @@ typedef __nv_bfloat16 floatX; // half or __nv_bfloat16 (or float) // ---------------------------------------------------------------------------- // CUDA & cuDNN setup -static cublasHandle_t cublas_handle; static bool first_run_validation = true; // always run e.g. permute on 1st run #ifdef ENABLE_CUDNN @@ -1285,7 +1284,7 @@ void attention_forward(int kernel_num, // ---------------------------------------------------------------------------- int main(int argc, char **argv) { - srand(0); + setup_main(); int B = 8; int T = 1024; diff --git a/dev/cuda/common.h b/dev/cuda/common.h index ae352a5bc..77e012fcd 100644 --- a/dev/cuda/common.h +++ b/dev/cuda/common.h @@ -2,6 +2,7 @@ #include #include #include +#include template @@ -32,6 +33,21 @@ void cublasCheck(cublasStatus_t status, const char *file, int line) } #define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); } +// ---------------------------------------------------------------------------- +// cuBLAS setup +// these will be initialized by setup_main + +// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK +static size_t cublaslt_workspace_size = 32 * 1024 * 1024; +static void* cublaslt_workspace = NULL; +static cublasComputeType_t cublas_compute_type; +cublasHandle_t cublas_handle; +cublasLtHandle_t cublaslt_handle; +int cuda_arch_major = 0; +int cuda_arch_minor = 0; +int cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM +int cuda_threads_per_SM = 0; // needed to calculate how many blocks to launch to fill up the GPU + // ---------------------------------------------------------------------------- // Packed128 data structure, which forces the compiler to use 128-bit loads/stores // in GPUs that support (the LDG.128 and STS.128 instructions) @@ -135,6 +151,50 @@ float* make_ones_float(size_t N) { // ---------------------------------------------------------------------------- // testing and benchmarking utils +template +[[nodiscard]] cudaError_t memcpy_convert(TargetType* d_ptr, float* h_ptr, size_t count) { + // copy from host to device with data type conversion. + TargetType* converted = (TargetType*)malloc(count * sizeof(TargetType)); + for (int i = 0; i < count; i++) { + converted[i] = (TargetType)h_ptr[i]; + } + + cudaError_t status = cudaMemcpy(d_ptr, converted, count * sizeof(TargetType), cudaMemcpyHostToDevice); + free(converted); + + // instead of checking the status at cudaMemcpy, we return it from here. This way, we + // still need to use our checking macro, and get better line info as to where the error + // happened. + return status; +} + +void setup_main() { + srand(0); // determinism + + // set up the device + int deviceIdx = 0; + cudaCheck(cudaSetDevice(deviceIdx)); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, deviceIdx); + cuda_num_SMs = deviceProp.multiProcessorCount; + cuda_threads_per_SM = deviceProp.maxThreadsPerMultiProcessor; + cuda_arch_major = deviceProp.major; + cuda_arch_minor = deviceProp.minor; + + // setup cuBLAS and cuBLASLt + cublasCheck(cublasCreate(&cublas_handle)); + cublasCheck(cublasLtCreate(&cublaslt_handle)); + cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); + + // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') + int enable_tf32 = cuda_arch_major >= 8 ? 1 : 0; + // TODO implement common CLI for all tests/benchmarks + // if (override_enable_tf32 == 0) { enable_tf32 = 0; } // force to zero via arg + cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; + cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); +} + template void validate_result(D* device_result, const T* cpu_reference, const char* name, std::size_t num_elements, T tolerance=1e-4) { D* out_gpu = (D*)malloc(num_elements * sizeof(D)); diff --git a/dev/cuda/gelu_forward.cu b/dev/cuda/gelu_forward.cu index 6a582d509..bb40a329e 100644 --- a/dev/cuda/gelu_forward.cu +++ b/dev/cuda/gelu_forward.cu @@ -145,14 +145,7 @@ int main(int argc, const char **argv) { floatX* d_inp; cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX))); cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX))); - - floatX* inpX = (floatX*)malloc(B * T * C * sizeof(floatX)); - - for (int i = 0; i < B * T * C; i++) { - inpX[i] = (floatX)inp[i]; - } - - cudaCheck(cudaMemcpy(d_inp, inpX, B * T * C * sizeof(floatX), cudaMemcpyHostToDevice)); + cudaCheck(memcpy_convert(d_inp, inp, B * T * C)); // time the kernel at different block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; @@ -191,7 +184,6 @@ int main(int argc, const char **argv) { // free memory free(out); free(inp); - free(inpX); cudaCheck(cudaFree(d_out)); cudaCheck(cudaFree(d_inp)); diff --git a/dev/cuda/layernorm_backward.cu b/dev/cuda/layernorm_backward.cu index 4e95cabec..222de6f72 100644 --- a/dev/cuda/layernorm_backward.cu +++ b/dev/cuda/layernorm_backward.cu @@ -19,10 +19,6 @@ version 2 moves a lot of reduction to shared memory over global memory #include #include "common.h" -// ---------------------------------------------------------------------------- -// CUDA settings -int cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM - // turn on bf16 as default, done up here for now #define ENABLE_BF16 @@ -757,18 +753,12 @@ void layernorm_backward(int kernel_num, // ---------------------------------------------------------------------------- int main(int argc, char **argv) { - srand(0); + setup_main(); int B = 8; int T = 1024; int C = 768; - int deviceIdx = 0; - cudaCheck(cudaSetDevice(deviceIdx)); - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, deviceIdx); - cuda_num_SMs = deviceProp.multiProcessorCount; - // first do the forward pass in CPU float* out = (float*)malloc(B * T * C * sizeof(float)); float* mean = (float*)malloc(B * T * sizeof(float)); diff --git a/dev/cuda/matmul_backward.cu b/dev/cuda/matmul_backward.cu index 806bb9818..9d3763930 100644 --- a/dev/cuda/matmul_backward.cu +++ b/dev/cuda/matmul_backward.cu @@ -14,11 +14,6 @@ OMP_NUM_THREADS=32 ./matmul_backward 1 #include #include "common.h" -// ---------------------------------------------------------------------------- -// CUDA / cuBLAS setup - -static cublasHandle_t cublas_handle; - // ---------------------------------------------------------------------------- // CPU code reference diff --git a/dev/cuda/matmul_forward.cu b/dev/cuda/matmul_forward.cu index ec13805a3..fe22729e7 100644 --- a/dev/cuda/matmul_forward.cu +++ b/dev/cuda/matmul_forward.cu @@ -23,16 +23,6 @@ OMP_NUM_THREADS=32 ./matmul_forward 3 #include #include "common.h" -// ---------------------------------------------------------------------------- -// CUDA setup - -// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK -static cublasHandle_t cublas_handle; -static cublasLtHandle_t cublaslt_handle; -static size_t cublaslt_workspace_size = 32 * 1024 * 1024; -static void* cublaslt_workspace = NULL; -static cublasComputeType_t cublas_compute_type; - // ---------------------------------------------------------------------------- // CPU code reference diff --git a/dev/cuda/trimat_forward.cu b/dev/cuda/trimat_forward.cu index af8981a95..133ced16f 100644 --- a/dev/cuda/trimat_forward.cu +++ b/dev/cuda/trimat_forward.cu @@ -33,7 +33,6 @@ tri4 #include #include "common.h" -static cublasHandle_t cublas_handle; static float* d_qkvr; // scratch for the cublas kernel /* ** Chapter I - Introduction ** @@ -532,17 +531,13 @@ void trimul_gpu(int kernel_num, int main(int argc, char **argv) { - srand(0); + setup_main(); int B = 8; int T = 1024; int C = 768; int NH = 12; - int deviceIdx = 0; - cudaCheck(cudaSetDevice(deviceIdx)); - cublasCreate(&cublas_handle); - // create host memory of random numbers float* out = (float*)malloc(B * NH * T * T * sizeof(float)); float* inp = make_random_float(B * T * 3 * C); From 91bc72d0c0478be3b97fcc3e506dd3dd4137a3f2 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Wed, 1 May 2024 22:40:26 +0300 Subject: [PATCH 2/3] fixup --- dev/cuda/gelu_forward.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dev/cuda/gelu_forward.cu b/dev/cuda/gelu_forward.cu index bb40a329e..27aa9d598 100644 --- a/dev/cuda/gelu_forward.cu +++ b/dev/cuda/gelu_forward.cu @@ -117,15 +117,12 @@ void gelu_forward(int kernel_num, // ---------------------------------------------------------------------------- int main(int argc, const char **argv) { - srand(0); + setup_main(); int B = 8; int T = 1024; int C = 768; - int deviceIdx = 0; - cudaCheck(cudaSetDevice(deviceIdx)); - // create host memory of random numbers float* out = (float*)malloc(B * T * C * sizeof(float)); float* inp = make_random_float(B * T * C); From b42db706b22603e744c934589ed2b5b62937627e Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Wed, 1 May 2024 22:47:24 +0300 Subject: [PATCH 3/3] this kernel will be released later, sorry for the spoiler :) --- dev/cuda/Makefile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dev/cuda/Makefile b/dev/cuda/Makefile index af4d7d554..af5fded38 100644 --- a/dev/cuda/Makefile +++ b/dev/cuda/Makefile @@ -18,7 +18,7 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux- $(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@ # Build all targets -TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward global_norm +TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward all: $(TARGETS) # Individual targets: forward pass @@ -46,7 +46,6 @@ matmul_backward: matmul_backward.cu # Update kernels adamw: adamw.cu -global_norm: global_norm.cu # NCCL communication kernels nccl_all_reduce: nccl_all_reduce.cu