Skip to content

Commit

Permalink
Merge pull request karpathy#325 from ngc92/dev-cuda-utils
Browse files Browse the repository at this point in the history
mixed precision utilities for dev/cuda
  • Loading branch information
karpathy authored May 1, 2024
2 parents d37639a + b42db70 commit ab95a11
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 77 deletions.
9 changes: 5 additions & 4 deletions dev/cuda/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ endif

# Compiler flags
CFLAGS = -O3 --use_fast_math
NVCCFLAGS = -lcublas -lcublasLt
MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/

# Default rule for our CUDA files
%: %.cu
$(NVCC) $(CFLAGS) $< -o $@ -lcublas
$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@

# Build all targets
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward
Expand All @@ -32,7 +33,7 @@ softmax_forward: softmax_forward.cu
trimat_forward: trimat_forward.cu
# matmul fwd/bwd also uses OpenMP (optionally) and cuBLASLt libs
matmul_forward: matmul_forward.cu
$(NVCC) $(CFLAGS) -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward -lcublas -lcublasLt
$(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward

# Individual targets: backward pass
attention_backward: attention_backward.cu
Expand All @@ -41,14 +42,14 @@ encoder_backward: encoder_backward.cu
layernorm_backward: layernorm_backward.cu
matmul_backward_bias: matmul_backward_bias.cu
matmul_backward: matmul_backward.cu
$(NVCC) $(CFLAGS) -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward -lcublas
$(NVCC) $(CFLAGS) $(NVCCFLAGS) -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward

# Update kernels
adamw: adamw.cu

# NCCL communication kernels
nccl_all_reduce: nccl_all_reduce.cu
$(NVCC) -lmpi -lnccl $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce
$(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce

# Run all targets
run_all: all
Expand Down
10 changes: 1 addition & 9 deletions dev/cuda/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ void adamw(int kernel_num,
// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);
setup_main();

const long num_parameters = 1048576;
const int t = 10;
Expand All @@ -156,14 +156,6 @@ int main(int argc, char **argv) {
const float eps = 1e-8f;
const float weight_decay = 0.0f;


// set up the device
int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

// create random data on host (to be used for the CPU reference implementation)
float* params_memory = make_random_float(num_parameters);
float* grads_memory = make_random_float(num_parameters);
Expand Down
24 changes: 7 additions & 17 deletions dev/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ OMP_NUM_THREADS=32 ./attention_backward 5
#include <cooperative_groups/scan.h>
#include "common.h"

// ----------------------------------------------------------------------------
// CUDA setup

static cublasHandle_t cublas_handle;

// ----------------------------------------------------------------------------
// CPU code reference

Expand Down Expand Up @@ -984,19 +979,14 @@ void attention_backward(int kernel_num,
// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0); // reproducibility
setup_main();

// hyperparameters
int B = 4;
int T = 1024;
int C = 768;
int NH = 12;

// set up CUDA / cuBLAS
int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
cublasCreate(&cublas_handle);

// read kernel_num from command line
int kernel_num = 1;
if (argc > 1) {
Expand Down Expand Up @@ -1032,9 +1022,9 @@ int main(int argc, char **argv) {

// check that preatt, att, and out match between the CPU and GPU versions
printf("Checking the forward pass CPU <-> GPU...\n");
printf("[preatt]\n"); validate_result(d_preatt, preatt, "preatt", B * T * C, 1e-4f);
printf("[att]\n"); validate_result(d_att, att, "att", B * T * C, 1e-4f);
printf("[out]\n"); validate_result(d_out, out, "out", B * T * C, 1e-4f);
printf("[preatt]\n"); validate_result(d_preatt, preatt, "preatt", B * T * C, 5e-3f);
printf("[att]\n"); validate_result(d_att, att, "att", B * T * C, 1e-3f);
printf("[out]\n"); validate_result(d_out, out, "out", B * T * C, 1e-3f);

// set up the memory for the backward pass
float* dout = make_random_float(B * T * C); // the gradients on the output
Expand Down Expand Up @@ -1072,9 +1062,9 @@ int main(int argc, char **argv) {
// the gradients at qkvr and vaccum will remain unchecked, but are
// assumed to be correct if the other gradients are correct
printf("Checking the backward pass CPU <-> GPU...\n");
printf("[datt]\n"); validate_result(d_datt, datt, "datt", B * NH * T * T, 1e-4f);
printf("[dpreatt]\n"); validate_result(d_dpreatt, dpreatt, "dpreatt", B * NH * T * T, 1e-4f);
printf("[dinp]\n"); validate_result(d_dinp, dinp, "dinp", B * T * 3 * C, 1e-4f);
printf("[datt]\n"); validate_result(d_datt, datt, "datt", B * NH * T * T, 5e-3f);
printf("[dpreatt]\n"); validate_result(d_dpreatt, dpreatt, "dpreatt", B * NH * T * T, 1e-3f);
printf("[dinp]\n"); validate_result(d_dinp, dinp, "dinp", B * T * 3 * C, 1e-3f);

// also let's manually step through the gradients here
float* h_dinp = (float*)malloc(B * T * 3 * C * sizeof(float));
Expand Down
3 changes: 1 addition & 2 deletions dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ typedef __nv_bfloat16 floatX; // half or __nv_bfloat16 (or float)

// ----------------------------------------------------------------------------
// CUDA & cuDNN setup
static cublasHandle_t cublas_handle;
static bool first_run_validation = true; // always run e.g. permute on 1st run

#ifdef ENABLE_CUDNN
Expand Down Expand Up @@ -1285,7 +1284,7 @@ void attention_forward(int kernel_num,
// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);
setup_main();

int B = 8;
int T = 1024;
Expand Down
60 changes: 60 additions & 0 deletions dev/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <stdio.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cublasLt.h>


template<class T>
Expand Down Expand Up @@ -32,6 +33,21 @@ void cublasCheck(cublasStatus_t status, const char *file, int line)
}
#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }

// ----------------------------------------------------------------------------
// cuBLAS setup
// these will be initialized by setup_main

// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK
static size_t cublaslt_workspace_size = 32 * 1024 * 1024;
static void* cublaslt_workspace = NULL;
static cublasComputeType_t cublas_compute_type;
cublasHandle_t cublas_handle;
cublasLtHandle_t cublaslt_handle;
int cuda_arch_major = 0;
int cuda_arch_minor = 0;
int cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM
int cuda_threads_per_SM = 0; // needed to calculate how many blocks to launch to fill up the GPU

// ----------------------------------------------------------------------------
// Packed128 data structure, which forces the compiler to use 128-bit loads/stores
// in GPUs that support (the LDG.128 and STS.128 instructions)
Expand Down Expand Up @@ -135,6 +151,50 @@ float* make_ones_float(size_t N) {
// ----------------------------------------------------------------------------
// testing and benchmarking utils

template<class TargetType>
[[nodiscard]] cudaError_t memcpy_convert(TargetType* d_ptr, float* h_ptr, size_t count) {
// copy from host to device with data type conversion.
TargetType* converted = (TargetType*)malloc(count * sizeof(TargetType));
for (int i = 0; i < count; i++) {
converted[i] = (TargetType)h_ptr[i];
}

cudaError_t status = cudaMemcpy(d_ptr, converted, count * sizeof(TargetType), cudaMemcpyHostToDevice);
free(converted);

// instead of checking the status at cudaMemcpy, we return it from here. This way, we
// still need to use our checking macro, and get better line info as to where the error
// happened.
return status;
}

void setup_main() {
srand(0); // determinism

// set up the device
int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
cuda_num_SMs = deviceProp.multiProcessorCount;
cuda_threads_per_SM = deviceProp.maxThreadsPerMultiProcessor;
cuda_arch_major = deviceProp.major;
cuda_arch_minor = deviceProp.minor;

// setup cuBLAS and cuBLASLt
cublasCheck(cublasCreate(&cublas_handle));
cublasCheck(cublasLtCreate(&cublaslt_handle));
cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size));

// TF32 precision is equivalent to torch.set_float32_matmul_precision('high')
int enable_tf32 = cuda_arch_major >= 8 ? 1 : 0;
// TODO implement common CLI for all tests/benchmarks
// if (override_enable_tf32 == 0) { enable_tf32 = 0; } // force to zero via arg
cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F;
cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH;
cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode));
}

template<class D, class T>
void validate_result(D* device_result, const T* cpu_reference, const char* name, std::size_t num_elements, T tolerance=1e-4) {
D* out_gpu = (D*)malloc(num_elements * sizeof(D));
Expand Down
15 changes: 2 additions & 13 deletions dev/cuda/gelu_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,12 @@ void gelu_forward(int kernel_num,
// ----------------------------------------------------------------------------

int main(int argc, const char **argv) {
srand(0);
setup_main();

int B = 8;
int T = 1024;
int C = 768;

int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));

// create host memory of random numbers
float* out = (float*)malloc(B * T * C * sizeof(float));
float* inp = make_random_float(B * T * C);
Expand All @@ -145,14 +142,7 @@ int main(int argc, const char **argv) {
floatX* d_inp;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(floatX)));

floatX* inpX = (floatX*)malloc(B * T * C * sizeof(floatX));

for (int i = 0; i < B * T * C; i++) {
inpX[i] = (floatX)inp[i];
}

cudaCheck(cudaMemcpy(d_inp, inpX, B * T * C * sizeof(floatX), cudaMemcpyHostToDevice));
cudaCheck(memcpy_convert(d_inp, inp, B * T * C));

// time the kernel at different block sizes
int block_sizes[] = {32, 64, 128, 256, 512, 1024};
Expand Down Expand Up @@ -191,7 +181,6 @@ int main(int argc, const char **argv) {
// free memory
free(out);
free(inp);
free(inpX);

cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_inp));
Expand Down
12 changes: 1 addition & 11 deletions dev/cuda/layernorm_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ version 2 moves a lot of reduction to shared memory over global memory
#include <cooperative_groups/reduce.h>
#include "common.h"

// ----------------------------------------------------------------------------
// CUDA settings
int cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM

// turn on bf16 as default, done up here for now
#define ENABLE_BF16

Expand Down Expand Up @@ -757,18 +753,12 @@ void layernorm_backward(int kernel_num,
// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);
setup_main();

int B = 8;
int T = 1024;
int C = 768;

int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
cuda_num_SMs = deviceProp.multiProcessorCount;

// first do the forward pass in CPU
float* out = (float*)malloc(B * T * C * sizeof(float));
float* mean = (float*)malloc(B * T * sizeof(float));
Expand Down
5 changes: 0 additions & 5 deletions dev/cuda/matmul_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ OMP_NUM_THREADS=32 ./matmul_backward 1
#include <omp.h>
#include "common.h"

// ----------------------------------------------------------------------------
// CUDA / cuBLAS setup

static cublasHandle_t cublas_handle;

// ----------------------------------------------------------------------------
// CPU code reference

Expand Down
10 changes: 0 additions & 10 deletions dev/cuda/matmul_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,6 @@ OMP_NUM_THREADS=32 ./matmul_forward 3
#include <omp.h>
#include "common.h"

// ----------------------------------------------------------------------------
// CUDA setup

// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK
static cublasHandle_t cublas_handle;
static cublasLtHandle_t cublaslt_handle;
static size_t cublaslt_workspace_size = 32 * 1024 * 1024;
static void* cublaslt_workspace = NULL;
static cublasComputeType_t cublas_compute_type;

// ----------------------------------------------------------------------------
// CPU code reference

Expand Down
7 changes: 1 addition & 6 deletions dev/cuda/trimat_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ tri4
#include <cooperative_groups/reduce.h>
#include "common.h"

static cublasHandle_t cublas_handle;
static float* d_qkvr; // scratch for the cublas kernel

/* ** Chapter I - Introduction **
Expand Down Expand Up @@ -532,17 +531,13 @@ void trimul_gpu(int kernel_num,


int main(int argc, char **argv) {
srand(0);
setup_main();

int B = 8;
int T = 1024;
int C = 768;
int NH = 12;

int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
cublasCreate(&cublas_handle);

// create host memory of random numbers
float* out = (float*)malloc(B * NH * T * T * sizeof(float));
float* inp = make_random_float(B * T * 3 * C);
Expand Down

0 comments on commit ab95a11

Please sign in to comment.