Skip to content

Commit 5eb58f9

Browse files
committed
fix for fp8 blockwise recipe
Signed-off-by: zhongboz <[email protected]>
1 parent 0064fbd commit 5eb58f9

File tree

4 files changed

+35
-15
lines changed

4 files changed

+35
-15
lines changed

transformer_engine/common/transpose/cast_transpose.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
2727
SimpleTensor &scale_inv_t, SimpleTensor &output,
2828
SimpleTensor &output_t, const float epsilon,
2929
const bool return_transpose, const bool pow_2_scale,
30-
cudaStream_t stream);
30+
const SimpleTensor &noop_tensor, cudaStream_t stream);
3131

3232
// enum class for rowwise usage
3333
enum class FP8BlockwiseRowwiseOption {
@@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
5959
SimpleTensor &output_t, const float epsilon,
6060
FP8BlockwiseRowwiseOption rowwise_option,
6161
FP8BlockwiseColumnwiseOption columnwise_option,
62-
const bool pow_2_scale, cudaStream_t stream);
62+
const bool pow_2_scale, const SimpleTensor &noop_tensor,
63+
cudaStream_t stream);
6364

6465
} // namespace transformer_engine::detail
6566

transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
7070
const size_t scale_stride_y, const size_t scale_t_stride_x,
7171
const size_t scale_t_stride_y, const float epsilon,
7272
const __grid_constant__ CUtensorMap tensor_map_output_t,
73-
bool pow_2_scaling) {
73+
bool pow_2_scaling, const float* noop_ptr) {
7474
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
7575
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
7676
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
7777

78+
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
79+
return;
80+
}
81+
7882
// shared mem for amax reduction in entire block, each warp produces one amax, there are
7983
// NUM_WARPS_IN_BLOCK amax to reduce
8084
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
@@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
249253
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
250254
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
251255
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
252-
bool pow_2_scaling) {
256+
bool pow_2_scaling, const float* noop_ptr) {
253257
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
254258
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
255259
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
256260

261+
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
262+
return;
263+
}
264+
257265
// shared mem for amax reduction in entire block, each warp produces one amax, there are
258266
// NUM_WARPS_IN_BLOCK amax to reduce
259267
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
@@ -473,7 +481,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
473481
SimpleTensor& scale_inv_t, SimpleTensor& output,
474482
SimpleTensor& output_t, const float epsilon,
475483
const bool return_transpose, const bool pow_2_scale,
476-
cudaStream_t stream) {
484+
const SimpleTensor& noop_tensor, cudaStream_t stream) {
477485
NVTE_API_CALL(quantize_transpose_square_blockwise);
478486
checkCuDriverContext(stream);
479487

@@ -494,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
494502
size_t scale_t_stride_x = 0;
495503
size_t scale_t_stride_y = 0;
496504

505+
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
506+
497507
if (return_transpose) {
498508
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
499509
"output_t must have same number of dimensions as input.");
@@ -541,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
541551
reinterpret_cast<float*>(scale_inv.dptr),
542552
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
543553
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
544-
tensor_map_output_trans, pow_2_scale);
554+
tensor_map_output_trans, pow_2_scale, noop_ptr);
545555
} else {
546556
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
547557
OutputType>
@@ -552,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
552562
reinterpret_cast<float*>(scale_inv.dptr),
553563
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
554564
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
555-
pow_2_scale);
565+
pow_2_scale, noop_ptr);
556566
} // full-tile
557567
) // return_transpose
558568
) // OutputType

transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
172172
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
173173
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
174174
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
175-
const bool pow_2_scaling) {
175+
const bool pow_2_scaling, const float* noop_ptr) {
176+
// skip execution if noop
177+
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
178+
return;
179+
}
180+
176181
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
177182
bool return_columnwise_gemm_ready =
178183
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
@@ -520,7 +525,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
520525
SimpleTensor& output_t, const float epsilon,
521526
FP8BlockwiseRowwiseOption rowwise_option,
522527
FP8BlockwiseColumnwiseOption columnwise_option,
523-
const bool pow2_scale, cudaStream_t stream) {
528+
const bool pow2_scale, const SimpleTensor& noop_tensor,
529+
cudaStream_t stream) {
524530
NVTE_API_CALL(quantize_transpose_vector_blockwise);
525531

526532
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
@@ -585,6 +591,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
585591
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
586592
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
587593

594+
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
595+
588596
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
589597
input.dtype, InputType,
590598

@@ -613,7 +621,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
613621
reinterpret_cast<float*>(scale_inv.dptr),
614622
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
615623
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
616-
columnwise_option, pow2_scale);) // kAligned
624+
columnwise_option, pow2_scale, noop_ptr);) // kAligned
617625
) // OutputType
618626
) // InputType
619627
NVTE_CHECK_CUDA(cudaGetLastError());

transformer_engine/common/util/cast_kernels.cuh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
14211421
quantize_transpose_square_blockwise(
14221422
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
14231423
output_tensor->data, output_tensor->columnwise_data, epsilon,
1424-
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
1424+
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
1425+
/*noop_tensor=*/noop_tensor.data, stream);
14251426
break;
14261427
}
14271428
case NVTE_BLOCK_SCALING_1D: {
@@ -1449,10 +1450,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
14491450
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
14501451
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
14511452
}
1452-
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
1453-
output_tensor->columnwise_scale_inv, output_tensor->data,
1454-
output_tensor->columnwise_data, epsilon, rowwise_option,
1455-
columnwise_option, force_pow_2_scales, stream);
1453+
quantize_transpose_vector_blockwise(
1454+
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
1455+
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
1456+
columnwise_option, force_pow_2_scales, noop_tensor.data, stream);
14561457
break;
14571458
}
14581459
default:

0 commit comments

Comments
 (0)