diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 5274342edc..db574748cc 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -111,7 +111,9 @@ void compute_ref_x1(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t block_size_Y, - const size_t block_size_X) { + const size_t block_size_X, + const size_t scales_stride) +{ std::vector output_dbias_fp32(cols, 0); const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y; @@ -123,7 +125,7 @@ void compute_ref_x1(const ProcessingMethod processing_method, for (size_t jj = 0; jj < blocks_X; ++jj) { const size_t j_min = jj * block_size_X; const size_t j_max = std::min((jj + 1) * block_size_X, cols); - const size_t scale_idx = ii * blocks_X + jj; + const size_t scale_idx = ii * scales_stride + jj; scale_block( processing_method, input, act_input, output_c, output_dbias_fp32.data(), output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); @@ -146,13 +148,15 @@ void compute_ref_x2(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t block_size_Y, - const size_t block_size_X) { + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { compute_ref_x1( processing_method, input, act_input, output_rowwise, scales_rowwise, output_dbias, - rows, cols, 1, block_size_X); + rows, cols, 1, block_size_X, scales_stride_rowwise); compute_ref_x1( processing_method, input, act_input, output_colwise, scales_colwise, output_dbias, - rows, cols, block_size_Y, 1); + rows, cols, block_size_Y, 1, scales_stride_colwise); } /** @@ -177,9 +181,20 @@ void performTest_x1(const ProcessingMethod processing_method, const size_t block_size_rows = rowwise ? 1 : 32; const size_t block_size_cols = colwise ? 1 : 32; - const size_t blocks_Y = (rows + block_size_rows - 1) / block_size_rows; - const size_t blocks_X = (cols + block_size_cols - 1) / block_size_cols; - const size_t blocks_num = blocks_Y * blocks_X; + const size_t unpadded_blocks_Y = (rows + block_size_rows - 1) / block_size_rows; + const size_t unpadded_blocks_X = (cols + block_size_cols - 1) / block_size_cols; + + const size_t block_alignment_X = rowwise + ? scale_tensor_alignment_X_rowwise + : scale_tensor_alignment_X_colwise; + const size_t block_alignment_Y = rowwise + ? scale_tensor_alignment_Y_rowwise + : scale_tensor_alignment_Y_colwise; + + // Roundup to the nearest multiple + const size_t blocks_Y = ((unpadded_blocks_Y + block_alignment_Y - 1) / block_alignment_Y) * block_alignment_Y; + const size_t blocks_X = ((unpadded_blocks_X + block_alignment_X - 1) / block_alignment_X) * block_alignment_X; + const size_t scales_stride = blocks_X; Tensor input({ rows, cols }, itype); Tensor act_input({ rows, cols }, itype); @@ -254,16 +269,18 @@ void performTest_x1(const ProcessingMethod processing_method, rows, cols, block_size_rows, - block_size_cols); + block_size_cols, + scales_stride); auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); - if (rowwise) { - compare_e8m0_scaling_factors("scales", output_c.rowwise_cpu_scale_inv_ptr(), ref_output_scales.get(), blocks_num); - } - if (colwise) { - compare_e8m0_scaling_factors("scales", output_c.columnwise_cpu_scale_inv_ptr(), ref_output_scales.get(), blocks_num); - } + + const uint8_t * const gpu_scales_ptr = rowwise + ? output_c.rowwise_cpu_scale_inv_ptr() + : output_c.columnwise_cpu_scale_inv_ptr(); + + compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { auto [atol_dbias, rtol_dbias] = getTolerances(itype); @@ -294,10 +311,22 @@ void performTest_x2(const ProcessingMethod processing_method, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - const size_t blocks_Y = (rows + block_size_rows - 1) / block_size_rows; - const size_t blocks_X = (cols + block_size_cols - 1) / block_size_cols; - const size_t blocks_num_rowwise = rows * blocks_X; - const size_t blocks_num_colwise = blocks_Y * cols; + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t scales_stride_rowwise = blocks_X_rowwise; + const size_t scales_stride_colwise = blocks_X_colwise; Tensor input({ rows, cols }, itype); Tensor act_input({ rows, cols }, itype); @@ -306,8 +335,8 @@ void performTest_x2(const ProcessingMethod processing_method, std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols); std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols); - std::unique_ptr ref_scales_rowwise = std::make_unique(rows * blocks_X); - std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); std::unique_ptr ref_output_dbias = std::make_unique(cols); fillCase(&input, fill_case); @@ -376,15 +405,19 @@ void performTest_x2(const ProcessingMethod processing_method, rows, cols, block_size_rows, - block_size_cols); + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); auto [atol, rtol] = getTolerances(otype); compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), blocks_num_rowwise); + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise); compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), blocks_num_colwise); + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise); if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { auto [atol_dbias, rtol_dbias] = getTolerances(itype); @@ -397,12 +430,16 @@ void performTest_x2(const ProcessingMethod processing_method, } std::vector> matrix_sizes = { + {1, 16}, + {16, 48}, + {65, 96}, {128, 128}, {256, 256}, + {993, 512}, {768, 1024}, - // {256, 65536}, // {2048, 12288}, // {65536, 128}, + // {16384, 1632}, // {16384, 6144}, }; diff --git a/tests/cpp/operator/test_dequantize_mxfp8.cu b/tests/cpp/operator/test_dequantize_mxfp8.cu index 6b09c50366..c24a739b81 100644 --- a/tests/cpp/operator/test_dequantize_mxfp8.cu +++ b/tests/cpp/operator/test_dequantize_mxfp8.cu @@ -58,7 +58,8 @@ void compute_ref_x1(const InputType* input, const size_t rows, const size_t cols, const size_t block_size_Y, - const size_t block_size_X) + const size_t block_size_X, + const size_t scales_stride) { const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y; const size_t blocks_X = (cols + block_size_X - 1) / block_size_X; @@ -69,7 +70,7 @@ void compute_ref_x1(const InputType* input, for (size_t jj = 0; jj < blocks_X; ++jj) { const size_t j_min = jj * block_size_X; const size_t j_max = std::min((jj + 1) * block_size_X, cols); - const size_t scale_idx = ii * blocks_X + jj; + const size_t scale_idx = ii * scales_stride + jj; dequantize_block( input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols); } @@ -85,10 +86,12 @@ void compute_ref_x2(const InputType* input, const size_t rows, const size_t cols, const size_t block_size_Y, - const size_t block_size_X) + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { - compute_ref_x1(input, output_rowwise, scales_rowwise, rows, cols, 1, block_size_X); - compute_ref_x1(input, output_colwise, scales_colwise, rows, cols, block_size_Y, 1); + compute_ref_x1(input, output_rowwise, scales_rowwise, rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1(input, output_colwise, scales_colwise, rows, cols, block_size_Y, 1, scales_stride_colwise); } void generate_scales(fp8e8m0 * const scales_ref, @@ -170,9 +173,26 @@ void performTest_x1(const size_t rows, const size_t block_size_rows = rowwise ? 1 : 32; const size_t block_size_cols = colwise ? 1 : 32; - const size_t blocks_Y = (rows + block_size_rows - 1) / block_size_rows; - const size_t blocks_X = (cols + block_size_cols - 1) / block_size_cols; - const size_t blocks_num = blocks_Y * blocks_X; + + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; + + const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise; + const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise; Tensor input({ rows, cols }, itype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); @@ -183,7 +203,7 @@ void performTest_x1(const size_t rows, std::unique_ptr scales = std::make_unique(blocks_num); fill_tensor_data(input, scales.get(), scales.get(), rowwise, colwise, rows, cols, - blocks_num, blocks_num); + blocks_num_rowwise, blocks_num_colwise); nvte_dequantize(input.data(), output.data(), 0); @@ -201,7 +221,8 @@ void performTest_x1(const size_t rows, rows, cols, block_size_rows, - block_size_cols); + block_size_cols, + scales_stride); auto [atol, rtol] = getTolerances(otype); compareResults("output", output, ref_output.get(), true, atol, rtol); @@ -273,18 +294,32 @@ void performTest_x2(const size_t rows, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - const size_t blocks_Y = (rows + block_size_rows - 1) / block_size_rows; - const size_t blocks_X = (cols + block_size_cols - 1) / block_size_cols; - const size_t blocks_num_rowwise = rows * blocks_X; - const size_t blocks_num_colwise = blocks_Y * cols; + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t scales_stride_rowwise = blocks_X_rowwise; + const size_t scales_stride_colwise = blocks_X_colwise; + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; Tensor input({ rows, cols }, itype, true, true, NVTE_MXFP8_1D_SCALING); Tensor output({ rows, cols }, otype); std::unique_ptr ref_output_rowwise = std::make_unique(rows * cols); std::unique_ptr ref_output_colwise = std::make_unique(rows * cols); - std::unique_ptr ref_scales_rowwise = std::make_unique(rows * blocks_X); - std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_num_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_num_colwise); constexpr bool rowwise = true; constexpr bool colwise = true; @@ -305,7 +340,9 @@ void performTest_x2(const size_t rows, rows, cols, block_size_rows, - block_size_cols); + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); auto [atol, rtol] = getTolerances(otype); compareResults("output_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); @@ -313,14 +350,17 @@ void performTest_x2(const size_t rows, } std::vector> tensor_dims = { + {1, 16}, + {16, 48}, + {65, 96}, {128, 128}, {256, 256}, + {993, 512}, {768, 1024}, - // {256, 65536}, // {2048, 12288}, // {65536, 128}, + // {16384, 1632}, // {16384, 6144}, - // {2048, 16384}, }; std::vector> block_sizes = { diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 8238e9a1e6..a0b65318e4 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -586,8 +586,22 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - size_t N) { - for (int i = 0; i < N; i++){ + const size_t row_blocks, const size_t col_blocks, const size_t stride) +{ + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int idx = i * stride + j; + ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl + << "Mismatch: " << static_cast(test[idx]) << " vs " + << static_cast(ref[idx]) << " at index " << idx; + } + } +} + +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t N) +{ + for (int i = 0; i < N; i++) { ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl << "Mismatch: " << static_cast(test[i]) << " vs " << static_cast(ref[i]) << " at index " << i; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 82ec1facd1..9ab59dfd96 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -274,6 +274,20 @@ class Tensor { constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; +// [128,4] rowwise and [4,128] colwise alignment requirement +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + +inline size_t divide_round_up(const size_t N, const size_t M) { + return (N - 1 + M) / M; +} + +inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) { + return divide_round_up(N, M) * M; +} + template struct Numeric_Traits { static constexpr double minSubnorm = 1.0; @@ -403,7 +417,10 @@ void compareResults(const std::string &name, const float test, const float ref, void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - size_t N); + const size_t row_blocks, const size_t col_blocks, const size_t stride); +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t N); + std::pair getTolerances(const DType type); diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 6cd5abcceb..f9474363c7 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -97,7 +97,14 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype); void *dataPtr = reinterpret_cast(tensor.dptr); - NVTE_CHECK(isPointerAligned(dataPtr, 16), "Tensor data must be 16B aligned"); + + constexpr int TMA_gmem_alignment = 16; // Alignment of the global memory address + NVTE_CHECK(isPointerAligned(dataPtr, TMA_gmem_alignment), + "Tensor data pointer must be 16B aligned"); + + const int TMA_needed_size = TMA_gmem_alignment / type_size; + NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size, + "-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX); // Create the tensor descriptor. NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index fa548f9a9e..f4999e8cdb 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -420,6 +420,12 @@ struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; +// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + size_t typeToSize(const DType type); void CheckNoopTensor(const Tensor &t, const std::string &name); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index afef29340f..d713738a4e 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -261,6 +261,9 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_chunk_offset_Y = scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; const int scales_rowwise_chunk_offset_X = @@ -289,6 +292,8 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { const int buff = iter % MXFP8_BUFFERS_NUM; const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + if (next_iter < MXFP8_ITERATIONS) { const int next_buff = next_iter % MXFP8_BUFFERS_NUM; const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; @@ -321,6 +326,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const int shmem_offset_x = thread_offset_X_rowwise; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); if constexpr (IS_DACT) { act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); @@ -331,6 +340,9 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + float elt = static_cast(in.data.elt[j]); if constexpr (IS_ACT || IS_DACT) { elt = OP(elt, {}); @@ -339,10 +351,14 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) elt *= static_cast(act_in.data.elt[j]); } if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + if (!out_of_bounds) { + partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + } } in_compute[j] = elt; - thread_amax = fmaxf(thread_amax, fabsf(elt)); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } } __builtin_assume(block_amax >= 0); @@ -375,11 +391,16 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) } if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); float in_compute[SCALE_DIM_Y]; float amax = 0; #pragma unroll for (int i = 0; i < SCALE_DIM_Y; ++i) { + const size_t row = row_base + i; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + float elt = static_cast(in_sh[buff][i][tid_colwise_X]); if constexpr (IS_ACT || IS_DACT) { elt = OP(elt, {}); @@ -388,10 +409,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) elt *= static_cast(act_in_sh[buff][i][tid_colwise_X]); } if constexpr (IS_DBIAS) { - partial_dbias_colwise[chunk_X] += elt; + if (!out_of_bounds) { + partial_dbias_colwise[chunk_X] += elt; + } } in_compute[i] = elt; - if (isfinite(elt)) { + if (!out_of_bounds) { amax = fmaxf(amax, fabsf(elt)); } } @@ -468,6 +491,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) #pragma unroll for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { Vec other_row_dbias; + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + + const int left_bound = dbias_rowwise_offset_X; + const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + #pragma unroll for (int i = 0; i < Y; ++i) { other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); @@ -476,9 +505,16 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; } } - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; - partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); + + // Vectorized store when all elements are inside the boundaries + if (right_bound < cols) { + partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + // Element-by-element store when some elements cross the boundaries + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, + in_bound_elts_count); + } } } } else { @@ -486,7 +522,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; - dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + } } } } @@ -908,10 +947,6 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; - const int TMA_needed_size = 16 / typeToSize(output->data.dtype); - NVTE_CHECK(cols % TMA_needed_size == 0, "Shape not supported. Expected multiple of " + - std::to_string(TMA_needed_size) + ", got " + - std::to_string(cols)); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); @@ -949,8 +984,10 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T if constexpr (IS_DACT) { create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, FP8_SHMEM_DIM_X, sizeof(IType)); - } create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, sizeof(OType)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, sizeof(OType)); cast_fp8_2D_kernel <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, @@ -995,11 +1032,27 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); - const size_t scale_stride_rowwise = DIVUP(cols, scale_dim_X_rowwise); - const size_t scale_stride_colwise = cols; - const bool isFullTile = (rows % MXFP8_CHUNK_DIM_Y == 0) && (cols % MXFP8_CHUNK_DIM_X == 0); - NVTE_CHECK(isFullTile, "Only full tiles are supported."); + const size_t unpadded_scales_Y_rowwise = rows; + const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise); + const size_t unpadded_scales_X_colwise = cols; + + const size_t scales_Y_rowwise = + DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * + scale_tensor_alignment_Y_rowwise; + const size_t scales_X_rowwise = + DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * + scale_tensor_alignment_X_rowwise; + const size_t scales_Y_colwise = + DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * + scale_tensor_alignment_Y_colwise; + const size_t scales_X_colwise = + DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * + scale_tensor_alignment_X_colwise; + + const size_t scale_stride_rowwise = scales_X_rowwise; + const size_t scale_stride_colwise = scales_X_colwise; e8m0_t *const scales_rowwise_ptr = use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index afffd290e5..59251f1e61 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -279,13 +279,29 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); - NVTE_CHECK(cols % 32 == 0, "Tensor column dimension must be a multiple of 32."); + const size_t unpadded_scales_Y_rowwise = rows; + const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise); + const size_t unpadded_scales_X_colwise = cols; + + const size_t scales_Y_rowwise = + DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * + scale_tensor_alignment_Y_rowwise; + const size_t scales_X_rowwise = + DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * + scale_tensor_alignment_X_rowwise; + const size_t scales_Y_colwise = + DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * + scale_tensor_alignment_Y_colwise; + const size_t scales_X_colwise = + DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * + scale_tensor_alignment_X_colwise; const e8m0_t *const scales_ptr = use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) : reinterpret_cast(input.columnwise_scale_inv.dptr); - const size_t scales_stride = use_rowwise_scaling ? DIVUP(cols, scale_dim_X_rowwise) : cols; + const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index e1605e1f9e..63ce369892 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -965,7 +965,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { } __device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { - return exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); + return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); } } // namespace transformer_engine