Skip to content

Commit 060811c

Browse files
[Common] Fix checks in quantize_transpose_vector_blockwise_fp4 (#2299)
fix checks in unoptimized non-rht fp4 quantize kernel Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent 6273ced commit 060811c

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -718,13 +718,11 @@ void quantize_transpose_vector_blockwise_fp4(
718718
// raise error if pow2_scale is true
719719
NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now");
720720

721-
if (!return_identity && !return_transpose) {
722-
return;
723-
}
721+
NVTE_CHECK(return_identity || return_transpose,
722+
"At least one of return_identity or return_transpose must be true.");
724723

725-
if (use_2d_quantization && !return_identity) {
726-
return;
727-
}
724+
NVTE_CHECK(return_identity || !use_2d_quantization,
725+
"2D block quantization is only supported when return_identity is true.");
728726

729727
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
730728
size_t num_elements = row_length;
@@ -777,7 +775,7 @@ void quantize_transpose_vector_blockwise_fp4(
777775
input.dtype, InputType,
778776

779777
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(
780-
output.dtype, 2, OutputType,
778+
return_identity ? output.dtype : output_t.dtype, 2, OutputType,
781779

782780
dim3 grid(num_blocks_x, num_blocks_y, 1);
783781

0 commit comments

Comments
 (0)