Skip to content

Commit 589ead1

Browse files
committed
updated code to adapt to latest changes
1 parent c3a3b9d commit 589ead1

File tree

2 files changed

+9
-22
lines changed

2 files changed

+9
-22
lines changed

profile_gpt2cu.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@
5050
# model config
5151
CLS_START = -1
5252
CLS_NUM = 6
53-
NORM_ID = 44
54-
ADAM_ID = 45
5553
N_LAYERS = 12
5654

5755
summaries = defaultdict(lambda: 0.0)
@@ -132,7 +130,7 @@
132130
# the classifier part, counts only once
133131
pass_name = "cls"
134132
phase = "bwd"
135-
elif "adamw" in kernel:
133+
elif "adamw" in kernel or "global_norm" in kernel:
136134
# encoder layer or adam
137135
pass_name = "opt"
138136
# before the first optimizer run, we create weight copies.

train_gpt2.cu

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,34 +1201,23 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
12011201
}
12021202

12031203
template<class T>
1204-
__global__ void norm_kernel(float* out, const T* data, size_t count) {
1204+
__global__ void global_norm_kernel(float* out, const T* data, size_t count) {
12051205
// we want as few atomics as possible, so each block tries to do
12061206
// the maximum amount of work (so no fixed chunk, but instead iterating
12071207
// until we run out of data), and then we reduce inside the block
12081208
// and finally have just one atomic per block.
1209-
namespace cg = cooperative_groups;
1210-
cg::thread_block block = cg::this_thread_block();
1211-
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
1212-
1213-
__shared__ float block_result[32];
1214-
1215-
// out will be updated atomically from all thread blocks
1209+
// out will be updated atomically from all thread blocks. It is a float, so the
1210+
// atomic op is unproblematic
12161211
size_t index = threadIdx.x + blockDim.x * blockIdx.x;
12171212
size_t grid_width = blockDim.x * gridDim.x;
12181213
float accumulator = 0.f;
12191214
for(size_t i = index; i < count; i += grid_width) {
12201215
accumulator += (float)data[i] * (float)data[i];
12211216
}
12221217
// warp-level reduce
1223-
float warp_result = cg::reduce(warp, accumulator, cg::plus<float>{});
1224-
block_result[warp.meta_group_rank()] = warp_result;
1225-
block.sync();
1226-
if(warp.meta_group_rank() == 0) {
1227-
float gather = warp.thread_rank() < warp.meta_group_size() ? block_result[warp.thread_rank()] : 0.f;
1228-
float block_sum = cg::reduce(warp, gather, cg::plus<float>{});
1229-
if(warp.thread_rank() == 0) {
1230-
atomicAdd(out, block_sum);
1231-
}
1218+
float block_sum = blockReduce<warpReduceSum>(accumulator);
1219+
if(threadIdx.x == 0) {
1220+
atomicAdd(out, block_sum);
12321221
}
12331222
}
12341223

@@ -1716,9 +1705,9 @@ void global_norm(float* out, const T* values, size_t count) {
17161705
// one block too many is catastrophic, since it only can start once all the other
17171706
// blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512
17181707
// on all gpus, so the division really is going to be exact.
1719-
const int grid_size = cuda_threads_per_SM * cuda_num_SMs / block_size;
1708+
const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size;
17201709
assert(grid_size > 0); // gives a better error than letting the call below fail
1721-
norm_kernel<<<grid_size, block_size>>>(out, values, count);
1710+
global_norm_kernel<<<grid_size, block_size>>>(out, values, count);
17221711
cudaCheck(cudaGetLastError());
17231712
}
17241713

0 commit comments

Comments
 (0)