Skip to content

Commit eec3190

Browse files
author
Changpeng Fang
committed
Tuning the inline and unroll to reduce the scratch usage
Summary: 1. remove the noinline attribute for AllReduceThreeKernel; 2. change AUTPUNROLL for tree functions to 1 or 2; Combining 1 and 2 will reduce the scratch usage from 1256 to 952
1 parent de25e4c commit eec3190

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/collectives/device/all_reduce.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
102102
#endif
103103
}
104104

105-
template<int UNROLL, class FUNC, typename T>
106-
__attribute__((noinline))
105+
template<int UNUSED, class FUNC, typename T>
107106
__device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
108107
const int tid = threadIdx.x;
109108
const int nthreads = blockDim.x;
@@ -122,7 +121,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
122121

123122
do {
124123
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
125-
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
124+
ncclPrimitives<1, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount);
126125
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
127126
// Up
128127
ssize_t offset = gridOffset + bid*chunkSize;
@@ -139,7 +138,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
139138

140139
do {
141140
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
142-
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
141+
ncclPrimitives<1, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount);
143142
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
144143
// Down
145144
ssize_t offset = gridOffset + bid*chunkSize;

0 commit comments

Comments
 (0)