diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 3b91cebd85d60e..2cc18b29c4be62 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -486,6 +487,10 @@ ilpReduce(index_t shift, } offset = size - last + threadIdx.x; + if (offset < 0) { + // Ensure offset >= 0 + offset += round_up(-offset, blockDim.x); + } // Epilogue for (; offset < size; offset += blockDim.x) threadVal = r(threadVal, data[offset]); @@ -543,6 +548,10 @@ WriteFpropResultsVectorized( } offset = size - last + threadIdx.x; + if (offset < 0) { + // Ensure offset >= 0 + offset += round_up(-offset, blockDim.x); + } // handle the tail for (; offset < size; offset += blockDim.x) { output[offset] = epilogue(input[offset]); @@ -603,6 +612,10 @@ WriteBpropResultsVectorized( } offset = size - last + threadIdx.x; + if (offset < 0) { + // Ensure offset >= 0 + offset += round_up(-offset, blockDim.x); + } for (; offset < size; offset += blockDim.x) { gradInput[offset] = epilogue(gradOutput[offset], output[offset]); }