@@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
7070 const size_t scale_stride_y, const size_t scale_t_stride_x,
7171 const size_t scale_t_stride_y, const float epsilon,
7272 const __grid_constant__ CUtensorMap tensor_map_output_t ,
73- bool pow_2_scaling) {
73+ bool pow_2_scaling, const float * noop_ptr ) {
7474 using IVec = Vec<IType, THREAD_TILE_DIM_X>;
7575 using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
7676 using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
7777
78+ if (noop_ptr != nullptr && noop_ptr[0 ] == 1 .0f ) {
79+ return ;
80+ }
81+
7882 // shared mem for amax reduction in entire block, each warp produces one amax, there are
7983 // NUM_WARPS_IN_BLOCK amax to reduce
8084 __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
@@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
249253 CType* const tile_scales_inv_c, CType* const tile_scales_inv_t , const size_t row_length,
250254 const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
251255 const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
252- bool pow_2_scaling) {
256+ bool pow_2_scaling, const float * noop_ptr ) {
253257 using IVec = Vec<IType, THREAD_TILE_DIM_X>;
254258 using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
255259 using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
256260
261+ if (noop_ptr != nullptr && noop_ptr[0 ] == 1 .0f ) {
262+ return ;
263+ }
264+
257265 // shared mem for amax reduction in entire block, each warp produces one amax, there are
258266 // NUM_WARPS_IN_BLOCK amax to reduce
259267 __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
@@ -473,7 +481,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
473481 SimpleTensor& scale_inv_t , SimpleTensor& output,
474482 SimpleTensor& output_t , const float epsilon,
475483 const bool return_transpose, const bool pow_2_scale,
476- cudaStream_t stream) {
484+ const SimpleTensor& noop_tensor, cudaStream_t stream) {
477485 NVTE_API_CALL (quantize_transpose_square_blockwise);
478486 checkCuDriverContext (stream);
479487
@@ -494,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
494502 size_t scale_t_stride_x = 0 ;
495503 size_t scale_t_stride_y = 0 ;
496504
505+ const float * noop_ptr = reinterpret_cast <const float *>(noop_tensor.dptr );
506+
497507 if (return_transpose) {
498508 NVTE_CHECK (output_t .shape .size () == input.shape .size (),
499509 " output_t must have same number of dimensions as input." );
@@ -541,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
541551 reinterpret_cast <float *>(scale_inv.dptr ),
542552 reinterpret_cast <float *>(scale_inv_t .dptr ), row_length, num_rows,
543553 scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
544- tensor_map_output_trans, pow_2_scale);
554+ tensor_map_output_trans, pow_2_scale, noop_ptr );
545555 } else {
546556 block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose , float , InputType,
547557 OutputType>
@@ -552,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
552562 reinterpret_cast <float *>(scale_inv.dptr ),
553563 reinterpret_cast <float *>(scale_inv_t .dptr ), row_length, num_rows,
554564 scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
555- pow_2_scale);
565+ pow_2_scale, noop_ptr );
556566 } // full-tile
557567 ) // return_transpose
558568 ) // OutputType
0 commit comments