Skip to content

Commit

Permalink
updated code to adapt to latest changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed May 18, 2024
1 parent c3a3b9d commit 589ead1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 22 deletions.
4 changes: 1 addition & 3 deletions profile_gpt2cu.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
# model config
CLS_START = -1
CLS_NUM = 6
NORM_ID = 44
ADAM_ID = 45
N_LAYERS = 12

summaries = defaultdict(lambda: 0.0)
Expand Down Expand Up @@ -132,7 +130,7 @@
# the classifier part, counts only once
pass_name = "cls"
phase = "bwd"
elif "adamw" in kernel:
elif "adamw" in kernel or "global_norm" in kernel:
# encoder layer or adam
pass_name = "opt"
# before the first optimizer run, we create weight copies.
Expand Down
27 changes: 8 additions & 19 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1201,34 +1201,23 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
}

template<class T>
__global__ void norm_kernel(float* out, const T* data, size_t count) {
__global__ void global_norm_kernel(float* out, const T* data, size_t count) {
// we want as few atomics as possible, so each block tries to do
// the maximum amount of work (so no fixed chunk, but instead iterating
// until we run out of data), and then we reduce inside the block
// and finally have just one atomic per block.
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);

__shared__ float block_result[32];

// out will be updated atomically from all thread blocks
// out will be updated atomically from all thread blocks. It is a float, so the
// atomic op is unproblematic
size_t index = threadIdx.x + blockDim.x * blockIdx.x;
size_t grid_width = blockDim.x * gridDim.x;
float accumulator = 0.f;
for(size_t i = index; i < count; i += grid_width) {
accumulator += (float)data[i] * (float)data[i];
}
// warp-level reduce
float warp_result = cg::reduce(warp, accumulator, cg::plus<float>{});
block_result[warp.meta_group_rank()] = warp_result;
block.sync();
if(warp.meta_group_rank() == 0) {
float gather = warp.thread_rank() < warp.meta_group_size() ? block_result[warp.thread_rank()] : 0.f;
float block_sum = cg::reduce(warp, gather, cg::plus<float>{});
if(warp.thread_rank() == 0) {
atomicAdd(out, block_sum);
}
float block_sum = blockReduce<warpReduceSum>(accumulator);
if(threadIdx.x == 0) {
atomicAdd(out, block_sum);
}
}

Expand Down Expand Up @@ -1716,9 +1705,9 @@ void global_norm(float* out, const T* values, size_t count) {
// one block too many is catastrophic, since it only can start once all the other
// blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512
// on all gpus, so the division really is going to be exact.
const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size;
const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;
assert(grid_size > 0); // gives a better error than letting the call below fail
norm_kernel<<<grid_size, block_size>>>(out, values, count);
global_norm_kernel<<<grid_size, block_size>>>(out, values, count);
cudaCheck(cudaGetLastError());
}

Expand Down

0 comments on commit 589ead1

Please sign in to comment.