From 536e03ed6ad288cb8477b01c7f1fd6752ac01743 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Sat, 18 May 2024 02:51:46 +0800 Subject: [PATCH 01/34] save --- .../group/gemm/impl/int4_dequantize_xe.hpp | 245 ++++++++++++------ include/subgroup/tile/impl/load_xe.hpp | 109 ++++---- include/subgroup/tile/impl/payload_xe.hpp | 3 +- .../int4_dequantization_bias/main_client.cpp | 170 ++++++++---- 4 files changed, 340 insertions(+), 187 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 528911fcf..cdeb445e5 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -115,7 +115,6 @@ class gemm_t< is_col_major_a ? tdesc_update_dir::y_dir : tdesc_update_dir::x_dir; static constexpr tdesc_update_dir update_dir_b = is_col_major_b ? tdesc_update_dir::x_dir : tdesc_update_dir::y_dir; - static_assert(!is_col_major_b, "only support MatB row-major for now"); static_assert( (!is_local_a) && (!is_local_b), "only support from global memory for now"); @@ -176,12 +175,20 @@ class gemm_t< // note: plane format, row-major // note: 4bit x 2, row-major - using matB_tile_desc_t = subgroup::tile_desc_t< - tile_size_x_b / pack_ratio, - tile_size_y_b, - block_size_x_b / pack_ratio, - block_size_y_b, - reg_layout::tiled>; + using matB_tile_desc_t = std::conditional_t< + is_col_major_b, + subgroup::tile_desc_t< + tile_size_x_b, + tile_size_y_b / pack_ratio, + block_size_x_b, + block_size_y_b / pack_ratio, + reg_layout::tiled>, + subgroup::tile_desc_t< + tile_size_x_b / pack_ratio, + tile_size_y_b, + block_size_x_b / pack_ratio, + block_size_y_b, + reg_layout::tiled>>; using matB_t = subgroup::tile_t; using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, @@ -254,12 +261,22 @@ class gemm_t< scale_tile_desc_t, subgroup::msg_type_v, arch_tag>; - using zero_pt_tile_desc_t = subgroup::tile_desc_t< - tile_size_x_b / pack_ratio, - tile_size_y_zero_pt, - block_size_x_b / pack_ratio, - block_size_y_zero_pt, - reg_layout::tiled>; + + using zero_pt_tile_desc_t = std::conditional_t< + is_col_major_b, + subgroup::tile_desc_t< + tile_size_x_b, + (tile_size_y_zero_pt + pack_ratio - 1) / pack_ratio, + block_size_x_b, + (block_size_y_zero_pt + pack_ratio - 1) / pack_ratio, + reg_layout::tiled>, + subgroup::tile_desc_t< + tile_size_x_b / pack_ratio, + tile_size_y_zero_pt, + block_size_x_b / pack_ratio, + block_size_y_zero_pt, + reg_layout::tiled>>; + using zero_pt_t = subgroup::tile_t; using zero_pt_payload_t = subgroup::mem_payload_t< mem_desc_zero_pt_t, @@ -495,6 +512,7 @@ class gemm_t< matB, matB_payload); subgroup::tile_load( scale, scale_payload); + dump_mat(scale); if constexpr ( compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { subgroup::tile_load( @@ -548,6 +566,9 @@ class gemm_t< } subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); + if constexpr (is_col_major_b) { + tile_transpose(matB_acc); + } SW_BARRIER(); tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); SW_BARRIER(); @@ -572,12 +593,11 @@ class gemm_t< matB_acc_t& matB_acc, matB_t& matB, scale_t& scale, - zero_pt_t& zero_pt) { + [[maybe_unused]] zero_pt_t& zero_pt) { // no tail, because this is matB constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; - constexpr uint32_t block_b_y_per_scale = dequant_s / block_size_y_b; #pragma unroll for (uint32_t i = 0; i < num_block_y; ++i) { #pragma unroll @@ -587,39 +607,12 @@ class gemm_t< .xetla_select( block_id * matB_t::block_elems) .xetla_format(); - int scale_block_id = (i / block_b_y_per_scale * num_block_x + j); - auto scale_vec = scale.reg.xetla_select( - scale_block_id * scale_t::block_size_x); auto dst_blk = matB_acc.reg.xetla_select( block_id * matB_acc_t::block_elems); // 2: int8 includes 2 4bits data. - xetla_vector cvt_blk; - xetla_vector cvt_blk_i32; - if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { - auto zero_pt_vec = zero_pt.reg - .xetla_select( - scale_block_id * zero_pt_t::block_size_x) - .xetla_format(); - cvt_blk.xetla_select(0) = matB_blk & 0x0f; - cvt_blk.xetla_select(1) = matB_blk >> 4; - xetla_vector zero_pt_sub; - zero_pt_sub.xetla_select(0) = - zero_pt_vec & 0x0f; - zero_pt_sub.xetla_select(1) = zero_pt_vec >> 4; - xetla_vector zero_pt_blk; -#pragma unroll - for (uint32_t row = 0; row < block_size_y_b; row++) { - zero_pt_blk.xetla_select(row * block_size_x_b) - .xetla_format() = - zero_pt_sub.xetla_format() + int8_t(1); - } - cvt_blk_i32 = - (cvt_blk.xetla_format() - - zero_pt_blk.xetla_format()); - } if constexpr ( compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { xetla_vector cvt_blk_i8; @@ -632,50 +625,130 @@ class gemm_t< matB_blk.xetla_format() >> 4; cvt_blk_i32 = (cvt_blk_i8.xetla_format()); } - if constexpr (compute_policy::mma_engine == mma_engine::xmx) { - constexpr uint32_t vnni_rows = sizeof(uint32_t) / sizeof(dtype_mma_b); - xetla_vector - temp_blk; - temp_blk.xetla_select(0) = - cvt_blk_i32; #pragma unroll - for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { -#pragma unroll - for (uint32_t row = 0; row < vnni_rows; row++) { - temp_blk.xetla_select( - row + block_size_x_b * k * vnni_rows) = - temp_blk.xetla_select( - (k + row) * block_size_x_b * vnni_rows); - } - } - - xetla_vector scale_blk; -#pragma unroll - for (uint32_t row = 0; row < vnni_rows; row++) { - scale_blk.xetla_select(row) = scale_vec; - } - -#pragma unroll - for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { - dst_blk.xetla_select( - k * block_size_x_b) = - temp_blk.xetla_select( - k * block_size_x_b * vnni_rows) * - scale_blk; - } - } else { -#pragma unroll - for (uint32_t k = 0; k < block_size_y_b; k++) { - dst_blk.xetla_select(k * block_size_x_b) = - cvt_blk_i32.xetla_select( - k * block_size_x_b) * - scale_vec; - } + for (uint32_t k = 0; k < matB_acc_t::block_elems; k += dequant_s) { + dst_blk.xetla_select(k) = + cvt_blk_i32.xetla_select(k) * + scale.reg.xetla_select<1, 1>(k / dequant_s); } } } } + // inline void dequantize( + // matB_acc_t& matB_acc, + // matB_t& matB, + // scale_t& scale, + // zero_pt_t& zero_pt) { + // // no tail, because this is matB + // constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; + // constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; + + // constexpr uint32_t block_b_y_per_scale = dequant_s / block_size_y_b; + // constexpr uint32_t block_b_x_per_scale = dequant_s / block_size_x_b; + // #pragma unroll + // for (uint32_t i = 0; i < num_block_y; ++i) { + // #pragma unroll + // for (uint32_t j = 0; j < num_block_x; ++j) { + // int block_id = (i * num_block_x + j); + // auto matB_blk = matB.reg + // .xetla_select( + // block_id * matB_t::block_elems) + // .xetla_format(); + // int scale_block_id = (i / block_b_y_per_scale * num_block_x + j); + // auto scale_vec = scale.reg.xetla_select( + // scale_block_id * scale_t::block_size_x); + // auto dst_blk = matB_acc.reg.xetla_select( + // block_id * matB_acc_t::block_elems); + + // // 2: int8 includes 2 4bits data. + // xetla_vector cvt_blk; + + // xetla_vector cvt_blk_i32; + // if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { + // auto zero_pt_vec = zero_pt.reg + // .xetla_select( + // scale_block_id * + // zero_pt_t::block_size_x) + // .xetla_format(); + // cvt_blk.xetla_select(0) = matB_blk & + // 0x0f; cvt_blk.xetla_select(1) = matB_blk + // >> 4; xetla_vector zero_pt_sub; + // zero_pt_sub.xetla_select(0) = + // zero_pt_vec & 0x0f; + // zero_pt_sub.xetla_select(1) = zero_pt_vec + // >> 4; xetla_vector + // zero_pt_blk; + // #pragma unroll + // for (uint32_t row = 0; row < block_size_y_b; row++) { + // zero_pt_blk.xetla_select(row * + // block_size_x_b) + // .xetla_format() = + // zero_pt_sub.xetla_format() + int8_t(1); + // } + // cvt_blk_i32 = + // (cvt_blk.xetla_format() - + // zero_pt_blk.xetla_format()); + // } + // if constexpr ( + // compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + // xetla_vector cvt_blk_i8; + // cvt_blk_i8.xetla_select(0) = matB_blk & + // 0x0f; cvt_blk_i8.xetla_select(0) = + // cvt_blk_i8.xetla_select(0) << 4; + // cvt_blk_i8.xetla_select(0) = + // cvt_blk_i8.xetla_select(0) >> 4; + // cvt_blk_i8.xetla_select(1) = + // matB_blk.xetla_format() >> 4; + // cvt_blk_i32 = (cvt_blk_i8.xetla_format()); + // } + // if constexpr (compute_policy::mma_engine == mma_engine::xmx) { + // constexpr uint32_t vnni_rows = sizeof(uint32_t) / + // sizeof(dtype_mma_b); xetla_vector + // temp_blk; + // temp_blk.xetla_select(0) = + // cvt_blk_i32; + + // #pragma unroll + // for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { + // #pragma unroll + // for (uint32_t row = 0; row < vnni_rows; row++) { + // temp_blk.xetla_select( + // row + block_size_x_b * k * vnni_rows) = + // temp_blk.xetla_select( + // (k + row) * block_size_x_b * vnni_rows); + // } + // } + + // xetla_vector scale_blk; + // #pragma unroll + // for (uint32_t row = 0; row < vnni_rows; row++) { + // scale_blk.xetla_select(row) = + // scale_vec; + // } + + // #pragma unroll + // for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { + // dst_blk.xetla_select( + // k * block_size_x_b) = + // temp_blk.xetla_select( + // k * block_size_x_b * vnni_rows) * + // scale_blk; + // } + // } else { + // #pragma unroll + // for (uint32_t k = 0; k < block_size_y_b; k++) { + // dst_blk.xetla_select(k * block_size_x_b) = + // cvt_blk_i32.xetla_select( + // k * block_size_x_b) * + // scale_vec; + // } + // } + // } + // } + // } /// @brief Updates tile base descriptor based on the tid. __XETLA_API static void update_sg_tile_tdesc( arguments_t& args, @@ -685,9 +758,15 @@ class gemm_t< int32_t tile_offset_m = sg_idy * sg_tile_m; args.matA_base_desc.update_coord_y(tile_offset_m); - args.matB_base_desc.update_coord_x(tile_offset_n / pack_ratio); - args.scale_base_desc.update_coord_x(tile_offset_n); - args.zero_pt_base_desc.update_coord_x(tile_offset_n / pack_ratio); + if constexpr (is_col_major_b) { + args.matB_base_desc.update_coord_x(tile_offset_n); + args.scale_base_desc.update_coord_x(tile_offset_n); + args.zero_pt_base_desc.update_coord_x(tile_offset_n); + } else { + args.matB_base_desc.update_coord_x(tile_offset_n / pack_ratio); + args.scale_base_desc.update_coord_x(tile_offset_n); + args.zero_pt_base_desc.update_coord_x(tile_offset_n / pack_ratio); + } } }; /// @} xetla_gemm diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 216a57d96..71ba2cb5a 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -469,63 +469,66 @@ tile_load(tile_t& tile, payload_t& payload) { uint32_t offset_x = j * tile_desc::block_size_x; auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); -#pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; - sub_block_y += num_channel) { - xetla_vector reg_tmp = 0; - uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) - : offset_x * sizeof(dtype) + - (offset_y + 0) * payload.pitch_in_bytes; - - const uint32_t sub_block_offset_x = payload.base_x + offset_x + 0; - const uint32_t sub_block_offset_y = - payload.base_y + offset_y + sub_block_y; - const auto offset_ch_dim = - payload_t::trans ? sub_block_offset_x : sub_block_offset_y; - const auto size_ch_dim = - payload_t::trans ? payload.width_in_elems : payload.height_in_elems; - - xetla_mask pred = offset_ch_dim + num_channel > size_ch_dim - ? (xetla_vector_gen(offset_ch_dim, 1) < - size_ch_dim) - : 1; + // #pragma unroll + // for (uint32_t sub_block_y = 0; sub_block_y < + // tile_desc::block_size_x; + // sub_block_y += num_channel) { + uint32_t sub_block_y = 0; + xetla_vector reg_tmp = 0; + uint32_t address_offset = payload_t::mem_transpose + ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) + : offset_x * sizeof(dtype) + (offset_y + 0) * payload.pitch_in_bytes; + + const uint32_t sub_block_offset_x = payload.base_x + offset_x + 0; + const uint32_t sub_block_offset_y = + payload.base_y + offset_y + sub_block_y; + const auto offset_ch_dim = + payload_t::trans ? sub_block_offset_x : sub_block_offset_y; + const auto size_ch_dim = + payload_t::trans ? payload.width_in_elems : payload.height_in_elems; + + xetla_mask pred = 1; + offset_ch_dim + num_channel > size_ch_dim + ? (xetla_vector_gen(offset_ch_dim, 1) < + size_ch_dim) + : 1; - reg_tmp = xetla_load_global< - load_dtype, - payload_t::simd_exec_size, - data_size::default_size, - L1, - L2, - payload_t::num_channel>( - payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, - pred); - if constexpr (payload_t::simd_exec_size > 1) { - xetla_vector reg_tmp_trans; + reg_tmp = xetla_load_global< + load_dtype, + payload_t::simd_exec_size, + data_size::default_size, + L1, + L2, + payload_t::num_channel>( + payload.base_ptr, + payload.channel_offset + payload.base_offset + address_offset, + 1); + + if constexpr (payload_t::simd_exec_size > 1) { + xetla_vector reg_tmp_trans; #pragma unroll - for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { - if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix - reg_tmp_trans.xetla_select( - iii * payload_t::simd_exec_size) = - reg_tmp.xetla_select< - payload_t::simd_exec_size, - payload_t::num_channel>(iii); - else // TODO (dingyi): Delete after driver fix - reg_tmp_trans.xetla_select( - iii * payload_t::simd_exec_size) = 0; - } - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp_trans; - } else { - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp; + for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { + if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix + reg_tmp_trans.xetla_select( + iii * payload_t::simd_exec_size) = + reg_tmp.xetla_select< + payload_t::simd_exec_size, + payload_t::num_channel>(iii); + else // TODO (dingyi): Delete after driver fix + reg_tmp_trans.xetla_select( + iii * payload_t::simd_exec_size) = 0; } + reg_sub + .xetla_select( + sub_block_y * tile_desc::block_size_x) + .xetla_format() = reg_tmp_trans; + } else { + reg_sub + .xetla_select( + sub_block_y * tile_desc::block_size_x) + .xetla_format() = reg_tmp; } + // } } } diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index c895614e0..4a86dea1d 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1097,7 +1097,8 @@ struct mem_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose; + static constexpr bool trans = + mem_transpose ^ reg_transpose && !std::is_same_v; static constexpr bool mem_transform = (sizeof(dtype) < 4) && (register_layout == reg_layout::vnni_tiled || diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index a8e4da602..824641138 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -16,10 +16,10 @@ #include #include "xetla.hpp" -#define UT_DEBUG 1 +// #define UT_DEBUG using namespace gpu::xetla; // The number of times the kernel is executed -constexpr int ITER = 200; +constexpr int ITER = 1; enum optional_feature { NONE, ACT_SHUFFLE }; @@ -48,6 +48,30 @@ class act_shuf_feature_next_token { static constexpr size_t shuf_load_block = 16; }; +class test_col_major { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 16; + static constexpr size_t mat_k = 16; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 16 * 1; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 16; + static constexpr size_t sg_k = 16; + static constexpr size_t dequant_s = 16; + + static constexpr size_t local_kslicing = 1; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x2; + using data_type_c = fp16; +}; + class test1_xehpg { public: // Extract the parameters required by different test cases @@ -531,6 +555,72 @@ int gemm_result_validate( return result ? 0 : 1; } +template < + gpu::xetla::group::quant_mode quant_type = + gpu::xetla::group::S4_FULLRANGE_NO_ZP, + typename data_type_acc_in = fp16, + typename data_type_b, + typename data_type_scale, + typename data_type_zero_pt> +std::tuple convert_int4( + data_type_b int4_data, + data_type_scale scale, + [[maybe_unused]] data_type_zero_pt zero_pt) { + uint8_t data_even = (int4_data & 0x0f) << 4; + int8_t data_0; + int8_t data_1; + memcpy(&data_0, &data_even, 1); + memcpy(&data_1, &int4_data, 1); + data_0 = data_0 >> 4; + data_1 = data_1 >> 4; + return std::make_tuple(fp16(data_0) * scale, fp16(data_1) * scale); +} +template < + size_t dequant_s, + mem_layout layout_b = mem_layout::row_major, + gpu::xetla::group::quant_mode quant_type = + gpu::xetla::group::S4_FULLRANGE_NO_ZP, + typename data_type_acc_in = fp16, + typename data_type_b, + typename data_type_scale, + typename data_type_zero_pt> +std::vector dequantize_weight( + size_t matrix_k, + size_t matrix_n, + data_type_b* b, + data_type_scale* scale, + data_type_zero_pt* zero_pt) { + std::vector b_out(matrix_k * matrix_n, 0); + constexpr size_t pack_radio = 2 * sizeof(data_type_b); + size_t width = layout_b == mem_layout::row_major ? matrix_n / pack_radio + : matrix_k / pack_radio; + size_t height = layout_b == mem_layout::row_major ? matrix_k : matrix_n; + size_t step = layout_b == mem_layout::row_major ? 1 : dequant_s / pack_radio; + + for (uint32_t i = 0; i < height; i++) { + for (uint32_t j = 0; j < width; j += step) { + int start_b_in = i * width + j; + int start_zero_pt_in = start_b_in; + int start_scale_in = + layout_b == mem_layout::row_major ? 0 : start_b_in / step; + + int start_out = + layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; + + for (uint32_t jj = 0; jj < step; jj++) { + std::tie( + b_out[start_out + pack_radio * jj], + b_out[start_out + pack_radio * jj + 1]) = + convert_int4( + b[start_b_in + jj], + scale[start_scale_in], + zero_pt[start_zero_pt_in + jj]); + } + } + } + return b_out; +} + template void dequantize_gemm_run(int iter) { using namespace gpu; @@ -553,7 +643,7 @@ void dequantize_gemm_run(int iter) { using data_type_zero_pt = int4x2; using data_type_scale = fp16; using data_type_acc_in = fp16; - using data_type_acc = float; // modify + using data_type_acc = float; using data_type_bias = fp16; constexpr mem_layout layout_a = Test::layout_a; @@ -721,7 +811,7 @@ void dequantize_gemm_run(int iter) { for (unsigned i = 0; i < size_scale; ++i) { scale_h[i] = random_float(); #ifdef UT_DEBUG - scale_h[i] = 1.f; + scale_h[i] = i; #endif } @@ -831,7 +921,7 @@ void dequantize_gemm_run(int iter) { size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); - int constexpr warm = 100; + int constexpr warm = 0; try { for (int i = 0; i < iter + warm; i++) { if (i >= warm) @@ -919,11 +1009,11 @@ void dequantize_gemm_run(int iter) { epilogue_args); cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - if (!gemm_op_t::can_implement(gemm_arg)) { - std::cout << "The arguments cannot be supported, aborting ... " - << std::endl; - FAIL(); - } + // if (!gemm_op_t::can_implement(gemm_arg)) { + // std::cout << "The arguments cannot be supported, aborting ... " + // << std::endl; + // FAIL(); + // } size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); @@ -954,29 +1044,9 @@ void dequantize_gemm_run(int iter) { // performance prof.print_profiling_result(profiling_selector::GPU); } - std::vector dequantize_b(matrix_k * matrix_n, 0); - for (uint32_t i = 0; i < matrix_k / dequant_s; i++) { - for (uint32_t j = 0; j < matrix_n / 2; j++) { - int start_in = i * dequant_s * matrix_n / 2 + j; - int start_out = i * dequant_s * matrix_n + j * 2; - int start_scale = i * size_scale_n + j * 2; - for (uint32_t ii = 0; ii < dequant_s; ii++) { - uint8_t data_in = B_h[start_in + ii * matrix_n / 2]; - uint8_t data_even = (data_in & 0x0f) << 4; - int8_t data_0; - int8_t data_1; - memcpy(&data_0, &data_even, 1); - memcpy(&data_1, &data_in, 1); - data_0 = data_0 >> 4; - data_1 = data_1 >> 4; - - dequantize_b[start_out + ii * matrix_n] = - fp16(data_0) * scale_h[start_scale]; - dequantize_b[start_out + ii * matrix_n + 1] = - fp16(data_1) * scale_h[start_scale + 1]; - } - } - } + std::vector dequantize_b = + dequantize_weight( + matrix_k, matrix_n, B_h, scale_h, zero_pt_h); queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); ASSERT_EQ( @@ -1017,28 +1087,28 @@ TYPED_TEST_P(dequantize_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_test, esimd); -using tests = ::testing::Types; +using tests = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemm_test_suite, dequantize_gemm_test, tests); -template -class dequantize_gemm_act_shuf_test : public ::testing::Test {}; -TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test); - -TYPED_TEST_P(dequantize_gemm_act_shuf_test, esimd) { - if constexpr (TypeParam::mat_m != 1) { - dequantize_gemm_run(ITER); - } else { - dequantize_gemm_run(ITER); - } -} +// template +// class dequantize_gemm_act_shuf_test : public ::testing::Test {}; +// TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test); -REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd); +// TYPED_TEST_P(dequantize_gemm_act_shuf_test, esimd) { +// if constexpr (TypeParam::mat_m != 1) { +// dequantize_gemm_run(ITER); +// } else { +// dequantize_gemm_run(ITER); +// } +// } -INSTANTIATE_TYPED_TEST_SUITE_P( - dequantize_gemm_act_shuf_test_suite, - dequantize_gemm_act_shuf_test, - tests); +// REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd); + +// INSTANTIATE_TYPED_TEST_SUITE_P( +// dequantize_gemm_act_shuf_test_suite, +// dequantize_gemm_act_shuf_test, +// tests); From 5949084e53cf080334e695c6245e006644cc5ce1 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Wed, 22 May 2024 04:04:22 +0800 Subject: [PATCH 02/34] save(some error with kslicing) --- .../group/gemm/impl/int4_dequantize_xe.hpp | 33 ++++++++++------ .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 25 +++++++++--- .../int4_dequantization_bias/main_client.cpp | 39 +++++++++++-------- 3 files changed, 63 insertions(+), 34 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index cdeb445e5..18cd0e298 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -319,6 +319,8 @@ class gemm_t< /// @brief Is the memory description of matB, including base, shape and /// coordinate. mem_desc_b_t matB_base_desc; + /// @brief The tile starting from K-dim + uint32_t inner_loop_start; /// @brief Is the total inner loop count required to compute the entire /// K-dim. uint32_t inner_loop_count; @@ -335,11 +337,13 @@ class gemm_t< inline arguments_t( mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_start, uint32_t loop_count, mem_desc_scale_t scale_desc, mem_desc_zero_pt_t zero_pt_desc) : matA_base_desc(matA_desc), matB_base_desc(matB_desc), + inner_loop_start(loop_start), inner_loop_count(loop_count), scale_base_desc(scale_desc), zero_pt_base_desc(zero_pt_desc) {} @@ -347,10 +351,12 @@ class gemm_t< inline arguments_t( mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_start, uint32_t loop_count, mem_desc_scale_t scale_desc) : matA_base_desc(matA_desc), matB_base_desc(matB_desc), + inner_loop_start(loop_start), inner_loop_count(loop_count), scale_base_desc(scale_desc) {} // Be aware of the risks: Rule of three (copy constructor, copy assignment, @@ -359,12 +365,14 @@ class gemm_t< inline arguments_t(const arguments_t& args) : matA_base_desc(args.matA_base_desc), matB_base_desc(args.matB_base_desc), + inner_loop_start(args.inner_loop_start), inner_loop_count(args.inner_loop_count), scale_base_desc(args.scale_base_desc), zero_pt_base_desc(args.zero_pt_base_desc) {} inline arguments_t& operator=(const arguments_t& args) { this->matA_base_desc = args.matA_base_desc; this->matB_base_desc = args.matB_base_desc; + this->inner_loop_start = args.inner_loop_start; this->inner_loop_count = args.inner_loop_count; this->scale_base_desc = args.scale_base_desc; this->zero_pt_base_desc = args.zero_pt_base_desc; @@ -373,11 +381,13 @@ class gemm_t< inline void init( mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_start, uint32_t loop_count, mem_desc_scale_t scale_desc, mem_desc_zero_pt_t zero_pt_desc) { matA_base_desc = matA_desc; matB_base_desc = matB_desc; + inner_loop_start = loop_start; inner_loop_count = loop_count; scale_base_desc = scale_desc; zero_pt_base_desc = zero_pt_desc; @@ -462,8 +472,8 @@ class gemm_t< sg_idx + barrier_count_y + nbarrier_base, nbarrier_role::producer_consumer); - int scale_prefetch_addr_i = args.matB_base_desc.coord.y; - int scale_load_addr_i = args.matB_base_desc.coord.y; + int scale_prefetch_addr_i = args.inner_loop_start; + int scale_load_addr_i = args.inner_loop_start; SW_BARRIER(); #pragma unroll for (uint32_t i = 0; i < stages; i++) { @@ -480,7 +490,7 @@ class gemm_t< subgroup::tile_prefetch( zero_pt_prefetch_payload); } - scale_prefetch_addr_i += dequant_s; + scale_prefetch_addr_i += k_stride; matA_prefetch_payload.template update_tdesc( matA_t::tile_size_x); matB_prefetch_payload.template update_tdesc( @@ -512,13 +522,12 @@ class gemm_t< matB, matB_payload); subgroup::tile_load( scale, scale_payload); - dump_mat(scale); if constexpr ( compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { subgroup::tile_load( zero_pt, zero_pt_payload); } - scale_load_addr_i += matB_t::tile_size_y; + scale_load_addr_i+= k_stride; SW_BARRIER(); if constexpr (stages != 0) { subgroup::tile_prefetch( @@ -534,7 +543,7 @@ class gemm_t< subgroup::tile_prefetch( zero_pt_prefetch_payload); } - scale_prefetch_addr_i += dequant_s; + scale_prefetch_addr_i += k_stride; } SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); @@ -542,7 +551,7 @@ class gemm_t< if ((scale_load_addr_i % dequant_s) == 0) { scale_payload.template update_tdesc( scale_t::tile_size_y); - zero_pt_payload.template update_tdesc( + zero_pt_payload.template update_tdesc( zero_pt_t::tile_size_y); } if constexpr (stages != 0) { @@ -625,12 +634,12 @@ class gemm_t< matB_blk.xetla_format() >> 4; cvt_blk_i32 = (cvt_blk_i8.xetla_format()); } - + constexpr uint32_t step = std::min(matB_acc_t::block_size_y, dequant_s); #pragma unroll - for (uint32_t k = 0; k < matB_acc_t::block_elems; k += dequant_s) { - dst_blk.xetla_select(k) = - cvt_blk_i32.xetla_select(k) * - scale.reg.xetla_select<1, 1>(k / dequant_s); + for (uint32_t k = 0; k < matB_acc_t::block_elems; k += step) { + dst_blk.xetla_select(k) = + cvt_blk_i32.xetla_select(k) * + scale.reg.xetla_select<1, 1>(k / step); } } } diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 6c98df456..2e9d0831a 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -647,10 +647,17 @@ class gemm_universal_t< args.matA_base, {boundary_k, boundary_m, args.matA_ld}, {start_k, start_m}); - mem_desc_b.init( - args.matB_base, - {boundary_n / pack_ratio, boundary_k, args.matB_ld / pack_ratio}, - {int(start_n / pack_ratio), start_k}); + if constexpr (gemm_t::is_col_major_b) { + mem_desc_b.init( + args.matB_base, + {boundary_n, boundary_k / pack_ratio, args.matB_ld / pack_ratio}, + {start_n, int(start_k / pack_ratio)}); + } else { + mem_desc_b.init( + args.matB_base, + {boundary_n / pack_ratio, boundary_k, args.matB_ld / pack_ratio}, + {int(start_n / pack_ratio), start_k}); + } uint32_t scale_size_y = ((args.matrix_k + dequant_s - 1) / dequant_s); mem_desc_scale_t mem_desc_scale( @@ -658,13 +665,18 @@ class gemm_universal_t< {args.matrix_n, scale_size_y, args.scale_ld}, {start_x_scale, start_y_scale}); + uint32_t inner_loop_start = (start_k + k_stride - 1) / k_stride; uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride; gemm_args_t gemm_args; if constexpr ( gemm_t::compute_policy::quant_type == group::quant_mode::S4_FULLRANGE_NO_ZP) { - gemm_args = - gemm_args_t(mem_desc_a, mem_desc_b, inner_loop_count, mem_desc_scale); + gemm_args = gemm_args_t( + mem_desc_a, + mem_desc_b, + inner_loop_start, + inner_loop_count, + mem_desc_scale); } else { mem_desc_zero_pt_t mem_desc_zero_pt( args.zero_pt_base, @@ -675,6 +687,7 @@ class gemm_universal_t< gemm_args = gemm_args_t( mem_desc_a, mem_desc_b, + inner_loop_start, inner_loop_count, mem_desc_scale, mem_desc_zero_pt); diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index 824641138..280a0ca67 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -52,16 +52,16 @@ class test_col_major { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 16; - static constexpr size_t mat_k = 16; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 16 * 1; + static constexpr size_t wg_n = 16 * 2; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 16; + static constexpr size_t sg_k = 32; + static constexpr size_t dequant_s = 128; - static constexpr size_t local_kslicing = 1; + static constexpr size_t local_kslicing = 2; static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; @@ -601,8 +601,7 @@ std::vector dequantize_weight( for (uint32_t j = 0; j < width; j += step) { int start_b_in = i * width + j; int start_zero_pt_in = start_b_in; - int start_scale_in = - layout_b == mem_layout::row_major ? 0 : start_b_in / step; + int start_scale_in = j / step * matrix_n + i; int start_out = layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; @@ -618,6 +617,12 @@ std::vector dequantize_weight( } } } + // for (size_t i = 0; i < matrix_n; i++) { + // for (size_t j = 0; j < matrix_k; j++) { + // std::cout << " " << float(b_out[i * matrix_k + j]); + // } + // std::cout << std::endl; + // } return b_out; } @@ -652,13 +657,13 @@ void dequantize_gemm_run(int iter) { constexpr size_t size_a = matrix_m * matrix_k; constexpr size_t size_b = matrix_k * matrix_n / 2; - constexpr size_t size_scale_m = matrix_k / dequant_s; + constexpr size_t size_scale_k = matrix_k / dequant_s; constexpr size_t size_scale_n = matrix_n; - constexpr size_t size_scale = size_scale_m * size_scale_n; + constexpr size_t size_scale = size_scale_k * size_scale_n; - constexpr size_t size_zero_pt_m = matrix_k / dequant_s; + constexpr size_t size_zero_pt_k = matrix_k / dequant_s; constexpr size_t size_zero_pt_n = matrix_n / 2; - constexpr size_t size_zero_pt = size_zero_pt_m * size_zero_pt_n; + constexpr size_t size_zero_pt = size_zero_pt_k * size_zero_pt_n; constexpr size_t size_c = matrix_m * matrix_n; constexpr size_t size_bias = matrix_n; @@ -666,8 +671,10 @@ void dequantize_gemm_run(int iter) { uint32_t lda = layout_a == mem_layout::row_major ? matrix_k : matrix_m; uint32_t ldb = layout_b == mem_layout::row_major ? matrix_n : matrix_k; uint32_t ldc = matrix_n; - // uint32_t ld_scale = size_scale_n; - // uint32_t ld_zero_pt = size_zero_pt_n; + uint32_t ld_scale = size_scale_n; + + // uint32_t ld_zero_pt = mem_layout::row_major ? size_zero_pt_n : + // size_zero_pt_k; // Turn on the enable_profiling property to facilitate subsequent profiling sycl::property_list properties{ @@ -811,7 +818,7 @@ void dequantize_gemm_run(int iter) { for (unsigned i = 0; i < size_scale; ++i) { scale_h[i] = random_float(); #ifdef UT_DEBUG - scale_h[i] = i; + scale_h[i] = 1.f; #endif } @@ -896,7 +903,7 @@ void dequantize_gemm_run(int iter) { C_d, ldc, scale_d, - matrix_n, + ld_scale, Acc_d, Cnt_d, epilogue_args); From 0669b12ba59d9dea11fc1ac36b320e01bd5fa5d2 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Wed, 22 May 2024 17:40:15 +0800 Subject: [PATCH 03/34] fix kslicing bug --- .../group/gemm/impl/int4_dequantize_xe.hpp | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 18cd0e298..285298d1c 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -490,12 +490,12 @@ class gemm_t< subgroup::tile_prefetch( zero_pt_prefetch_payload); } - scale_prefetch_addr_i += k_stride; + scale_prefetch_addr_i++; matA_prefetch_payload.template update_tdesc( matA_t::tile_size_x); matB_prefetch_payload.template update_tdesc( matB_t::tile_size_y); - if ((scale_prefetch_addr_i % dequant_s) == 0) { + if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); zero_pt_prefetch_payload.template update_tdesc( @@ -527,7 +527,7 @@ class gemm_t< subgroup::tile_load( zero_pt, zero_pt_payload); } - scale_load_addr_i+= k_stride; + scale_load_addr_i++; SW_BARRIER(); if constexpr (stages != 0) { subgroup::tile_prefetch( @@ -543,12 +543,12 @@ class gemm_t< subgroup::tile_prefetch( zero_pt_prefetch_payload); } - scale_prefetch_addr_i += k_stride; + scale_prefetch_addr_i++; } SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); - if ((scale_load_addr_i % dequant_s) == 0) { + if ((scale_load_addr_i % scale_addr_update_freq) == 0) { scale_payload.template update_tdesc( scale_t::tile_size_y); zero_pt_payload.template update_tdesc( @@ -575,11 +575,18 @@ class gemm_t< } subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); - if constexpr (is_col_major_b) { - tile_transpose(matB_acc); - } + // if constexpr (is_col_major_b) { + // tile_transpose(matB_acc); + // } SW_BARRIER(); - tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + XETLA_PRINT(); + XETLA_PRINT(); + XETLA_PRINT(); + // if constexpr (is_col_major_b) { + // tile_mma::mma(matAcc, matAcc, matA_acc, matB_acc); + // } else { + tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + // } SW_BARRIER(); if constexpr (enable_periodic_sync) { if ((i % sync_freq) == 0) { From aafe774b97ad01ee97059b221f2cbd27ae7b6fde Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 24 May 2024 22:58:48 +0800 Subject: [PATCH 04/34] save(g128 MTL 270Gflops bug on g32) save(g128 MTL 270Gflops bug on g32) add UT for gemv --- CMakeLists.txt | 2 +- .../group/gemm/compute_policy.hpp | 4 +- .../group/gemm/impl/int4_dequantize_xe.hpp | 201 ++++--- .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 2 +- include/subgroup/tile/common.hpp | 8 +- include/subgroup/tile/impl/fma_xe.hpp | 135 ++++- include/subgroup/tile/impl/load_xe.hpp | 4 +- include/subgroup/tile/impl/op_function.hpp | 10 +- include/subgroup/tile/impl/payload_xe.hpp | 11 +- tests/integration/CMakeLists.txt | 1 + .../int4_dequantization_bias/main_client.cpp | 189 ++----- tests/integration/gemv/CMakeLists.txt | 3 + tests/integration/gemv/int4/CMakeLists.txt | 9 + tests/integration/gemv/int4/main.cpp | 497 ++++++++++++++++++ 14 files changed, 842 insertions(+), 234 deletions(-) create mode 100644 tests/integration/gemv/CMakeLists.txt create mode 100644 tests/integration/gemv/int4/CMakeLists.txt create mode 100644 tests/integration/gemv/int4/main.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c07ef0fc..e2213a9b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,7 +66,7 @@ else() set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "${XETLA_OFFLINE_OPTIONS}") endif() -add_compile_options(-fsycl -fsycl-device-code-split=per_kernel) +add_compile_options(-fsycl -fsycl-device-code-split=per_kernel -ftemplate-backtrace-limit=0) add_compile_options(-Wall -Wextra -Werror) include(ProcessorCount) diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 8d7ffe33d..84d00b577 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -90,13 +90,13 @@ struct compute_policy_int4_dequantize< static constexpr uint32_t block_size_y_a = 16; using mma_attr = mma_attr_t; static constexpr uint32_t block_bytes_x_a = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 32; + (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 256; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); static constexpr uint32_t block_size_x_b = (mma_engine == mma_engine::xmx) ? mma_attr::mma_n_in_elem : 32; static constexpr uint32_t block_bytes_y_b = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 32; + (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 256; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 285298d1c..bc044fe99 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -152,10 +152,22 @@ class gemm_t< compute_policy::mma_engine == mma_engine::xmx ? ((sizeof(dtype_a) < sizeof(uint32_t)) && is_col_major_a) : false; + + static constexpr reg_layout reg_layout_a_ = + compute_policy::mma_engine == mma_engine::fpu + ? reg_layout::transpose_tiled + : is_vnni_tiled_a ? reg_layout::vnni_tiled + : reg_layout::tiled; + + static constexpr reg_layout reg_layout_b_ = + compute_policy::mma_engine == mma_engine::fpu ? reg_layout::tiled + : (sizeof(dtype_mma_b) < sizeof(uint32_t)) ? reg_layout::vnni_tiled + : reg_layout::tiled; + static constexpr reg_layout reg_layout_a = - compute_policy::mma_engine == mma_engine::xmx - ? (is_vnni_tiled_a ? reg_layout::vnni_tiled : reg_layout::tiled) - : reg_layout::transpose_tiled; + is_col_major_b ? reg_layout_b_ : reg_layout_a_; + static constexpr reg_layout reg_layout_b = + is_col_major_b ? reg_layout_a_ : reg_layout_b_; using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, @@ -182,13 +194,13 @@ class gemm_t< tile_size_y_b / pack_ratio, block_size_x_b, block_size_y_b / pack_ratio, - reg_layout::tiled>, + reg_layout_b>, subgroup::tile_desc_t< tile_size_x_b / pack_ratio, tile_size_y_b, block_size_x_b / pack_ratio, block_size_y_b, - reg_layout::tiled>>; + reg_layout_b>>; using matB_t = subgroup::tile_t; using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, @@ -203,8 +215,7 @@ class gemm_t< tile_size_y_b, block_size_x_b, block_size_y_b, - compute_policy::mma_engine == mma_engine::xmx ? reg_layout::vnni_tiled - : reg_layout::tiled>; + reg_layout_b>; using matB_acc_t = subgroup::tile_t; public: @@ -240,15 +251,22 @@ class gemm_t< mem_space::global, mem_desc_b_t::alignment>; - using matAcc_tile_desc_t = subgroup::tile_desc_t< - tile_size_x_c, - tile_size_y_c, - block_size_x_b, - block_size_y_a, + using matC_tile_desc_t = subgroup::tile_desc_t< // M X N (Y x X) + tile_size_x_c, // sg_n + tile_size_y_c, // sg_m == 1 + block_size_x_b, // + block_size_y_a, // == 1 reg_layout::tiled>; - using matAcc_t = subgroup::tile_t; + using matC_t = subgroup::tile_t; private: + using matAcc_tile_desc_t = subgroup::tile_desc_t< // N x K (Y x X) + block_size_y_b, // K + tile_size_x_b, // N + block_size_y_b, // K + block_size_x_b, // N + reg_layout::tiled>; + using matAcc_t = subgroup::tile_t; using scale_tile_desc_t = subgroup::tile_desc_t< tile_size_x_b, tile_size_y_scale, @@ -264,18 +282,8 @@ class gemm_t< using zero_pt_tile_desc_t = std::conditional_t< is_col_major_b, - subgroup::tile_desc_t< - tile_size_x_b, - (tile_size_y_zero_pt + pack_ratio - 1) / pack_ratio, - block_size_x_b, - (block_size_y_zero_pt + pack_ratio - 1) / pack_ratio, - reg_layout::tiled>, - subgroup::tile_desc_t< - tile_size_x_b / pack_ratio, - tile_size_y_zero_pt, - block_size_x_b / pack_ratio, - block_size_y_zero_pt, - reg_layout::tiled>>; + subgroup::tile_desc_t<16, 16, 16, 16, reg_layout::tiled>, + subgroup::tile_desc_t<16, 16, 16, 16, reg_layout::tiled>>; using zero_pt_t = subgroup::tile_t; using zero_pt_payload_t = subgroup::mem_payload_t< @@ -288,13 +296,22 @@ class gemm_t< using zero_pt_prefetch_payload_t = subgroup:: prefetch_payload_t; - using tile_mma = subgroup::tile_mma_t< - matAcc_t, - matAcc_t, - matB_acc_t, - matA_acc_t, - compute_policy::mma_engine, - arch_tag>; + using tile_mma = std::conditional_t< + is_col_major_b, + subgroup::tile_fma_t< + matC_t, + matC_t, + matAcc_t, + matB_acc_t, + matA_acc_t, + arch_tag>, + subgroup::tile_mma_t< + matC_t, + matC_t, + matB_acc_t, + matA_acc_t, + compute_policy::mma_engine, + arch_tag>>; static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; @@ -311,7 +328,8 @@ class gemm_t< static constexpr msg_type msg_type_b = matB_payload_t::message_type; /// @brief Arguments for gemm. - /// User should prepare matA_base_desc, matB_base_desc, inner_loop_count... + /// User should prepare matA_base_desc, matB_base_desc, + /// inner_loop_start inner_loop_count... struct arguments_t { /// @brief Is the memory description of matA, including base, shape and /// coordinate. @@ -359,9 +377,9 @@ class gemm_t< inner_loop_start(loop_start), inner_loop_count(loop_count), scale_base_desc(scale_desc) {} - // Be aware of the risks: Rule of three (copy constructor, copy assignment, - // destructor) Please check if you need to add self-define destructor inline - // ~arguments_t(){} + // Be aware of the risks: Rule of three (copy constructor, copy + // assignment, destructor) Please check if you need to add self-define + // destructor inline ~arguments_t(){} inline arguments_t(const arguments_t& args) : matA_base_desc(args.matA_base_desc), matB_base_desc(args.matB_base_desc), @@ -435,13 +453,13 @@ class gemm_t< /// @brief Main execution function for gemm. /// The basic process is load data -> matrix multiply. /// @param g Is the workgroup of the current tile. - /// @param matAcc Is the reference of the accumulation buffer. + /// @param matC Is the reference of the accumulation buffer. /// @param args Is the gemm::arguments_t. /// @param slm_base Is the slm base address. /// @param nbarrier_base Is the named barrier base. __XETLA_API KERNEL_FUNC void operator()( work_group_t& g, - matAcc_t& matAcc, + matC_t& matC, arguments_t args, [[maybe_unused]] uint32_t slm_base = 0, uint32_t nbarrier_base = 0) { @@ -453,6 +471,8 @@ class gemm_t< matB_t matB; scale_t scale; zero_pt_t zero_pt; + matAcc_t matAcc; + matAcc.reg = 0; matA_payload_t matA_payload(args.matA_base_desc); matB_payload_t matB_payload(args.matB_base_desc); @@ -575,18 +595,16 @@ class gemm_t< } subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); - // if constexpr (is_col_major_b) { - // tile_transpose(matB_acc); - // } + // XETLA_PRINT(); // 2 32(K) 2 16(K) + // dump_mat(matB_acc); + // dump_mat(scale); SW_BARRIER(); - XETLA_PRINT(); - XETLA_PRINT(); - XETLA_PRINT(); - // if constexpr (is_col_major_b) { - // tile_mma::mma(matAcc, matAcc, matA_acc, matB_acc); - // } else { - tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); - // } + if constexpr ( + is_col_major_b && compute_policy::mma_engine == mma_engine::fpu) { + tile_mma::fma(matAcc, matAcc, matB_acc, matA_acc); + } else { + tile_mma::mma(matC, matC, matB_acc, matA_acc); + } SW_BARRIER(); if constexpr (enable_periodic_sync) { if ((i % sync_freq) == 0) { @@ -602,6 +620,10 @@ class gemm_t< } } SW_BARRIER(); + if constexpr ( + is_col_major_b && compute_policy::mma_engine == mma_engine::fpu) { + tile_mma::reduce_acc_k(matAcc, matC); + } } private: @@ -613,7 +635,6 @@ class gemm_t< // no tail, because this is matB constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; - #pragma unroll for (uint32_t i = 0; i < num_block_y; ++i) { #pragma unroll @@ -641,12 +662,29 @@ class gemm_t< matB_blk.xetla_format() >> 4; cvt_blk_i32 = (cvt_blk_i8.xetla_format()); } - constexpr uint32_t step = std::min(matB_acc_t::block_size_y, dequant_s); + constexpr uint32_t step = std::min(block_size_y_b, dequant_s); + #pragma unroll - for (uint32_t k = 0; k < matB_acc_t::block_elems; k += step) { - dst_blk.xetla_select(k) = - cvt_blk_i32.xetla_select(k) * - scale.reg.xetla_select<1, 1>(k / step); + for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { + for (uint32_t jj = 0; jj < block_size_x_b; jj++) { + uint32_t offset_y_in_tile = i * block_size_y_b + ii; + uint32_t offset_x_in_tile = j * block_size_x_b + jj; + + uint32_t scale_idx = + (offset_y_in_tile) / dequant_s * scale_t::block_size_x + + offset_x_in_tile; + // uint32_t scale_idx = + // (k + (i * num_block_x + j) * matB_acc_t::block_elems) / step; + + dst_blk.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i32.xetla_select(jj * block_size_y_b + ii) * + scale.reg.xetla_select<1, 1>(scale_idx); + + // sycl::ext::oneapi::experimental::printf( + // "scale[%d] %f \n", + // scale_idx, + // float(sycl::half(scale.reg.xetla_select<1, 1>(scale_idx)))); + } } } } @@ -660,8 +698,9 @@ class gemm_t< // constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; // constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; - // constexpr uint32_t block_b_y_per_scale = dequant_s / block_size_y_b; - // constexpr uint32_t block_b_x_per_scale = dequant_s / block_size_x_b; + // constexpr uint32_t block_b_y_per_scale = dequant_s / + // block_size_y_b; constexpr uint32_t block_b_x_per_scale = dequant_s + // / block_size_x_b; // #pragma unroll // for (uint32_t i = 0; i < num_block_y; ++i) { // #pragma unroll @@ -671,29 +710,34 @@ class gemm_t< // .xetla_select( // block_id * matB_t::block_elems) // .xetla_format(); - // int scale_block_id = (i / block_b_y_per_scale * num_block_x + j); - // auto scale_vec = scale.reg.xetla_select( + // int scale_block_id = (i / block_b_y_per_scale * num_block_x + + // j); auto scale_vec = + // scale.reg.xetla_select( // scale_block_id * scale_t::block_size_x); - // auto dst_blk = matB_acc.reg.xetla_select( + // auto dst_blk = + // matB_acc.reg.xetla_select( // block_id * matB_acc_t::block_elems); // // 2: int8 includes 2 4bits data. // xetla_vector cvt_blk; - // xetla_vector cvt_blk_i32; - // if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { + // xetla_vector + // cvt_blk_i32; if constexpr (compute_policy::quant_type == + // quant_mode::S4_ASYM) { // auto zero_pt_vec = zero_pt.reg - // .xetla_select( + // .xetla_select( // scale_block_id * // zero_pt_t::block_size_x) // .xetla_format(); // cvt_blk.xetla_select(0) = matB_blk & - // 0x0f; cvt_blk.xetla_select(1) = matB_blk + // 0x0f; cvt_blk.xetla_select(1) = + // matB_blk // >> 4; xetla_vector zero_pt_sub; // zero_pt_sub.xetla_select(0) = // zero_pt_vec & 0x0f; - // zero_pt_sub.xetla_select(1) = zero_pt_vec + // zero_pt_sub.xetla_select(1) = + // zero_pt_vec // >> 4; xetla_vector // zero_pt_blk; // #pragma unroll @@ -708,10 +752,12 @@ class gemm_t< // zero_pt_blk.xetla_format()); // } // if constexpr ( - // compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - // xetla_vector cvt_blk_i8; - // cvt_blk_i8.xetla_select(0) = matB_blk & - // 0x0f; cvt_blk_i8.xetla_select(0) = + // compute_policy::quant_type == + // quant_mode::S4_FULLRANGE_NO_ZP) { + // xetla_vector + // cvt_blk_i8; cvt_blk_i8.xetla_select(0) = matB_blk & 0x0f; + // cvt_blk_i8.xetla_select(0) = // cvt_blk_i8.xetla_select(0) << 4; // cvt_blk_i8.xetla_select(0) = // cvt_blk_i8.xetla_select(0) >> 4; @@ -724,7 +770,8 @@ class gemm_t< // sizeof(dtype_mma_b); xetla_vector // temp_blk; - // temp_blk.xetla_select(0) = + // temp_blk.xetla_select(0) + // = // cvt_blk_i32; // #pragma unroll @@ -738,7 +785,8 @@ class gemm_t< // } // } - // xetla_vector scale_blk; + // xetla_vector + // scale_blk; // #pragma unroll // for (uint32_t row = 0; row < vnni_rows; row++) { // scale_blk.xetla_select(row) = @@ -756,7 +804,8 @@ class gemm_t< // } else { // #pragma unroll // for (uint32_t k = 0; k < block_size_y_b; k++) { - // dst_blk.xetla_select(k * block_size_x_b) = + // dst_blk.xetla_select(k * block_size_x_b) + // = // cvt_blk_i32.xetla_select( // k * block_size_x_b) * // scale_vec; @@ -776,13 +825,11 @@ class gemm_t< args.matA_base_desc.update_coord_y(tile_offset_m); if constexpr (is_col_major_b) { args.matB_base_desc.update_coord_x(tile_offset_n); - args.scale_base_desc.update_coord_x(tile_offset_n); - args.zero_pt_base_desc.update_coord_x(tile_offset_n); } else { args.matB_base_desc.update_coord_x(tile_offset_n / pack_ratio); - args.scale_base_desc.update_coord_x(tile_offset_n); - args.zero_pt_base_desc.update_coord_x(tile_offset_n / pack_ratio); } + args.scale_base_desc.update_coord_x(tile_offset_n); + args.zero_pt_base_desc.update_coord_x(tile_offset_n / pack_ratio); } }; /// @} xetla_gemm diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 2e9d0831a..d2e905f78 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -95,7 +95,7 @@ class gemm_universal_t< using dtype_c = typename mem_desc_c_t::dtype; using dtype_scale = typename mem_desc_scale_t::dtype; using dtype_zero_pt = typename mem_desc_zero_pt_t::dtype; - using matAcc_t = typename gemm_t::matAcc_t; + using matAcc_t = typename gemm_t::matC_t; using dtype_acc = typename matAcc_t::dtype; using mem_desc_acc_t = mem_desc_t; diff --git a/include/subgroup/tile/common.hpp b/include/subgroup/tile/common.hpp index 9385c700f..0e2cda92d 100644 --- a/include/subgroup/tile/common.hpp +++ b/include/subgroup/tile/common.hpp @@ -331,8 +331,12 @@ struct msg_type_query { : msg_type::scatter); }; -template -constexpr msg_type msg_type_v = msg_type_query::value; +template < + typename tile_desc_, + mem_space memory_space, + mem_layout memory_layout = mem_layout::row_major> +constexpr msg_type msg_type_v = + msg_type_query::value; template < typename dtype, diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index c1ca0c6ff..77fca6b9c 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -24,6 +24,129 @@ namespace gpu::xetla::subgroup { /// @brief Is the tile mma operation functor, specialized for Xe and fpu engine. +template < + typename matDst_t_, + typename matSrc_t_, + typename matAcc_t_, + typename matB_t_, + typename matA_t_, + gpu_arch arch_tag_> +struct tile_fma_t { + using matA_t = matA_t_; + using matB_t = matB_t_; + using matSrc_t = matSrc_t_; + using matDst_t = matDst_t_; + using matAcc_t = matAcc_t_; + using dtype_a = typename matA_t::dtype; + using dtype_b = typename matB_t::dtype; + using dtype_src = typename matSrc_t::dtype; + using dtype_dst = typename matDst_t::dtype; + + using register_attr = + typename arch_attr_t::template register_attr<>; + + static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; + static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; + static constexpr uint32_t a_tile_elems = matA_t::tile_elems; + static constexpr uint32_t a_block_size_y = matA_t::block_size_y; + static constexpr uint32_t a_block_size_x = matA_t::block_size_x; + static constexpr uint32_t a_block_elems = matA_t::block_elems; + + static constexpr uint32_t b_tile_size_x = matB_t::tile_size_x; + static constexpr uint32_t b_tile_size_y = matB_t::tile_size_y; + static constexpr uint32_t b_tile_elems = matB_t::tile_elems; + static constexpr uint32_t b_block_size_x = matB_t::block_size_x; + static constexpr uint32_t b_block_size_y = matB_t::block_size_y; + static constexpr uint32_t b_block_elems = matB_t::block_elems; + + static constexpr uint32_t tile_size_m = a_tile_size_y; + static constexpr uint32_t tile_size_k = a_tile_size_x; + static constexpr uint32_t tile_size_n = b_tile_size_x; + static constexpr uint32_t block_size_m = a_block_size_y; + static constexpr uint32_t block_size_k = a_block_size_x; + static constexpr uint32_t block_size_n = b_block_size_x; + + static_assert( + a_tile_size_x == b_tile_size_y, + "matA tile k should match with matB tile k"); + static_assert( + a_block_size_x == b_block_size_y, + "matA block k should match with matB block k"); + static_assert( + b_block_size_y == matAcc_t::block_size_x, + "matA block k should match with matAcc block k"); + static_assert( + b_block_size_x == matAcc_t::block_size_y, + "matb block n should match with matAcc block n"); + + static_assert(tile_size_m == 1, "matA tile m must be 1"); + static_assert(a_block_size_y == 1, "matA block m must be 1"); + __XETLA_API static void fma( + matAcc_t& dst, + matAcc_t& src, + matB_t& b, + matA_t& a) { +#pragma unroll + for (uint32_t k = 0; k < tile_size_k / block_size_k; k++) { + auto a_block = + a.reg.xetla_select(k * block_size_k); +#pragma unroll + for (uint32_t n = 0; n < tile_size_n / block_size_n; n++) { + uint32_t b_block_idx = n * tile_size_k / block_size_k + k; + auto b_block = b.reg.xetla_select( + b_block_idx * matB_t::block_elems); + + uint32_t src_dst_idx = n * block_size_n; + auto src_block = + src.reg.xetla_select(src_dst_idx); + auto dst_block = + dst.reg.xetla_select(src_dst_idx); + fma_core( + dst_block, src_block, b_block, a_block); + } + } + } + template + __XETLA_API static void fma_core( + xetla_vector_ref __REF__ dst, + xetla_vector_ref __REF__ src, + xetla_vector_ref __REF__ b_block, + xetla_vector_ref __REF__ a_block) { + static_assert(blk_m == 1, "block m must be 1"); + auto dst_blk_2d = dst.xetla_format(); + auto src_blk_2d = src.xetla_format(); + auto b_blk_2d = b_block.xetla_format(); + auto a_blk_2d = a_block.xetla_format(); +#pragma unroll + for (uint32_t n = 0; n < blk_n; n++) { + dst_blk_2d.row(n) = b_blk_2d.row(n) * a_blk_2d.row(0) + src_blk_2d.row(n); + } + } + __XETLA_API static void reduce_acc_k(matAcc_t& matAcc, matDst_t_& matC) { + // matC [tx,ty,bx,by](matmul): tile_n, 1, block_n, 1 + // matAcc[tx,ty,bx,by](matmul): tile_n, block_k, block_n, block_k + // matAcc[tx,ty,bx,by](memory): block_k, tile_n, block_k, block_n + + static_assert( + matDst_t_::tile_size_y == 1 && matDst_t_::block_size_y == 1, + "matDst_t_ tile m and block m should match be 1"); + static_assert( + matAcc_t::tile_size_y == matDst_t_::tile_size_x, + "matAcc_t tile n should match with matDst_t_ tile n"); + static_assert( + matAcc_t::block_size_y == matDst_t_::block_size_x, + "matAcc_t block n should match with matDst_t_ block n"); + static constexpr auto block_k = matAcc_t::block_size_x; + static constexpr auto tile_n = matAcc_t::tile_size_y; + using dtype = matAcc_t::dtype; + + matC.reg = + recur_col_reduce(matAcc.reg); + } +}; + +/// @brief Is the tile mma operation functor, specialized for Xe and fpu +/// engine. template < typename matAcc_dst_t_, typename matAcc_src_t_, @@ -60,8 +183,8 @@ struct tile_mma_t< static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; static constexpr uint32_t a_tile_elems = matA_t::tile_elems; - static constexpr uint32_t a_block_size_w = matA_t::block_size_y; - static constexpr uint32_t a_block_size_h = matA_t::block_size_x; + static constexpr uint32_t a_block_size_y = matA_t::block_size_y; + static constexpr uint32_t a_block_size_x = matA_t::block_size_x; static constexpr uint32_t a_block_elems = matA_t::block_elems; static constexpr uint32_t b_tile_size_x = matB_t::tile_size_x; @@ -76,7 +199,7 @@ struct tile_mma_t< static constexpr uint32_t tile_size_n = matDst_t::tile_size_x; static constexpr uint32_t tile_elems = tile_size_m * tile_size_n; static constexpr uint32_t block_size_n = matDst_t::block_size_x; - static constexpr uint32_t block_size_k = a_block_size_h; + static constexpr uint32_t block_size_k = a_block_size_x; static constexpr uint32_t block_size_m = matDst_t::block_size_y; static constexpr uint32_t block_elems = block_size_m * block_size_n; @@ -90,7 +213,7 @@ struct tile_mma_t< tile_size_n == matB_t::tile_size_x, "matAcc tile n should match with matB tile n"); static_assert( - block_size_m == a_block_size_w, + block_size_m == a_block_size_y, "matAcc block m should match with matA block m"); static_assert( block_size_n == b_block_size_x, @@ -194,7 +317,7 @@ struct tile_mma_t< constexpr uint32_t tail_start_m = tile_size_m / block_size_m * block_size_m; constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; - constexpr uint32_t a_tail_blk_elems = a_block_size_h * a_tail_blk_w; + constexpr uint32_t a_tail_blk_elems = a_block_size_x * a_tail_blk_w; constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; auto a_block = a.reg.xetla_select( @@ -236,7 +359,7 @@ struct tile_mma_t< constexpr uint32_t tail_start_m = tile_size_m / block_size_m * block_size_m; constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; - constexpr uint32_t a_tail_blk_elems = a_block_size_h * a_tail_blk_w; + constexpr uint32_t a_tail_blk_elems = a_block_size_x * a_tail_blk_w; constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; auto a_block = a.reg.xetla_select( diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 71ba2cb5a..6d13356c3 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -26,9 +26,9 @@ namespace gpu::xetla::subgroup { namespace detail { -template +template struct check_load_type { - static constexpr bool is_lsc_gather = true; + static constexpr bool is_lsc_gather = is_lsc_gather_; static constexpr bool is_global_block_2d = (payload_t::memory_space == mem_space::global && (payload_t::message_type == msg_type::block_2d)); diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 85e83b45b..a4796815e 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -706,8 +706,8 @@ layout_convert(T_dst& dst, T_src& src) { template void dump_mat( T mat, - size_t tile_x = T::tile_size_x, - size_t tile_y = T::tile_size_y) { + size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x, + size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) { #pragma unroll for (size_t row = 0; row < tile_y; row++) { #pragma unroll @@ -715,7 +715,7 @@ void dump_mat( sycl::ext::oneapi::experimental::printf( "%d ", (int)(sycl::half)mat.reg[row * tile_x + col]); } - sycl::ext::oneapi::experimental::printf("\n "); + sycl::ext::oneapi::experimental::printf("\n"); } sycl::ext::oneapi::experimental::printf("\n "); } @@ -728,9 +728,9 @@ void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) { sycl::ext::oneapi::experimental::printf( "%d ", (int)(sycl::half)mat[row * tile_x + col]); } - sycl::ext::oneapi::experimental::printf("\n "); + sycl::ext::oneapi::experimental::printf("\n"); } - sycl::ext::oneapi::experimental::printf("\n "); + sycl::ext::oneapi::experimental::printf("\n"); } } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 4a86dea1d..70cbe2a0a 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -436,6 +436,7 @@ struct mem_payload_t< (bytes_per_row % sizeof(uint32_t) == 0), uint32_t, dtype>::type>::type; + static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype); uint64_t base_offset; @@ -1084,11 +1085,6 @@ struct mem_payload_t< static constexpr uint32_t block_size_x = tile_desc::block_size_x; static constexpr uint32_t block_size_y = tile_desc::block_size_y; - static constexpr uint32_t block_per_row_bytes = - alignment_in_bytes < block_size_x * sizeof(dtype) - ? alignment_in_bytes - : block_size_x * sizeof(dtype); - using this_payload_t = mem_payload_t; @@ -1110,6 +1106,11 @@ struct mem_payload_t< block_size_x * block_size_y * sizeof(dtype); // using mem_dtype = uint32_t; + + static constexpr uint32_t block_per_row_bytes = std::min( + (mem_transpose ? block_size_y : block_size_x) * uint32_t(sizeof(dtype)), + alignment_in_bytes); + using mem_dtype = typename std::conditional< (block_per_row_bytes % sizeof(uint64_t) == 0), uint64_t, diff --git a/tests/integration/CMakeLists.txt b/tests/integration/CMakeLists.txt index 8e0e8e64b..fcb4a9fc2 100644 --- a/tests/integration/CMakeLists.txt +++ b/tests/integration/CMakeLists.txt @@ -21,6 +21,7 @@ endfunction() # add_subdirectory(vector_add) add_subdirectory(gemm) +add_subdirectory(gemv) add_subdirectory(row_reduction) add_subdirectory(layer_norm) add_subdirectory(data_transformer) diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index 280a0ca67..c979a19ef 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -16,10 +16,10 @@ #include #include "xetla.hpp" -// #define UT_DEBUG +#define UT_DEBUG 1 using namespace gpu::xetla; // The number of times the kernel is executed -constexpr int ITER = 1; +constexpr int ITER = 200; enum optional_feature { NONE, ACT_SHUFFLE }; @@ -48,30 +48,6 @@ class act_shuf_feature_next_token { static constexpr size_t shuf_load_block = 16; }; -class test_col_major { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 16 * 2; - static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 32; - static constexpr size_t dequant_s = 128; - - static constexpr size_t local_kslicing = 2; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; - class test1_xehpg { public: // Extract the parameters required by different test cases @@ -555,77 +531,6 @@ int gemm_result_validate( return result ? 0 : 1; } -template < - gpu::xetla::group::quant_mode quant_type = - gpu::xetla::group::S4_FULLRANGE_NO_ZP, - typename data_type_acc_in = fp16, - typename data_type_b, - typename data_type_scale, - typename data_type_zero_pt> -std::tuple convert_int4( - data_type_b int4_data, - data_type_scale scale, - [[maybe_unused]] data_type_zero_pt zero_pt) { - uint8_t data_even = (int4_data & 0x0f) << 4; - int8_t data_0; - int8_t data_1; - memcpy(&data_0, &data_even, 1); - memcpy(&data_1, &int4_data, 1); - data_0 = data_0 >> 4; - data_1 = data_1 >> 4; - return std::make_tuple(fp16(data_0) * scale, fp16(data_1) * scale); -} -template < - size_t dequant_s, - mem_layout layout_b = mem_layout::row_major, - gpu::xetla::group::quant_mode quant_type = - gpu::xetla::group::S4_FULLRANGE_NO_ZP, - typename data_type_acc_in = fp16, - typename data_type_b, - typename data_type_scale, - typename data_type_zero_pt> -std::vector dequantize_weight( - size_t matrix_k, - size_t matrix_n, - data_type_b* b, - data_type_scale* scale, - data_type_zero_pt* zero_pt) { - std::vector b_out(matrix_k * matrix_n, 0); - constexpr size_t pack_radio = 2 * sizeof(data_type_b); - size_t width = layout_b == mem_layout::row_major ? matrix_n / pack_radio - : matrix_k / pack_radio; - size_t height = layout_b == mem_layout::row_major ? matrix_k : matrix_n; - size_t step = layout_b == mem_layout::row_major ? 1 : dequant_s / pack_radio; - - for (uint32_t i = 0; i < height; i++) { - for (uint32_t j = 0; j < width; j += step) { - int start_b_in = i * width + j; - int start_zero_pt_in = start_b_in; - int start_scale_in = j / step * matrix_n + i; - - int start_out = - layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; - - for (uint32_t jj = 0; jj < step; jj++) { - std::tie( - b_out[start_out + pack_radio * jj], - b_out[start_out + pack_radio * jj + 1]) = - convert_int4( - b[start_b_in + jj], - scale[start_scale_in], - zero_pt[start_zero_pt_in + jj]); - } - } - } - // for (size_t i = 0; i < matrix_n; i++) { - // for (size_t j = 0; j < matrix_k; j++) { - // std::cout << " " << float(b_out[i * matrix_k + j]); - // } - // std::cout << std::endl; - // } - return b_out; -} - template void dequantize_gemm_run(int iter) { using namespace gpu; @@ -648,7 +553,7 @@ void dequantize_gemm_run(int iter) { using data_type_zero_pt = int4x2; using data_type_scale = fp16; using data_type_acc_in = fp16; - using data_type_acc = float; + using data_type_acc = float; // modify using data_type_bias = fp16; constexpr mem_layout layout_a = Test::layout_a; @@ -657,13 +562,13 @@ void dequantize_gemm_run(int iter) { constexpr size_t size_a = matrix_m * matrix_k; constexpr size_t size_b = matrix_k * matrix_n / 2; - constexpr size_t size_scale_k = matrix_k / dequant_s; + constexpr size_t size_scale_m = matrix_k / dequant_s; constexpr size_t size_scale_n = matrix_n; - constexpr size_t size_scale = size_scale_k * size_scale_n; + constexpr size_t size_scale = size_scale_m * size_scale_n; - constexpr size_t size_zero_pt_k = matrix_k / dequant_s; + constexpr size_t size_zero_pt_m = matrix_k / dequant_s; constexpr size_t size_zero_pt_n = matrix_n / 2; - constexpr size_t size_zero_pt = size_zero_pt_k * size_zero_pt_n; + constexpr size_t size_zero_pt = size_zero_pt_m * size_zero_pt_n; constexpr size_t size_c = matrix_m * matrix_n; constexpr size_t size_bias = matrix_n; @@ -671,10 +576,8 @@ void dequantize_gemm_run(int iter) { uint32_t lda = layout_a == mem_layout::row_major ? matrix_k : matrix_m; uint32_t ldb = layout_b == mem_layout::row_major ? matrix_n : matrix_k; uint32_t ldc = matrix_n; - uint32_t ld_scale = size_scale_n; - - // uint32_t ld_zero_pt = mem_layout::row_major ? size_zero_pt_n : - // size_zero_pt_k; + // uint32_t ld_scale = size_scale_n; + // uint32_t ld_zero_pt = size_zero_pt_n; // Turn on the enable_profiling property to facilitate subsequent profiling sycl::property_list properties{ @@ -903,7 +806,7 @@ void dequantize_gemm_run(int iter) { C_d, ldc, scale_d, - ld_scale, + matrix_n, Acc_d, Cnt_d, epilogue_args); @@ -928,7 +831,7 @@ void dequantize_gemm_run(int iter) { size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); - int constexpr warm = 0; + int constexpr warm = 100; try { for (int i = 0; i < iter + warm; i++) { if (i >= warm) @@ -1016,11 +919,11 @@ void dequantize_gemm_run(int iter) { epilogue_args); cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - // if (!gemm_op_t::can_implement(gemm_arg)) { - // std::cout << "The arguments cannot be supported, aborting ... " - // << std::endl; - // FAIL(); - // } + if (!gemm_op_t::can_implement(gemm_arg)) { + std::cout << "The arguments cannot be supported, aborting ... " + << std::endl; + FAIL(); + } size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); @@ -1051,9 +954,29 @@ void dequantize_gemm_run(int iter) { // performance prof.print_profiling_result(profiling_selector::GPU); } - std::vector dequantize_b = - dequantize_weight( - matrix_k, matrix_n, B_h, scale_h, zero_pt_h); + std::vector dequantize_b(matrix_k * matrix_n, 0); + for (uint32_t i = 0; i < matrix_k / dequant_s; i++) { + for (uint32_t j = 0; j < matrix_n / 2; j++) { + int start_in = i * dequant_s * matrix_n / 2 + j; + int start_out = i * dequant_s * matrix_n + j * 2; + int start_scale = i * size_scale_n + j * 2; + for (uint32_t ii = 0; ii < dequant_s; ii++) { + uint8_t data_in = B_h[start_in + ii * matrix_n / 2]; + uint8_t data_even = (data_in & 0x0f) << 4; + int8_t data_0; + int8_t data_1; + memcpy(&data_0, &data_even, 1); + memcpy(&data_1, &data_in, 1); + data_0 = data_0 >> 4; + data_1 = data_1 >> 4; + + dequantize_b[start_out + ii * matrix_n] = + fp16(data_0) * scale_h[start_scale]; + dequantize_b[start_out + ii * matrix_n + 1] = + fp16(data_1) * scale_h[start_scale + 1]; + } + } + } queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); ASSERT_EQ( @@ -1094,28 +1017,28 @@ TYPED_TEST_P(dequantize_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_test, esimd); -using tests = ::testing::Types; +using tests = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemm_test_suite, dequantize_gemm_test, tests); -// template -// class dequantize_gemm_act_shuf_test : public ::testing::Test {}; -// TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test); - -// TYPED_TEST_P(dequantize_gemm_act_shuf_test, esimd) { -// if constexpr (TypeParam::mat_m != 1) { -// dequantize_gemm_run(ITER); -// } else { -// dequantize_gemm_run(ITER); -// } -// } +template +class dequantize_gemm_act_shuf_test : public ::testing::Test {}; +TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test); + +TYPED_TEST_P(dequantize_gemm_act_shuf_test, esimd) { + if constexpr (TypeParam::mat_m != 1) { + dequantize_gemm_run(ITER); + } else { + dequantize_gemm_run(ITER); + } +} -// REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd); +REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd); -// INSTANTIATE_TYPED_TEST_SUITE_P( -// dequantize_gemm_act_shuf_test_suite, -// dequantize_gemm_act_shuf_test, -// tests); +INSTANTIATE_TYPED_TEST_SUITE_P( + dequantize_gemm_act_shuf_test_suite, + dequantize_gemm_act_shuf_test, + tests); \ No newline at end of file diff --git a/tests/integration/gemv/CMakeLists.txt b/tests/integration/gemv/CMakeLists.txt new file mode 100644 index 000000000..8620dac8b --- /dev/null +++ b/tests/integration/gemv/CMakeLists.txt @@ -0,0 +1,3 @@ +include_directories(${CMAKE_SOURCE_DIR}/tests/integration/gemv) + +add_subdirectory(int4) diff --git a/tests/integration/gemv/int4/CMakeLists.txt b/tests/integration/gemv/int4/CMakeLists.txt new file mode 100644 index 000000000..f38ac28eb --- /dev/null +++ b/tests/integration/gemv/int4/CMakeLists.txt @@ -0,0 +1,9 @@ +get_filename_component(ProjectId ${CMAKE_CURRENT_SOURCE_DIR} NAME) +string(REPLACE " " "_" ProjectId ${ProjectId}) +set(ProjectIdClient ${ProjectId}) +set(ProjectIdXe ${ProjectId}) +string(PREPEND ProjectIdClient "gemv_") + +FILE(GLOB src main.cpp) +add_integration_test(${ProjectIdClient} ${src}) + diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp new file mode 100644 index 000000000..12c269637 --- /dev/null +++ b/tests/integration/gemv/int4/main.cpp @@ -0,0 +1,497 @@ +/******************************************************************************* + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include +#include "xetla.hpp" +// #define UT_DEBUG +using namespace gpu::xetla; +// The number of times the kernel is executed +constexpr int ITER = 200; + +class test_col_major { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 1; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 1; + static constexpr size_t sg_k = 1024; + static constexpr size_t dequant_s = 128; + + static constexpr size_t local_kslicing = 1; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x2; + using data_type_c = fp16; +}; + +template < + typename data_type_a, + typename data_type_b, + typename data_type_c, + typename data_type_bias, + typename data_type_acc = float> +int gemm_result_validate( + data_type_a* A, + data_type_b* B, + data_type_c* C, + data_type_bias* bias, + uint32_t m, + uint32_t k, + uint32_t n, + mem_layout mem_layout_a_ = mem_layout::row_major, + mem_layout mem_layout_b_ = mem_layout::row_major) { + buff_cmp::buff_vals data(C, m, n, n); + std::vector gold_C(m * n, 0); + get_gemm_gold( + m, n, k, mem_layout_a_, mem_layout_b_, A, B, gold_C.data()); + + // BiasAdd + for (uint32_t i = 0; i < gold_C.size(); ++i) { + uint32_t col = i % n; + gold_C[i] += bias[col]; + } + + buff_cmp::buff_vals other(gold_C.data(), m, n, n); + + bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation"); + + std::cout << (!result ? "FAILED\n" : "PASSED\n"); + return result ? 0 : 1; +} + +template < + gpu::xetla::group::quant_mode quant_type = + gpu::xetla::group::S4_FULLRANGE_NO_ZP, + typename data_type_acc_in = fp16, + typename data_type_b, + typename data_type_scale, + typename data_type_zero_pt> +std::tuple convert_int4( + data_type_b int4_data, + data_type_scale scale, + [[maybe_unused]] data_type_zero_pt zero_pt) { + uint8_t data_even = (int4_data & 0x0f) << 4; + int8_t data_0; + int8_t data_1; + memcpy(&data_0, &data_even, 1); + memcpy(&data_1, &int4_data, 1); + data_0 = data_0 >> 4; + data_1 = data_1 >> 4; + return std::make_tuple(fp16(data_0) * scale, fp16(data_1) * scale); +} +template < + size_t dequant_s, + mem_layout layout_b = mem_layout::row_major, + gpu::xetla::group::quant_mode quant_type = + gpu::xetla::group::S4_FULLRANGE_NO_ZP, + typename data_type_acc_in = fp16, + typename data_type_b, + typename data_type_scale, + typename data_type_zero_pt> +std::vector dequantize_weight( + size_t matrix_k, + size_t matrix_n, + data_type_b* b, + data_type_scale* scale, + data_type_zero_pt* zero_pt) { + std::vector b_out(matrix_k * matrix_n, 0); + constexpr size_t pack_radio = 2 * sizeof(data_type_b); + size_t width = layout_b == mem_layout::row_major ? matrix_n / pack_radio + : matrix_k / pack_radio; + size_t height = layout_b == mem_layout::row_major ? matrix_k : matrix_n; + size_t step = layout_b == mem_layout::row_major ? 1 : dequant_s / pack_radio; + + for (uint32_t i = 0; i < height; i++) { + for (uint32_t j = 0; j < width; j += step) { + int start_b_in = i * width + j; + int start_zero_pt_in = start_b_in; + + int start_scale_in = j / step * matrix_n + i; + + int start_out = + layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; + + for (uint32_t jj = 0; jj < step; jj++) { + std::tie( + b_out[start_out + pack_radio * jj], + b_out[start_out + pack_radio * jj + 1]) = + convert_int4( + b[start_b_in + jj], + scale[start_scale_in], + zero_pt[start_zero_pt_in + jj]); + } + } + } + return b_out; +} + +template +void dequantize_gemv_run(int iter) { + using namespace gpu; + // Accept incoming parameters + constexpr size_t matrix_m = Test::mat_m; + constexpr size_t matrix_n = Test::mat_n; + constexpr size_t matrix_k = Test::mat_k; + constexpr uint32_t global_kslicing = Test::global_kslicing; + constexpr uint32_t local_kslicing = Test::local_kslicing; + + constexpr size_t wg_tile_m = Test::wg_m; + constexpr size_t wg_tile_n = Test::wg_n; + constexpr size_t sg_tile_m = Test::sg_m; + constexpr size_t sg_tile_n = Test::sg_n; + constexpr size_t sg_tile_k = Test::sg_k; + constexpr size_t dequant_s = Test::dequant_s; + using data_type_a = typename Test::data_type_a; + using data_type_b = typename Test::data_type_b; + using data_type_c = typename Test::data_type_c; + using data_type_zero_pt = int4x2; + using data_type_scale = fp16; + using data_type_acc_in = fp16; + using data_type_acc = float; + using data_type_bias = float; + + constexpr mem_layout layout_a = Test::layout_a; + constexpr mem_layout layout_b = Test::layout_b; + + constexpr size_t size_a = matrix_m * matrix_k; + constexpr size_t size_b = matrix_k * matrix_n / 2; + + constexpr size_t size_scale_k = matrix_k / dequant_s; + constexpr size_t size_scale_n = matrix_n; + constexpr size_t size_scale = size_scale_k * size_scale_n; + + constexpr size_t size_zero_pt_k = matrix_k / dequant_s; + constexpr size_t size_zero_pt_n = matrix_n / 2; + constexpr size_t size_zero_pt = size_zero_pt_k * size_zero_pt_n; + + constexpr size_t size_c = matrix_m * matrix_n; + constexpr size_t size_bias = matrix_n; + + uint32_t lda = layout_a == mem_layout::row_major ? matrix_k : matrix_m; + uint32_t ldb = layout_b == mem_layout::row_major ? matrix_n : matrix_k; + uint32_t ldc = matrix_n; + uint32_t ld_scale = size_scale_n; + + // uint32_t ld_zero_pt = mem_layout::row_major ? size_zero_pt_n : + // size_zero_pt_k; + + // Turn on the enable_profiling property to facilitate subsequent profiling + sycl::property_list properties{ + sycl::property::queue::enable_profiling(), + sycl::property::queue::in_order()}; + auto queue = sycl::queue(properties); + auto context = queue.get_info(); + auto device = queue.get_info(); + + std::cout << "Running on " << device.get_info() << "\n"; + + using tile_shape = + xetla::group::tile_shape_t; + static constexpr uint32_t periodic_sync_interval = 0; + static constexpr uint32_t prefetch_distance = 0; + + using mem_desc_a_t = xetla::mem_desc_t< + data_type_a, + layout_a, + mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(data_type_a)>; + using mem_desc_b_t = xetla::mem_desc_t< + data_type_b, + layout_b, + mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(data_type_b)>; + using mem_desc_c_t = xetla::mem_desc_t< + data_type_c, + mem_layout::row_major, + mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(data_type_c)>; + + using mem_desc_bias_t = xetla::mem_desc_t< + data_type_bias, + mem_layout::row_major, + mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(data_type_bias)>; + + using compute_attr = xetla::group:: + compute_attr_t; + using perf_tuning_knob = xetla::group:: + perf_tuning_knob_t; + + using compute_policy = xetla::group::compute_policy_int4_dequantize< + compute_attr, + perf_tuning_knob, + data_type_scale, + data_type_zero_pt, + gpu::xetla::group::quant_mode::S4_FULLRANGE_NO_ZP, + dequant_s, + Test::mma_eng, + Test::arch>; + + using gemm_t = xetla::group:: + gemm_t; + + using bias_op_t = + gpu::xetla::subgroup::bias_add_op_t; + + using tile_op_t = gpu::xetla::subgroup::chained_tile_op_t; + + using epilogue_t = xetla::group::epilogue_t< + xetla::group::epilogue_policy_tile_op, + tile_shape, + mem_desc_c_t>; + + using group_swizzle = xetla::kernel::group_swizzle_default; + + using gemm_op_t = xetla::kernel::gemm_universal_t< + gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing< + group_swizzle, + global_kslicing, + local_kslicing>, + gemm_t, + epilogue_t>; + + size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n); + size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n); + + // Define and initialize the data required for the calculation + auto* A_h = static_cast( + malloc_host(size_a * sizeof(data_type_a), context)); + auto* B_h = static_cast( + malloc_host(size_b * sizeof(data_type_b), context)); + auto* C_h = static_cast( + malloc_host(size_c * sizeof(data_type_c), context)); + auto* Acc_h = static_cast( + malloc_host(size_acc * sizeof(data_type_acc), context)); + auto* Cnt_h = + static_cast(malloc_host(size_cnt * sizeof(uint32_t), context)); + auto* scale_h = static_cast( + malloc_host(size_scale * sizeof(data_type_scale), context)); + auto* zero_pt_h = static_cast( + malloc_host(size_zero_pt * sizeof(data_type_zero_pt), context)); + auto* bias_h = static_cast( + malloc_host(size_bias * sizeof(data_type_bias), context)); + + auto* A_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_a * sizeof(data_type_a), device, context)); + auto* B_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_b * sizeof(data_type_b), device, context)); + auto* C_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_c * sizeof(data_type_c), device, context)); + auto* Acc_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_acc * sizeof(data_type_acc), device, context)); + auto* Cnt_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_cnt * sizeof(uint32_t), device, context)); + auto* scale_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + size_scale * sizeof(data_type_scale), + device, + context)); + auto* zero_pt_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + size_zero_pt * sizeof(data_type_zero_pt), + device, + context)); + auto* bias_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + size_bias * sizeof(data_type_bias), + device, + context)); + + for (unsigned i = 0; i < size_a; ++i) { + A_h[i] = random_float(); +#ifdef UT_DEBUG + A_h[i] = 1.f; + // A_h[i] = layout_a == mem_layout::row_major + // ? (i % matrix_k + i / matrix_k * 100) + // : (i % matrix_m + i / matrix_m * 100); +#endif + } + + for (unsigned i = 0; i < size_b; ++i) { + B_h[i] = uint8_t(random_uint8()); +#ifdef UT_DEBUG + B_h[i] = 17; +#endif + } + + for (unsigned i = 0; i < size_scale; ++i) { + scale_h[i] = random_float(); +#ifdef UT_DEBUG + scale_h[i] = i + 1; +#endif + } + + for (unsigned i = 0; i < size_zero_pt; ++i) { + zero_pt_h[i] = 0; + } + + for (unsigned i = 0; i < size_c; ++i) { + C_h[i] = random_float(); + } + + for (unsigned i = 0; i < size_acc; ++i) { + Acc_h[i] = random_float(); + } + + for (unsigned i = 0; i < size_cnt; ++i) { + Cnt_h[i] = random_uint8(); + } + + for (unsigned i = 0; i < size_bias; ++i) { + bias_h[i] = random_float(); +#ifdef UT_DEBUG + bias_h[i] = 0.f; +#endif + } + + queue.memcpy((void*)A_d, (void*)A_h, size_a * sizeof(data_type_a)).wait(); + queue.memcpy((void*)B_d, (void*)B_h, size_b * sizeof(data_type_b)).wait(); + queue.memcpy((void*)C_d, (void*)C_h, size_c * sizeof(data_type_c)).wait(); + queue.memcpy((void*)Acc_d, (void*)Acc_h, size_acc * sizeof(data_type_acc)) + .wait(); + queue.memcpy((void*)Cnt_d, (void*)Cnt_h, size_cnt * sizeof(uint32_t)).wait(); + queue + .memcpy( + (void*)scale_d, (void*)scale_h, size_scale * sizeof(data_type_scale)) + .wait(); + queue + .memcpy( + (void*)zero_pt_d, + (void*)zero_pt_h, + size_zero_pt * sizeof(data_type_zero_pt)) + .wait(); + queue.memcpy((void*)bias_d, (void*)bias_h, size_bias * sizeof(data_type_bias)) + .wait(); + + queue.memset(Cnt_d, 0, size_cnt * sizeof(uint32_t)).wait(); + queue.memset(Acc_d, 0, size_acc * sizeof(data_type_acc)).wait(); + // set up gemm arguments + typename bias_op_t::shape_t bias_add_shape(matrix_n, 1, matrix_n); + using epilogue_args_t = epilogue_t::arguments_t; + + epilogue_args_t epilogue_args( + {// epilogue_args init list + // It accepts the base pointer to matrix D, and its dimensions + {bias_d, bias_add_shape}}); + typename gemm_op_t::template arguments_t gemm_arg( + matrix_m, + matrix_k, + matrix_n, + A_d, + lda, + B_d, + ldb, + C_d, + ldc, + scale_d, + ld_scale, + Acc_d, + Cnt_d, + epilogue_args); + + cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); + if (!gemm_op_t::can_implement(gemm_arg)) { + std::cout << "The arguments cannot be supported, aborting ... " + << std::endl; + FAIL(); + } + + size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; + profiling_helper prof("dequantize_gemm", ops, "gflops"); + int constexpr warm = 0; + try { + for (int i = 0; i < iter + warm; i++) { + if (i >= warm) + prof.cpu_start(); + auto e_esimd = queue.submit([&](handler& cgh) { + cgh.parallel_for(nd_range, [=](nd_item<3> item) SYCL_ESIMD_KERNEL { + // allocate slm and nbarrier resource + slm_barrier_init(); + gemm_op_t gemm_op; + gemm_op(item, gemm_arg); + }); + }); + if (i >= warm) { + e_esimd.wait(); + prof.cpu_end(); + prof.add_gpu_event(e_esimd); + } + } + } catch (cl::sycl::exception const& e) { + std::cout << "SYCL exception caught: " << e.what() << '\n'; + FAIL(); + } + + // performance + prof.print_profiling_result(profiling_selector::GPU); + // check result + std::vector dequantize_b = + dequantize_weight( + matrix_k, matrix_n, B_h, scale_h, zero_pt_h); + + queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); + ASSERT_EQ( + 0, + gemm_result_validate( + A_h, + dequantize_b.data(), + C_h, + bias_h, + matrix_m, + matrix_k, + matrix_n, + layout_a, + layout_b)); + + free(A_h, context); + free(B_h, context); + free(C_h, context); + free(scale_h, context); + free(zero_pt_h, context); + free(A_d, context); + free(B_d, context); + free(C_d, context); + free(scale_d, context); + free(zero_pt_d, context); + free(Acc_h, context); + free(Cnt_h, context); + free(Acc_d, context); + free(Cnt_d, context); +} + +template +class dequantize_gemv_test : public ::testing::Test {}; +TYPED_TEST_SUITE_P(dequantize_gemv_test); + +TYPED_TEST_P(dequantize_gemv_test, esimd) { + dequantize_gemv_run(ITER); +} + +REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd); +using tests = ::testing::Types; + +INSTANTIATE_TYPED_TEST_SUITE_P( + dequantize_gemv_test_suite, + dequantize_gemv_test, + tests); From 1b9a4437d6367d1dc55f056a8a8e2853b01a5275 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 24 May 2024 23:34:34 +0800 Subject: [PATCH 05/34] add Specialized for FPU --- .../group/gemm/compute_policy.hpp | 77 ++++++++++++++++--- .../group/gemm/impl/int4_dequantize_xe.hpp | 7 +- 2 files changed, 67 insertions(+), 17 deletions(-) diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 84d00b577..274b2fce4 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -25,7 +25,7 @@ namespace gpu::xetla::group { enum quant_mode : uint8_t { S4_ASYM, S4_FULLRANGE_NO_ZP }; -/// @brief Compute policy for unaligned shape and xmx engine. +/// @brief Compute policy for int4 dequant gemm. /// @tparam compute_attr_ Is compute-related attributes. /// @tparam perf_tuning_knob_ Is performance-related knobs. /// @tparam arch_tag_ Is the HW architecture. @@ -41,7 +41,7 @@ template < typename enable = void> struct compute_policy_int4_dequantize {}; -/// @brief Specialized for XeHpc and XeHpg architecture. +/// @brief Specialized for xmx engine. template < typename compute_attr_, typename perf_tuning_knob_, @@ -60,7 +60,8 @@ struct compute_policy_int4_dequantize< dequant_s_, mma_engine_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t<( + arch_tag_ <= gpu_arch::XeHpc && mma_engine_ == mma_engine::xmx)>> { using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; @@ -73,9 +74,7 @@ struct compute_policy_int4_dequantize< static constexpr mma_engine mma_engine = mma_engine_; static constexpr gpu_arch arch_tag = arch_tag_; - static_assert( - !(mma_engine == mma_engine::xmx && arch_tag == gpu_arch::XeLpg), - "XeLpg does not support xmx"); + static_assert(!(arch_tag == gpu_arch::XeLpg), "XeLpg does not support xmx"); static constexpr bool is_int4_matB_policy = true; @@ -89,14 +88,68 @@ struct compute_policy_int4_dequantize< static constexpr uint32_t block_size_y_a = 16; using mma_attr = mma_attr_t; - static constexpr uint32_t block_bytes_x_a = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 256; + static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; + static constexpr uint32_t block_size_x_a = + block_bytes_x_a / sizeof(dtype_mma_a); + static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem; + static constexpr uint32_t block_bytes_y_b = mma_attr::mma_k_in_bytes; + static constexpr uint32_t block_size_y_b = + block_bytes_y_b / sizeof(dtype_mma_b); + + static_assert( + block_bytes_x_a == block_bytes_y_b, + "mat_a x need to match with mat_b y"); +}; + +/// @brief Specialized for fpu engine. +template < + typename compute_attr_, + typename perf_tuning_knob_, + typename dtype_scale_, + typename dtype_zero_pt_, + quant_mode quant_type_, + int dequant_s_, + mma_engine mma_engine_, + gpu_arch arch_tag_> +struct compute_policy_int4_dequantize< + compute_attr_, + perf_tuning_knob_, + dtype_scale_, + dtype_zero_pt_, + quant_type_, + dequant_s_, + mma_engine_, + arch_tag_, + std::enable_if_t<( + arch_tag_ <= gpu_arch::XeHpc && mma_engine_ == mma_engine::fpu)>> { + using compute_attr = compute_attr_; + using dtype_mma_acc = typename compute_attr::dtype_acc; + using dtype_mma_a = typename compute_attr::dtype_a; + using dtype_mma_b = typename compute_attr::dtype_b; + + using perf_tuning_knob = perf_tuning_knob_; + static constexpr int stages = perf_tuning_knob::stages; + static constexpr int sync_freq = perf_tuning_knob::sync_freq; + static constexpr int k_stride = perf_tuning_knob::k_stride; + static constexpr mma_engine mma_engine = mma_engine_; + static constexpr gpu_arch arch_tag = arch_tag_; + + static constexpr bool is_int4_matB_policy = true; + + static constexpr uint32_t dequant_s = dequant_s_; + static_assert( + (dequant_s % (32 / sizeof(dtype_mma_b))) == 0, + "dequant_s should be a multiply of 32B"); + using dtype_scale = dtype_scale_; + using dtype_zero_pt = dtype_zero_pt_; + static constexpr quant_mode quant_type = quant_type_; + + static constexpr uint32_t block_size_y_a = 16; + static constexpr uint32_t block_bytes_x_a = 256; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_size_x_b = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_n_in_elem : 32; - static constexpr uint32_t block_bytes_y_b = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 256; + static constexpr uint32_t block_size_x_b = 32; + static constexpr uint32_t block_bytes_y_b = 256; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index bc044fe99..07c33344e 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -222,9 +222,6 @@ class gemm_t< static_assert( (k_stride % (block_size_y_b) == 0), "k_stride%(block_size_y_b) == 0"); - static_assert( - (dequant_s % (block_size_y_b) == 0), - "dequant_s%(block_size_y_b) == 0"); static_assert( (k_stride % (dequant_s) == 0) || (dequant_s % (k_stride) == 0), "k_stride should match with dequant_s"); @@ -596,8 +593,8 @@ class gemm_t< subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); // XETLA_PRINT(); // 2 32(K) 2 16(K) - // dump_mat(matB_acc); - // dump_mat(scale); + // dump_mat(matB_acc); + // dump_mat(scale); SW_BARRIER(); if constexpr ( is_col_major_b && compute_policy::mma_engine == mma_engine::fpu) { From 194ca35e644435ac38a5b781ebf124faa863fbc2 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Mon, 27 May 2024 20:23:47 +0800 Subject: [PATCH 06/34] support int scale col_major(with opt 10% perf when g = 32) --- CMakeLists.txt | 2 +- .../group/gemm/impl/int4_dequantize_xe.hpp | 56 +++++++++++-------- include/subgroup/tile/impl/payload_xe.hpp | 13 +++-- tests/integration/gemv/int4/main.cpp | 27 +++++---- 4 files changed, 59 insertions(+), 39 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e2213a9b3..6409d5a10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,7 +67,7 @@ else() endif() add_compile_options(-fsycl -fsycl-device-code-split=per_kernel -ftemplate-backtrace-limit=0) -add_compile_options(-Wall -Wextra -Werror) +add_compile_options(-Wall -Wextra) include(ProcessorCount) ProcessorCount(nproc) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 07c33344e..f93d9437d 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -239,29 +239,30 @@ class gemm_t< using mem_desc_scale_t = mem_desc_t< dtype_scale, - mem_layout::row_major, + mem_layout_b, mem_space::global, mem_desc_b_t::alignment>; + using mem_desc_zero_pt_t = mem_desc_t< dtype_zero_pt, - mem_layout::row_major, + mem_layout_b, mem_space::global, mem_desc_b_t::alignment>; - using matC_tile_desc_t = subgroup::tile_desc_t< // M X N (Y x X) - tile_size_x_c, // sg_n - tile_size_y_c, // sg_m == 1 - block_size_x_b, // - block_size_y_a, // == 1 + using matC_tile_desc_t = subgroup::tile_desc_t< + tile_size_x_c, + tile_size_y_c, + block_size_x_b, + block_size_y_a, reg_layout::tiled>; using matC_t = subgroup::tile_t; private: - using matAcc_tile_desc_t = subgroup::tile_desc_t< // N x K (Y x X) - block_size_y_b, // K - tile_size_x_b, // N - block_size_y_b, // K - block_size_x_b, // N + using matAcc_tile_desc_t = subgroup::tile_desc_t< + block_size_y_b, + tile_size_x_b, + block_size_y_b, + block_size_x_b, reg_layout::tiled>; using matAcc_t = subgroup::tile_t; using scale_tile_desc_t = subgroup::tile_desc_t< @@ -269,18 +270,31 @@ class gemm_t< tile_size_y_scale, block_size_x_b, block_size_y_scale, - reg_layout::tiled>; + reg_layout_b>; using scale_t = subgroup::tile_t; using scale_payload_t = subgroup::mem_payload_t< mem_desc_scale_t, scale_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + scale_tile_desc_t, + mem_space::global, + mem_desc_scale_t::layout>, arch_tag>; using zero_pt_tile_desc_t = std::conditional_t< is_col_major_b, - subgroup::tile_desc_t<16, 16, 16, 16, reg_layout::tiled>, - subgroup::tile_desc_t<16, 16, 16, 16, reg_layout::tiled>>; + subgroup::tile_desc_t< + tile_size_x_b, + (tile_size_y_zero_pt + pack_ratio - 1) / pack_ratio, + block_size_x_b, + (block_size_y_zero_pt + pack_ratio - 1) / pack_ratio, + reg_layout_b>, + subgroup::tile_desc_t< + (tile_size_x_b + pack_ratio - 1) / pack_ratio, + tile_size_y_zero_pt, + (block_size_x_b + pack_ratio - 1) / pack_ratio, + block_size_y_zero_pt, + reg_layout_b>>; using zero_pt_t = subgroup::tile_t; using zero_pt_payload_t = subgroup::mem_payload_t< @@ -513,9 +527,9 @@ class gemm_t< matB_prefetch_payload.template update_tdesc( matB_t::tile_size_y); if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { - scale_prefetch_payload.template update_tdesc( + scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); - zero_pt_prefetch_payload.template update_tdesc( + zero_pt_prefetch_payload.template update_tdesc( zero_pt_t::tile_size_y); } } @@ -566,8 +580,7 @@ class gemm_t< matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); if ((scale_load_addr_i % scale_addr_update_freq) == 0) { - scale_payload.template update_tdesc( - scale_t::tile_size_y); + scale_payload.template update_tdesc(scale_t::tile_size_y); zero_pt_payload.template update_tdesc( zero_pt_t::tile_size_y); } @@ -592,9 +605,6 @@ class gemm_t< } subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); - // XETLA_PRINT(); // 2 32(K) 2 16(K) - // dump_mat(matB_acc); - // dump_mat(scale); SW_BARRIER(); if constexpr ( is_col_major_b && compute_policy::mma_engine == mma_engine::fpu) { diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 70cbe2a0a..595919d4a 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -397,19 +397,20 @@ template < typename dtype_, typename tile_desc_, gpu_arch arch_tag_, - uint32_t alignment_> + uint32_t alignment_, + mem_layout memory_layout_> struct mem_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_, msg_type::block_1d, arch_tag_, std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { using mem_desc_t = - mem_desc_t; + mem_desc_t; using dtype = dtype_; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::global; - static constexpr mem_layout memory_layout = mem_layout::row_major; + static constexpr mem_layout memory_layout = memory_layout_; static constexpr msg_type message_type = msg_type::block_1d; static constexpr uint32_t alignment_in_bytes = mem_desc_t::alignment_in_bytes; static constexpr gpu_arch arch_tag = arch_tag_; @@ -427,7 +428,9 @@ struct mem_payload_t< mem_payload_t; public: - static constexpr uint32_t bytes_per_row = tile_size_x * sizeof(dtype); + static constexpr uint32_t bytes_per_row = + memory_layout == mem_layout::row_major ? tile_size_x * sizeof(dtype) + : tile_size_y * sizeof(dtype); using mem_dtype = typename std::conditional< (bytes_per_row % sizeof(uint64_t) == 0) && (alignment_in_bytes % sizeof(uint64_t) == 0), diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 12c269637..159e1a0ea 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -32,7 +32,7 @@ class test_col_major { static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024; - static constexpr size_t dequant_s = 128; + static constexpr size_t dequant_s = 32; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -127,7 +127,7 @@ std::vector dequantize_weight( int start_b_in = i * width + j; int start_zero_pt_in = start_b_in; - int start_scale_in = j / step * matrix_n + i; + int start_scale_in = start_b_in / step; int start_out = layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; @@ -143,6 +143,12 @@ std::vector dequantize_weight( } } } + // for (uint32_t i = 0; i < matrix_n; i++) { + // for (uint32_t j = 0; j < matrix_k; j++) { + // std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; + // } + // std::cout << std::endl; + // } return b_out; } @@ -175,7 +181,7 @@ void dequantize_gemv_run(int iter) { constexpr mem_layout layout_b = Test::layout_b; constexpr size_t size_a = matrix_m * matrix_k; - constexpr size_t size_b = matrix_k * matrix_n / 2; + constexpr size_t size_b = matrix_k * matrix_n / (2 * sizeof(data_type_b)); constexpr size_t size_scale_k = matrix_k / dequant_s; constexpr size_t size_scale_n = matrix_n; @@ -191,7 +197,8 @@ void dequantize_gemv_run(int iter) { uint32_t lda = layout_a == mem_layout::row_major ? matrix_k : matrix_m; uint32_t ldb = layout_b == mem_layout::row_major ? matrix_n : matrix_k; uint32_t ldc = matrix_n; - uint32_t ld_scale = size_scale_n; + uint32_t ld_scale = + layout_b == mem_layout::row_major ? size_scale_n : size_scale_k; // uint32_t ld_zero_pt = mem_layout::row_major ? size_zero_pt_n : // size_zero_pt_k; @@ -411,15 +418,15 @@ void dequantize_gemv_run(int iter) { epilogue_args); cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - if (!gemm_op_t::can_implement(gemm_arg)) { - std::cout << "The arguments cannot be supported, aborting ... " - << std::endl; - FAIL(); - } + // if (!gemm_op_t::can_implement(gemm_arg)) { + // std::cout << "The arguments cannot be supported, aborting ... " + // << std::endl; + // FAIL(); + // } size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); - int constexpr warm = 0; + int constexpr warm = 100; try { for (int i = 0; i < iter + warm; i++) { if (i >= warm) From 2bc4877be745b4c886d8c5e15497c8ca328bc00b Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Tue, 28 May 2024 03:43:33 +0800 Subject: [PATCH 07/34] support int4x8 for int32 weight --- CMakeLists.txt | 2 +- include/common/core/base_types.hpp | 40 +++ include/experimental/common/base_types.hpp | 49 --- include/experimental/common/common.hpp | 23 -- include/experimental/experimental.hpp | 1 - include/experimental/group/gemm/common.hpp | 1 - .../group/gemm/impl/int4_dequantize_xe.hpp | 327 +++++++++++------- include/experimental/kernel/gemm/common.hpp | 1 - .../int4_dequantization_bias/main_client.cpp | 2 +- tests/integration/gemv/int4/main.cpp | 63 ++-- tests/utils/common.hpp | 3 + 11 files changed, 278 insertions(+), 234 deletions(-) delete mode 100644 include/experimental/common/base_types.hpp delete mode 100644 include/experimental/common/common.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 6409d5a10..e2213a9b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,7 +67,7 @@ else() endif() add_compile_options(-fsycl -fsycl-device-code-split=per_kernel -ftemplate-backtrace-limit=0) -add_compile_options(-Wall -Wextra) +add_compile_options(-Wall -Wextra -Werror) include(ProcessorCount) ProcessorCount(nproc) diff --git a/include/common/core/base_types.hpp b/include/common/core/base_types.hpp index 33ed26b74..ee2a77a81 100644 --- a/include/common/core/base_types.hpp +++ b/include/common/core/base_types.hpp @@ -55,6 +55,32 @@ using fp16 = sycl::half; /// using tf32 = sycl::ext::intel::experimental::esimd::tfloat32; +/// @brief xetla 4bits data packed as 8bits data type. +/// 2 4bit data pack to one byte +struct int4x2 { + uint8_t data; + + operator uint8_t() const { + return data; + } + int4x2(uint8_t val) { + data = val; + } +}; + +/// @brief xetla 4bits data packed as 32bits data type. +/// 8 4bit data pack to 4 bytes +struct int4x8 { + uint32_t data; + + operator uint32_t() const { + return data; + } + int4x8(uint32_t val) { + data = val; + } +}; + /// @brief mx_fp4(E2M1) data packed as 8bits data type. struct mx_fp4 { uint8_t data; @@ -89,6 +115,8 @@ template struct is_internal_type { static constexpr bool value = std::is_same, bf16>::value || std::is_same, tf32>::value || + std::is_same, int4x2>::value || + std::is_same, int4x8>::value || std::is_same, mx_fp4>::value; }; template @@ -137,6 +165,18 @@ struct native_type { using type = uint8_t; }; +/// @brief Set uint8_t as the native data type of int4x2. +template <> +struct native_type { + using type = uint8_t; +}; + +/// @brief Set uint8_t as the native data type of int4x8. +template <> +struct native_type { + using type = uint32_t; +}; + /// @brief Return the native data type of T template using native_type_t = typename native_type::type; diff --git a/include/experimental/common/base_types.hpp b/include/experimental/common/base_types.hpp deleted file mode 100644 index 8755a8ce0..000000000 --- a/include/experimental/common/base_types.hpp +++ /dev/null @@ -1,49 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2022-2023 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -/// @file -/// C++ API - -#pragma once - -namespace gpu::xetla { - -/// @brief xetla 4bits data packed as 8bits data type. -/// 2 4bit data pack to one byte -struct int4x2 { - uint8_t data; - - operator uint8_t() const { - return data; - } - int4x2(uint8_t val) { - data = val; - } -}; - -/// @brief Used to check if the type is xetla internal data type -template <> -struct is_internal_type { - static constexpr bool value = true; -}; - -/// @brief Set uint8_t as the native data type of int4x2. -template <> -struct native_type { - using type = uint8_t; -}; - -} // namespace gpu::xetla diff --git a/include/experimental/common/common.hpp b/include/experimental/common/common.hpp deleted file mode 100644 index b1cc9d38d..000000000 --- a/include/experimental/common/common.hpp +++ /dev/null @@ -1,23 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2022-2023 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -/// @file -/// C++ API - -#pragma once - -#include -#include diff --git a/include/experimental/experimental.hpp b/include/experimental/experimental.hpp index 206768c7d..87f9883e0 100644 --- a/include/experimental/experimental.hpp +++ b/include/experimental/experimental.hpp @@ -19,7 +19,6 @@ #pragma once -#include #include #include #include \ No newline at end of file diff --git a/include/experimental/group/gemm/common.hpp b/include/experimental/group/gemm/common.hpp index 1c48e9e7c..73f8daff4 100644 --- a/include/experimental/group/gemm/common.hpp +++ b/include/experimental/group/gemm/common.hpp @@ -19,7 +19,6 @@ #pragma once -#include #include #include #include diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index f93d9437d..256540191 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -98,11 +98,14 @@ class gemm_t< using dtype_scale = typename compute_policy::dtype_scale; static_assert( - std::is_same, remove_const_t>::value, + std::is_same, remove_const_t>::value || + std::is_same, remove_const_t>::value, "this is for 4bit matB "); static_assert( std::is_same, remove_const_t>:: - value, + value || + std::is_same, remove_const_t>:: + value, "this is for 4bit zero_pt "); /******** set memory attribute **********/ @@ -221,7 +224,10 @@ class gemm_t< public: static_assert( (k_stride % (block_size_y_b) == 0), - "k_stride%(block_size_y_b) == 0"); + "k_stride % (block_size_y_b) == 0"); + static_assert( + (dequant_s % block_size_y_b == 0 || block_size_y_b % dequant_s == 0), + "dequant_s % block_size_y_b == 0 || block_size_y_b % dequant_s == 0"); static_assert( (k_stride % (dequant_s) == 0) || (dequant_s % (k_stride) == 0), "k_stride should match with dequant_s"); @@ -270,7 +276,7 @@ class gemm_t< tile_size_y_scale, block_size_x_b, block_size_y_scale, - reg_layout_b>; + is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled>; using scale_t = subgroup::tile_t; using scale_payload_t = subgroup::mem_payload_t< mem_desc_scale_t, @@ -288,13 +294,13 @@ class gemm_t< (tile_size_y_zero_pt + pack_ratio - 1) / pack_ratio, block_size_x_b, (block_size_y_zero_pt + pack_ratio - 1) / pack_ratio, - reg_layout_b>, + is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled>, subgroup::tile_desc_t< (tile_size_x_b + pack_ratio - 1) / pack_ratio, tile_size_y_zero_pt, (block_size_x_b + pack_ratio - 1) / pack_ratio, block_size_y_zero_pt, - reg_layout_b>>; + is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled>>; using zero_pt_t = subgroup::tile_t; using zero_pt_payload_t = subgroup::mem_payload_t< @@ -696,131 +702,190 @@ class gemm_t< } } } - // inline void dequantize( - // matB_acc_t& matB_acc, - // matB_t& matB, - // scale_t& scale, - // zero_pt_t& zero_pt) { - // // no tail, because this is matB - // constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; - // constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; - - // constexpr uint32_t block_b_y_per_scale = dequant_s / - // block_size_y_b; constexpr uint32_t block_b_x_per_scale = dequant_s - // / block_size_x_b; - // #pragma unroll - // for (uint32_t i = 0; i < num_block_y; ++i) { - // #pragma unroll - // for (uint32_t j = 0; j < num_block_x; ++j) { - // int block_id = (i * num_block_x + j); - // auto matB_blk = matB.reg - // .xetla_select( - // block_id * matB_t::block_elems) - // .xetla_format(); - // int scale_block_id = (i / block_b_y_per_scale * num_block_x + - // j); auto scale_vec = - // scale.reg.xetla_select( - // scale_block_id * scale_t::block_size_x); - // auto dst_blk = - // matB_acc.reg.xetla_select( - // block_id * matB_acc_t::block_elems); - - // // 2: int8 includes 2 4bits data. - // xetla_vector cvt_blk; - - // xetla_vector - // cvt_blk_i32; if constexpr (compute_policy::quant_type == - // quant_mode::S4_ASYM) { - // auto zero_pt_vec = zero_pt.reg - // .xetla_select( - // scale_block_id * - // zero_pt_t::block_size_x) - // .xetla_format(); - // cvt_blk.xetla_select(0) = matB_blk & - // 0x0f; cvt_blk.xetla_select(1) = - // matB_blk - // >> 4; xetla_vector zero_pt_sub; - // zero_pt_sub.xetla_select(0) = - // zero_pt_vec & 0x0f; - // zero_pt_sub.xetla_select(1) = - // zero_pt_vec - // >> 4; xetla_vector - // zero_pt_blk; - // #pragma unroll - // for (uint32_t row = 0; row < block_size_y_b; row++) { - // zero_pt_blk.xetla_select(row * - // block_size_x_b) - // .xetla_format() = - // zero_pt_sub.xetla_format() + int8_t(1); - // } - // cvt_blk_i32 = - // (cvt_blk.xetla_format() - - // zero_pt_blk.xetla_format()); - // } - // if constexpr ( - // compute_policy::quant_type == - // quant_mode::S4_FULLRANGE_NO_ZP) { - // xetla_vector - // cvt_blk_i8; cvt_blk_i8.xetla_select(0) = matB_blk & 0x0f; - // cvt_blk_i8.xetla_select(0) = - // cvt_blk_i8.xetla_select(0) << 4; - // cvt_blk_i8.xetla_select(0) = - // cvt_blk_i8.xetla_select(0) >> 4; - // cvt_blk_i8.xetla_select(1) = - // matB_blk.xetla_format() >> 4; - // cvt_blk_i32 = (cvt_blk_i8.xetla_format()); - // } - // if constexpr (compute_policy::mma_engine == mma_engine::xmx) { - // constexpr uint32_t vnni_rows = sizeof(uint32_t) / - // sizeof(dtype_mma_b); xetla_vector - // temp_blk; - // temp_blk.xetla_select(0) - // = - // cvt_blk_i32; - - // #pragma unroll - // for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { - // #pragma unroll - // for (uint32_t row = 0; row < vnni_rows; row++) { - // temp_blk.xetla_select( - // row + block_size_x_b * k * vnni_rows) = - // temp_blk.xetla_select( - // (k + row) * block_size_x_b * vnni_rows); - // } - // } - - // xetla_vector - // scale_blk; - // #pragma unroll - // for (uint32_t row = 0; row < vnni_rows; row++) { - // scale_blk.xetla_select(row) = - // scale_vec; - // } - - // #pragma unroll - // for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { - // dst_blk.xetla_select( - // k * block_size_x_b) = - // temp_blk.xetla_select( - // k * block_size_x_b * vnni_rows) * - // scale_blk; - // } - // } else { - // #pragma unroll - // for (uint32_t k = 0; k < block_size_y_b; k++) { - // dst_blk.xetla_select(k * block_size_x_b) - // = - // cvt_blk_i32.xetla_select( - // k * block_size_x_b) * - // scale_vec; - // } - // } - // } - // } - // } + /* + inline void dequantize( + matB_acc_t& matB_acc, + matB_t& matB, + scale_t& scale, + [[maybe_unused]] zero_pt_t& zero_pt) { + // no tail, because this is matB + constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; + constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; +#pragma unroll + for (uint32_t i = 0; i < num_block_y; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_block_x; ++j) { + int block_id = (i * num_block_x + j); + auto matB_blk = matB.reg + .xetla_select( + block_id * matB_t::block_elems) + .xetla_format>(); + + auto dst_blk = matB_acc.reg.xetla_select( + block_id * matB_acc_t::block_elems); + + // 2: int8 includes 2 4bits data. + using dtype_8bit = std::conditional_t< + compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP, + int8_t, + uint8_t>; + xetla_vector cvt_blk_8bit; + if constexpr ( + compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { +#pragma unroll + for (uint32_t shift = 0; shift < sizeof(dtype_b) * 2; shift++) { + auto dequant_8bit = + cvt_blk_8bit + .xetla_select( + shift); + dequant_8bit = (matB_blk & 0xf); + dequant_8bit = dequant_8bit << 4; + dequant_8bit = dequant_8bit >> 4; + matB_blk = matB_blk >> 4; + } + } + + constexpr uint32_t step = std::min(block_size_y_b, dequant_s); +#pragma unroll + for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { + for (uint32_t jj = 0; jj < block_size_x_b; jj++) { + uint32_t offset_y_in_tile = i * block_size_y_b + ii; + uint32_t offset_x_in_tile = j * block_size_x_b + jj; + + uint32_t scale_idx = + (offset_y_in_tile) / dequant_s * scale_t::block_size_x + + offset_x_in_tile; + // uint32_t scale_idx = + // (k + (i * num_block_x + j) * matB_acc_t::block_elems) / step; + + dst_blk.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_8bit.xetla_select(jj * block_size_y_b + ii) * + scale.reg.xetla_select<1, 1>(scale_idx); + + // sycl::ext::oneapi::experimental::printf( + // "scale[%d] %f \n", + // scale_idx, + // float(sycl::half(scale.reg.xetla_select<1, 1>(scale_idx)))); + } + } + } + } + } + */ + /* + inline void dequantize( + matB_acc_t & matB_acc, + matB_t & matB, + scale_t & scale, + zero_pt_t & zero_pt) { + // no tail, because this is matB + constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; + constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; + + constexpr uint32_t block_b_y_per_scale = dequant_s / block_size_y_b; + constexpr uint32_t block_b_x_per_scale = dequant_s / block_size_x_b; +#pragma unroll + for (uint32_t i = 0; i < num_block_y; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_block_x; ++j) { + int block_id = (i * num_block_x + j); + auto matB_blk = matB.reg + .xetla_select( + block_id * matB_t::block_elems) + .xetla_format(); + int scale_block_id = (i / block_b_y_per_scale * num_block_x + j); + auto scale_vec = scale.reg.xetla_select( + scale_block_id * scale_t::block_size_x); + auto dst_blk = matB_acc.reg.xetla_select( + block_id * matB_acc_t::block_elems); + + // 2: int8 includes 2 4bits data. + xetla_vector cvt_blk; + + xetla_vector cvt_blk_i32; + if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { + auto zero_pt_vec = zero_pt.reg + .xetla_select( + scale_block_id * zero_pt_t::block_size_x) + .xetla_format(); + cvt_blk.xetla_select(0) = matB_blk & 0x0f; + cvt_blk.xetla_select(1) = matB_blk >> 4; + xetla_vector zero_pt_sub; + zero_pt_sub.xetla_select(0) = + zero_pt_vec & 0x0f; + zero_pt_sub.xetla_select(1) = + zero_pt_vec >> 4; + xetla_vector zero_pt_blk; +#pragma unroll + for (uint32_t row = 0; row < block_size_y_b; row++) { + zero_pt_blk.xetla_select(row * block_size_x_b) + .xetla_format() = + zero_pt_sub.xetla_format() + int8_t(1); + } + cvt_blk_i32 = + (cvt_blk.xetla_format() - + zero_pt_blk.xetla_format()); + } + if constexpr ( + compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + xetla_vector cvt_blk_i8; + cvt_blk_i8.xetla_select(0) = + matB_blk & 0x0f; + cvt_blk_i8.xetla_select(0) = + cvt_blk_i8.xetla_select(0) << 4; + cvt_blk_i8.xetla_select(0) = + cvt_blk_i8.xetla_select(0) >> 4; + cvt_blk_i8.xetla_select(1) = + matB_blk.xetla_format() >> 4; + cvt_blk_i32 = (cvt_blk_i8.xetla_format()); + } + if constexpr (compute_policy::mma_engine == mma_engine::xmx) { + constexpr uint32_t vnni_rows = + sizeof(uint32_t) / sizeof(dtype_mma_b); + xetla_vector + temp_blk; + temp_blk.xetla_select(0) = + cvt_blk_i32; + +#pragma unroll + for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { +#pragma unroll + for (uint32_t row = 0; row < vnni_rows; row++) { + temp_blk.xetla_select( + row + block_size_x_b * k * vnni_rows) = + temp_blk.xetla_select( + (k + row) * block_size_x_b * vnni_rows); + } + } + + xetla_vector scale_blk; +#pragma unroll + for (uint32_t row = 0; row < vnni_rows; row++) { + scale_blk.xetla_select(row) = + scale_vec; + } + +#pragma unroll + for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { + dst_blk.xetla_select( + k * block_size_x_b) = + temp_blk.xetla_select( + k * block_size_x_b * vnni_rows) * + scale_blk; + } + } else { +#pragma unroll + for (uint32_t k = 0; k < block_size_y_b; k++) { + dst_blk.xetla_select(k * block_size_x_b) = + cvt_blk_i32.xetla_select( + k * block_size_x_b) * + scale_vec; + } + } + } + } + } */ + /// @brief Updates tile base descriptor based on the tid. __XETLA_API static void update_sg_tile_tdesc( arguments_t& args, diff --git a/include/experimental/kernel/gemm/common.hpp b/include/experimental/kernel/gemm/common.hpp index 367583687..2004ce0cc 100644 --- a/include/experimental/kernel/gemm/common.hpp +++ b/include/experimental/kernel/gemm/common.hpp @@ -19,7 +19,6 @@ #pragma once -#include #include #include #include diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index c979a19ef..e9d7a0fa9 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -16,7 +16,7 @@ #include #include "xetla.hpp" -#define UT_DEBUG 1 +// #define UT_DEBUG 1 using namespace gpu::xetla; // The number of times the kernel is executed constexpr int ITER = 200; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 159e1a0ea..066dd9d1f 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -87,19 +87,26 @@ template < typename data_type_b, typename data_type_scale, typename data_type_zero_pt> -std::tuple convert_int4( - data_type_b int4_data, +std::vector convert_int4( + data_type_b data_b, data_type_scale scale, [[maybe_unused]] data_type_zero_pt zero_pt) { - uint8_t data_even = (int4_data & 0x0f) << 4; - int8_t data_0; - int8_t data_1; - memcpy(&data_0, &data_even, 1); - memcpy(&data_1, &int4_data, 1); - data_0 = data_0 >> 4; - data_1 = data_1 >> 4; - return std::make_tuple(fp16(data_0) * scale, fp16(data_1) * scale); + std::vector dequant_fp16(sizeof(data_type_b) * 2); + + using dtype_8bit = std::conditional_t< + quant_type == gpu::xetla::group::quant_mode::S4_FULLRANGE_NO_ZP, + int8_t, + uint8_t>; + + for (uint32_t i = 0; i < dequant_fp16.size(); i++) { + dtype_8bit dequant_8bit; + dequant_8bit = static_cast((data_b & 0xf) << 4) >> 4; + dequant_fp16[i] = scale * dequant_8bit; + data_b = data_b >> 4; + } + return dequant_fp16; } + template < size_t dequant_s, mem_layout layout_b = mem_layout::row_major, @@ -133,22 +140,24 @@ std::vector dequantize_weight( layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; for (uint32_t jj = 0; jj < step; jj++) { - std::tie( - b_out[start_out + pack_radio * jj], - b_out[start_out + pack_radio * jj + 1]) = - convert_int4( - b[start_b_in + jj], - scale[start_scale_in], - zero_pt[start_zero_pt_in + jj]); + std::vector dequant_fp16 = convert_int4( + b[start_b_in + jj], + scale[start_scale_in], + zero_pt[start_zero_pt_in + jj]); + for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) { + b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj]; + } } } } - // for (uint32_t i = 0; i < matrix_n; i++) { - // for (uint32_t j = 0; j < matrix_k; j++) { - // std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; - // } - // std::cout << std::endl; - // } +#ifdef UT_DEBUG + for (uint32_t i = 0; i < matrix_n; i++) { + for (uint32_t j = 0; j < matrix_k; j++) { + std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; + } + std::cout << std::endl; + } +#endif return b_out; } @@ -171,7 +180,7 @@ void dequantize_gemv_run(int iter) { using data_type_a = typename Test::data_type_a; using data_type_b = typename Test::data_type_b; using data_type_c = typename Test::data_type_c; - using data_type_zero_pt = int4x2; + using data_type_zero_pt = int4x8; using data_type_scale = fp16; using data_type_acc_in = fp16; using data_type_acc = float; @@ -336,10 +345,12 @@ void dequantize_gemv_run(int iter) { } for (unsigned i = 0; i < size_b; ++i) { - B_h[i] = uint8_t(random_uint8()); + for (unsigned j = 0; j < sizeof(data_type_b); j++) { + B_h[i] = random_uint32(); #ifdef UT_DEBUG - B_h[i] = 17; + B_h[i] = 0x11; #endif + } } for (unsigned i = 0; i < size_scale; ++i) { diff --git a/tests/utils/common.hpp b/tests/utils/common.hpp index 432727252..4a79de7c4 100644 --- a/tests/utils/common.hpp +++ b/tests/utils/common.hpp @@ -36,6 +36,9 @@ using namespace gpu::xetla; #define random_float() (generate_real_random()) #define random_uint8() (generate_int_random(0, 255)) +#define random_uint32() \ + (generate_int_random(0, std::numeric_limits::max())) + template inline auto getTypeName() { fprintf(stderr, "FAIL: Not implemented specialization\n"); From 8b9df8b9911ad5f0073ed9f3c1a52116c2a17c68 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Tue, 28 May 2024 10:26:11 +0800 Subject: [PATCH 08/34] Update include/experimental/group/gemm/compute_policy.hpp Co-authored-by: Meng, Hengyu --- include/experimental/group/gemm/compute_policy.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 274b2fce4..d1728764a 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -61,7 +61,7 @@ struct compute_policy_int4_dequantize< mma_engine_, arch_tag_, std::enable_if_t<( - arch_tag_ <= gpu_arch::XeHpc && mma_engine_ == mma_engine::xmx)>> { + mma_engine_ == mma_engine::xmx)>> { using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; From e8d3fbbbe27fa3ee5c73a86c94765d3e0b4a19d5 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Tue, 28 May 2024 10:33:51 +0800 Subject: [PATCH 09/34] Update include/experimental/group/gemm/compute_policy.hpp Co-authored-by: Meng, Hengyu --- include/experimental/group/gemm/compute_policy.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index d1728764a..4b94b5c1e 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -74,7 +74,7 @@ struct compute_policy_int4_dequantize< static constexpr mma_engine mma_engine = mma_engine_; static constexpr gpu_arch arch_tag = arch_tag_; - static_assert(!(arch_tag == gpu_arch::XeLpg), "XeLpg does not support xmx"); + static_assert(arch_has_xmx(), "XeLpg does not support xmx"); static constexpr bool is_int4_matB_policy = true; From b0621df20a0c1f88edd812181a8fd28859a31cc8 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Wed, 29 May 2024 04:31:07 +0800 Subject: [PATCH 10/34] save(perf bug with int4x8 load) --- .../group/gemm/impl/int4_dequantize_xe.hpp | 137 +++++++----------- include/subgroup/tile/impl/payload_xe.hpp | 4 +- tests/integration/gemv/int4/main.cpp | 17 ++- 3 files changed, 63 insertions(+), 95 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 256540191..4ab664300 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -656,95 +656,58 @@ class gemm_t< auto matB_blk = matB.reg .xetla_select( block_id * matB_t::block_elems) - .xetla_format(); + .xetla_format(); auto dst_blk = matB_acc.reg.xetla_select( block_id * matB_acc_t::block_elems); - // 2: int8 includes 2 4bits data. - xetla_vector cvt_blk_i32; - if constexpr ( - compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - xetla_vector cvt_blk_i8; - cvt_blk_i8.xetla_select(0) = matB_blk & 0x0f; - cvt_blk_i8.xetla_select(0) = - cvt_blk_i8.xetla_select(0) << 4; - cvt_blk_i8.xetla_select(0) = - cvt_blk_i8.xetla_select(0) >> 4; - cvt_blk_i8.xetla_select(1) = - matB_blk.xetla_format() >> 4; - cvt_blk_i32 = (cvt_blk_i8.xetla_format()); - } - constexpr uint32_t step = std::min(block_size_y_b, dequant_s); - -#pragma unroll - for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { - for (uint32_t jj = 0; jj < block_size_x_b; jj++) { - uint32_t offset_y_in_tile = i * block_size_y_b + ii; - uint32_t offset_x_in_tile = j * block_size_x_b + jj; - - uint32_t scale_idx = - (offset_y_in_tile) / dequant_s * scale_t::block_size_x + - offset_x_in_tile; - // uint32_t scale_idx = - // (k + (i * num_block_x + j) * matB_acc_t::block_elems) / step; - - dst_blk.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_i32.xetla_select(jj * block_size_y_b + ii) * - scale.reg.xetla_select<1, 1>(scale_idx); - - // sycl::ext::oneapi::experimental::printf( - // "scale[%d] %f \n", - // scale_idx, - // float(sycl::half(scale.reg.xetla_select<1, 1>(scale_idx)))); - } - } - } - } - } - /* - inline void dequantize( - matB_acc_t& matB_acc, - matB_t& matB, - scale_t& scale, - [[maybe_unused]] zero_pt_t& zero_pt) { - // no tail, because this is matB - constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; - constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; -#pragma unroll - for (uint32_t i = 0; i < num_block_y; ++i) { -#pragma unroll - for (uint32_t j = 0; j < num_block_x; ++j) { - int block_id = (i * num_block_x + j); - auto matB_blk = matB.reg - .xetla_select( - block_id * matB_t::block_elems) - .xetla_format>(); - - auto dst_blk = matB_acc.reg.xetla_select( - block_id * matB_acc_t::block_elems); - - // 2: int8 includes 2 4bits data. + // int8 includes 2 4bits data. using dtype_8bit = std::conditional_t< compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP, int8_t, uint8_t>; - xetla_vector cvt_blk_8bit; - if constexpr ( - compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + xetla_vector cvt_blk_i8; + #pragma unroll - for (uint32_t shift = 0; shift < sizeof(dtype_b) * 2; shift++) { - auto dequant_8bit = - cvt_blk_8bit - .xetla_select( - shift); - dequant_8bit = (matB_blk & 0xf); - dequant_8bit = dequant_8bit << 4; - dequant_8bit = dequant_8bit >> 4; - matB_blk = matB_blk >> 4; + for (uint32_t i8_offset = 0; i8_offset < pack_ratio; i8_offset += 2) { + uint32_t i4_offset = i8_offset / 2; + // lowest 4 bit + { + auto dequant_i8_low_4bit = + cvt_blk_i8.xetla_select( + i8_offset); + dequant_i8_low_4bit = + matB_blk.xetla_select( + i4_offset) & + 0xf; + // Only int8 needs to reserve the sign bit + if constexpr (std::is_same_v) { + dequant_i8_low_4bit = dequant_i8_low_4bit << 4; + dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; + } + } + // highest 4 bit + { + auto dequant_i8_high_4bit = + cvt_blk_i8.xetla_select( + i8_offset + 1); + if constexpr (std::is_same_v) { + dequant_i8_high_4bit = + matB_blk + .xetla_select( + i4_offset) + .xetla_format() >> + 4; + } else { + dequant_i8_high_4bit = + matB_blk.xetla_select( + i4_offset) >> + 4; + } } } + // int8 x scale = fp16 constexpr uint32_t step = std::min(block_size_y_b, dequant_s); #pragma unroll for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { @@ -759,7 +722,7 @@ class gemm_t< // (k + (i * num_block_x + j) * matB_acc_t::block_elems) / step; dst_blk.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_8bit.xetla_select(jj * block_size_y_b + ii) * + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * scale.reg.xetla_select<1, 1>(scale_idx); // sycl::ext::oneapi::experimental::printf( @@ -771,7 +734,7 @@ class gemm_t< } } } - */ + /* inline void dequantize( matB_acc_t & matB_acc, @@ -784,9 +747,9 @@ class gemm_t< constexpr uint32_t block_b_y_per_scale = dequant_s / block_size_y_b; constexpr uint32_t block_b_x_per_scale = dequant_s / block_size_x_b; -#pragma unroll + #pragma unroll for (uint32_t i = 0; i < num_block_y; ++i) { -#pragma unroll + #pragma unroll for (uint32_t j = 0; j < num_block_x; ++j) { int block_id = (i * num_block_x + j); auto matB_blk = matB.reg @@ -816,7 +779,7 @@ class gemm_t< zero_pt_sub.xetla_select(1) = zero_pt_vec >> 4; xetla_vector zero_pt_blk; -#pragma unroll + #pragma unroll for (uint32_t row = 0; row < block_size_y_b; row++) { zero_pt_blk.xetla_select(row * block_size_x_b) .xetla_format() = @@ -847,9 +810,9 @@ class gemm_t< temp_blk.xetla_select(0) = cvt_blk_i32; -#pragma unroll + #pragma unroll for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { -#pragma unroll + #pragma unroll for (uint32_t row = 0; row < vnni_rows; row++) { temp_blk.xetla_select( row + block_size_x_b * k * vnni_rows) = @@ -859,13 +822,13 @@ class gemm_t< } xetla_vector scale_blk; -#pragma unroll + #pragma unroll for (uint32_t row = 0; row < vnni_rows; row++) { scale_blk.xetla_select(row) = scale_vec; } -#pragma unroll + #pragma unroll for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { dst_blk.xetla_select( k * block_size_x_b) = @@ -874,7 +837,7 @@ class gemm_t< scale_blk; } } else { -#pragma unroll + #pragma unroll for (uint32_t k = 0; k < block_size_y_b; k++) { dst_blk.xetla_select(k * block_size_x_b) = cvt_blk_i32.xetla_select( diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 595919d4a..724c36d59 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1096,8 +1096,8 @@ struct mem_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = - mem_transpose ^ reg_transpose && !std::is_same_v; + static constexpr bool trans = mem_transpose ^ reg_transpose && + !(std::is_same_v || std::is_same_v); static constexpr bool mem_transform = (sizeof(dtype) < 4) && (register_layout == reg_layout::vnni_tiled || diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 066dd9d1f..6889d8c8e 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -32,7 +32,7 @@ class test_col_major { static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024; - static constexpr size_t dequant_s = 32; + static constexpr size_t dequant_s = 128; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -41,7 +41,7 @@ class test_col_major { static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; - using data_type_b = int4x2; + using data_type_b = int4x8; using data_type_c = fp16; }; @@ -180,7 +180,7 @@ void dequantize_gemv_run(int iter) { using data_type_a = typename Test::data_type_a; using data_type_b = typename Test::data_type_b; using data_type_c = typename Test::data_type_c; - using data_type_zero_pt = int4x8; + using data_type_zero_pt = data_type_b; using data_type_scale = fp16; using data_type_acc_in = fp16; using data_type_acc = float; @@ -345,10 +345,15 @@ void dequantize_gemv_run(int iter) { } for (unsigned i = 0; i < size_b; ++i) { - for (unsigned j = 0; j < sizeof(data_type_b); j++) { + if constexpr (std::is_same_v) { + B_h[i] = random_uint8(); +#ifdef UT_DEBUG + B_h[i] = 0x12; +#endif + } else if constexpr (std::is_same_v) { B_h[i] = random_uint32(); #ifdef UT_DEBUG - B_h[i] = 0x11; + B_h[i] = 0x01234567; #endif } } @@ -356,7 +361,7 @@ void dequantize_gemv_run(int iter) { for (unsigned i = 0; i < size_scale; ++i) { scale_h[i] = random_float(); #ifdef UT_DEBUG - scale_h[i] = i + 1; + scale_h[i] = 1; #endif } From 56be57ac114f4d468f13517f2191069f976e78af Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Wed, 29 May 2024 23:19:16 +0800 Subject: [PATCH 11/34] save --- include/common/core/arch_config.hpp | 4 +- .../group/gemm/impl/int4_dequantize_xe.hpp | 70 ++++++++++--------- .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 4 +- include/subgroup/tile/common.hpp | 20 +++--- include/subgroup/tile/impl/load_xe.hpp | 33 ++++----- include/subgroup/tile/impl/payload_xe.hpp | 51 ++++++-------- include/subgroup/tile/impl/store_xe.hpp | 35 ++++------ tests/integration/gemv/int4/main.cpp | 16 +++-- 8 files changed, 116 insertions(+), 117 deletions(-) diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 8c7c56463..3f3153090 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -93,8 +93,8 @@ inline constexpr bool arch_has_2d_load_store = template struct load_store_attr_t { - static constexpr uint32_t max_load_vec_len = 32; - static constexpr uint32_t max_store_vec_len = 32; + static constexpr uint32_t max_load_vec_len = 256; + static constexpr uint32_t max_store_vec_len = 256; static constexpr uint32_t max_prefetch_vec_len = 32; }; diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 4ab664300..1c321a9da 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -182,7 +182,7 @@ class gemm_t< using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, matA_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matA_acc_t = subgroup::tile_t; using matA_prefetch_payload_t = subgroup:: @@ -306,7 +306,10 @@ class gemm_t< using zero_pt_payload_t = subgroup::mem_payload_t< mem_desc_zero_pt_t, zero_pt_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + zero_pt_tile_desc_t, + mem_space::global, + mem_desc_zero_pt_t::layout>, arch_tag>; using scale_prefetch_payload_t = subgroup:: prefetch_payload_t; @@ -516,28 +519,28 @@ class gemm_t< for (uint32_t i = 0; i < stages; i++) { subgroup::tile_prefetch( matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); + // subgroup::tile_prefetch( + // matB_prefetch_payload); // TODO 1D prefetch need pack to U32/U64 - subgroup::tile_prefetch( - scale_prefetch_payload); - if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { - // TODO 1D prefetch need pack to U32/U64 - subgroup::tile_prefetch( - zero_pt_prefetch_payload); - } - scale_prefetch_addr_i++; + // subgroup::tile_prefetch( + // scale_prefetch_payload); + // if constexpr ( + // compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + // // TODO 1D prefetch need pack to U32/U64 + // subgroup::tile_prefetch( + // zero_pt_prefetch_payload); + // } + // scale_prefetch_addr_i++; matA_prefetch_payload.template update_tdesc( matA_t::tile_size_x); - matB_prefetch_payload.template update_tdesc( - matB_t::tile_size_y); - if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { - scale_prefetch_payload.template update_tdesc( - scale_t::tile_size_y); - zero_pt_prefetch_payload.template update_tdesc( - zero_pt_t::tile_size_y); - } + // matB_prefetch_payload.template update_tdesc( + // matB_t::tile_size_y); + // if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { + // scale_prefetch_payload.template update_tdesc( + // scale_t::tile_size_y); + // zero_pt_prefetch_payload.template update_tdesc( + // zero_pt_t::tile_size_y); + // } } for (uint32_t i = 0; i < args.inner_loop_count; i++) { @@ -569,18 +572,18 @@ class gemm_t< if constexpr (stages != 0) { subgroup::tile_prefetch( matA_prefetch_payload); - subgroup::tile_prefetch( - matB_prefetch_payload); - // TODO 1D prefetch need pack to U32/U64 - subgroup::tile_prefetch( - scale_prefetch_payload); - if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { - // TODO 1D prefetch need pack to U32/U64 - subgroup::tile_prefetch( - zero_pt_prefetch_payload); - } - scale_prefetch_addr_i++; + // subgroup::tile_prefetch( + // matB_prefetch_payload); + // // TODO 1D prefetch need pack to U32/U64 + // subgroup::tile_prefetch( + // scale_prefetch_payload); + // if constexpr ( + // compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + // // TODO 1D prefetch need pack to U32/U64 + // subgroup::tile_prefetch( + // zero_pt_prefetch_payload); + // } + // scale_prefetch_addr_i++; } SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); @@ -611,6 +614,7 @@ class gemm_t< } subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); + // dump_mat(matB_acc); SW_BARRIER(); if constexpr ( is_col_major_b && compute_policy::mma_engine == mma_engine::fpu) { diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index d2e905f78..b2a767cf9 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -622,8 +622,8 @@ class gemm_universal_t< int start_x_scale = start_n; int start_y_scale = start_k / dequant_s; - int start_x_zero_pt = start_n / pack_ratio; - int start_y_zero_pt = start_k / dequant_s; + int start_x_zero_pt = start_n; + int start_y_zero_pt = start_k / (dequant_s * pack_ratio); // set up arguments uint32_t gemm_slm_base = slm_base; diff --git a/include/subgroup/tile/common.hpp b/include/subgroup/tile/common.hpp index 0e2cda92d..8c851028f 100644 --- a/include/subgroup/tile/common.hpp +++ b/include/subgroup/tile/common.hpp @@ -132,23 +132,21 @@ __XETLA_API typename std::enable_if_t< base_len != 0 && payload_t::memory_space == mem_space::global> process_1d_tail(tile_t& tile, payload_t& payload, uint32_t offset) { using dtype = typename payload_t::dtype; - using mem_dtype = typename payload_t::mem_dtype; if constexpr (remained_len >= base_len) { uint32_t address_offset = offset * sizeof(dtype); - auto reg_sub = - tile.reg.xetla_select(offset); + auto reg_sub = tile.reg.xetla_select(offset); if constexpr (flag == process_flag::load) { - reg_sub.xetla_format() = - xetla_load_global( + reg_sub.xetla_format() = + xetla_load_global( payload.base_ptr, payload.base_offset + address_offset); } else { - xetla_store_global( + xetla_store_global( payload.base_ptr, payload.base_offset + address_offset, - reg_sub.xetla_format()); + reg_sub.xetla_format()); } process_1d_tail> 1), flag, L1, L2>( - tile, payload, offset + base_len * payload_t::scale_factor); + tile, payload, offset + base_len); } else { process_1d_tail> 1), flag, L1, L2>( tile, payload, offset); @@ -321,8 +319,10 @@ template < mem_layout memory_layout = mem_layout::row_major> struct msg_type_query { static constexpr msg_type value = memory_space == mem_space::global - ? (((tile_desc_::tile_size_y == 1) && - (memory_layout == mem_layout::row_major)) + ? (((tile_desc_::tile_size_y == 1 && + memory_layout == mem_layout::row_major) || + (tile_desc_::tile_size_x == 1 && + memory_layout == mem_layout::col_major)) ? msg_type::block_1d : msg_type::block_2d) : (((tile_desc_::tile_size_y == 1) && diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 6d13356c3..23071a9f2 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -34,9 +34,8 @@ struct check_load_type { (payload_t::message_type == msg_type::block_2d)); static constexpr bool is_global_block_1d = - ((payload_t::memory_space == mem_space::global) && - (tile_t::tile_size_y == 1) && (tile_t::block_size_y == 1) && - (payload_t::message_type == msg_type::block_1d)); + (payload_t::memory_space == mem_space::global && + payload_t::message_type == msg_type::block_1d); static constexpr bool is_global_unaligned_2d_xe = ((payload_t::memory_space == mem_space::global) && @@ -395,33 +394,31 @@ template < __XETLA_API typename std::enable_if_t< detail::check_load_type::is_global_block_1d> tile_load(tile_t& tile, payload_t& payload) { - using dtype = typename tile_t::dtype; - using load_dtype = typename payload_t::mem_dtype; - - static constexpr uint32_t tile_size_x = tile_t::tile_size_x; - static constexpr uint32_t scale_factor = payload_t::scale_factor; - static constexpr uint32_t load_len = tile_size_x / scale_factor; + using dtype = typename payload_t::dtype; + static constexpr uint32_t load_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_vec_len = load_store_attr::max_load_vec_len; + static constexpr uint32_t max_load_vec_elems = + max_load_vec_len / sizeof(dtype); - static constexpr uint32_t load_iter_steps = load_len / max_load_vec_len; - if constexpr (load_len >= max_load_vec_len) { + static constexpr uint32_t load_iter_steps = load_len / max_load_vec_elems; + if constexpr (load_len >= max_load_vec_elems) { #pragma unroll for (uint32_t i = 0; i < load_iter_steps; i++) { - uint32_t offset_x = i * max_load_vec_len * scale_factor; - auto reg_sub = - tile.reg.xetla_select(offset_x); + uint32_t offset_x = i * max_load_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); - reg_sub.xetla_format() = - xetla_load_global( + reg_sub.xetla_format() = + xetla_load_global( payload.base_ptr, payload.base_offset + address_offset); } } - constexpr uint32_t tail_len = load_len % max_load_vec_len; - uint32_t tail_offset = load_iter_steps * max_load_vec_len * scale_factor; + constexpr uint32_t tail_len = load_len % max_load_vec_elems * sizeof(dtype); + uint32_t tail_offset = load_iter_steps * max_load_vec_len; detail::process_1d_tail< tail_len, (max_load_vec_len >> 1), diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 724c36d59..90b63f3f3 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -407,7 +407,7 @@ struct mem_payload_t< std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { using mem_desc_t = mem_desc_t; - using dtype = dtype_; + using dtype = native_type_t; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = memory_layout_; @@ -422,36 +422,26 @@ struct mem_payload_t< static constexpr uint32_t tile_size_x = tile_desc::tile_size_x; static constexpr uint32_t tile_size_y = tile_desc::tile_size_y; static_assert( - tile_size_y == 1, - "For tile_size_y > 1 case, please use 2d block message! "); + (tile_size_y == 1 && memory_layout == mem_layout::row_major) || + (tile_size_x == 1 && memory_layout == mem_layout::col_major), + "For tile_size_y > 1 or tile_size_x > 1 case, please use 2d block message! "); using this_payload_t = mem_payload_t; + static constexpr bool mem_transpose = memory_layout == mem_layout::col_major; public: - static constexpr uint32_t bytes_per_row = - memory_layout == mem_layout::row_major ? tile_size_x * sizeof(dtype) - : tile_size_y * sizeof(dtype); - using mem_dtype = typename std::conditional< - (bytes_per_row % sizeof(uint64_t) == 0) && - (alignment_in_bytes % sizeof(uint64_t) == 0), - uint64_t, - typename std::conditional< - (bytes_per_row % sizeof(uint32_t) == 0), - uint32_t, - dtype>::type>::type; - - static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype); - uint64_t base_offset; - mem_dtype* base_ptr; + dtype* base_ptr; uint32_t pitch_in_bytes; inline mem_payload_t(mem_desc_t& mem_tdesc) { pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; - base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - base_ptr = (mem_dtype*)mem_tdesc.base.base; + base_offset = mem_transpose + ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) + : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + base_ptr = (dtype*)mem_tdesc.base.base; } inline mem_payload_t( @@ -464,16 +454,20 @@ struct mem_payload_t< pitch_in_bytes = surface_pitch * sizeof(dtype); uint32_t offset_x = surface_offset_x; uint32_t offset_y = surface_offset_y; - base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - base_ptr = (mem_dtype*)p; + base_offset = mem_transpose + ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) + : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + base_ptr = (dtype*)p; } __XETLA_API void init(mem_desc_t& mem_tdesc) { pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; - base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - base_ptr = (mem_dtype*)mem_tdesc.base.base; + base_offset = mem_transpose + ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) + : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + base_ptr = (dtype*)mem_tdesc.base.base; } __XETLA_API void init( @@ -486,8 +480,10 @@ struct mem_payload_t< pitch_in_bytes = surface_pitch * sizeof(dtype); uint32_t offset_x = surface_offset_x; uint32_t offset_y = surface_offset_y; - base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - base_ptr = (mem_dtype*)p; + base_offset = mem_transpose + ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) + : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + base_ptr = (dtype*)p; } inline mem_payload_t(const this_payload_t& rhs) { @@ -1071,8 +1067,7 @@ struct mem_payload_t< msg_type::block_2d, arch_tag_, std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpg)>> { - using dtype = - std::conditional_t, uint8_t, dtype_>; + using dtype = native_type_t; using mem_desc_t = mem_desc_t; using tile_desc = tile_desc_; diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 56196da6d..baa663076 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -34,9 +34,8 @@ struct check_store_type { (payload_t::message_type == msg_type::block_2d)); static constexpr bool is_global_block_1d_xe = - ((payload_t::memory_space == mem_space::global) && - (tile_t::tile_size_y == 1) && (tile_t::block_size_y == 1) && - (payload_t::message_type == msg_type::block_1d)); + (payload_t::memory_space == mem_space::global && + payload_t::message_type == msg_type::block_1d); static constexpr bool is_global_unaligned_2d_xe = (payload_t::memory_space == mem_space::global && @@ -280,36 +279,32 @@ template < __XETLA_API typename std::enable_if_t< detail::check_store_type::is_global_block_1d_xe> tile_store(tile_t& tile, payload_t& payload) { - using dtype = typename tile_t::dtype; - using store_dtype = typename payload_t::mem_dtype; - - static constexpr uint32_t tile_size_x = tile_t::tile_size_x; - static constexpr uint32_t scale_factor = payload_t::scale_factor; - - static constexpr uint32_t store_len = tile_size_x / scale_factor; - + using dtype = typename payload_t::dtype; + static constexpr uint32_t store_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_store_vec_len = load_store_attr::max_store_vec_len; + static constexpr uint32_t max_store_vec_elems = + max_store_vec_len / sizeof(dtype); - static constexpr uint32_t store_iter_steps = store_len / max_store_vec_len; - if constexpr (store_len >= max_store_vec_len) { + static constexpr uint32_t store_iter_steps = store_len / max_store_vec_elems; + if constexpr (store_len >= max_store_vec_elems) { #pragma unroll for (uint32_t i = 0; i < store_iter_steps; i++) { - uint32_t offset_x = i * max_store_vec_len * scale_factor; - auto reg_sub = - tile.reg.xetla_select(offset_x); + uint32_t offset_x = i * max_store_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); - xetla_store_global( + xetla_store_global( payload.base_ptr, payload.base_offset + address_offset, - reg_sub.xetla_format()); + reg_sub.xetla_format()); } } - constexpr uint32_t tail_len = store_len % max_store_vec_len; - uint32_t tail_offset = store_iter_steps * max_store_vec_len * scale_factor; + constexpr uint32_t tail_len = store_len % max_store_vec_elems * sizeof(dtype); + uint32_t tail_offset = store_iter_steps * max_store_vec_len; detail::process_1d_tail< tail_len, (max_store_vec_len >> 1), diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 6889d8c8e..9de4630bf 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -19,7 +19,11 @@ // #define UT_DEBUG using namespace gpu::xetla; // The number of times the kernel is executed +#ifdef UT_DEBUG +constexpr int ITER = 1; +#else constexpr int ITER = 200; +#endif class test_col_major { public: @@ -41,7 +45,7 @@ class test_col_major { static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; - using data_type_b = int4x8; + using data_type_b = int4x2; using data_type_c = fp16; }; @@ -197,7 +201,7 @@ void dequantize_gemv_run(int iter) { constexpr size_t size_scale = size_scale_k * size_scale_n; constexpr size_t size_zero_pt_k = matrix_k / dequant_s; - constexpr size_t size_zero_pt_n = matrix_n / 2; + constexpr size_t size_zero_pt_n = matrix_n / (2 * sizeof(data_type_zero_pt)); constexpr size_t size_zero_pt = size_zero_pt_k * size_zero_pt_n; constexpr size_t size_c = matrix_m * matrix_n; @@ -225,7 +229,7 @@ void dequantize_gemv_run(int iter) { using tile_shape = xetla::group::tile_shape_t; static constexpr uint32_t periodic_sync_interval = 0; - static constexpr uint32_t prefetch_distance = 0; + static constexpr uint32_t prefetch_distance = 1; using mem_desc_a_t = xetla::mem_desc_t< data_type_a, @@ -361,7 +365,7 @@ void dequantize_gemv_run(int iter) { for (unsigned i = 0; i < size_scale; ++i) { scale_h[i] = random_float(); #ifdef UT_DEBUG - scale_h[i] = 1; + scale_h[i] = i + 1; #endif } @@ -442,7 +446,11 @@ void dequantize_gemv_run(int iter) { size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); +#ifdef UT_DEBUG + int constexpr warm = 0; +#else int constexpr warm = 100; +#endif try { for (int i = 0; i < iter + warm; i++) { if (i >= warm) From 2b3717300a3e2c8b6a22a825c79361815064c4bc Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Thu, 30 May 2024 18:20:29 +0800 Subject: [PATCH 12/34] add first token UT --- .../group/gemm/compute_policy.hpp | 10 ++- .../group/gemm/impl/int4_dequantize_xe.hpp | 66 +++++++++--------- include/subgroup/tile/impl/payload_xe.hpp | 24 +++++-- include/subgroup/tile/impl/prefetch_xe.hpp | 69 ++++++++++--------- tests/integration/gemv/int4/main.cpp | 29 +++++++- 5 files changed, 118 insertions(+), 80 deletions(-) diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 4b94b5c1e..e0eb2bb72 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -60,8 +60,7 @@ struct compute_policy_int4_dequantize< dequant_s_, mma_engine_, arch_tag_, - std::enable_if_t<( - mma_engine_ == mma_engine::xmx)>> { + std::enable_if_t> { using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; @@ -120,8 +119,7 @@ struct compute_policy_int4_dequantize< dequant_s_, mma_engine_, arch_tag_, - std::enable_if_t<( - arch_tag_ <= gpu_arch::XeHpc && mma_engine_ == mma_engine::fpu)>> { + std::enable_if_t> { using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; @@ -144,11 +142,11 @@ struct compute_policy_int4_dequantize< using dtype_zero_pt = dtype_zero_pt_; static constexpr quant_mode quant_type = quant_type_; - static constexpr uint32_t block_size_y_a = 16; + static constexpr uint32_t block_size_y_a = 1; static constexpr uint32_t block_bytes_x_a = 256; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_size_x_b = 32; + static constexpr uint32_t block_size_x_b = 1; static constexpr uint32_t block_bytes_y_b = 256; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 1c321a9da..6d67648c2 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -519,28 +519,31 @@ class gemm_t< for (uint32_t i = 0; i < stages; i++) { subgroup::tile_prefetch( matA_prefetch_payload); - // subgroup::tile_prefetch( - // matB_prefetch_payload); + subgroup::tile_prefetch( + matB_prefetch_payload); // TODO 1D prefetch need pack to U32/U64 - // subgroup::tile_prefetch( - // scale_prefetch_payload); - // if constexpr ( - // compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { - // // TODO 1D prefetch need pack to U32/U64 - // subgroup::tile_prefetch( - // zero_pt_prefetch_payload); - // } - // scale_prefetch_addr_i++; + subgroup::tile_prefetch( + scale_prefetch_payload); + if constexpr ( + compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + // TODO 1D prefetch need pack to U32/U64 + subgroup::tile_prefetch( + zero_pt_prefetch_payload); + } + scale_prefetch_addr_i++; matA_prefetch_payload.template update_tdesc( matA_t::tile_size_x); - // matB_prefetch_payload.template update_tdesc( - // matB_t::tile_size_y); - // if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { - // scale_prefetch_payload.template update_tdesc( - // scale_t::tile_size_y); - // zero_pt_prefetch_payload.template update_tdesc( - // zero_pt_t::tile_size_y); - // } + matB_prefetch_payload.template update_tdesc( + matB_t::tile_size_y); + if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { + scale_prefetch_payload.template update_tdesc( + scale_t::tile_size_y); + if constexpr ( + compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + zero_pt_prefetch_payload.template update_tdesc( + zero_pt_t::tile_size_y); + } + } } for (uint32_t i = 0; i < args.inner_loop_count; i++) { @@ -572,18 +575,18 @@ class gemm_t< if constexpr (stages != 0) { subgroup::tile_prefetch( matA_prefetch_payload); - // subgroup::tile_prefetch( - // matB_prefetch_payload); - // // TODO 1D prefetch need pack to U32/U64 - // subgroup::tile_prefetch( - // scale_prefetch_payload); - // if constexpr ( - // compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { - // // TODO 1D prefetch need pack to U32/U64 - // subgroup::tile_prefetch( - // zero_pt_prefetch_payload); - // } - // scale_prefetch_addr_i++; + subgroup::tile_prefetch( + matB_prefetch_payload); + // TODO 1D prefetch need pack to U32/U64 + subgroup::tile_prefetch( + scale_prefetch_payload); + if constexpr ( + compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + // TODO 1D prefetch need pack to U32/U64 + subgroup::tile_prefetch( + zero_pt_prefetch_payload); + } + scale_prefetch_addr_i++; } SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); @@ -614,7 +617,6 @@ class gemm_t< } subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); - // dump_mat(matB_acc); SW_BARRIER(); if constexpr ( is_col_major_b && compute_policy::mma_engine == mma_engine::fpu) { diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 90b63f3f3..708fd365f 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1618,8 +1618,10 @@ struct prefetch_payload_t< arch_tag_, std::enable_if_t<( arch_tag_ <= gpu_arch::XeHpg && - (tile_size_y_ != 1 || block_size_y_ != 1))>> { - using dtype = dtype_; + ((tile_size_y_ != 1 || block_size_y_ != 1) || + ((tile_size_x_ != 1 || block_size_x_ != 1) && + reg_layout_ == reg_layout::transpose_tiled)))>> { + using dtype = native_type_t; using mem_desc_t = mem_desc_t; using tile_desc = tile_desc_t< @@ -1653,7 +1655,8 @@ struct prefetch_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose; + static constexpr bool trans = mem_transpose ^ reg_transpose && + !(std::is_same_v || std::is_same_v); using prefetch_dtype = typename std::conditional< (alignment_in_bytes % (sizeof(uint64_t)) == 0), @@ -1665,12 +1668,22 @@ struct prefetch_payload_t< static constexpr uint32_t pack_factor = sizeof(prefetch_dtype) / sizeof(dtype); + static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype); + static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype); + static constexpr uint32_t simd_channel = + ((tile_bytes % max_store_bytes) == 0 && + (block_bytes % max_store_bytes) == 0) + ? 32 + : 16; + static constexpr uint32_t num_channel = mem_transpose + ? (simd_channel >= block_size_x) ? block_size_x : simd_channel + : (simd_channel >= block_size_y) ? block_size_y + : simd_channel; + static constexpr uint32_t simd_exec_size = (mem_transpose ? block_size_y : block_size_x) >= pack_factor ? (mem_transpose ? block_size_y : block_size_x) / pack_factor : 1; - static constexpr uint32_t num_channel = - mem_transpose ? block_size_x : block_size_y; static constexpr uint32_t mem_tile_size_w = mem_transpose ? tile_size_y : tile_size_x; @@ -2125,6 +2138,7 @@ struct prefetch_payload_t< mem_desc_t; // CL aligned, so we can use uint64_t using prefetch_dtype = uint64_t; + static constexpr msg_type message_type = msg_type::block_1d; using tile_desc = tile_desc_t; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = mem_layout_; diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index 0f9e2d3bb..2bda68cab 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -28,17 +28,16 @@ namespace detail { template struct check_prefetch_type { static constexpr bool is_global_2d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y != 1)); + (payload_t::memory_space == mem_space::global && + payload_t::message_type == msg_type::block_2d); static constexpr bool is_global_block_1d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y == 1)); + (payload_t::memory_space == mem_space::global && + payload_t::message_type == msg_type::block_1d); static constexpr bool is_global_unaligned_2d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y != 1) && - (payload_t::message_type == msg_type::unaligned_2d)); + (payload_t::memory_space == mem_space::global && + payload_t::message_type == msg_type::unaligned_2d); static constexpr bool is_local = (payload_t::memory_space == mem_space::local); @@ -104,26 +103,25 @@ tile_prefetch(payload_t& payload) { #pragma unroll for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { uint32_t offset_x = j * tile_desc::block_size_x; -#pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; - sub_block_y += num_channel) { - uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + - (offset_y + sub_block_y) * sizeof(dtype) - : offset_x * sizeof(dtype) + - (offset_y + sub_block_y) * payload.pitch_in_bytes; + // #pragma unroll + // for (uint32_t sub_block_y = 0; sub_block_y < + // tile_desc::block_size_y; + // sub_block_y += num_channel) { + uint32_t address_offset = payload_t::mem_transpose + ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) + : offset_x * sizeof(dtype) + (offset_y + 0) * payload.pitch_in_bytes; - xetla_prefetch_global< - prefetch_dtype, - payload_t::simd_exec_size, - data_size::default_size, - L1, - L2, - payload_t::num_channel>( - payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, - 1); - } + xetla_prefetch_global< + prefetch_dtype, + payload_t::simd_exec_size, + data_size::default_size, + L1, + L2, + num_channel>( + payload.base_ptr, + payload.channel_offset + payload.base_offset + address_offset, + 1); + // } } } } @@ -150,7 +148,9 @@ tile_prefetch(payload_t& payload) { using prefetch_dtype = typename payload_t::prefetch_dtype; constexpr uint32_t prefetch_len = tile_desc::tile_size_x / payload_t::scale_factor; - constexpr uint32_t max_prefetch_in_bytes = load_store_attr_t::max_prefetch_vec_len; + constexpr uint32_t max_prefetch_in_bytes = + load_store_attr_t:: + max_prefetch_vec_len; if constexpr (prefetch_len >= max_prefetch_in_bytes) { #pragma unroll for (uint32_t j = 0; j < prefetch_len / max_prefetch_in_bytes; j++) { @@ -165,10 +165,11 @@ tile_prefetch(payload_t& payload) { } } constexpr uint32_t tail_len = prefetch_len % max_prefetch_in_bytes; - uint32_t tail_offset = - prefetch_len / max_prefetch_in_bytes * max_prefetch_in_bytes * payload_t::scale_factor; - detail::process_1d_tail( - payload, tail_offset); + uint32_t tail_offset = prefetch_len / max_prefetch_in_bytes * + max_prefetch_in_bytes * payload_t::scale_factor; + detail:: + process_1d_tail( + payload, tail_offset); } /// @brief Is prefetch data func. @@ -183,8 +184,8 @@ template < cache_hint L1 = cache_hint::cached, cache_hint L2 = cache_hint::cached, typename payload_t> -__XETLA_API typename std::enable_if_t< - detail::check_prefetch_type::is_local> -tile_prefetch([[maybe_unused]] payload_t& payload) {} +__XETLA_API + typename std::enable_if_t::is_local> + tile_prefetch([[maybe_unused]] payload_t& payload) {} } // namespace gpu::xetla::subgroup diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 9de4630bf..eb3b93816 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -25,7 +25,7 @@ constexpr int ITER = 1; constexpr int ITER = 200; #endif -class test_col_major { +class test_col_major_1 { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; @@ -48,6 +48,29 @@ class test_col_major { using data_type_b = int4x2; using data_type_c = fp16; }; +class test_col_major_2 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 32; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 1; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 1; + static constexpr size_t sg_k = 1024; + static constexpr size_t dequant_s = 128; + + static constexpr size_t local_kslicing = 1; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x2; + using data_type_c = fp16; +}; template < typename data_type_a, @@ -229,7 +252,7 @@ void dequantize_gemv_run(int iter) { using tile_shape = xetla::group::tile_shape_t; static constexpr uint32_t periodic_sync_interval = 0; - static constexpr uint32_t prefetch_distance = 1; + static constexpr uint32_t prefetch_distance = 0; using mem_desc_a_t = xetla::mem_desc_t< data_type_a, @@ -520,7 +543,7 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd); -using tests = ::testing::Types; +using tests = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemv_test_suite, From f973aa217f7e0704157325d98734ee03c3426778 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Thu, 30 May 2024 20:15:38 +0800 Subject: [PATCH 13/34] opt mma code --- .../group/gemm/impl/int4_dequantize_xe.hpp | 70 +++++++++---------- include/subgroup/tile/impl/fma_xe.hpp | 44 ++++++------ tests/integration/gemv/int4/main.cpp | 2 +- 3 files changed, 58 insertions(+), 58 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 6d67648c2..e1c77ae2c 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -88,6 +88,8 @@ class gemm_t< static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; static constexpr bool is_col_major_a = mem_layout_a == mem_layout::col_major; static constexpr bool is_col_major_b = mem_layout_b == mem_layout::col_major; + static constexpr bool is_gemv = is_col_major_b && + compute_policy::mma_engine == mma_engine::fpu && sg_tile_m == 1; private: /******** set data type **********/ @@ -134,21 +136,13 @@ class gemm_t< static constexpr uint32_t tile_size_y_c = sg_tile_m; static constexpr uint32_t block_size_x_a = - (compute_policy::block_size_x_a > tile_size_x_a) - ? tile_size_x_a - : compute_policy::block_size_x_a; + std::min(compute_policy::block_size_x_a, tile_size_x_a); static constexpr uint32_t block_size_y_a = - (compute_policy::block_size_y_a > tile_size_y_a) - ? tile_size_y_a - : compute_policy::block_size_y_a; + std::min(compute_policy::block_size_y_a, tile_size_y_a); static constexpr uint32_t block_size_x_b = - (compute_policy::block_size_x_b > tile_size_x_b) - ? tile_size_x_b - : compute_policy::block_size_x_b; + std::min(compute_policy::block_size_x_b, tile_size_x_b); static constexpr uint32_t block_size_y_b = - (compute_policy::block_size_y_b > tile_size_y_b) - ? tile_size_y_b - : compute_policy::block_size_y_b; + std::min(compute_policy::block_size_y_b, tile_size_y_b); /******** set tile **********/ static constexpr bool is_vnni_tiled_a = @@ -156,21 +150,21 @@ class gemm_t< ? ((sizeof(dtype_a) < sizeof(uint32_t)) && is_col_major_a) : false; - static constexpr reg_layout reg_layout_a_ = + static constexpr reg_layout reg_layout_a = + // fpu compute_policy::mma_engine == mma_engine::fpu - ? reg_layout::transpose_tiled + ? (is_gemv ? reg_layout::tiled : reg_layout::transpose_tiled) + // xmx : is_vnni_tiled_a ? reg_layout::vnni_tiled : reg_layout::tiled; - static constexpr reg_layout reg_layout_b_ = - compute_policy::mma_engine == mma_engine::fpu ? reg_layout::tiled - : (sizeof(dtype_mma_b) < sizeof(uint32_t)) ? reg_layout::vnni_tiled - : reg_layout::tiled; - - static constexpr reg_layout reg_layout_a = - is_col_major_b ? reg_layout_b_ : reg_layout_a_; static constexpr reg_layout reg_layout_b = - is_col_major_b ? reg_layout_a_ : reg_layout_b_; + // fpu + compute_policy::mma_engine == mma_engine::fpu + ? (is_gemv ? reg_layout::transpose_tiled : reg_layout::tiled) + // xmx + : (sizeof(dtype_mma_b) < sizeof(uint32_t)) ? reg_layout::vnni_tiled + : reg_layout::tiled; using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, @@ -192,12 +186,14 @@ class gemm_t< // note: 4bit x 2, row-major using matB_tile_desc_t = std::conditional_t< is_col_major_b, + // compress int4 along K dimensions subgroup::tile_desc_t< tile_size_x_b, tile_size_y_b / pack_ratio, block_size_x_b, block_size_y_b / pack_ratio, reg_layout_b>, + // compress int4 along N dimensions subgroup::tile_desc_t< tile_size_x_b / pack_ratio, tile_size_y_b, @@ -289,12 +285,14 @@ class gemm_t< using zero_pt_tile_desc_t = std::conditional_t< is_col_major_b, + // compress int4 along K dimensions subgroup::tile_desc_t< tile_size_x_b, (tile_size_y_zero_pt + pack_ratio - 1) / pack_ratio, block_size_x_b, (block_size_y_zero_pt + pack_ratio - 1) / pack_ratio, is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled>, + // compress int4 along N dimensions subgroup::tile_desc_t< (tile_size_x_b + pack_ratio - 1) / pack_ratio, tile_size_y_zero_pt, @@ -317,14 +315,8 @@ class gemm_t< prefetch_payload_t; using tile_mma = std::conditional_t< - is_col_major_b, - subgroup::tile_fma_t< - matC_t, - matC_t, - matAcc_t, - matB_acc_t, - matA_acc_t, - arch_tag>, + is_gemv, + subgroup::tile_fma_t, subgroup::tile_mma_t< matC_t, matC_t, @@ -618,10 +610,18 @@ class gemm_t< subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); SW_BARRIER(); - if constexpr ( - is_col_major_b && compute_policy::mma_engine == mma_engine::fpu) { - tile_mma::fma(matAcc, matAcc, matB_acc, matA_acc); + if constexpr (is_gemv) { + tile_mma::mma( + matAcc, + matAcc, + matC, + matB_acc, + matA_acc, + i == args.inner_loop_count - 1); } else { + if constexpr (is_col_major_b) { + tile_transpose(matB_acc); + } tile_mma::mma(matC, matC, matB_acc, matA_acc); } SW_BARRIER(); @@ -639,10 +639,6 @@ class gemm_t< } } SW_BARRIER(); - if constexpr ( - is_col_major_b && compute_policy::mma_engine == mma_engine::fpu) { - tile_mma::reduce_acc_k(matAcc, matC); - } } private: diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index 77fca6b9c..df352311f 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -25,22 +25,19 @@ namespace gpu::xetla::subgroup { /// @brief Is the tile mma operation functor, specialized for Xe and fpu engine. template < - typename matDst_t_, - typename matSrc_t_, typename matAcc_t_, + typename matC_t_, typename matB_t_, typename matA_t_, gpu_arch arch_tag_> struct tile_fma_t { using matA_t = matA_t_; using matB_t = matB_t_; - using matSrc_t = matSrc_t_; - using matDst_t = matDst_t_; + using matC_t = matC_t_; using matAcc_t = matAcc_t_; using dtype_a = typename matA_t::dtype; using dtype_b = typename matB_t::dtype; - using dtype_src = typename matSrc_t::dtype; - using dtype_dst = typename matDst_t::dtype; + using dtype_acc = typename matAcc_t_::dtype; using register_attr = typename arch_attr_t::template register_attr<>; @@ -81,11 +78,13 @@ struct tile_fma_t { static_assert(tile_size_m == 1, "matA tile m must be 1"); static_assert(a_block_size_y == 1, "matA block m must be 1"); - __XETLA_API static void fma( - matAcc_t& dst, - matAcc_t& src, + __XETLA_API static void mma( + matAcc_t& acc_dst, + matAcc_t& acc_src, + matC_t& c, matB_t& b, - matA_t& a) { + matA_t& a, + bool reduce) { #pragma unroll for (uint32_t k = 0; k < tile_size_k / block_size_k; k++) { auto a_block = @@ -98,23 +97,28 @@ struct tile_fma_t { uint32_t src_dst_idx = n * block_size_n; auto src_block = - src.reg.xetla_select(src_dst_idx); + acc_src.reg.xetla_select( + src_dst_idx); auto dst_block = - dst.reg.xetla_select(src_dst_idx); + acc_dst.reg.xetla_select( + src_dst_idx); fma_core( dst_block, src_block, b_block, a_block); } } + if (reduce) { + reduce_acc_k(acc_dst, c); + } } template __XETLA_API static void fma_core( - xetla_vector_ref __REF__ dst, - xetla_vector_ref __REF__ src, + xetla_vector_ref __REF__ dst, + xetla_vector_ref __REF__ src, xetla_vector_ref __REF__ b_block, xetla_vector_ref __REF__ a_block) { static_assert(blk_m == 1, "block m must be 1"); - auto dst_blk_2d = dst.xetla_format(); - auto src_blk_2d = src.xetla_format(); + auto dst_blk_2d = dst.xetla_format(); + auto src_blk_2d = src.xetla_format(); auto b_blk_2d = b_block.xetla_format(); auto a_blk_2d = a_block.xetla_format(); #pragma unroll @@ -122,19 +126,19 @@ struct tile_fma_t { dst_blk_2d.row(n) = b_blk_2d.row(n) * a_blk_2d.row(0) + src_blk_2d.row(n); } } - __XETLA_API static void reduce_acc_k(matAcc_t& matAcc, matDst_t_& matC) { + __XETLA_API static void reduce_acc_k(matAcc_t& matAcc, matC_t& matC) { // matC [tx,ty,bx,by](matmul): tile_n, 1, block_n, 1 // matAcc[tx,ty,bx,by](matmul): tile_n, block_k, block_n, block_k // matAcc[tx,ty,bx,by](memory): block_k, tile_n, block_k, block_n static_assert( - matDst_t_::tile_size_y == 1 && matDst_t_::block_size_y == 1, + matC_t::tile_size_y == 1 && matC_t::block_size_y == 1, "matDst_t_ tile m and block m should match be 1"); static_assert( - matAcc_t::tile_size_y == matDst_t_::tile_size_x, + matAcc_t::tile_size_y == matC_t::tile_size_x, "matAcc_t tile n should match with matDst_t_ tile n"); static_assert( - matAcc_t::block_size_y == matDst_t_::block_size_x, + matAcc_t::block_size_y == matC_t::block_size_x, "matAcc_t block n should match with matDst_t_ block n"); static constexpr auto block_k = matAcc_t::block_size_x; static constexpr auto tile_n = matAcc_t::tile_size_y; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index eb3b93816..7a07e95a8 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -543,7 +543,7 @@ TYPED_TEST_P(dequantize_gemv_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd); -using tests = ::testing::Types; +using tests = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemv_test_suite, From 0f36c04db9050f9b783540cf9d1188d97a7cbf71 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Thu, 30 May 2024 22:08:18 +0800 Subject: [PATCH 14/34] opt perf for int4x8 --- .../group/gemm/impl/int4_dequantize_xe.hpp | 62 ++++++------------- tests/integration/gemv/int4/main.cpp | 2 +- 2 files changed, 21 insertions(+), 43 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index e1c77ae2c..5542a9881 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -655,10 +655,9 @@ class gemm_t< #pragma unroll for (uint32_t j = 0; j < num_block_x; ++j) { int block_id = (i * num_block_x + j); - auto matB_blk = matB.reg - .xetla_select( - block_id * matB_t::block_elems) - .xetla_format(); + auto matB_blk = matB.reg.xetla_format() + .xetla_select( + block_id * matB_acc_t::block_elems / 2); auto dst_blk = matB_acc.reg.xetla_select( block_id * matB_acc_t::block_elems); @@ -670,45 +669,24 @@ class gemm_t< uint8_t>; xetla_vector cvt_blk_i8; -#pragma unroll - for (uint32_t i8_offset = 0; i8_offset < pack_ratio; i8_offset += 2) { - uint32_t i4_offset = i8_offset / 2; - // lowest 4 bit - { - auto dequant_i8_low_4bit = - cvt_blk_i8.xetla_select( - i8_offset); - dequant_i8_low_4bit = - matB_blk.xetla_select( - i4_offset) & - 0xf; - // Only int8 needs to reserve the sign bit - if constexpr (std::is_same_v) { - dequant_i8_low_4bit = dequant_i8_low_4bit << 4; - dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; - } - } - // highest 4 bit - { - auto dequant_i8_high_4bit = - cvt_blk_i8.xetla_select( - i8_offset + 1); - if constexpr (std::is_same_v) { - dequant_i8_high_4bit = - matB_blk - .xetla_select( - i4_offset) - .xetla_format() >> - 4; - } else { - dequant_i8_high_4bit = - matB_blk.xetla_select( - i4_offset) >> - 4; - } - } + // lowest 4 bit + auto dequant_i8_low_4bit = + cvt_blk_i8.xetla_select(0); + dequant_i8_low_4bit = matB_blk & 0xf; + // Only int8 needs to reserve the sign bit + if constexpr (std::is_same_v) { + dequant_i8_low_4bit = dequant_i8_low_4bit << 4; + dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; } - + // highest 4 bit + auto dequant_i8_high_4bit = + cvt_blk_i8.xetla_select(1); + if constexpr (std::is_same_v) { + dequant_i8_high_4bit = matB_blk.xetla_format() >> 4; + } else { + dequant_i8_high_4bit = matB_blk >> 4; + } + // int8 x scale = fp16 constexpr uint32_t step = std::min(block_size_y_b, dequant_s); #pragma unroll diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 7a07e95a8..38c63c4ec 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -45,7 +45,7 @@ class test_col_major_1 { static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; - using data_type_b = int4x2; + using data_type_b = int4x8; using data_type_c = fp16; }; class test_col_major_2 { From d9902d80b3c35bdba6cab580fafb25b4ec4246d1 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 31 May 2024 19:54:28 +0800 Subject: [PATCH 15/34] support load one fp16 data --- include/common/core/memory.hpp | 8 ++++- .../group/gemm/impl/int4_dequantize_xe.hpp | 33 +++++++++++-------- include/subgroup/tile/chained_tile_op.hpp | 3 ++ include/subgroup/tile/impl/prefetch_xe.hpp | 13 ++++---- tests/integration/gemv/int4/main.cpp | 2 +- 5 files changed, 37 insertions(+), 22 deletions(-) diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index 0bc360d6e..8bd90eab4 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -355,7 +355,13 @@ __XETLA_API xetla_vector xetla_load_global( __ESIMD_NS::cache_hint_L1, __ESIMD_NS::cache_hint_L2, __ESIMD_NS::alignment}; - return __ESIMD_NS::block_load(ptr, byte_offset, props); + if constexpr (sizeof(T) * N < sizeof(uint32_t)) { + auto padding_load = __ESIMD_NS::block_load( + ptr, byte_offset, props); + return padding_load.xetla_select(0); + } else { + return __ESIMD_NS::block_load(ptr, byte_offset, props); + } } /// @brief Stateless scattered load. diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 5542a9881..be1a8a23b 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -655,6 +655,7 @@ class gemm_t< #pragma unroll for (uint32_t j = 0; j < num_block_x; ++j) { int block_id = (i * num_block_x + j); + // Must be little-endian auto matB_blk = matB.reg.xetla_format() .xetla_select( block_id * matB_acc_t::block_elems / 2); @@ -670,23 +671,27 @@ class gemm_t< xetla_vector cvt_blk_i8; // lowest 4 bit - auto dequant_i8_low_4bit = - cvt_blk_i8.xetla_select(0); - dequant_i8_low_4bit = matB_blk & 0xf; - // Only int8 needs to reserve the sign bit - if constexpr (std::is_same_v) { - dequant_i8_low_4bit = dequant_i8_low_4bit << 4; - dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; + { + auto dequant_i8_low_4bit = + cvt_blk_i8.xetla_select(0); + dequant_i8_low_4bit = matB_blk & 0xf; + // Only int8 needs to reserve the sign bit + if constexpr (std::is_same_v) { + dequant_i8_low_4bit = dequant_i8_low_4bit << 4; + dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; + } } // highest 4 bit - auto dequant_i8_high_4bit = - cvt_blk_i8.xetla_select(1); - if constexpr (std::is_same_v) { - dequant_i8_high_4bit = matB_blk.xetla_format() >> 4; - } else { - dequant_i8_high_4bit = matB_blk >> 4; + { + auto dequant_i8_high_4bit = + cvt_blk_i8.xetla_select(1); + if constexpr (std::is_same_v) { + dequant_i8_high_4bit = matB_blk.xetla_format() >> 4; + } else { + dequant_i8_high_4bit = matB_blk >> 4; + } } - + // int8 x scale = fp16 constexpr uint32_t step = std::min(block_size_y_b, dequant_s); #pragma unroll diff --git a/include/subgroup/tile/chained_tile_op.hpp b/include/subgroup/tile/chained_tile_op.hpp index 3d5698a01..294295082 100644 --- a/include/subgroup/tile/chained_tile_op.hpp +++ b/include/subgroup/tile/chained_tile_op.hpp @@ -71,6 +71,9 @@ struct chained_tile_op_arg_t inline chained_tile_op_arg_t( chained_tile_op_arg_t const& args) = default; + inline chained_tile_op_arg_t& operator=( + chained_tile_op_arg_t const& args) = + default; template inline T get() const { diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index 2bda68cab..0372f96cc 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -28,16 +28,17 @@ namespace detail { template struct check_prefetch_type { static constexpr bool is_global_2d = - (payload_t::memory_space == mem_space::global && - payload_t::message_type == msg_type::block_2d); + ((payload_t::memory_space == mem_space::global) && + (payload_t::tile_desc::tile_size_y != 1)); static constexpr bool is_global_block_1d = - (payload_t::memory_space == mem_space::global && - payload_t::message_type == msg_type::block_1d); + ((payload_t::memory_space == mem_space::global) && + (payload_t::tile_desc::tile_size_y == 1)); static constexpr bool is_global_unaligned_2d = - (payload_t::memory_space == mem_space::global && - payload_t::message_type == msg_type::unaligned_2d); + ((payload_t::memory_space == mem_space::global) && + (payload_t::tile_desc::tile_size_y != 1) && + (payload_t::message_type == msg_type::unaligned_2d)); static constexpr bool is_local = (payload_t::memory_space == mem_space::local); diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 38c63c4ec..38ade9e80 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -211,7 +211,7 @@ void dequantize_gemv_run(int iter) { using data_type_scale = fp16; using data_type_acc_in = fp16; using data_type_acc = float; - using data_type_bias = float; + using data_type_bias = data_type_a; constexpr mem_layout layout_a = Test::layout_a; constexpr mem_layout layout_b = Test::layout_b; From 30b8e950e401b63783f3b1967fa6222556b91638 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Sat, 1 Jun 2024 03:07:58 +0800 Subject: [PATCH 16/34] support zero_pt --- .../group/gemm/impl/int4_dequantize_xe.hpp | 62 ++++++--- include/subgroup/tile/impl/op_function.hpp | 2 +- tests/integration/gemv/int4/main.cpp | 120 ++++++++++++------ 3 files changed, 125 insertions(+), 59 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index be1a8a23b..f6d7ee9c1 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -239,6 +239,10 @@ class gemm_t< static constexpr uint32_t scale_addr_update_freq = (k_stride < dequant_s) ? dequant_s / k_stride : 1; + static constexpr uint32_t zero_pt_addr_update_freq = + (k_stride < dequant_s * pack_ratio) ? (dequant_s * pack_ratio) / k_stride + : 1; + using mem_desc_scale_t = mem_desc_t< dtype_scale, mem_layout_b, @@ -505,7 +509,7 @@ class gemm_t< nbarrier_role::producer_consumer); int scale_prefetch_addr_i = args.inner_loop_start; - int scale_load_addr_i = args.inner_loop_start; + int tile_k_idx = args.inner_loop_start; SW_BARRIER(); #pragma unroll for (uint32_t i = 0; i < stages; i++) { @@ -562,7 +566,7 @@ class gemm_t< subgroup::tile_load( zero_pt, zero_pt_payload); } - scale_load_addr_i++; + tile_k_idx++; SW_BARRIER(); if constexpr (stages != 0) { subgroup::tile_prefetch( @@ -583,10 +587,15 @@ class gemm_t< SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); - if ((scale_load_addr_i % scale_addr_update_freq) == 0) { + if (tile_k_idx % scale_addr_update_freq == 0) { scale_payload.template update_tdesc(scale_t::tile_size_y); - zero_pt_payload.template update_tdesc( - zero_pt_t::tile_size_y); + } + if constexpr ( + compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + if (tile_k_idx % zero_pt_addr_update_freq == 0) { + zero_pt_payload.template update_tdesc( + zero_pt_t::tile_size_y); + } } if constexpr (stages != 0) { matA_prefetch_payload.template update_tdesc( @@ -596,9 +605,12 @@ class gemm_t< if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); - zero_pt_prefetch_payload - .template update_tdesc( - zero_pt_t::tile_size_y); + if constexpr ( + compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + zero_pt_prefetch_payload + .template update_tdesc( + zero_pt_t::tile_size_y); + } } } SW_BARRIER(); @@ -609,6 +621,7 @@ class gemm_t< } subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); + dump_mat(matB_acc); SW_BARRIER(); if constexpr (is_gemv) { tile_mma::mma( @@ -667,7 +680,7 @@ class gemm_t< using dtype_8bit = std::conditional_t< compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP, int8_t, - uint8_t>; + int8_t>; xetla_vector cvt_blk_i8; // lowest 4 bit @@ -692,23 +705,38 @@ class gemm_t< } } - // int8 x scale = fp16 + // (b_i8 - zero_pt_i8) x scale = fp16 constexpr uint32_t step = std::min(block_size_y_b, dequant_s); #pragma unroll - for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { - for (uint32_t jj = 0; jj < block_size_x_b; jj++) { + for (uint32_t jj = 0; jj < block_size_x_b; jj++) { + for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { uint32_t offset_y_in_tile = i * block_size_y_b + ii; uint32_t offset_x_in_tile = j * block_size_x_b + jj; uint32_t scale_idx = (offset_y_in_tile) / dequant_s * scale_t::block_size_x + offset_x_in_tile; - // uint32_t scale_idx = - // (k + (i * num_block_x + j) * matB_acc_t::block_elems) / step; - dst_blk.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * - scale.reg.xetla_select<1, 1>(scale_idx); + if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { + uint32_t zero_pt_idx = offset_y_in_tile / + (dequant_s * pack_ratio) * zero_pt_t::block_size_x + + offset_x_in_tile; + native_type_t zero_pt_pack = zero_pt.reg[zero_pt_idx]; + + uint8_t zero_pt_u8 = + (zero_pt_pack >> (4 * (scale_idx % pack_ratio))) & 0xf; + + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - + zero_pt_u8; + dst_blk.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * + scale.reg[scale_idx]; + } else { + dst_blk.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * + scale.reg[scale_idx]; + } // sycl::ext::oneapi::experimental::printf( // "scale[%d] %f \n", diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index a4796815e..2ed2a93a8 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -713,7 +713,7 @@ void dump_mat( #pragma unroll for (size_t col = 0; col < tile_x; col++) { sycl::ext::oneapi::experimental::printf( - "%d ", (int)(sycl::half)mat.reg[row * tile_x + col]); + "%f ", (float)(sycl::half)mat.reg[row * tile_x + col]); } sycl::ext::oneapi::experimental::printf("\n"); } diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 38ade9e80..e44c24c58 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -16,8 +16,9 @@ #include #include "xetla.hpp" -// #define UT_DEBUG +#define UT_DEBUG using namespace gpu::xetla; +using namespace gpu::xetla::group; // The number of times the kernel is executed #ifdef UT_DEBUG constexpr int ITER = 1; @@ -29,14 +30,16 @@ class test_col_major_1 { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; + static constexpr size_t mat_n = 1; + static constexpr size_t mat_k = 128; static constexpr size_t wg_m = 1; static constexpr size_t wg_n = 1; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; - static constexpr size_t sg_k = 1024; - static constexpr size_t dequant_s = 128; + static constexpr size_t sg_k = 128; + static constexpr size_t dequant_s = 16; + static constexpr quant_mode quant_type = quant_mode::S4_ASYM; + // static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -45,7 +48,7 @@ class test_col_major_1 { static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; - using data_type_b = int4x8; + using data_type_b = int4x2; using data_type_c = fp16; }; class test_col_major_2 { @@ -108,8 +111,7 @@ int gemm_result_validate( } template < - gpu::xetla::group::quant_mode quant_type = - gpu::xetla::group::S4_FULLRANGE_NO_ZP, + quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -119,16 +121,20 @@ std::vector convert_int4( data_type_scale scale, [[maybe_unused]] data_type_zero_pt zero_pt) { std::vector dequant_fp16(sizeof(data_type_b) * 2); - using dtype_8bit = std::conditional_t< - quant_type == gpu::xetla::group::quant_mode::S4_FULLRANGE_NO_ZP, + quant_type == quant_mode::S4_FULLRANGE_NO_ZP, int8_t, uint8_t>; + uint8_t zero_pt_i8 = zero_pt & 0xf; for (uint32_t i = 0; i < dequant_fp16.size(); i++) { dtype_8bit dequant_8bit; dequant_8bit = static_cast((data_b & 0xf) << 4) >> 4; - dequant_fp16[i] = scale * dequant_8bit; + if constexpr (quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + dequant_fp16[i] = scale * dequant_8bit; + } else { + dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); + } data_b = data_b >> 4; } return dequant_fp16; @@ -137,8 +143,7 @@ std::vector convert_int4( template < size_t dequant_s, mem_layout layout_b = mem_layout::row_major, - gpu::xetla::group::quant_mode quant_type = - gpu::xetla::group::S4_FULLRANGE_NO_ZP, + quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -159,9 +164,8 @@ std::vector dequantize_weight( for (uint32_t i = 0; i < height; i++) { for (uint32_t j = 0; j < width; j += step) { int start_b_in = i * width + j; - int start_zero_pt_in = start_b_in; - int start_scale_in = start_b_in / step; + int start_zero_pt_in = start_scale_in / pack_radio; int start_out = layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; @@ -170,7 +174,7 @@ std::vector dequantize_weight( std::vector dequant_fp16 = convert_int4( b[start_b_in + jj], scale[start_scale_in], - zero_pt[start_zero_pt_in + jj]); + zero_pt[start_zero_pt_in] >> (4 * (start_scale_in % pack_radio))); for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) { b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj]; } @@ -204,6 +208,7 @@ void dequantize_gemv_run(int iter) { constexpr size_t sg_tile_n = Test::sg_n; constexpr size_t sg_tile_k = Test::sg_k; constexpr size_t dequant_s = Test::dequant_s; + constexpr quant_mode quant_type = Test::quant_type; using data_type_a = typename Test::data_type_a; using data_type_b = typename Test::data_type_b; using data_type_c = typename Test::data_type_c; @@ -224,7 +229,7 @@ void dequantize_gemv_run(int iter) { constexpr size_t size_scale = size_scale_k * size_scale_n; constexpr size_t size_zero_pt_k = matrix_k / dequant_s; - constexpr size_t size_zero_pt_n = matrix_n / (2 * sizeof(data_type_zero_pt)); + constexpr size_t size_zero_pt_n = matrix_n; constexpr size_t size_zero_pt = size_zero_pt_k * size_zero_pt_n; constexpr size_t size_c = matrix_m * matrix_n; @@ -236,8 +241,8 @@ void dequantize_gemv_run(int iter) { uint32_t ld_scale = layout_b == mem_layout::row_major ? size_scale_n : size_scale_k; - // uint32_t ld_zero_pt = mem_layout::row_major ? size_zero_pt_n : - // size_zero_pt_k; + uint32_t ld_zero_pt = + layout_b == mem_layout::row_major ? size_zero_pt_n : size_zero_pt_k; // Turn on the enable_profiling property to facilitate subsequent profiling sycl::property_list properties{ @@ -286,7 +291,7 @@ void dequantize_gemv_run(int iter) { perf_tuning_knob, data_type_scale, data_type_zero_pt, - gpu::xetla::group::quant_mode::S4_FULLRANGE_NO_ZP, + quant_type, dequant_s, Test::mma_eng, Test::arch>; @@ -375,12 +380,12 @@ void dequantize_gemv_run(int iter) { if constexpr (std::is_same_v) { B_h[i] = random_uint8(); #ifdef UT_DEBUG - B_h[i] = 0x12; + B_h[i] = 0x22; #endif } else if constexpr (std::is_same_v) { B_h[i] = random_uint32(); #ifdef UT_DEBUG - B_h[i] = 0x01234567; + B_h[i] = 0x22222222; #endif } } @@ -388,12 +393,23 @@ void dequantize_gemv_run(int iter) { for (unsigned i = 0; i < size_scale; ++i) { scale_h[i] = random_float(); #ifdef UT_DEBUG - scale_h[i] = i + 1; + scale_h[i] = 1; #endif } for (unsigned i = 0; i < size_zero_pt; ++i) { - zero_pt_h[i] = 0; + if constexpr (std::is_same_v) { + zero_pt_h[i] = random_uint8(); +#ifdef UT_DEBUG + zero_pt_h[i] = zero_pt_h[i] << 4 + i % 8; + zero_pt_h[i] = 0x33; +#endif + } else if constexpr (std::is_same_v) { + zero_pt_h[i] = random_uint32(); +#ifdef UT_DEBUG + zero_pt_h[i] = 0x01234567; +#endif + } } for (unsigned i = 0; i < size_c; ++i) { @@ -444,22 +460,44 @@ void dequantize_gemv_run(int iter) { {// epilogue_args init list // It accepts the base pointer to matrix D, and its dimensions {bias_d, bias_add_shape}}); - typename gemm_op_t::template arguments_t gemm_arg( - matrix_m, - matrix_k, - matrix_n, - A_d, - lda, - B_d, - ldb, - C_d, - ldc, - scale_d, - ld_scale, - Acc_d, - Cnt_d, - epilogue_args); - + typename gemm_op_t::template arguments_t gemm_arg; + if constexpr (compute_policy::quant_type == S4_FULLRANGE_NO_ZP) { + gemm_arg = + typename gemm_op_t::template arguments_t( + matrix_m, + matrix_k, + matrix_n, + A_d, + lda, + B_d, + ldb, + C_d, + ldc, + scale_d, + ld_scale, + Acc_d, + Cnt_d, + epilogue_args); + } else if constexpr (compute_policy::quant_type == S4_ASYM) { + gemm_arg = + typename gemm_op_t::template arguments_t( + matrix_m, + matrix_k, + matrix_n, + A_d, + lda, + B_d, + ldb, + C_d, + ldc, + scale_d, + ld_scale, + zero_pt_d, + ld_zero_pt, + Acc_d, + Cnt_d, + epilogue_args); + } cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); // if (!gemm_op_t::can_implement(gemm_arg)) { // std::cout << "The arguments cannot be supported, aborting ... " @@ -501,7 +539,7 @@ void dequantize_gemv_run(int iter) { prof.print_profiling_result(profiling_selector::GPU); // check result std::vector dequantize_b = - dequantize_weight( + dequantize_weight( matrix_k, matrix_n, B_h, scale_h, zero_pt_h); queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); From 885995fa6b8ab314ec103d44077ae1d849a5932f Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Mon, 3 Jun 2024 17:37:50 +0800 Subject: [PATCH 17/34] support ASYM and SYM --- .../group/gemm/impl/int4_dequantize_xe.hpp | 19 +++++++-------- tests/integration/gemv/int4/main.cpp | 24 +++++++++---------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index f6d7ee9c1..202ab637a 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -621,7 +621,6 @@ class gemm_t< } subgroup::elemwise_cvt(matA_acc, matA); dequantize(matB_acc, matB, scale, zero_pt); - dump_mat(matB_acc); SW_BARRIER(); if constexpr (is_gemv) { tile_mma::mma( @@ -677,11 +676,7 @@ class gemm_t< block_id * matB_acc_t::block_elems); // int8 includes 2 4bits data. - using dtype_8bit = std::conditional_t< - compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP, - int8_t, - int8_t>; - xetla_vector cvt_blk_i8; + xetla_vector cvt_blk_i8; // lowest 4 bit { @@ -689,7 +684,8 @@ class gemm_t< cvt_blk_i8.xetla_select(0); dequant_i8_low_4bit = matB_blk & 0xf; // Only int8 needs to reserve the sign bit - if constexpr (std::is_same_v) { + if constexpr ( + compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { dequant_i8_low_4bit = dequant_i8_low_4bit << 4; dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; } @@ -698,8 +694,9 @@ class gemm_t< { auto dequant_i8_high_4bit = cvt_blk_i8.xetla_select(1); - if constexpr (std::is_same_v) { - dequant_i8_high_4bit = matB_blk.xetla_format() >> 4; + if constexpr ( + compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + dequant_i8_high_4bit = matB_blk.xetla_format() >> 4; } else { dequant_i8_high_4bit = matB_blk >> 4; } @@ -723,12 +720,12 @@ class gemm_t< offset_x_in_tile; native_type_t zero_pt_pack = zero_pt.reg[zero_pt_idx]; - uint8_t zero_pt_u8 = + int8_t zero_pt_i8 = (zero_pt_pack >> (4 * (scale_idx % pack_ratio))) & 0xf; cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - - zero_pt_u8; + zero_pt_i8; dst_blk.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * scale.reg[scale_idx]; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index e44c24c58..498f712c6 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -16,7 +16,7 @@ #include #include "xetla.hpp" -#define UT_DEBUG +// #define UT_DEBUG using namespace gpu::xetla; using namespace gpu::xetla::group; // The number of times the kernel is executed @@ -30,14 +30,14 @@ class test_col_major_1 { public: // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 1; - static constexpr size_t mat_k = 128; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; static constexpr size_t wg_m = 1; static constexpr size_t wg_n = 1; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; - static constexpr size_t sg_k = 128; - static constexpr size_t dequant_s = 16; + static constexpr size_t sg_k = 1024; + static constexpr size_t dequant_s = 128; static constexpr quant_mode quant_type = quant_mode::S4_ASYM; // static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; @@ -48,7 +48,7 @@ class test_col_major_1 { static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; - using data_type_b = int4x2; + using data_type_b = int4x8; using data_type_c = fp16; }; class test_col_major_2 { @@ -121,16 +121,14 @@ std::vector convert_int4( data_type_scale scale, [[maybe_unused]] data_type_zero_pt zero_pt) { std::vector dequant_fp16(sizeof(data_type_b) * 2); - using dtype_8bit = std::conditional_t< - quant_type == quant_mode::S4_FULLRANGE_NO_ZP, - int8_t, - uint8_t>; - uint8_t zero_pt_i8 = zero_pt & 0xf; + int8_t zero_pt_i8 = zero_pt & 0xf; for (uint32_t i = 0; i < dequant_fp16.size(); i++) { - dtype_8bit dequant_8bit; - dequant_8bit = static_cast((data_b & 0xf) << 4) >> 4; + int8_t dequant_8bit; + dequant_8bit = data_b & 0xf; if constexpr (quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + dequant_8bit = dequant_8bit << 4; + dequant_8bit = dequant_8bit >> 4; dequant_fp16[i] = scale * dequant_8bit; } else { dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); From 7e99e687800076ffc3009e281f1a4654bc71b106 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Wed, 5 Jun 2024 00:55:05 +0800 Subject: [PATCH 18/34] save --- .../group/gemm/impl/int4_dequantize_xe.hpp | 21 ++++++++++--------- include/subgroup/tile/impl/op_function.hpp | 6 +++--- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 202ab637a..697645003 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -686,8 +686,8 @@ class gemm_t< // Only int8 needs to reserve the sign bit if constexpr ( compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - dequant_i8_low_4bit = dequant_i8_low_4bit << 4; - dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; + // dequant_i8_low_4bit = dequant_i8_low_4bit << 4; + // dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; } } // highest 4 bit @@ -696,7 +696,7 @@ class gemm_t< cvt_blk_i8.xetla_select(1); if constexpr ( compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - dequant_i8_high_4bit = matB_blk.xetla_format() >> 4; + dequant_i8_high_4bit = matB_blk >> 4; } else { dequant_i8_high_4bit = matB_blk >> 4; } @@ -726,14 +726,15 @@ class gemm_t< cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - zero_pt_i8; - dst_blk.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * - scale.reg[scale_idx]; - } else { - dst_blk.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * - scale.reg[scale_idx]; + } else if constexpr ( + compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - + int8_t(8); } + dst_blk.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * + scale.reg[scale_idx]; // sycl::ext::oneapi::experimental::printf( // "scale[%d] %f \n", diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 2ed2a93a8..613def4d7 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -713,11 +713,11 @@ void dump_mat( #pragma unroll for (size_t col = 0; col < tile_x; col++) { sycl::ext::oneapi::experimental::printf( - "%f ", (float)(sycl::half)mat.reg[row * tile_x + col]); + "%d ", (int)mat.reg[row * tile_x + col]); } - sycl::ext::oneapi::experimental::printf("\n"); + // sycl::ext::oneapi::experimental::printf("\n"); } - sycl::ext::oneapi::experimental::printf("\n "); + // sycl::ext::oneapi::experimental::printf("\n "); } template void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) { From 150f7d3b6f5f57dc7f224b337a0fde1f3872c10b Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Thu, 6 Jun 2024 17:54:32 +0800 Subject: [PATCH 19/34] ut improve --- tests/integration/gemv/int4/main.cpp | 53 ++++++++++++++++++---------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 498f712c6..3dccefefe 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -25,6 +25,7 @@ constexpr int ITER = 1; #else constexpr int ITER = 200; #endif +constexpr size_t UNDEFINED_DATA_SIZE = 1024; class test_col_major_1 { public: @@ -38,8 +39,8 @@ class test_col_major_1 { static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024; static constexpr size_t dequant_s = 128; - static constexpr quant_mode quant_type = quant_mode::S4_ASYM; - // static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; + // static constexpr quant_mode quant_type = quant_mode::S4_ASYM; + static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -127,9 +128,9 @@ std::vector convert_int4( int8_t dequant_8bit; dequant_8bit = data_b & 0xf; if constexpr (quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - dequant_8bit = dequant_8bit << 4; - dequant_8bit = dequant_8bit >> 4; - dequant_fp16[i] = scale * dequant_8bit; + // dequant_8bit = dequant_8bit << 4; + // dequant_8bit = dequant_8bit >> 4; + dequant_fp16[i] = scale * (dequant_8bit - 8); } else { dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); } @@ -323,25 +324,29 @@ void dequantize_gemv_run(int iter) { // Define and initialize the data required for the calculation auto* A_h = static_cast( malloc_host(size_a * sizeof(data_type_a), context)); - auto* B_h = static_cast( - malloc_host(size_b * sizeof(data_type_b), context)); + auto* B_h = static_cast(malloc_host( + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b), context)); auto* C_h = static_cast( malloc_host(size_c * sizeof(data_type_c), context)); auto* Acc_h = static_cast( malloc_host(size_acc * sizeof(data_type_acc), context)); auto* Cnt_h = static_cast(malloc_host(size_cnt * sizeof(uint32_t), context)); - auto* scale_h = static_cast( - malloc_host(size_scale * sizeof(data_type_scale), context)); - auto* zero_pt_h = static_cast( - malloc_host(size_zero_pt * sizeof(data_type_zero_pt), context)); + auto* scale_h = static_cast(malloc_host( + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale), context)); + auto* zero_pt_h = static_cast(malloc_host( + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt), + context)); auto* bias_h = static_cast( malloc_host(size_bias * sizeof(data_type_bias), context)); auto* A_d = static_cast(aligned_alloc_device( DEVICE_MEM_ALIGNMENT, size_a * sizeof(data_type_a), device, context)); auto* B_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_b * sizeof(data_type_b), device, context)); + DEVICE_MEM_ALIGNMENT, + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b), + device, + context)); auto* C_d = static_cast(aligned_alloc_device( DEVICE_MEM_ALIGNMENT, size_c * sizeof(data_type_c), device, context)); auto* Acc_d = static_cast(aligned_alloc_device( @@ -350,12 +355,12 @@ void dequantize_gemv_run(int iter) { DEVICE_MEM_ALIGNMENT, size_cnt * sizeof(uint32_t), device, context)); auto* scale_d = static_cast(aligned_alloc_device( DEVICE_MEM_ALIGNMENT, - size_scale * sizeof(data_type_scale), + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale), device, context)); auto* zero_pt_d = static_cast(aligned_alloc_device( DEVICE_MEM_ALIGNMENT, - size_zero_pt * sizeof(data_type_zero_pt), + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt), device, context)); auto* bias_d = static_cast(aligned_alloc_device( @@ -374,7 +379,7 @@ void dequantize_gemv_run(int iter) { #endif } - for (unsigned i = 0; i < size_b; ++i) { + for (unsigned i = 0; i < size_b + UNDEFINED_DATA_SIZE; ++i) { if constexpr (std::is_same_v) { B_h[i] = random_uint8(); #ifdef UT_DEBUG @@ -394,8 +399,11 @@ void dequantize_gemv_run(int iter) { scale_h[i] = 1; #endif } + for (unsigned i = size_scale; i < size_scale + UNDEFINED_DATA_SIZE; ++i) { + scale_h[i] = INFINITY; + } - for (unsigned i = 0; i < size_zero_pt; ++i) { + for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) { if constexpr (std::is_same_v) { zero_pt_h[i] = random_uint8(); #ifdef UT_DEBUG @@ -430,20 +438,27 @@ void dequantize_gemv_run(int iter) { } queue.memcpy((void*)A_d, (void*)A_h, size_a * sizeof(data_type_a)).wait(); - queue.memcpy((void*)B_d, (void*)B_h, size_b * sizeof(data_type_b)).wait(); + queue + .memcpy( + (void*)B_d, + (void*)B_h, + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b)) + .wait(); queue.memcpy((void*)C_d, (void*)C_h, size_c * sizeof(data_type_c)).wait(); queue.memcpy((void*)Acc_d, (void*)Acc_h, size_acc * sizeof(data_type_acc)) .wait(); queue.memcpy((void*)Cnt_d, (void*)Cnt_h, size_cnt * sizeof(uint32_t)).wait(); queue .memcpy( - (void*)scale_d, (void*)scale_h, size_scale * sizeof(data_type_scale)) + (void*)scale_d, + (void*)scale_h, + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale)) .wait(); queue .memcpy( (void*)zero_pt_d, (void*)zero_pt_h, - size_zero_pt * sizeof(data_type_zero_pt)) + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt)) .wait(); queue.memcpy((void*)bias_d, (void*)bias_h, size_bias * sizeof(data_type_bias)) .wait(); From ddbac9741919a1e002893c2b1ccd6663f9ce04a6 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 7 Jun 2024 01:04:01 +0800 Subject: [PATCH 20/34] support sg_n > 1 --- .../group/gemm/compute_policy.hpp | 2 +- .../group/gemm/impl/int4_dequantize_xe.hpp | 27 ++---- .../kernel/col_major_shuf/col_major_shuf.hpp | 2 +- include/subgroup/tile/impl/fma_xe.hpp | 91 +++++++++++-------- include/subgroup/tile/impl/load_xe.hpp | 2 +- include/subgroup/tile/impl/op_function.hpp | 4 +- include/subgroup/tile/impl/payload_xe.hpp | 22 +++-- tests/integration/gemv/int4/main.cpp | 30 +++--- 8 files changed, 101 insertions(+), 79 deletions(-) diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index e0eb2bb72..d3cd003c5 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -142,7 +142,7 @@ struct compute_policy_int4_dequantize< using dtype_zero_pt = dtype_zero_pt_; static constexpr quant_mode quant_type = quant_type_; - static constexpr uint32_t block_size_y_a = 1; + static constexpr uint32_t block_size_y_a = 4; static constexpr uint32_t block_bytes_x_a = 256; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 697645003..9dc381950 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -89,7 +89,7 @@ class gemm_t< static constexpr bool is_col_major_a = mem_layout_a == mem_layout::col_major; static constexpr bool is_col_major_b = mem_layout_b == mem_layout::col_major; static constexpr bool is_gemv = is_col_major_b && - compute_policy::mma_engine == mma_engine::fpu && sg_tile_m == 1; + compute_policy::mma_engine == mma_engine::fpu && sg_tile_m <= 4; private: /******** set data type **********/ @@ -266,9 +266,9 @@ class gemm_t< private: using matAcc_tile_desc_t = subgroup::tile_desc_t< block_size_y_b, - tile_size_x_b, + tile_size_y_a, block_size_y_b, - block_size_x_b, + block_size_y_a, reg_layout::tiled>; using matAcc_t = subgroup::tile_t; using scale_tile_desc_t = subgroup::tile_desc_t< @@ -680,26 +680,13 @@ class gemm_t< // lowest 4 bit { - auto dequant_i8_low_4bit = - cvt_blk_i8.xetla_select(0); - dequant_i8_low_4bit = matB_blk & 0xf; - // Only int8 needs to reserve the sign bit - if constexpr ( - compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - // dequant_i8_low_4bit = dequant_i8_low_4bit << 4; - // dequant_i8_low_4bit = dequant_i8_low_4bit >> 4; - } + cvt_blk_i8.xetla_select(0) = + matB_blk & 0xf; } // highest 4 bit { - auto dequant_i8_high_4bit = - cvt_blk_i8.xetla_select(1); - if constexpr ( - compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - dequant_i8_high_4bit = matB_blk >> 4; - } else { - dequant_i8_high_4bit = matB_blk >> 4; - } + cvt_blk_i8.xetla_select(1) = + matB_blk >> 4; } // (b_i8 - zero_pt_i8) x scale = fp16 diff --git a/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp b/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp index 084ebdc88..1cd3b7153 100644 --- a/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp +++ b/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp @@ -20,6 +20,6 @@ #pragma once #include +#include #include #include -#include diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index df352311f..cf15e12a8 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -73,11 +73,11 @@ struct tile_fma_t { b_block_size_y == matAcc_t::block_size_x, "matA block k should match with matAcc block k"); static_assert( - b_block_size_x == matAcc_t::block_size_y, - "matb block n should match with matAcc block n"); + a_block_size_y == matAcc_t::block_size_y, + "mata block m should match with matAcc block m"); - static_assert(tile_size_m == 1, "matA tile m must be 1"); - static_assert(a_block_size_y == 1, "matA block m must be 1"); + // static_assert(tile_size_m == 1, "matA tile m must be 1"); + // static_assert(a_block_size_y == 1, "matA block m must be 1"); __XETLA_API static void mma( matAcc_t& acc_dst, matAcc_t& acc_src, @@ -87,23 +87,26 @@ struct tile_fma_t { bool reduce) { #pragma unroll for (uint32_t k = 0; k < tile_size_k / block_size_k; k++) { - auto a_block = - a.reg.xetla_select(k * block_size_k); #pragma unroll - for (uint32_t n = 0; n < tile_size_n / block_size_n; n++) { - uint32_t b_block_idx = n * tile_size_k / block_size_k + k; - auto b_block = b.reg.xetla_select( - b_block_idx * matB_t::block_elems); + for (uint32_t m = 0; m < tile_size_m / block_size_m; m++) { + uint32_t a_block_idx = m * tile_size_k / block_size_k + k; + auto a_block = a.reg.xetla_select( + a_block_idx * matA_t::block_elems); +#pragma unroll + for (uint32_t n = 0; n < tile_size_n / block_size_n; n++) { + uint32_t b_block_idx = n * tile_size_k / block_size_k + k; + auto b_block = b.reg.xetla_select( + b_block_idx * matB_t::block_elems); - uint32_t src_dst_idx = n * block_size_n; - auto src_block = - acc_src.reg.xetla_select( - src_dst_idx); - auto dst_block = - acc_dst.reg.xetla_select( - src_dst_idx); - fma_core( - dst_block, src_block, b_block, a_block); + auto src_block = + acc_src.reg.xetla_select( + m * matAcc_t::block_elems); + auto dst_block = + acc_dst.reg.xetla_select( + m * matAcc_t::block_elems); + fma_core( + dst_block, src_block, b_block, a_block); + } } } if (reduce) { @@ -112,18 +115,24 @@ struct tile_fma_t { } template __XETLA_API static void fma_core( - xetla_vector_ref __REF__ dst, - xetla_vector_ref __REF__ src, + xetla_vector_ref __REF__ dst_block, + xetla_vector_ref __REF__ src_block, xetla_vector_ref __REF__ b_block, xetla_vector_ref __REF__ a_block) { - static_assert(blk_m == 1, "block m must be 1"); - auto dst_blk_2d = dst.xetla_format(); - auto src_blk_2d = src.xetla_format(); + static_assert(blk_n == 1, "block n must be 1"); + auto dst_blk_2d = dst_block.xetla_format(); + auto src_blk_2d = src_block.xetla_format(); auto b_blk_2d = b_block.xetla_format(); auto a_blk_2d = a_block.xetla_format(); + +#pragma unroll + for (uint32_t m = 0; m < blk_m; m++) { + auto a_row = a_blk_2d.row(m); #pragma unroll - for (uint32_t n = 0; n < blk_n; n++) { - dst_blk_2d.row(n) = b_blk_2d.row(n) * a_blk_2d.row(0) + src_blk_2d.row(n); + for (uint32_t n = 0; n < blk_n; n++) { + auto b_row = b_blk_2d.row(n); + dst_blk_2d.row(m) = b_row * a_row + src_blk_2d.row(m); + } } } __XETLA_API static void reduce_acc_k(matAcc_t& matAcc, matC_t& matC) { @@ -131,21 +140,29 @@ struct tile_fma_t { // matAcc[tx,ty,bx,by](matmul): tile_n, block_k, block_n, block_k // matAcc[tx,ty,bx,by](memory): block_k, tile_n, block_k, block_n + // static_assert( + // matC_t::tile_size_y == 1 && matC_t::block_size_y == 1, + // "matDst_t_ tile m and block m should match be 1"); static_assert( - matC_t::tile_size_y == 1 && matC_t::block_size_y == 1, - "matDst_t_ tile m and block m should match be 1"); - static_assert( - matAcc_t::tile_size_y == matC_t::tile_size_x, - "matAcc_t tile n should match with matDst_t_ tile n"); + matAcc_t::tile_size_y == matC_t::tile_size_y, + "matAcc_t tile m should match with matDst_t_ tile m"); static_assert( - matAcc_t::block_size_y == matC_t::block_size_x, - "matAcc_t block n should match with matDst_t_ block n"); - static constexpr auto block_k = matAcc_t::block_size_x; - static constexpr auto tile_n = matAcc_t::tile_size_y; + matAcc_t::block_size_y == matC_t::block_size_y, + "matAcc_t block m should match with matDst_t_ block m"); + static constexpr uint32_t block_k = matAcc_t::block_size_x; + static constexpr uint32_t block_m = matAcc_t::block_size_y; using dtype = matAcc_t::dtype; - matC.reg = - recur_col_reduce(matAcc.reg); +#pragma unroll + for (uint32_t m = 0; m < a_tile_size_y / a_block_size_y; m++) { + matC.reg.xetla_select(m * block_m) = + recur_col_reduce( + matAcc.reg.xetla_select( + m * block_m * block_k)); + // matC.reg = + // recur_col_reduce(matAcc.reg); + } } }; diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 23071a9f2..d3faab6b0 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -397,7 +397,7 @@ tile_load(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; static constexpr uint32_t load_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; - + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_vec_len = load_store_attr::max_load_vec_len; diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 613def4d7..67e115f16 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -715,9 +715,9 @@ void dump_mat( sycl::ext::oneapi::experimental::printf( "%d ", (int)mat.reg[row * tile_x + col]); } - // sycl::ext::oneapi::experimental::printf("\n"); + sycl::ext::oneapi::experimental::printf("\n"); } - // sycl::ext::oneapi::experimental::printf("\n "); + sycl::ext::oneapi::experimental::printf("\n "); } template void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) { diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 708fd365f..44a494b4a 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1616,11 +1616,10 @@ struct prefetch_payload_t< reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t<( + std::enable_if_t< arch_tag_ <= gpu_arch::XeHpg && - ((tile_size_y_ != 1 || block_size_y_ != 1) || - ((tile_size_x_ != 1 || block_size_x_ != 1) && - reg_layout_ == reg_layout::transpose_tiled)))>> { + ((block_size_y_ != 1 && reg_layout_ == reg_layout::tiled) || + (block_size_x_ != 1 && reg_layout_ == reg_layout::transpose_tiled))>> { using dtype = native_type_t; using mem_desc_t = mem_desc_t; @@ -2121,7 +2120,9 @@ struct prefetch_payload_t< template < typename dtype_, uint32_t tile_size_x_, + uint32_t tile_size_y_, uint32_t block_size_x_, + uint32_t block_size_y_, mem_layout mem_layout_, uint32_t alignment_, uint32_t num_coop_sg_, @@ -2129,10 +2130,19 @@ template < gpu_arch arch_tag_> struct prefetch_payload_t< mem_desc_t, - tile_desc_t, + tile_desc_t< + tile_size_x_, + tile_size_y_, + block_size_x_, + block_size_y_, + reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t< + ((tile_size_y_ == 1 || block_size_y_ == 1) && + reg_layout_ == reg_layout::tiled) || + ((tile_size_x_ == 1 || block_size_x_ == 1) && + reg_layout_ == reg_layout::transpose_tiled)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 3dccefefe..8e366f4d0 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -72,7 +72,7 @@ class test_col_major_2 { static constexpr mma_engine mma_eng = mma_engine::fpu; static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; - using data_type_b = int4x2; + using data_type_b = int4x8; using data_type_c = fp16; }; @@ -107,6 +107,14 @@ int gemm_result_validate( bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation"); +#ifdef UT_DEBUG + for (uint32_t i = 0; i < m; i++) { + for (uint32_t j = 0; j < n; j++) { + std::cout << float(sycl::half(C[i * n + j])) << " "; + } + std::cout << std::endl; + } +#endif std::cout << (!result ? "FAILED\n" : "PASSED\n"); return result ? 0 : 1; } @@ -180,14 +188,14 @@ std::vector dequantize_weight( } } } -#ifdef UT_DEBUG - for (uint32_t i = 0; i < matrix_n; i++) { - for (uint32_t j = 0; j < matrix_k; j++) { - std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; - } - std::cout << std::endl; - } -#endif +// #ifdef UT_DEBUG +// for (uint32_t i = 0; i < matrix_n; i++) { +// for (uint32_t j = 0; j < matrix_k; j++) { +// std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; +// } +// std::cout << std::endl; +// } +// #endif return b_out; } @@ -388,7 +396,7 @@ void dequantize_gemv_run(int iter) { } else if constexpr (std::is_same_v) { B_h[i] = random_uint32(); #ifdef UT_DEBUG - B_h[i] = 0x22222222; + B_h[i] = i % 128; #endif } } @@ -396,7 +404,7 @@ void dequantize_gemv_run(int iter) { for (unsigned i = 0; i < size_scale; ++i) { scale_h[i] = random_float(); #ifdef UT_DEBUG - scale_h[i] = 1; + scale_h[i] = 1.f; #endif } for (unsigned i = size_scale; i < size_scale + UNDEFINED_DATA_SIZE; ++i) { From d2aff4b63dd781118fa1c543d66b4e86326723b1 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 7 Jun 2024 17:59:44 +0800 Subject: [PATCH 21/34] add #pragma unroll --- .../group/gemm/impl/int4_dequantize_xe.hpp | 1 + tests/integration/gemv/int4/main.cpp | 26 +++++++++---------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 9dc381950..7fc563de0 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -693,6 +693,7 @@ class gemm_t< constexpr uint32_t step = std::min(block_size_y_b, dequant_s); #pragma unroll for (uint32_t jj = 0; jj < block_size_x_b; jj++) { +#pragma unroll for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { uint32_t offset_y_in_tile = i * block_size_y_b + ii; uint32_t offset_x_in_tile = j * block_size_x_b + jj; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 8e366f4d0..38b52ca63 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -188,14 +188,14 @@ std::vector dequantize_weight( } } } -// #ifdef UT_DEBUG -// for (uint32_t i = 0; i < matrix_n; i++) { -// for (uint32_t j = 0; j < matrix_k; j++) { -// std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; -// } -// std::cout << std::endl; -// } -// #endif + // #ifdef UT_DEBUG + // for (uint32_t i = 0; i < matrix_n; i++) { + // for (uint32_t j = 0; j < matrix_k; j++) { + // std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; + // } + // std::cout << std::endl; + // } + // #endif return b_out; } @@ -520,11 +520,11 @@ void dequantize_gemv_run(int iter) { epilogue_args); } cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - // if (!gemm_op_t::can_implement(gemm_arg)) { - // std::cout << "The arguments cannot be supported, aborting ... " - // << std::endl; - // FAIL(); - // } + if (!gemm_op_t::can_implement(gemm_arg)) { + std::cout << "The arguments cannot be supported, aborting ... " + << std::endl; + FAIL(); + } size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); From 97c2481abfa38cf4af6ce2cfda00a17fcb883cbd Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Sat, 8 Jun 2024 05:18:03 +0800 Subject: [PATCH 22/34] support HF zero pt layout K x N, compress int4 along N dimensions --- .../group/gemm/impl/int4_dequantize_xe.hpp | 58 ++++++++-------- .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 8 +-- include/subgroup/tile/impl/op_function.hpp | 3 +- tests/integration/gemv/int4/main.cpp | 67 +++++++++---------- 4 files changed, 68 insertions(+), 68 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 7fc563de0..5fd6cb465 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -240,8 +240,7 @@ class gemm_t< (k_stride < dequant_s) ? dequant_s / k_stride : 1; static constexpr uint32_t zero_pt_addr_update_freq = - (k_stride < dequant_s * pack_ratio) ? (dequant_s * pack_ratio) / k_stride - : 1; + (k_stride < dequant_s) ? dequant_s / k_stride : 1; using mem_desc_scale_t = mem_desc_t< dtype_scale, @@ -251,7 +250,7 @@ class gemm_t< using mem_desc_zero_pt_t = mem_desc_t< dtype_zero_pt, - mem_layout_b, + mem_layout::row_major, mem_space::global, mem_desc_b_t::alignment>; @@ -287,22 +286,13 @@ class gemm_t< mem_desc_scale_t::layout>, arch_tag>; - using zero_pt_tile_desc_t = std::conditional_t< - is_col_major_b, - // compress int4 along K dimensions - subgroup::tile_desc_t< - tile_size_x_b, - (tile_size_y_zero_pt + pack_ratio - 1) / pack_ratio, - block_size_x_b, - (block_size_y_zero_pt + pack_ratio - 1) / pack_ratio, - is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled>, - // compress int4 along N dimensions - subgroup::tile_desc_t< - (tile_size_x_b + pack_ratio - 1) / pack_ratio, - tile_size_y_zero_pt, - (block_size_x_b + pack_ratio - 1) / pack_ratio, - block_size_y_zero_pt, - is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled>>; + // compress int4 along N dimensions + using zero_pt_tile_desc_t = subgroup::tile_desc_t< + (tile_size_x_b + pack_ratio - 1) / pack_ratio, + tile_size_y_zero_pt, + (block_size_x_b + pack_ratio - 1) / pack_ratio, + block_size_y_zero_pt, + reg_layout::tiled>; using zero_pt_t = subgroup::tile_t; using zero_pt_payload_t = subgroup::mem_payload_t< @@ -332,6 +322,9 @@ class gemm_t< static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + uint32_t wg_start_m = 0; + uint32_t wg_start_n = 0; + uint32_t wg_start_k = 0; public: static constexpr uint32_t barrier_count = @@ -500,6 +493,10 @@ class gemm_t< zero_pt_prefetch_payload_t zero_pt_prefetch_payload( args.zero_pt_base_desc, 0); + wg_start_m = args.matA_base_desc.coord.y; + wg_start_n = args.scale_base_desc.coord.x; + wg_start_k = args.matA_base_desc.coord.x; + xetla_nbarrier_t nbarrier_a; nbarrier_a.init_nbarrier( sg_idy + nbarrier_base, nbarrier_role::producer_consumer); @@ -536,8 +533,9 @@ class gemm_t< scale_t::tile_size_y); if constexpr ( compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { - zero_pt_prefetch_payload.template update_tdesc( - zero_pt_t::tile_size_y); + zero_pt_prefetch_payload + .template update_tdesc( + zero_pt_t::tile_size_y); } } } @@ -593,7 +591,7 @@ class gemm_t< if constexpr ( compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { if (tile_k_idx % zero_pt_addr_update_freq == 0) { - zero_pt_payload.template update_tdesc( + zero_pt_payload.template update_tdesc( zero_pt_t::tile_size_y); } } @@ -658,7 +656,7 @@ class gemm_t< matB_acc_t& matB_acc, matB_t& matB, scale_t& scale, - [[maybe_unused]] zero_pt_t& zero_pt) { + zero_pt_t& zero_pt) { // no tail, because this is matB constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; @@ -703,13 +701,19 @@ class gemm_t< offset_x_in_tile; if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { - uint32_t zero_pt_idx = offset_y_in_tile / - (dequant_s * pack_ratio) * zero_pt_t::block_size_x + - offset_x_in_tile; + uint32_t zero_pt_idx = + offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + + offset_x_in_tile / pack_ratio; native_type_t zero_pt_pack = zero_pt.reg[zero_pt_idx]; int8_t zero_pt_i8 = - (zero_pt_pack >> (4 * (scale_idx % pack_ratio))) & 0xf; + (zero_pt_pack >> + (4 * ((wg_start_n + offset_x_in_tile) % pack_ratio))) & + 0xf; + // sycl::ext::oneapi::experimental::printf( + // "zero_pt.reg[%d} %x zero_pt_i8 %x offset_x_in_tile:%d + // \n", zero_pt_idx, zero_pt_pack, (int32_t)zero_pt_i8 , + // offset_x_in_tile); cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index b2a767cf9..bfc3bcd74 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -622,8 +622,8 @@ class gemm_universal_t< int start_x_scale = start_n; int start_y_scale = start_k / dequant_s; - int start_x_zero_pt = start_n; - int start_y_zero_pt = start_k / (dequant_s * pack_ratio); + int start_x_zero_pt = start_n / pack_ratio; + int start_y_zero_pt = start_k / dequant_s; // set up arguments uint32_t gemm_slm_base = slm_base; @@ -680,8 +680,8 @@ class gemm_universal_t< } else { mem_desc_zero_pt_t mem_desc_zero_pt( args.zero_pt_base, - {args.matrix_n / pack_ratio, - scale_size_y, + {(args.matrix_n + pack_ratio - 1) / pack_ratio, + ((args.matrix_k + dequant_s - 1) / dequant_s), args.zero_pt_ld / pack_ratio}, {start_x_zero_pt, start_y_zero_pt}); gemm_args = gemm_args_t( diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 67e115f16..567abcb61 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -713,7 +713,8 @@ void dump_mat( #pragma unroll for (size_t col = 0; col < tile_x; col++) { sycl::ext::oneapi::experimental::printf( - "%d ", (int)mat.reg[row * tile_x + col]); + "%x ", + int(native_type_t(mat.reg[row * tile_x + col]))); } sycl::ext::oneapi::experimental::printf("\n"); } diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 38b52ca63..20ffd664f 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -39,8 +39,8 @@ class test_col_major_1 { static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024; static constexpr size_t dequant_s = 128; - // static constexpr quant_mode quant_type = quant_mode::S4_ASYM; - static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; + static constexpr quant_mode quant_type = quant_mode::S4_ASYM; + // static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -128,16 +128,13 @@ template < std::vector convert_int4( data_type_b data_b, data_type_scale scale, - [[maybe_unused]] data_type_zero_pt zero_pt) { + data_type_zero_pt zero_pt) { std::vector dequant_fp16(sizeof(data_type_b) * 2); int8_t zero_pt_i8 = zero_pt & 0xf; for (uint32_t i = 0; i < dequant_fp16.size(); i++) { - int8_t dequant_8bit; - dequant_8bit = data_b & 0xf; + int8_t dequant_8bit = data_b & 0xf; if constexpr (quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - // dequant_8bit = dequant_8bit << 4; - // dequant_8bit = dequant_8bit >> 4; dequant_fp16[i] = scale * (dequant_8bit - 8); } else { dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); @@ -149,7 +146,7 @@ std::vector convert_int4( template < size_t dequant_s, - mem_layout layout_b = mem_layout::row_major, + mem_layout layout_b = mem_layout::col_major, quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP, typename data_type_acc_in = fp16, typename data_type_b, @@ -172,30 +169,29 @@ std::vector dequantize_weight( for (uint32_t j = 0; j < width; j += step) { int start_b_in = i * width + j; int start_scale_in = start_b_in / step; - int start_zero_pt_in = start_scale_in / pack_radio; - + int start_zero_pt_in = + (j / step) * (matrix_n / pack_radio) + i / pack_radio; int start_out = layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; - for (uint32_t jj = 0; jj < step; jj++) { std::vector dequant_fp16 = convert_int4( b[start_b_in + jj], scale[start_scale_in], - zero_pt[start_zero_pt_in] >> (4 * (start_scale_in % pack_radio))); + zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio))); for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) { b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj]; } } } } - // #ifdef UT_DEBUG - // for (uint32_t i = 0; i < matrix_n; i++) { - // for (uint32_t j = 0; j < matrix_k; j++) { - // std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; - // } - // std::cout << std::endl; - // } - // #endif +#ifdef UT_DEBUG + for (uint32_t i = 0; i < matrix_n; i++) { + for (uint32_t j = 0; j < matrix_k; j++) { + std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; + } + std::cout << std::endl; + } +#endif return b_out; } @@ -237,7 +233,8 @@ void dequantize_gemv_run(int iter) { constexpr size_t size_zero_pt_k = matrix_k / dequant_s; constexpr size_t size_zero_pt_n = matrix_n; - constexpr size_t size_zero_pt = size_zero_pt_k * size_zero_pt_n; + constexpr size_t size_zero_pt = + size_zero_pt_k * size_zero_pt_n / (2 * sizeof(data_type_b)); constexpr size_t size_c = matrix_m * matrix_n; constexpr size_t size_bias = matrix_n; @@ -247,9 +244,7 @@ void dequantize_gemv_run(int iter) { uint32_t ldc = matrix_n; uint32_t ld_scale = layout_b == mem_layout::row_major ? size_scale_n : size_scale_k; - - uint32_t ld_zero_pt = - layout_b == mem_layout::row_major ? size_zero_pt_n : size_zero_pt_k; + uint32_t ld_zero_pt = size_zero_pt_n; // Turn on the enable_profiling property to facilitate subsequent profiling sycl::property_list properties{ @@ -391,12 +386,12 @@ void dequantize_gemv_run(int iter) { if constexpr (std::is_same_v) { B_h[i] = random_uint8(); #ifdef UT_DEBUG - B_h[i] = 0x22; + B_h[i] = 0x77; #endif } else if constexpr (std::is_same_v) { B_h[i] = random_uint32(); #ifdef UT_DEBUG - B_h[i] = i % 128; + B_h[i] = 0x77777777; #endif } } @@ -410,22 +405,22 @@ void dequantize_gemv_run(int iter) { for (unsigned i = size_scale; i < size_scale + UNDEFINED_DATA_SIZE; ++i) { scale_h[i] = INFINITY; } - for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) { if constexpr (std::is_same_v) { zero_pt_h[i] = random_uint8(); #ifdef UT_DEBUG - zero_pt_h[i] = zero_pt_h[i] << 4 + i % 8; - zero_pt_h[i] = 0x33; + zero_pt_h[i] = 0x12 << i; #endif } else if constexpr (std::is_same_v) { zero_pt_h[i] = random_uint32(); #ifdef UT_DEBUG - zero_pt_h[i] = 0x01234567; + zero_pt_h[i] = 0x33333333; #endif } } - + zero_pt_h[0] = 0x12; + zero_pt_h[1] = 0x34; + for (unsigned i = 0; i < size_c; ++i) { C_h[i] = random_float(); } @@ -520,11 +515,11 @@ void dequantize_gemv_run(int iter) { epilogue_args); } cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - if (!gemm_op_t::can_implement(gemm_arg)) { - std::cout << "The arguments cannot be supported, aborting ... " - << std::endl; - FAIL(); - } + // if (!gemm_op_t::can_implement(gemm_arg)) { + // std::cout << "The arguments cannot be supported, aborting ... " + // << std::endl; + // FAIL(); + // } size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); From f19c86f551ebb0cf9b0d4363776ab5685039967f Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Wed, 12 Jun 2024 05:21:29 +0800 Subject: [PATCH 23/34] save --- .../group/gemm/impl/int4_dequantize_xe.hpp | 6 + include/subgroup/tile/impl/load_xe.hpp | 142 +++++++++--------- tests/integration/gemv/int4/main.cpp | 6 +- 3 files changed, 81 insertions(+), 73 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 5fd6cb465..3ec16904e 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -191,12 +191,14 @@ class gemm_t< tile_size_x_b, tile_size_y_b / pack_ratio, block_size_x_b, + // block_size_y_b * sizeof(dtype_mma_b) / sizeof(dtype_b), block_size_y_b / pack_ratio, reg_layout_b>, // compress int4 along N dimensions subgroup::tile_desc_t< tile_size_x_b / pack_ratio, tile_size_y_b, + // block_size_x_b * sizeof(dtype_mma_b) / sizeof(dtype_b), block_size_x_b / pack_ratio, block_size_y_b, reg_layout_b>>; @@ -205,6 +207,8 @@ class gemm_t< mem_desc_b_t, matB_tile_desc_t, subgroup::msg_type_v, + // subgroup::msg_type_v, arch_tag>; using matB_prefetch_payload_t = subgroup:: prefetch_payload_t; @@ -557,6 +561,8 @@ class gemm_t< matA, matA_payload); subgroup::tile_load( matB, matB_payload); + // subgroup::tile_load( + // matB, matB_payload); subgroup::tile_load( scale, scale_payload); if constexpr ( diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index d3faab6b0..a6675fb75 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -466,66 +466,68 @@ tile_load(tile_t& tile, payload_t& payload) { uint32_t offset_x = j * tile_desc::block_size_x; auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); - // #pragma unroll - // for (uint32_t sub_block_y = 0; sub_block_y < - // tile_desc::block_size_x; - // sub_block_y += num_channel) { - uint32_t sub_block_y = 0; - xetla_vector reg_tmp = 0; - uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) - : offset_x * sizeof(dtype) + (offset_y + 0) * payload.pitch_in_bytes; +#pragma unroll + for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_x; + sub_block_y += num_channel) { + xetla_vector reg_tmp = 0; + uint32_t address_offset = payload_t::mem_transpose + ? offset_x * payload.pitch_in_bytes + + (offset_y + sub_block_y) * sizeof(dtype) + : offset_x * sizeof(dtype) + + (offset_y + sub_block_y) * payload.pitch_in_bytes; - const uint32_t sub_block_offset_x = payload.base_x + offset_x + 0; - const uint32_t sub_block_offset_y = - payload.base_y + offset_y + sub_block_y; - const auto offset_ch_dim = - payload_t::trans ? sub_block_offset_x : sub_block_offset_y; - const auto size_ch_dim = - payload_t::trans ? payload.width_in_elems : payload.height_in_elems; - - xetla_mask pred = 1; - offset_ch_dim + num_channel > size_ch_dim - ? (xetla_vector_gen(offset_ch_dim, 1) < - size_ch_dim) - : 1; + // const uint32_t sub_block_offset_x = payload.base_x + offset_x + 0; + // const uint32_t sub_block_offset_y = + // payload.base_y + offset_y + sub_block_y; + // const auto offset_ch_dim = + // payload_t::trans ? sub_block_offset_x : sub_block_offset_y; + // const auto size_ch_dim = + // payload_t::trans ? payload.width_in_elems : + // payload.height_in_elems; + + // xetla_mask pred = offset_ch_dim + num_channel > + // size_ch_dim + // ? (xetla_vector_gen(offset_ch_dim, 1) < + // size_ch_dim) + // : 1; + + reg_tmp = xetla_load_global< + load_dtype, + payload_t::simd_exec_size, + data_size::default_size, + L1, + L2, + payload_t::num_channel>( + payload.base_ptr, + payload.channel_offset + payload.base_offset + address_offset, + 1); - reg_tmp = xetla_load_global< - load_dtype, - payload_t::simd_exec_size, - data_size::default_size, - L1, - L2, - payload_t::num_channel>( - payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, - 1); - - if constexpr (payload_t::simd_exec_size > 1) { - xetla_vector reg_tmp_trans; + if constexpr ( + payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) { + xetla_vector reg_tmp_trans; #pragma unroll - for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { - if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix - reg_tmp_trans.xetla_select( - iii * payload_t::simd_exec_size) = - reg_tmp.xetla_select< - payload_t::simd_exec_size, - payload_t::num_channel>(iii); - else // TODO (dingyi): Delete after driver fix - reg_tmp_trans.xetla_select( - iii * payload_t::simd_exec_size) = 0; + for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { + // if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix + // reg_tmp_trans.xetla_select( + // iii * payload_t::simd_exec_size) = + // reg_tmp.xetla_select< + // payload_t::simd_exec_size, + // payload_t::num_channel>(iii); + // else // TODO (dingyi): Delete after driver fix + // reg_tmp_trans.xetla_select( + // iii * payload_t::simd_exec_size) = 0; + } + reg_sub + .xetla_select( + sub_block_y * tile_desc::block_size_x) + .xetla_format() = reg_tmp_trans; + } else { + reg_sub + .xetla_select( + sub_block_y * tile_desc::block_size_x) + .xetla_format() = reg_tmp; } - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp_trans; - } else { - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp; } - // } } } @@ -565,7 +567,9 @@ __XETLA_API typename std::enable_if_t< tile_load(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; - constexpr uint32_t load_elems = tile_desc::block_size_x; + constexpr uint32_t load_elems = payload_t::mem_transpose + ? tile_desc::block_size_y + : tile_desc::block_size_x; #pragma unroll for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { @@ -575,21 +579,21 @@ tile_load(tile_t& tile, payload_t& payload) { uint32_t offset_x = j * tile_desc::block_size_x; auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); -#pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; - sub_block_y += 1) { - uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + - (offset_y + sub_block_y) * sizeof(dtype) - : offset_x * sizeof(dtype) + - (offset_y + sub_block_y) * payload.pitch_in_bytes; + // #pragma unroll + // for (uint32_t sub_block_y = 0; sub_block_y < + // tile_desc::block_size_x; + // sub_block_y += 1) { + uint32_t sub_block_y = 0; + uint32_t address_offset = payload_t::mem_transpose + ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) + : offset_x * sizeof(dtype) + (offset_y + 0) * payload.pitch_in_bytes; - reg_sub.xetla_select( - sub_block_y * tile_desc::block_size_x) = - xetla_load_global( - (dtype*)payload.base_ptr, payload.base_offset + address_offset); - } + reg_sub.xetla_select( + sub_block_y * tile_desc::block_size_x) = + xetla_load_global( + (dtype*)payload.base_ptr, payload.base_offset + address_offset); } + // } } if constexpr (payload_t::trans) { diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 20ffd664f..dae9739fa 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -39,8 +39,8 @@ class test_col_major_1 { static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024; static constexpr size_t dequant_s = 128; - static constexpr quant_mode quant_type = quant_mode::S4_ASYM; - // static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; + // static constexpr quant_mode quant_type = quant_mode::S4_ASYM; + static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -418,8 +418,6 @@ void dequantize_gemv_run(int iter) { #endif } } - zero_pt_h[0] = 0x12; - zero_pt_h[1] = 0x34; for (unsigned i = 0; i < size_c; ++i) { C_h[i] = random_float(); From 897f5d588168f922fe44148f689e74c740936c23 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Fri, 14 Jun 2024 23:09:26 +0800 Subject: [PATCH 24/34] sg_m =4 for first token --- include/subgroup/tile/impl/fma_xe.hpp | 4 +- include/subgroup/tile/impl/load_xe.hpp | 160 +++++++++++---------- include/subgroup/tile/impl/op_function.hpp | 3 +- include/subgroup/tile/impl/payload_xe.hpp | 19 ++- tests/integration/gemv/int4/main.cpp | 12 +- 5 files changed, 108 insertions(+), 90 deletions(-) diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index cf15e12a8..e81d8d7df 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -76,8 +76,8 @@ struct tile_fma_t { a_block_size_y == matAcc_t::block_size_y, "mata block m should match with matAcc block m"); - // static_assert(tile_size_m == 1, "matA tile m must be 1"); - // static_assert(a_block_size_y == 1, "matA block m must be 1"); + static_assert(tile_size_n == 1, "matB tile n must be 1"); + static_assert(b_block_size_x == 1, "matB block n must be 1"); __XETLA_API static void mma( matAcc_t& acc_dst, matAcc_t& acc_src, diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index a6675fb75..f11cd2e9f 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -26,7 +26,7 @@ namespace gpu::xetla::subgroup { namespace detail { -template +template struct check_load_type { static constexpr bool is_lsc_gather = is_lsc_gather_; static constexpr bool is_global_block_2d = @@ -466,72 +466,78 @@ tile_load(tile_t& tile, payload_t& payload) { uint32_t offset_x = j * tile_desc::block_size_x; auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); + // #pragma unroll + // for (uint32_t sub_block_offset = 0; sub_block_offset < + // (payload_t::mem_transpose ? tile_desc::block_size_x + // : tile_desc::block_size_y); + // sub_block_offset += num_channel) { + uint32_t sub_block_offset = 0; + xetla_vector reg_tmp = 0; + uint32_t address_offset = payload_t::mem_transpose + ? (offset_x + sub_block_offset) * payload.pitch_in_bytes + + offset_y * sizeof(dtype) + : offset_x * sizeof(dtype) + + (offset_y + sub_block_offset) * payload.pitch_in_bytes; + xetla_mask pred = 1; + if constexpr (num_channel > 1) { + // For SDP load, need pred + const uint32_t sub_block_offset_x = payload.base_x + offset_x + + (payload_t::mem_transpose ? sub_block_offset : 0); + const uint32_t sub_block_offset_y = payload.base_y + offset_y + + (payload_t::mem_transpose ? 0 : sub_block_offset); + const auto offset_ch_dim = + payload_t::trans ? sub_block_offset_x : sub_block_offset_y; + const auto size_ch_dim = + payload_t::trans ? payload.width_in_elems : payload.height_in_elems; + + pred = offset_ch_dim + num_channel > size_ch_dim + ? (xetla_vector_gen(offset_ch_dim, 1) < + size_ch_dim) + : 1; + } + reg_tmp = xetla_load_global< + load_dtype, + payload_t::simd_exec_size, + data_size::default_size, + L1, + L2, + payload_t::num_channel>( + payload.base_ptr, + payload.channel_offset + payload.base_offset + address_offset, + pred); + + if constexpr ( + payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) { + xetla_vector reg_tmp_trans; #pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_x; - sub_block_y += num_channel) { - xetla_vector reg_tmp = 0; - uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + - (offset_y + sub_block_y) * sizeof(dtype) - : offset_x * sizeof(dtype) + - (offset_y + sub_block_y) * payload.pitch_in_bytes; - - // const uint32_t sub_block_offset_x = payload.base_x + offset_x + 0; - // const uint32_t sub_block_offset_y = - // payload.base_y + offset_y + sub_block_y; - // const auto offset_ch_dim = - // payload_t::trans ? sub_block_offset_x : sub_block_offset_y; - // const auto size_ch_dim = - // payload_t::trans ? payload.width_in_elems : - // payload.height_in_elems; - - // xetla_mask pred = offset_ch_dim + num_channel > - // size_ch_dim - // ? (xetla_vector_gen(offset_ch_dim, 1) < - // size_ch_dim) - // : 1; - - reg_tmp = xetla_load_global< - load_dtype, - payload_t::simd_exec_size, - data_size::default_size, - L1, - L2, - payload_t::num_channel>( - payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, - 1); - - if constexpr ( - payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) { - xetla_vector reg_tmp_trans; -#pragma unroll - for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { - // if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix - // reg_tmp_trans.xetla_select( - // iii * payload_t::simd_exec_size) = - // reg_tmp.xetla_select< - // payload_t::simd_exec_size, - // payload_t::num_channel>(iii); - // else // TODO (dingyi): Delete after driver fix - // reg_tmp_trans.xetla_select( - // iii * payload_t::simd_exec_size) = 0; - } - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp_trans; - } else { - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp; + for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { + if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix + reg_tmp_trans.xetla_select( + iii * payload_t::simd_exec_size) = + reg_tmp.xetla_select< + payload_t::simd_exec_size, + payload_t::num_channel>(iii); + else // TODO (dingyi): Delete after driver fix + reg_tmp_trans.xetla_select( + iii * payload_t::simd_exec_size) = 0; } + reg_sub + .xetla_select( + sub_block_offset * tile_desc::block_size_x) + .xetla_format() = reg_tmp_trans; + } else { + reg_sub + .xetla_select( + sub_block_offset * tile_desc::block_size_x) + .xetla_format() = reg_tmp; } } + // } } - if constexpr (payload_t::trans) { + if constexpr ( + payload_t::trans && + !(std::is_same_v || std::is_same_v)) { SW_BARRIER(); tile_transpose(tile); } @@ -579,24 +585,28 @@ tile_load(tile_t& tile, payload_t& payload) { uint32_t offset_x = j * tile_desc::block_size_x; auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); - // #pragma unroll - // for (uint32_t sub_block_y = 0; sub_block_y < - // tile_desc::block_size_x; - // sub_block_y += 1) { - uint32_t sub_block_y = 0; - uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) - : offset_x * sizeof(dtype) + (offset_y + 0) * payload.pitch_in_bytes; +#pragma unroll + for (uint32_t sub_block_y = 0; + sub_block_y < (payload_t::mem_transpose ? tile_desc::block_size_x + : tile_desc::block_size_y); + sub_block_y += 1) { + uint32_t address_offset = payload_t::mem_transpose + ? (offset_x + sub_block_y) * payload.pitch_in_bytes + + offset_y * sizeof(dtype) + : offset_x * sizeof(dtype) + + (offset_y + sub_block_y) * payload.pitch_in_bytes; - reg_sub.xetla_select( - sub_block_y * tile_desc::block_size_x) = - xetla_load_global( - (dtype*)payload.base_ptr, payload.base_offset + address_offset); + reg_sub.xetla_select( + sub_block_y * tile_desc::block_size_x) = + xetla_load_global( + (dtype*)payload.base_ptr, payload.base_offset + address_offset); + } } - // } } - if constexpr (payload_t::trans) { + if constexpr ( + payload_t::trans && + !(std::is_same_v || std::is_same_v)) { SW_BARRIER(); tile_transpose(tile); } diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 567abcb61..8c43b5a32 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -713,7 +713,8 @@ void dump_mat( #pragma unroll for (size_t col = 0; col < tile_x; col++) { sycl::ext::oneapi::experimental::printf( - "%x ", + "%x(%d) ", + int(native_type_t(mat.reg[row * tile_x + col])), int(native_type_t(mat.reg[row * tile_x + col]))); } sycl::ext::oneapi::experimental::printf("\n"); diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 44a494b4a..3d7751247 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1091,8 +1091,7 @@ struct mem_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose && - !(std::is_same_v || std::is_same_v); + static constexpr bool trans = mem_transpose ^ reg_transpose; static constexpr bool mem_transform = (sizeof(dtype) < 4) && (register_layout == reg_layout::vnni_tiled || @@ -1849,7 +1848,10 @@ struct prefetch_payload_t< arch_tag_, std::enable_if_t< (arch_tag_ == gpu_arch::XeHpc) && - (tile_size_y_ != 1 || block_size_y_ != 1)>> { + (((tile_size_y_ != 1 || block_size_y_ != 1) && + mem_layout_ == mem_layout::row_major) || + ((tile_size_x_ != 1 || block_size_x_ != 1) && + mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -2140,16 +2142,21 @@ struct prefetch_payload_t< arch_tag_, std::enable_if_t< ((tile_size_y_ == 1 || block_size_y_ == 1) && - reg_layout_ == reg_layout::tiled) || + mem_layout_ == mem_layout::row_major) || ((tile_size_x_ == 1 || block_size_x_ == 1) && - reg_layout_ == reg_layout::transpose_tiled)>> { + mem_layout_ == mem_layout::col_major)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; // CL aligned, so we can use uint64_t using prefetch_dtype = uint64_t; static constexpr msg_type message_type = msg_type::block_1d; - using tile_desc = tile_desc_t; + using tile_desc = tile_desc_t< + tile_size_x_, + tile_size_y_, + block_size_x_, + block_size_x_, + reg_layout_>; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = mem_layout_; static constexpr gpu_arch arch_tag = arch_tag_; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index dae9739fa..8ace568eb 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -30,14 +30,14 @@ constexpr size_t UNDEFINED_DATA_SIZE = 1024; class test_col_major_1 { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; + static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 4096; static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 1; + static constexpr size_t wg_m = 4; static constexpr size_t wg_n = 1; - static constexpr size_t sg_m = 1; + static constexpr size_t sg_m = 4; static constexpr size_t sg_n = 1; - static constexpr size_t sg_k = 1024; + static constexpr size_t sg_k = 1024 / 4; static constexpr size_t dequant_s = 128; // static constexpr quant_mode quant_type = quant_mode::S4_ASYM; static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; @@ -375,7 +375,7 @@ void dequantize_gemv_run(int iter) { for (unsigned i = 0; i < size_a; ++i) { A_h[i] = random_float(); #ifdef UT_DEBUG - A_h[i] = 1.f; + A_h[i] = i; // A_h[i] = layout_a == mem_layout::row_major // ? (i % matrix_k + i / matrix_k * 100) // : (i % matrix_m + i / matrix_m * 100); @@ -418,7 +418,7 @@ void dequantize_gemv_run(int iter) { #endif } } - + for (unsigned i = 0; i < size_c; ++i) { C_h[i] = random_float(); } From e7f27161f7da9df7f78c783b3e0aa61b1e953ae7 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Sat, 15 Jun 2024 02:01:08 +0800 Subject: [PATCH 25/34] Extract dequant func --- include/common/core/base_consts.hpp | 6 +- .../group/gemm/compute_policy.hpp | 5 +- .../group/gemm/impl/int4_dequantize_xe.hpp | 192 ++++++++++-------- .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 16 +- include/subgroup/tile/impl/payload_xe.hpp | 28 +-- include/subgroup/tile/impl/prefetch_xe.hpp | 34 ++-- .../subgroup/tile/impl/tile_op_functor.hpp | 118 +++++++++++ tests/integration/gemv/int4/main.cpp | 14 +- 8 files changed, 270 insertions(+), 143 deletions(-) diff --git a/include/common/core/base_consts.hpp b/include/common/core/base_consts.hpp index 2f8bbd489..13cf0b369 100644 --- a/include/common/core/base_consts.hpp +++ b/include/common/core/base_consts.hpp @@ -23,9 +23,9 @@ namespace gpu::xetla { -/// @addtogroup xetla_core_base_types +/// @addtogroup xetla_core_base_consts /// @{ - -/// @} xetla_core_base_types +enum quant_mode : uint8_t { S4_ASYM, S4_FULLRANGE_NO_ZP }; +/// @} xetla_core_base_consts } // namespace gpu::xetla diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index d3cd003c5..47b503182 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -22,9 +22,6 @@ #include namespace gpu::xetla::group { - -enum quant_mode : uint8_t { S4_ASYM, S4_FULLRANGE_NO_ZP }; - /// @brief Compute policy for int4 dequant gemm. /// @tparam compute_attr_ Is compute-related attributes. /// @tparam perf_tuning_knob_ Is performance-related knobs. @@ -35,7 +32,7 @@ template < typename dtype_scale_, typename dtype_zero_pt_, quant_mode quant_type_, - int dequant_s_, + uint32_t dequant_s_, mma_engine mma_engine_ = mma_engine::xmx, gpu_arch arch_tag_ = gpu_arch::XeHpc, typename enable = void> diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 3ec16904e..c0b4157ec 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -36,7 +36,7 @@ template < typename mem_desc_b_t_, typename dtype_scale_, typename dtype_zero_pt_, - int dequant_s_, + uint32_t dequant_s_, quant_mode quant_type_, mma_engine mma_engine_, typename pre_processing_t_, @@ -322,7 +322,13 @@ class gemm_t< matA_acc_t, compute_policy::mma_engine, arch_tag>>; - + using dequantize_t = subgroup::dequant_int4_weight_t< + matB_acc_t, + matB_t, + scale_t, + zero_pt_t, + dequant_s, + quant_type_>; static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; @@ -500,6 +506,9 @@ class gemm_t< wg_start_m = args.matA_base_desc.coord.y; wg_start_n = args.scale_base_desc.coord.x; wg_start_k = args.matA_base_desc.coord.x; + typename dequantize_t::arguments_t dequantize_args{ + wg_start_m, wg_start_n, wg_start_k}; + dequantize_t dequantize; xetla_nbarrier_t nbarrier_a; nbarrier_a.init_nbarrier( @@ -624,7 +633,8 @@ class gemm_t< subgroup::vnni_reverse(matA); } subgroup::elemwise_cvt(matA_acc, matA); - dequantize(matB_acc, matB, scale, zero_pt); + + dequantize(matB_acc, matB, scale, zero_pt, dequantize_args); SW_BARRIER(); if constexpr (is_gemv) { tile_mma::mma( @@ -658,91 +668,97 @@ class gemm_t< } private: - inline void dequantize( - matB_acc_t& matB_acc, - matB_t& matB, - scale_t& scale, - zero_pt_t& zero_pt) { - // no tail, because this is matB - constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; - constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; -#pragma unroll - for (uint32_t i = 0; i < num_block_y; ++i) { -#pragma unroll - for (uint32_t j = 0; j < num_block_x; ++j) { - int block_id = (i * num_block_x + j); - // Must be little-endian - auto matB_blk = matB.reg.xetla_format() - .xetla_select( - block_id * matB_acc_t::block_elems / 2); - - auto dst_blk = matB_acc.reg.xetla_select( - block_id * matB_acc_t::block_elems); - - // int8 includes 2 4bits data. - xetla_vector cvt_blk_i8; - - // lowest 4 bit - { - cvt_blk_i8.xetla_select(0) = - matB_blk & 0xf; - } - // highest 4 bit - { - cvt_blk_i8.xetla_select(1) = - matB_blk >> 4; - } - - // (b_i8 - zero_pt_i8) x scale = fp16 - constexpr uint32_t step = std::min(block_size_y_b, dequant_s); -#pragma unroll - for (uint32_t jj = 0; jj < block_size_x_b; jj++) { -#pragma unroll - for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { - uint32_t offset_y_in_tile = i * block_size_y_b + ii; - uint32_t offset_x_in_tile = j * block_size_x_b + jj; - - uint32_t scale_idx = - (offset_y_in_tile) / dequant_s * scale_t::block_size_x + - offset_x_in_tile; - - if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { - uint32_t zero_pt_idx = - offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + - offset_x_in_tile / pack_ratio; - native_type_t zero_pt_pack = zero_pt.reg[zero_pt_idx]; - - int8_t zero_pt_i8 = - (zero_pt_pack >> - (4 * ((wg_start_n + offset_x_in_tile) % pack_ratio))) & - 0xf; - // sycl::ext::oneapi::experimental::printf( - // "zero_pt.reg[%d} %x zero_pt_i8 %x offset_x_in_tile:%d - // \n", zero_pt_idx, zero_pt_pack, (int32_t)zero_pt_i8 , - // offset_x_in_tile); - - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - - zero_pt_i8; - } else if constexpr ( - compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - - int8_t(8); - } - dst_blk.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * - scale.reg[scale_idx]; - - // sycl::ext::oneapi::experimental::printf( - // "scale[%d] %f \n", - // scale_idx, - // float(sycl::half(scale.reg.xetla_select<1, 1>(scale_idx)))); - } - } - } - } - } + // inline void dequantize( + // matB_acc_t& matB_acc, + // matB_t& matB, + // scale_t& scale, + // zero_pt_t& zero_pt) { + // // no tail, because this is matB + // constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; + // constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; + // #pragma unroll + // for (uint32_t i = 0; i < num_block_y; ++i) { + // #pragma unroll + // for (uint32_t j = 0; j < num_block_x; ++j) { + // int block_id = (i * num_block_x + j); + // // Must be little-endian + // auto matB_blk = matB.reg.xetla_format() + // .xetla_select( + // block_id * matB_acc_t::block_elems / 2); + + // auto dst_blk = matB_acc.reg.xetla_select( + // block_id * matB_acc_t::block_elems); + + // // int8 includes 2 4bits data. + // xetla_vector cvt_blk_i8; + + // // lowest 4 bit + // { + // cvt_blk_i8.xetla_select(0) = + // matB_blk & 0xf; + // } + // // highest 4 bit + // { + // cvt_blk_i8.xetla_select(1) = + // matB_blk >> 4; + // } + + // // (b_i8 - zero_pt_i8) x scale = fp16 + // constexpr uint32_t step = std::min(block_size_y_b, dequant_s); + // #pragma unroll + // for (uint32_t jj = 0; jj < block_size_x_b; jj++) { + // #pragma unroll + // for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { + // uint32_t offset_y_in_tile = i * block_size_y_b + ii; + // uint32_t offset_x_in_tile = j * block_size_x_b + jj; + + // uint32_t scale_idx = + // (offset_y_in_tile) / dequant_s * scale_t::block_size_x + + // offset_x_in_tile; + + // if constexpr (compute_policy::quant_type == + // quant_mode::S4_ASYM) { + // uint32_t zero_pt_idx = + // offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + + // offset_x_in_tile / pack_ratio; + // native_type_t zero_pt_pack = + // zero_pt.reg[zero_pt_idx]; + + // int8_t zero_pt_i8 = + // (zero_pt_pack >> + // (4 * ((wg_start_n + offset_x_in_tile) % pack_ratio))) & + // 0xf; + // // sycl::ext::oneapi::experimental::printf( + // // "zero_pt.reg[%d} %x zero_pt_i8 %x + // offset_x_in_tile:%d + // // \n", zero_pt_idx, zero_pt_pack, (int32_t)zero_pt_i8 , + // // offset_x_in_tile); + + // cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + // cvt_blk_i8.xetla_select(jj * block_size_y_b + + // ii) - zero_pt_i8; + // } else if constexpr ( + // compute_policy::quant_type == + // quant_mode::S4_FULLRANGE_NO_ZP) { + // cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + // cvt_blk_i8.xetla_select(jj * block_size_y_b + + // ii) - int8_t(8); + // } + // dst_blk.xetla_select(jj * block_size_y_b + ii) = + // cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) + // * scale.reg[scale_idx]; + + // // sycl::ext::oneapi::experimental::printf( + // // "scale[%d] %f \n", + // // scale_idx, + // // float(sycl::half(scale.reg.xetla_select<1, + // 1>(scale_idx)))); + // } + // } + // } + // } + // } /* inline void dequantize( diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index bfc3bcd74..358092bc1 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -159,7 +159,7 @@ class gemm_universal_t< /// @brief GEMM arguments. /// This is the interface for users to pass the application-related runtime /// variables. - template + template struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). @@ -295,7 +295,7 @@ class gemm_universal_t< } }; template <> - struct arguments_t { + struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). uint32_t matrix_m; @@ -486,7 +486,7 @@ class gemm_universal_t< /// @param args Is the GEMM arguments for application-related runtime /// variables. /// @return Expected nd_range. - template + template static cl::sycl::nd_range<3> get_nd_range(arguments_t& args) { cl::sycl::range<3> local_range = get_local_range(); cl::sycl::range<3> group_range = @@ -523,7 +523,7 @@ class gemm_universal_t< /// @param args Is the GEMM arguments for application-related runtime /// variables. /// @return Check result. - template + template static bool can_implement(arguments_t& args) { bool implementable = true; if (gemm_t::msg_type_a != msg_type::unaligned_2d) { @@ -566,8 +566,7 @@ class gemm_universal_t< implementable &= ((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0)); if constexpr ( - gemm_t::compute_policy::quant_type != - group::quant_mode::S4_FULLRANGE_NO_ZP) { + gemm_t::compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { implementable &= (args.zero_pt_ld % pack_ratio == 0); } @@ -584,7 +583,7 @@ class gemm_universal_t< /// variables. /// @param slm_base Is the slm base address. /// @param nbarrier_base Is the named barrier base. - template + template __XETLA_API KERNEL_FUNC void operator()( sycl::nd_item<3>& item, const arguments_t& args, @@ -669,8 +668,7 @@ class gemm_universal_t< uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride; gemm_args_t gemm_args; if constexpr ( - gemm_t::compute_policy::quant_type == - group::quant_mode::S4_FULLRANGE_NO_ZP) { + gemm_t::compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { gemm_args = gemm_args_t( mem_desc_a, mem_desc_b, diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 3d7751247..21b80a9f0 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1091,7 +1091,8 @@ struct mem_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose; + static constexpr bool trans = mem_transpose ^ reg_transpose && + !(std::is_same_v || std::is_same_v); static constexpr bool mem_transform = (sizeof(dtype) < 4) && (register_layout == reg_layout::vnni_tiled || @@ -1617,8 +1618,8 @@ struct prefetch_payload_t< arch_tag_, std::enable_if_t< arch_tag_ <= gpu_arch::XeHpg && - ((block_size_y_ != 1 && reg_layout_ == reg_layout::tiled) || - (block_size_x_ != 1 && reg_layout_ == reg_layout::transpose_tiled))>> { + ((block_size_y_ != 1 && mem_layout_ == mem_layout::row_major) || + (block_size_x_ != 1 && mem_layout_ == mem_layout::col_major))>> { using dtype = native_type_t; using mem_desc_t = mem_desc_t; @@ -1848,10 +1849,8 @@ struct prefetch_payload_t< arch_tag_, std::enable_if_t< (arch_tag_ == gpu_arch::XeHpc) && - (((tile_size_y_ != 1 || block_size_y_ != 1) && - mem_layout_ == mem_layout::row_major) || - ((tile_size_x_ != 1 || block_size_x_ != 1) && - mem_layout_ == mem_layout::col_major))>> { + (((block_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || + ((block_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -2141,22 +2140,15 @@ struct prefetch_payload_t< num_coop_sg_, arch_tag_, std::enable_if_t< - ((tile_size_y_ == 1 || block_size_y_ == 1) && - mem_layout_ == mem_layout::row_major) || - ((tile_size_x_ == 1 || block_size_x_ == 1) && - mem_layout_ == mem_layout::col_major)>> { + ((block_size_y_ == 1) && mem_layout_ == mem_layout::row_major) || + ((block_size_x_ == 1) && mem_layout_ == mem_layout::col_major)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; // CL aligned, so we can use uint64_t using prefetch_dtype = uint64_t; static constexpr msg_type message_type = msg_type::block_1d; - using tile_desc = tile_desc_t< - tile_size_x_, - tile_size_y_, - block_size_x_, - block_size_x_, - reg_layout_>; + using tile_desc = tile_desc_t; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = mem_layout_; static constexpr gpu_arch arch_tag = arch_tag_; @@ -2283,4 +2275,4 @@ struct prefetch_payload_t< __XETLA_API void update_tdesc([[maybe_unused]] int offset) {} }; -} // namespace gpu::xetla::subgroup +} // namespace gpu::xetla::subgroup \ No newline at end of file diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index 0372f96cc..c821ab25d 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -25,23 +25,29 @@ namespace gpu::xetla::subgroup { namespace detail { + +template +struct check_prefetch_type; +template +struct check_prefetch_type< + payload_t, + std::enable_if_t> { + static constexpr bool is_global_2d = false; + static constexpr bool is_global_block_1d = false; + static constexpr bool is_global_unaligned_2d = false; + static constexpr bool is_local = true; +}; template -struct check_prefetch_type { +struct check_prefetch_type< + payload_t, + std::enable_if_t> { static constexpr bool is_global_2d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y != 1)); - + payload_t::message_type == msg_type::block_2d; static constexpr bool is_global_block_1d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y == 1)); - + payload_t::message_type == msg_type::block_1d; static constexpr bool is_global_unaligned_2d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y != 1) && - (payload_t::message_type == msg_type::unaligned_2d)); - - static constexpr bool is_local = - (payload_t::memory_space == mem_space::local); + payload_t::message_type == msg_type::unaligned_2d; + static constexpr bool is_local = false; }; } // namespace detail @@ -189,4 +195,4 @@ __XETLA_API typename std::enable_if_t::is_local> tile_prefetch([[maybe_unused]] payload_t& payload) {} -} // namespace gpu::xetla::subgroup +} // namespace gpu::xetla::subgroup \ No newline at end of file diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 644717df8..b59dfe935 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -51,6 +51,124 @@ struct none_op_t { } }; +template < + typename matB_acc_t, + typename matB_t, + typename scale_t, + typename zero_pt_t, + uint32_t dequant_s, + quant_mode quant_type> +struct dequant_int4_weight_t { + struct arguments_t { + uint32_t wg_start_m; + uint32_t wg_start_n; + uint32_t wg_start_k; + inline arguments_t() = default; + inline arguments_t( + uint32_t wg_start_m_, + uint32_t wg_start_n_, + uint32_t wg_start_k_) + : wg_start_m(wg_start_m_), + wg_start_n(wg_start_n_), + wg_start_k(wg_start_k_) {} + }; + __XETLA_API KERNEL_FUNC void operator()( + matB_acc_t& matB_acc, + matB_t& matB, + scale_t& scale, + zero_pt_t& zero_pt, + // [[maybe_unused]] const coord_t& coord, + [[maybe_unused]] const arguments_t& args, + [[maybe_unused]] uint32_t slm_base = 0, + [[maybe_unused]] uint32_t nbarrier_base = 0) { + // no tail, because this is matB + constexpr uint32_t tile_size_x_b = matB_acc_t::tile_size_x; + constexpr uint32_t tile_size_y_b = matB_acc_t::tile_size_y; + constexpr uint32_t block_size_x_b = matB_acc_t::block_size_x; + constexpr uint32_t block_size_y_b = matB_acc_t::block_size_y; + static constexpr uint32_t pack_ratio = sizeof(typename matB_t::dtype) * 2; + + constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; + constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; +#pragma unroll + for (uint32_t i = 0; i < num_block_y; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_block_x; ++j) { + int block_id = (i * num_block_x + j); + // Must be little-endian + auto matB_blk = matB.reg.xetla_format() + .xetla_select( + block_id * matB_acc_t::block_elems / 2); + + auto dst_blk = matB_acc.reg.xetla_select( + block_id * matB_acc_t::block_elems); + + // int8 includes 2 4bits data. + xetla_vector cvt_blk_i8; + + // lowest 4 bit + { + cvt_blk_i8.xetla_select(0) = + matB_blk & 0xf; + } + // highest 4 bit + { + cvt_blk_i8.xetla_select(1) = + matB_blk >> 4; + } + + // (b_i8 - zero_pt_i8) x scale = fp16 + constexpr uint32_t step = std::min(block_size_y_b, dequant_s); +#pragma unroll + for (uint32_t jj = 0; jj < block_size_x_b; jj++) { +#pragma unroll + for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { + uint32_t offset_y_in_tile = i * block_size_y_b + ii; + uint32_t offset_x_in_tile = j * block_size_x_b + jj; + + uint32_t scale_idx = + (offset_y_in_tile) / dequant_s * scale_t::block_size_x + + offset_x_in_tile; + + if constexpr (quant_type == quant_mode::S4_ASYM) { + uint32_t zero_pt_idx = + offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + + offset_x_in_tile / pack_ratio; + native_type_t zero_pt_pack = + zero_pt.reg[zero_pt_idx]; + + int8_t zero_pt_i8 = + (zero_pt_pack >> + (4 * ((args.wg_start_n + offset_x_in_tile) % pack_ratio))) & + 0xf; + // sycl::ext::oneapi::experimental::printf( + // "zero_pt.reg[%d} %x zero_pt_i8 %x offset_x_in_tile:%d + // \n", zero_pt_idx, zero_pt_pack, (int32_t)zero_pt_i8 , + // offset_x_in_tile); + + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - + zero_pt_i8; + } else if constexpr (quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - + int8_t(8); + } + dst_blk.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * + scale.reg[scale_idx]; + + // sycl::ext::oneapi::experimental::printf( + // "scale[%d] %f \n", + // scale_idx, + // float(sycl::half(scale.reg.xetla_select<1, 1>(scale_idx)))); + } + } + } + } + } +}; + /// @brief Is the element-wise relu op functor. /// Get the relu input from matAcc, update the relu output in place, /// Used in epilogue::tile_op or chained_tile_op. diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 8ace568eb..117fe127e 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -30,14 +30,14 @@ constexpr size_t UNDEFINED_DATA_SIZE = 1024; class test_col_major_1 { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 4; + static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 4096; static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 4; + static constexpr size_t wg_m = 1; static constexpr size_t wg_n = 1; - static constexpr size_t sg_m = 4; + static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; - static constexpr size_t sg_k = 1024 / 4; + static constexpr size_t sg_k = 1024 / 1; static constexpr size_t dequant_s = 128; // static constexpr quant_mode quant_type = quant_mode::S4_ASYM; static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; @@ -55,12 +55,12 @@ class test_col_major_1 { class test_col_major_2 { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 32; + static constexpr size_t mat_m = 4; static constexpr size_t mat_n = 4096; static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 1; + static constexpr size_t wg_m = 4; static constexpr size_t wg_n = 1; - static constexpr size_t sg_m = 1; + static constexpr size_t sg_m = 4; static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024; static constexpr size_t dequant_s = 128; From 0ebd89065bd65a3853de14eca14aff4cc503b944 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Mon, 17 Jun 2024 22:00:59 +0800 Subject: [PATCH 26/34] update row_major for origin PVC/ARC template --- include/common/core/base_consts.hpp | 1 - include/common/core/common_types.hpp | 9 ++++ .../group/gemm/compute_policy.hpp | 35 +++++++------- .../group/gemm/impl/int4_dequantize_xe.hpp | 32 ++++++------- .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 8 ++-- include/subgroup/tile/impl/load_xe.hpp | 10 ++-- include/subgroup/tile/impl/payload_xe.hpp | 5 +- .../subgroup/tile/impl/tile_op_functor.hpp | 6 +-- .../gemm/int4_dequantization/main.cpp | 10 ++-- .../int4_dequantization_bias/main_client.cpp | 10 ++-- .../gemm/int4_dequantization_bias/main_xe.cpp | 11 +++-- tests/integration/gemv/int4/main.cpp | 47 +++++++++---------- 12 files changed, 96 insertions(+), 88 deletions(-) diff --git a/include/common/core/base_consts.hpp b/include/common/core/base_consts.hpp index 13cf0b369..67bcd7e92 100644 --- a/include/common/core/base_consts.hpp +++ b/include/common/core/base_consts.hpp @@ -25,7 +25,6 @@ namespace gpu::xetla { /// @addtogroup xetla_core_base_consts /// @{ -enum quant_mode : uint8_t { S4_ASYM, S4_FULLRANGE_NO_ZP }; /// @} xetla_core_base_consts } // namespace gpu::xetla diff --git a/include/common/core/common_types.hpp b/include/common/core/common_types.hpp index 2a23a9e5e..cbd174462 100644 --- a/include/common/core/common_types.hpp +++ b/include/common/core/common_types.hpp @@ -26,4 +26,13 @@ enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 }; enum class grf_mode : uint8_t { normal = 0, double_grf = 1 }; enum class mem_layout : uint8_t { row_major = 0, col_major = 1 }; + +enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 }; + +struct quant_info { + quant_mode quant_mode; + uint32_t dequant_s; + mem_layout weight_mem_layout; +}; + } // namespace gpu::xetla diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 47b503182..1de706ccb 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -31,8 +31,7 @@ template < typename perf_tuning_knob_, typename dtype_scale_, typename dtype_zero_pt_, - quant_mode quant_type_, - uint32_t dequant_s_, + quant_info quant_info_, mma_engine mma_engine_ = mma_engine::xmx, gpu_arch arch_tag_ = gpu_arch::XeHpc, typename enable = void> @@ -44,8 +43,7 @@ template < typename perf_tuning_knob_, typename dtype_scale_, typename dtype_zero_pt_, - quant_mode quant_type_, - int dequant_s_, + quant_info quant_info_, mma_engine mma_engine_, gpu_arch arch_tag_> struct compute_policy_int4_dequantize< @@ -53,8 +51,7 @@ struct compute_policy_int4_dequantize< perf_tuning_knob_, dtype_scale_, dtype_zero_pt_, - quant_type_, - dequant_s_, + quant_info_, mma_engine_, arch_tag_, std::enable_if_t> { @@ -70,17 +67,17 @@ struct compute_policy_int4_dequantize< static constexpr mma_engine mma_engine = mma_engine_; static constexpr gpu_arch arch_tag = arch_tag_; - static_assert(arch_has_xmx(), "XeLpg does not support xmx"); + static_assert(arch_has_xmx, "XeLpg does not support xmx"); static constexpr bool is_int4_matB_policy = true; - static constexpr uint32_t dequant_s = dequant_s_; + static constexpr uint32_t dequant_s = quant_info_.dequant_s; static_assert( (dequant_s % (32 / sizeof(dtype_mma_b))) == 0, "dequant_s should be a multiply of 32B"); using dtype_scale = dtype_scale_; using dtype_zero_pt = dtype_zero_pt_; - static constexpr quant_mode quant_type = quant_type_; + static constexpr quant_mode quant_mode = quant_info_.quant_mode; static constexpr uint32_t block_size_y_a = 16; using mma_attr = mma_attr_t; @@ -103,8 +100,7 @@ template < typename perf_tuning_knob_, typename dtype_scale_, typename dtype_zero_pt_, - quant_mode quant_type_, - int dequant_s_, + quant_info quant_info_, mma_engine mma_engine_, gpu_arch arch_tag_> struct compute_policy_int4_dequantize< @@ -112,8 +108,7 @@ struct compute_policy_int4_dequantize< perf_tuning_knob_, dtype_scale_, dtype_zero_pt_, - quant_type_, - dequant_s_, + quant_info_, mma_engine_, arch_tag_, std::enable_if_t> { @@ -131,20 +126,22 @@ struct compute_policy_int4_dequantize< static constexpr bool is_int4_matB_policy = true; - static constexpr uint32_t dequant_s = dequant_s_; + static constexpr uint32_t dequant_s = quant_info_.dequant_s; static_assert( (dequant_s % (32 / sizeof(dtype_mma_b))) == 0, "dequant_s should be a multiply of 32B"); using dtype_scale = dtype_scale_; using dtype_zero_pt = dtype_zero_pt_; - static constexpr quant_mode quant_type = quant_type_; + static constexpr quant_mode quant_mode = quant_info_.quant_mode; + static constexpr bool is_col_major_b = + quant_info_.weight_mem_layout == mem_layout::col_major; - static constexpr uint32_t block_size_y_a = 4; - static constexpr uint32_t block_bytes_x_a = 256; + static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16; + static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_size_x_b = 1; - static constexpr uint32_t block_bytes_y_b = 256; + static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32; + static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index c0b4157ec..52cfecfd7 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -36,8 +36,7 @@ template < typename mem_desc_b_t_, typename dtype_scale_, typename dtype_zero_pt_, - uint32_t dequant_s_, - quant_mode quant_type_, + quant_info quant_info_, mma_engine mma_engine_, typename pre_processing_t_, gpu_arch arch_tag_> @@ -47,8 +46,7 @@ class gemm_t< perf_tuning_knob_, dtype_scale_, dtype_zero_pt_, - quant_type_, - dequant_s_, + quant_info_, mma_engine_, arch_tag_>, tile_shape_, // tile shape of workgroup-level gemm @@ -66,8 +64,7 @@ class gemm_t< perf_tuning_knob_, dtype_scale_, dtype_zero_pt_, - quant_type_, - dequant_s_, + quant_info_, mma_engine_, arch_tag_>; static constexpr uint32_t k_stride = compute_policy::k_stride; @@ -80,6 +77,7 @@ class gemm_t< constexpr static gpu_arch arch_tag = compute_policy::arch_tag; static constexpr uint32_t dequant_s = compute_policy::dequant_s; + static constexpr quant_mode quant_mode = compute_policy::quant_mode; using dtype_b = typename mem_desc_b_t::dtype; using dtype_zero_pt = typename compute_policy::dtype_zero_pt; static constexpr uint32_t pack_ratio = sizeof(dtype_b) * 2; @@ -328,7 +326,7 @@ class gemm_t< scale_t, zero_pt_t, dequant_s, - quant_type_>; + quant_mode>; static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; @@ -531,7 +529,7 @@ class gemm_t< subgroup::tile_prefetch( scale_prefetch_payload); if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( zero_pt_prefetch_payload); @@ -545,7 +543,7 @@ class gemm_t< scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { zero_pt_prefetch_payload .template update_tdesc( zero_pt_t::tile_size_y); @@ -575,7 +573,7 @@ class gemm_t< subgroup::tile_load( scale, scale_payload); if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { subgroup::tile_load( zero_pt, zero_pt_payload); } @@ -590,7 +588,7 @@ class gemm_t< subgroup::tile_prefetch( scale_prefetch_payload); if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( zero_pt_prefetch_payload); @@ -604,7 +602,7 @@ class gemm_t< scale_payload.template update_tdesc(scale_t::tile_size_y); } if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { if (tile_k_idx % zero_pt_addr_update_freq == 0) { zero_pt_payload.template update_tdesc( zero_pt_t::tile_size_y); @@ -619,7 +617,7 @@ class gemm_t< scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { zero_pt_prefetch_payload .template update_tdesc( zero_pt_t::tile_size_y); @@ -717,7 +715,7 @@ class gemm_t< // (offset_y_in_tile) / dequant_s * scale_t::block_size_x + // offset_x_in_tile; - // if constexpr (compute_policy::quant_type == + // if constexpr (compute_policy::quant_mode == // quant_mode::S4_ASYM) { // uint32_t zero_pt_idx = // offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + @@ -739,7 +737,7 @@ class gemm_t< // cvt_blk_i8.xetla_select(jj * block_size_y_b + // ii) - zero_pt_i8; // } else if constexpr ( - // compute_policy::quant_type == + // compute_policy::quant_mode == // quant_mode::S4_FULLRANGE_NO_ZP) { // cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = // cvt_blk_i8.xetla_select(jj * block_size_y_b + @@ -791,7 +789,7 @@ class gemm_t< xetla_vector cvt_blk; xetla_vector cvt_blk_i32; - if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { + if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) { auto zero_pt_vec = zero_pt.reg .xetla_select( scale_block_id * zero_pt_t::block_size_x) @@ -815,7 +813,7 @@ class gemm_t< zero_pt_blk.xetla_format()); } if constexpr ( - compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { xetla_vector cvt_blk_i8; cvt_blk_i8.xetla_select(0) = matB_blk & 0x0f; diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 358092bc1..4801ccc67 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -159,7 +159,7 @@ class gemm_universal_t< /// @brief GEMM arguments. /// This is the interface for users to pass the application-related runtime /// variables. - template + template struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). @@ -295,7 +295,7 @@ class gemm_universal_t< } }; template <> - struct arguments_t { + struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). uint32_t matrix_m; @@ -566,7 +566,7 @@ class gemm_universal_t< implementable &= ((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0)); if constexpr ( - gemm_t::compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { implementable &= (args.zero_pt_ld % pack_ratio == 0); } @@ -668,7 +668,7 @@ class gemm_universal_t< uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride; gemm_args_t gemm_args; if constexpr ( - gemm_t::compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { gemm_args = gemm_args_t( mem_desc_a, mem_desc_b, diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index f11cd2e9f..c76d0fc26 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -100,7 +100,7 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr bool reg_transpose = tile_desc::reg_transpose; static constexpr bool mem_transpose = payload_t::mem_transpose; - static constexpr bool trans = reg_transpose ^ mem_transpose; + static constexpr bool trans = payload_t::trans; static constexpr uint32_t scale_factor = payload_t::scale_factor; static constexpr bool mem_transform = payload_t::mem_transform; @@ -535,9 +535,7 @@ tile_load(tile_t& tile, payload_t& payload) { // } } - if constexpr ( - payload_t::trans && - !(std::is_same_v || std::is_same_v)) { + if constexpr (payload_t::trans) { SW_BARRIER(); tile_transpose(tile); } @@ -604,9 +602,7 @@ tile_load(tile_t& tile, payload_t& payload) { } } - if constexpr ( - payload_t::trans && - !(std::is_same_v || std::is_same_v)) { + if constexpr (payload_t::trans) { SW_BARRIER(); tile_transpose(tile); } diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 21b80a9f0..6d0417f43 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -65,11 +65,14 @@ struct mem_payload_t< mem_payload_t; public: - static constexpr bool mem_transpose = memory_layout == mem_layout::col_major; + static constexpr bool mem_transpose = + memory_layout == mem_layout::col_major && + !(std::is_same_v || std::is_same_v); static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; + static constexpr bool trans = mem_transpose ^ reg_transpose; static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose && diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index b59dfe935..1945184c9 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -57,7 +57,7 @@ template < typename scale_t, typename zero_pt_t, uint32_t dequant_s, - quant_mode quant_type> + quant_mode quant_mode> struct dequant_int4_weight_t { struct arguments_t { uint32_t wg_start_m; @@ -130,7 +130,7 @@ struct dequant_int4_weight_t { (offset_y_in_tile) / dequant_s * scale_t::block_size_x + offset_x_in_tile; - if constexpr (quant_type == quant_mode::S4_ASYM) { + if constexpr (quant_mode == quant_mode::S4_ASYM) { uint32_t zero_pt_idx = offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + offset_x_in_tile / pack_ratio; @@ -149,7 +149,7 @@ struct dequant_int4_weight_t { cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - zero_pt_i8; - } else if constexpr (quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + } else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - int8_t(8); diff --git a/tests/integration/gemm/int4_dequantization/main.cpp b/tests/integration/gemm/int4_dequantization/main.cpp index ab120d187..18e40ded5 100644 --- a/tests/integration/gemm/int4_dequantization/main.cpp +++ b/tests/integration/gemm/int4_dequantization/main.cpp @@ -164,6 +164,8 @@ void dequantize_gemm_run(uint32_t iter) { constexpr size_t matrix_m = Test::mat_m; constexpr size_t matrix_n = Test::mat_n; constexpr size_t matrix_k = Test::mat_k; + + static constexpr mem_layout layout_b = Test::layout_b; constexpr uint32_t global_kslicing = Test::global_kslicing; constexpr uint32_t local_kslicing = Test::local_kslicing; @@ -227,13 +229,15 @@ void dequantize_gemm_run(uint32_t iter) { compute_attr_t; using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; + + static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b}; + using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, perf_tuning_knob, data_type_scale, data_type_zero_pt, - gpu::xetla::group::quant_mode::S4_ASYM, - dequant_s, + quant_info, mma_engine::xmx, gpu_arch::XeHpg>; using gemm_t = xetla::group:: @@ -332,7 +336,7 @@ void dequantize_gemm_run(uint32_t iter) { .wait(); // set up gemm arguments - typename gemm_op_t::template arguments_t gemm_arg( + typename gemm_op_t::template arguments_t gemm_arg( matrix_m, matrix_k, matrix_n, diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index e9d7a0fa9..69fdfc1fe 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -621,13 +621,15 @@ void dequantize_gemm_run(int iter) { using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; + static constexpr quant_info quant_info{ + quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b}; + using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, perf_tuning_knob, data_type_scale, data_type_zero_pt, - gpu::xetla::group::quant_mode::S4_FULLRANGE_NO_ZP, - dequant_s, + quant_info, Test::mma_eng, Test::arch>; @@ -794,7 +796,7 @@ void dequantize_gemm_run(int iter) { } queue.memcpy((void*)gidx_d, (void*)gidx_h, size_gidx * sizeof(uint32_t)) .wait(); - typename gemm_op_t::template arguments_t + typename gemm_op_t::template arguments_t gemm_arg( matrix_m, matrix_k, @@ -901,7 +903,7 @@ void dequantize_gemm_run(int iter) { free(A_d_shuf, context); } if constexpr (Feature::feature == optional_feature::NONE) { - typename gemm_op_t::template arguments_t + typename gemm_op_t::template arguments_t gemm_arg( matrix_m, matrix_k, diff --git a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp index b64928d8a..1c42454df 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp @@ -318,7 +318,7 @@ void dequantize_gemm_run(int iter) { constexpr size_t matrix_k = Test::mat_k; constexpr uint32_t global_kslicing = Test::global_kslicing; constexpr uint32_t local_kslicing = Test::local_kslicing; - + static constexpr mem_layout layout_b = Test::layout_b; constexpr size_t wg_tile_m = Test::wg_m; constexpr size_t wg_tile_n = Test::wg_n; constexpr size_t sg_tile_m = Test::sg_m; @@ -368,7 +368,7 @@ void dequantize_gemm_run(int iter) { DEVICE_MEM_ALIGNMENT / sizeof(data_type_a)>; using mem_desc_b_t = xetla::mem_desc_t< data_type_b, - mem_layout::row_major, + layout_b, mem_space::global, DEVICE_MEM_ALIGNMENT / sizeof(data_type_b)>; using mem_desc_c_t = xetla::mem_desc_t< @@ -387,14 +387,15 @@ void dequantize_gemm_run(int iter) { compute_attr_t; using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; + static constexpr quant_info quant_info{ + quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, perf_tuning_knob, data_type_scale, data_type_zero_pt, - gpu::xetla::group::quant_mode::S4_FULLRANGE_NO_ZP, - dequant_s, + quant_info, mma_engine::xmx, gpu_arch::XeHpc>; @@ -531,7 +532,7 @@ void dequantize_gemm_run(int iter) { // It accepts the base pointer to matrix D, and its dimensions {bias_d, bias_add_shape}}); - typename gemm_op_t::template arguments_t gemm_arg( + typename gemm_op_t::template arguments_t gemm_arg( matrix_m, matrix_k, matrix_n, diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 117fe127e..d9ed21e01 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -38,9 +38,9 @@ class test_col_major_1 { static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024 / 1; - static constexpr size_t dequant_s = 128; - // static constexpr quant_mode quant_type = quant_mode::S4_ASYM; - static constexpr quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP; + static constexpr size_t dequant_s = 131072; + // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; + static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -63,7 +63,7 @@ class test_col_major_2 { static constexpr size_t sg_m = 4; static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024; - static constexpr size_t dequant_s = 128; + static constexpr size_t dequant_s = 4096; static constexpr size_t local_kslicing = 1; static constexpr size_t global_kslicing = 1; @@ -120,7 +120,7 @@ int gemm_result_validate( } template < - quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP, + quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -134,7 +134,7 @@ std::vector convert_int4( int8_t zero_pt_i8 = zero_pt & 0xf; for (uint32_t i = 0; i < dequant_fp16.size(); i++) { int8_t dequant_8bit = data_b & 0xf; - if constexpr (quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { dequant_fp16[i] = scale * (dequant_8bit - 8); } else { dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); @@ -147,7 +147,7 @@ std::vector convert_int4( template < size_t dequant_s, mem_layout layout_b = mem_layout::col_major, - quant_mode quant_type = quant_mode::S4_FULLRANGE_NO_ZP, + quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, typename data_type_acc_in = fp16, typename data_type_b, typename data_type_scale, @@ -174,7 +174,7 @@ std::vector dequantize_weight( int start_out = layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; for (uint32_t jj = 0; jj < step; jj++) { - std::vector dequant_fp16 = convert_int4( + std::vector dequant_fp16 = convert_int4( b[start_b_in + jj], scale[start_scale_in], zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio))); @@ -210,8 +210,8 @@ void dequantize_gemv_run(int iter) { constexpr size_t sg_tile_m = Test::sg_m; constexpr size_t sg_tile_n = Test::sg_n; constexpr size_t sg_tile_k = Test::sg_k; - constexpr size_t dequant_s = Test::dequant_s; - constexpr quant_mode quant_type = Test::quant_type; + constexpr size_t dequant_s = std::min(Test::dequant_s, matrix_k); + constexpr quant_mode quant_mode = Test::quant_mode; using data_type_a = typename Test::data_type_a; using data_type_b = typename Test::data_type_b; using data_type_c = typename Test::data_type_c; @@ -287,14 +287,13 @@ void dequantize_gemv_run(int iter) { compute_attr_t; using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; - + static constexpr quant_info quant_info{quant_mode, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, perf_tuning_knob, data_type_scale, data_type_zero_pt, - quant_type, - dequant_s, + quant_info, Test::mma_eng, Test::arch>; @@ -474,10 +473,10 @@ void dequantize_gemv_run(int iter) { {// epilogue_args init list // It accepts the base pointer to matrix D, and its dimensions {bias_d, bias_add_shape}}); - typename gemm_op_t::template arguments_t gemm_arg; - if constexpr (compute_policy::quant_type == S4_FULLRANGE_NO_ZP) { + typename gemm_op_t::template arguments_t gemm_arg; + if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { gemm_arg = - typename gemm_op_t::template arguments_t( + typename gemm_op_t::template arguments_t( matrix_m, matrix_k, matrix_n, @@ -492,9 +491,9 @@ void dequantize_gemv_run(int iter) { Acc_d, Cnt_d, epilogue_args); - } else if constexpr (compute_policy::quant_type == S4_ASYM) { + } else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) { gemm_arg = - typename gemm_op_t::template arguments_t( + typename gemm_op_t::template arguments_t( matrix_m, matrix_k, matrix_n, @@ -513,11 +512,11 @@ void dequantize_gemv_run(int iter) { epilogue_args); } cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - // if (!gemm_op_t::can_implement(gemm_arg)) { - // std::cout << "The arguments cannot be supported, aborting ... " - // << std::endl; - // FAIL(); - // } + if (!gemm_op_t::can_implement(gemm_arg)) { + std::cout << "The arguments cannot be supported, aborting ... " + << std::endl; + FAIL(); + } size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); @@ -553,7 +552,7 @@ void dequantize_gemv_run(int iter) { prof.print_profiling_result(profiling_selector::GPU); // check result std::vector dequantize_b = - dequantize_weight( + dequantize_weight( matrix_k, matrix_n, B_h, scale_h, zero_pt_h); queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); From b2dfad5e5f687d38bf6b47c5c811c8655ca1f0e6 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Mon, 17 Jun 2024 14:52:36 +0000 Subject: [PATCH 27/34] save(fix HPC 2D load) --- .../softmax.hpp | 11 +- .../09_gate_recurrent_unit/kernel_func.hpp | 2 +- include/common/core/memory.hpp | 13 +- .../fused_op/row_reduction_fused_op_xe.hpp | 18 ++- .../group/gemm/impl/int4_dequantize_xe.hpp | 22 ++- .../group/reduction/row_reduce_store_xe.hpp | 12 +- .../col_major_shuf/col_major_shuf_xe.hpp | 4 +- .../data_transformer/data_transformer_xe.hpp | 5 +- .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 24 ++-- .../kernel/layer_norm/layer_norm_bwd_xe.hpp | 18 ++- .../kernel/layer_norm/layer_norm_fwd_xe.hpp | 20 ++- .../mha_core_attention/mha_attn_reg.hpp | 133 +++++++++--------- .../mha_core_attention/mha_core_attn.hpp | 54 ++++--- .../kernel/reduction/row_reduction_xe.hpp | 6 +- include/group/cooperative_reduction.hpp | 10 +- include/group/epilogue/impl/tile_op_xe.hpp | 6 +- include/group/gemm/impl/default_fpu_xe.hpp | 4 +- include/group/gemm/impl/default_xmx_xe.hpp | 4 +- include/group/reduction/reduction_xe.hpp | 9 +- include/group/softmax/impl/softmax_bwd_xe.hpp | 2 +- include/subgroup/tile/common.hpp | 17 +-- include/subgroup/tile/impl/load_xe.hpp | 2 +- include/subgroup/tile/impl/store_xe.hpp | 1 - .../subgroup/tile/impl/tile_op_functor.hpp | 22 +-- tests/integration/gemv/int4/main.cpp | 2 +- third_party/pybind11 | 1 + 26 files changed, 236 insertions(+), 186 deletions(-) create mode 160000 third_party/pybind11 diff --git a/examples/08_scaled_dot_product_attention/softmax.hpp b/examples/08_scaled_dot_product_attention/softmax.hpp index a14386efe..40f986a70 100644 --- a/examples/08_scaled_dot_product_attention/softmax.hpp +++ b/examples/08_scaled_dot_product_attention/softmax.hpp @@ -60,18 +60,21 @@ struct xetla_softmax_fwd_t { using softmax_tile_desc_t = subgroup:: tile_desc_t; using softmax_load_t = subgroup::tile_t; + using mem_desc_in_t = mem_desc_t; using softmax_load_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, softmax_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; // this tile will store the softmax result to global memory using softmax_store_t = subgroup::tile_t; + using mem_desc_out_t = + mem_desc_t; using softmax_store_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_out_t, softmax_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; struct arguments_t { diff --git a/examples/09_gate_recurrent_unit/kernel_func.hpp b/examples/09_gate_recurrent_unit/kernel_func.hpp index b2e3994c3..0dd76188a 100644 --- a/examples/09_gate_recurrent_unit/kernel_func.hpp +++ b/examples/09_gate_recurrent_unit/kernel_func.hpp @@ -156,7 +156,7 @@ struct gru_layer { using mat_hidden_payload_t = mem_payload_t< mem_desc_a_t, matC_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; using matC_payload_t = mem_payload_t< mem_desc_c_t, diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index 8bd90eab4..1927ce97f 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -356,9 +356,8 @@ __XETLA_API xetla_vector xetla_load_global( __ESIMD_NS::cache_hint_L2, __ESIMD_NS::alignment}; if constexpr (sizeof(T) * N < sizeof(uint32_t)) { - auto padding_load = __ESIMD_NS::block_load( - ptr, byte_offset, props); - return padding_load.xetla_select(0); + xetla_vector offsets(byte_offset, sizeof(T)); + return __ESIMD_NS::gather(ptr, offsets); } else { return __ESIMD_NS::block_load(ptr, byte_offset, props); } @@ -501,7 +500,13 @@ __XETLA_API void xetla_store_global( __ESIMD_NS::cache_hint_L1, __ESIMD_NS::cache_hint_L2, __ESIMD_NS::alignment}; - __ESIMD_NS::block_store(ptr, byte_offset, vals, props); + + if constexpr (sizeof(T) * N < sizeof(uint32_t)) { + xetla_vector offsets(byte_offset, sizeof(T)); + return __ESIMD_NS::scatter(ptr, offsets, vals); + } else { + __ESIMD_NS::block_store(ptr, byte_offset, vals, props); + } } /// @brief Stateless scattered atomic (0 src). diff --git a/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp b/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp index 611a09b8c..142c55e18 100644 --- a/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp +++ b/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp @@ -139,10 +139,12 @@ struct row_reduction_fused_op_t< block_size_y, reg_layout::tiled>; using dgelu_w_in_t = subgroup::tile_t; + using mem_desc_in_t = + mem_desc_t; using dgelu_w_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, dgelu_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using dgelu_x_out_t = subgroup::tile_t; using dgelu_x_out_payload_t = subgroup::mem_payload_t< @@ -234,17 +236,21 @@ struct row_reduction_fused_op_t< block_size_y, reg_layout::tiled>; using mask_in_t = subgroup::tile_t; + using mem_desc_mask_t = + mem_desc_t; using mask_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_mask_t, reduction_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using dropout_bwd_out_t = subgroup::tile_t; + using mem_desc_out_t = + mem_desc_t; using dropout_bwd_out_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_out_t, reduction_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; if (dropout_prob != 0) { mask_in_t mask_in; diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 52cfecfd7..3d3ebf019 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -174,7 +174,7 @@ class gemm_t< using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, matA_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matA_acc_t = subgroup::tile_t; using matA_prefetch_payload_t = subgroup:: @@ -204,9 +204,13 @@ class gemm_t< using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, matB_tile_desc_t, - subgroup::msg_type_v, - // subgroup::msg_type_v, + subgroup::msg_type_v< + matB_tile_desc_t, + mem_desc_t< + typename mem_desc_b_t::dtype, + mem_layout::row_major, + mem_desc_b_t::space>>, + // subgroup::msg_type_v, arch_tag>; using matB_prefetch_payload_t = subgroup:: prefetch_payload_t; @@ -282,10 +286,7 @@ class gemm_t< using scale_payload_t = subgroup::mem_payload_t< mem_desc_scale_t, scale_tile_desc_t, - subgroup::msg_type_v< - scale_tile_desc_t, - mem_space::global, - mem_desc_scale_t::layout>, + subgroup::msg_type_v, arch_tag>; // compress int4 along N dimensions @@ -300,10 +301,7 @@ class gemm_t< using zero_pt_payload_t = subgroup::mem_payload_t< mem_desc_zero_pt_t, zero_pt_tile_desc_t, - subgroup::msg_type_v< - zero_pt_tile_desc_t, - mem_space::global, - mem_desc_zero_pt_t::layout>, + subgroup::msg_type_v, arch_tag>; using scale_prefetch_payload_t = subgroup:: prefetch_payload_t; diff --git a/include/experimental/group/reduction/row_reduce_store_xe.hpp b/include/experimental/group/reduction/row_reduce_store_xe.hpp index a38d4e5bb..29a2474ef 100644 --- a/include/experimental/group/reduction/row_reduce_store_xe.hpp +++ b/include/experimental/group/reduction/row_reduce_store_xe.hpp @@ -58,10 +58,12 @@ struct group_row_reduce_store_t< using local_st_tile_desc_t = subgroup::tile_desc_t; using local_st_t = subgroup::tile_t; + using mem_desc_acc = + mem_desc_t; using local_st_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_acc, local_st_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using local_ld_tile_desc_t = subgroup::tile_desc_t< local_tile_size_x, @@ -70,10 +72,12 @@ struct group_row_reduce_store_t< wg_size_y, reg_layout::tiled>; using local_ld_t = subgroup::tile_t; + using mem_desc_ld_t = + mem_desc_t; using local_ld_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_ld_t, local_ld_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; // If the local tile size is small, we still can use 2D block store diff --git a/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp b/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp index 0395a77fd..6b860686d 100644 --- a/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp +++ b/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp @@ -83,7 +83,7 @@ struct col_major_shuf_t< using store_tile_payload_t = subgroup::mem_payload_t< mem_desc_store_tile_t, store_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_>; using mem_desc_gidx_t = mem_desc_t< @@ -97,7 +97,7 @@ struct col_major_shuf_t< using gidx_payload_t = subgroup::mem_payload_t< mem_desc_gidx_t, gidx_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_>; struct arguments_t { diff --git a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp index 4cbd1498a..2ccf069d1 100644 --- a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp +++ b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp @@ -122,10 +122,11 @@ struct xetla_data_transformer< block_size_y, in_reg_layout>; using global_ld_t = subgroup::tile_t; + using mem_desc_ld_t = mem_desc_t; using global_ld_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_ld_t, global_ld_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using global_st_tile_desc_t = subgroup::tile_desc_t< diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 4801ccc67..862cd30f8 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -550,18 +550,18 @@ class gemm_universal_t< args.matB_base.base, args.matB_ld / pack_ratio); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); - } else { - implementable &= kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m, + // args.matC_ld); + // } else { + // implementable &= kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } // check for int4x2 implementable &= ((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0)); diff --git a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp index f4c2fff61..b67d7f997 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp @@ -92,20 +92,26 @@ struct layer_norm_bwd_t< using gamma_in_t = subgroup::tile_t; using dx_out_t = subgroup::tile_t; + using mem_desc_y_t = + mem_desc_t; using dy_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_y_t, ln_bwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_x_t = + mem_desc_t; using x_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_x_t, ln_bwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_weight_t = + mem_desc_t; using gamma_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_weight_t, ln_bwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using dx_out_payload_t = subgroup::mem_payload_t< mem_desc_t, diff --git a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp index 08ad1eaf3..1726ea0bc 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp @@ -101,23 +101,29 @@ struct layer_norm_fwd_t< using beta_in_t = subgroup::tile_t; using y_out_t = subgroup::tile_t; + using mem_desc_x_t = + mem_desc_t; using x_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_x_t, ln_fwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_weight_t = + mem_desc_t; using gamma_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_weight_t, ln_fwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using beta_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_weight_t, ln_fwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_y_t = + mem_desc_t; using y_out_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_y_t, ln_fwd_tile_desc_t, msg_type::block_1d, gpu_arch::XeHpc>; diff --git a/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp b/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp index 0dac16a9d..398e6df96 100644 --- a/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp +++ b/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp @@ -226,54 +226,55 @@ struct xetla_mha_attn_reg_fwd_t { using matC_16x2048_t = subgroup::tile_t; using matC_128x64_t = subgroup::tile_t; + using mem_desc_c_t = mem_desc_t; using matC_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_64x384_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_64x512_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_32x1024_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_16x2048_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_128x128_t = @@ -288,40 +289,41 @@ struct xetla_mha_attn_reg_fwd_t { subgroup::tile_t; using matDpotMk_128x64_t = subgroup::tile_t; + using mem_desc_dpot_t = mem_desc_t; using matDpotMk_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_128x128_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_128x256_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_64x384_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_64x384_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_64x512_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_64x512_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_32x1024_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_32x1024_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_16x2048_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_16x2048_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_128x64_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_128x64_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; /// @brief Arguments for xetla_softmax_fwd_t::run. @@ -1780,47 +1782,48 @@ struct xetla_mha_attn_reg_bwd_t { using matC_32x1024_t = subgroup::tile_t; using matC_16x2048_t = subgroup::tile_t; + using mem_desc_c_t = mem_desc_t; using matC_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_64x384_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_64x512_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_32x1024_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_16x2048_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_tile_desc_t = subgroup::tile_desc_t< @@ -1862,39 +1865,42 @@ struct xetla_mha_attn_reg_bwd_t { subgroup::tile_t; using matC_256x64_trnp_af_t = subgroup::tile_t; - + using mem_desc_bot_t = mem_desc_t; using matC_128x64_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_256x64_trnp_a_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_256x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matW_128x128_t = subgroup::tile_t; @@ -1904,35 +1910,36 @@ struct xetla_mha_attn_reg_bwd_t { using matW_32x1024_t = subgroup::tile_t; using matW_16x2048_t = subgroup::tile_t; + using mem_desc_w_t = mem_desc_t; using matW_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_128x128_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_128x256_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_64x384_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_64x384_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_64x512_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_64x512_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_32x1024_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_32x1024_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_16x2048_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_16x2048_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; #if 0 diff --git a/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp b/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp index e6764ff3c..41dfa7606 100644 --- a/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp +++ b/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp @@ -188,17 +188,19 @@ struct xetla_mha_core_attn_fwd_t { reg_layout::tiled>; using matElem_ld_t = gpu::xetla::subgroup::tile_t; + using mem_desc_elem_ld_t = + mem_desc_t; using matElem_ld_payload_t = gpu::xetla::subgroup::mem_payload_t< - mem_desc_t, + mem_desc_elem_ld_t, matElem_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matElem_st_t = gpu::xetla::subgroup::tile_t; using matElem_st_payload_t = gpu::xetla::subgroup::mem_payload_t< - mem_desc_t, + mem_desc_elem_ld_t, matElem_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matElem_reg_t = gpu::xetla::subgroup::tile_t< float, @@ -1065,55 +1067,60 @@ struct xetla_mha_core_attn_bwd_t { subgroup::tile_t; using matC_256x64_trnp_af_t = subgroup::tile_t; - + using mem_desc_c_t = mem_desc_t; using matC_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_bot_t = mem_desc_t; using matC_128x64_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_256x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_256x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; // 512 = 16x32 or 8x64 @@ -1127,13 +1134,16 @@ struct xetla_mha_core_attn_bwd_t { gpu::xetla::subgroup::tile_t; using matElem_st_t = gpu::xetla::subgroup::tile_t; + + using mem_desc_elem_ld_t = + mem_desc_t; using matElem_ld_payload_t = gpu::xetla::subgroup::mem_payload_t< - mem_desc_t, + mem_desc_elem_ld_t, matElem_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matElem_st_payload_t = gpu::xetla::subgroup::mem_payload_t< - mem_desc_t, + mem_desc_elem_ld_t, matElem_tile_desc_t, msg_type::block_2d, gpu_arch::XeHpc>; diff --git a/include/experimental/kernel/reduction/row_reduction_xe.hpp b/include/experimental/kernel/reduction/row_reduction_xe.hpp index 1d24d50b0..c2a4a11c9 100644 --- a/include/experimental/kernel/reduction/row_reduction_xe.hpp +++ b/include/experimental/kernel/reduction/row_reduction_xe.hpp @@ -108,10 +108,12 @@ struct xetla_row_reduction_t< block_size_y, reg_layout::tiled>; using global_ld_t = subgroup::tile_t; + using mem_desc_in_t = + mem_desc_t; using global_ld_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, global_ld_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using mat_buffer_t = subgroup::tile_t< dtype_acc, diff --git a/include/group/cooperative_reduction.hpp b/include/group/cooperative_reduction.hpp index b5ab96d23..ecdc11c62 100644 --- a/include/group/cooperative_reduction.hpp +++ b/include/group/cooperative_reduction.hpp @@ -104,10 +104,12 @@ class cooperative_reduce_t< src_block_size_y, reg_layout::tiled>; using local_st_tile_t = subgroup::tile_t; + using mem_desc_st_t = + mem_desc_t; using local_st_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_st_t, local_st_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using local_ld_tile_desc_t = subgroup::tile_desc_t< tile_size_x, @@ -116,10 +118,12 @@ class cooperative_reduce_t< block_size_y, reg_layout::tiled>; using local_ld_tile_t = subgroup::tile_t; + using mem_desc_ld_t = + mem_desc_t; using local_ld_payload_t = subgroup::mem_payload_t< mem_desc_t, local_ld_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; public: diff --git a/include/group/epilogue/impl/tile_op_xe.hpp b/include/group/epilogue/impl/tile_op_xe.hpp index 656cdabde..d9afac465 100644 --- a/include/group/epilogue/impl/tile_op_xe.hpp +++ b/include/group/epilogue/impl/tile_op_xe.hpp @@ -106,10 +106,6 @@ class epilogue_t< } public: - static constexpr msg_type msg_type_c = - (mem_space_c == mem_space::global ? msg_type::block_2d - : msg_type::scatter); - /// @brief Default epilogue. /// 1) Call tile_op/chained_tile_op 2) Convert dtype_acc to dtype_c /// 3) Overwrite/reduce_sum to memory. @@ -131,6 +127,8 @@ class epilogue_t< uint32_t nbarrier_base = 0) { using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; + static constexpr msg_type msg_type_c = + subgroup::msg_type_v; using matC_payload_t = subgroup:: mem_payload_t; update_sg_tile_tdesc(g, mem_desc_c); diff --git a/include/group/gemm/impl/default_fpu_xe.hpp b/include/group/gemm/impl/default_fpu_xe.hpp index add8e6790..7345ea322 100644 --- a/include/group/gemm/impl/default_fpu_xe.hpp +++ b/include/group/gemm/impl/default_fpu_xe.hpp @@ -150,7 +150,7 @@ class gemm_t< using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, matA_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; // the tile size in register may bigger than in memory because of the padding using matA_acc_t = subgroup::tile_t; @@ -171,7 +171,7 @@ class gemm_t< using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, matB_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matB_acc_t = subgroup::tile_t; using matB_prefetch_payload_t = subgroup::prefetch_payload_t< diff --git a/include/group/gemm/impl/default_xmx_xe.hpp b/include/group/gemm/impl/default_xmx_xe.hpp index 75a0ef79c..0626f2310 100644 --- a/include/group/gemm/impl/default_xmx_xe.hpp +++ b/include/group/gemm/impl/default_xmx_xe.hpp @@ -136,7 +136,7 @@ class gemm_t< using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, matA_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matA_acc_t = subgroup::tile_t; using matA_prefetch_payload_t = subgroup::prefetch_payload_t< @@ -157,7 +157,7 @@ class gemm_t< using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, matB_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matB_acc_t = subgroup::tile_t; using matB_prefetch_payload_t = subgroup::prefetch_payload_t< diff --git a/include/group/reduction/reduction_xe.hpp b/include/group/reduction/reduction_xe.hpp index ee39acc80..d07873551 100644 --- a/include/group/reduction/reduction_xe.hpp +++ b/include/group/reduction/reduction_xe.hpp @@ -41,14 +41,17 @@ struct group_reduce_t { subgroup::tile_desc_t; using local_ld_t = subgroup::tile_t; using local_st_t = subgroup::tile_t; + using mem_desc_ld_t = mem_desc_t; using local_ld_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_ld_t, local_ld_tile_desc, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_st_t = mem_desc_t; using local_st_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_st_t, local_st_tile_desc, + // subgroup::msg_type_v, msg_type::block_1d, gpu_arch::XeHpc>; inline group_reduce_t() = default; diff --git a/include/group/softmax/impl/softmax_bwd_xe.hpp b/include/group/softmax/impl/softmax_bwd_xe.hpp index dd31f9f1e..84d41643e 100644 --- a/include/group/softmax/impl/softmax_bwd_xe.hpp +++ b/include/group/softmax/impl/softmax_bwd_xe.hpp @@ -105,7 +105,7 @@ class softmax_t< using mat_in_payload_t = subgroup::mem_payload_t< mem_desc_in_t, mat_in_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; int32_t sg_idx = g.get_id() % wg_size_x; diff --git a/include/subgroup/tile/common.hpp b/include/subgroup/tile/common.hpp index 8c851028f..91c3c3c3c 100644 --- a/include/subgroup/tile/common.hpp +++ b/include/subgroup/tile/common.hpp @@ -313,11 +313,12 @@ struct is_floating_to_integer { is_integral::value; }; -template < - typename tile_desc_, - mem_space memory_space, - mem_layout memory_layout = mem_layout::row_major> +template struct msg_type_query { + using dtype = mem_desc_::dtype; + static constexpr mem_layout memory_layout = mem_desc_::layout; + static constexpr mem_space memory_space = mem_desc_::space; + static constexpr msg_type value = memory_space == mem_space::global ? (((tile_desc_::tile_size_y == 1 && memory_layout == mem_layout::row_major) || @@ -331,12 +332,8 @@ struct msg_type_query { : msg_type::scatter); }; -template < - typename tile_desc_, - mem_space memory_space, - mem_layout memory_layout = mem_layout::row_major> -constexpr msg_type msg_type_v = - msg_type_query::value; +template +constexpr msg_type msg_type_v = msg_type_query::value; template < typename dtype, diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index c76d0fc26..786cfac06 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -219,7 +219,7 @@ tile_load(tile_t& tile, payload_t& payload) { .xetla_format>() = reg_tmp .xetla_format< - load_dtype, + native_type_t, block_size_x / scale_factor, ld_blk_height>() .xetla_select< diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index baa663076..886a8b595 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -510,7 +510,6 @@ tile_store(tile_t& tile, payload_t& payload) { xetla_mask pred_y = channel_index < payload.height_in_elems; - xetla_store_global< store_dtype, payload_t::simd_exec_size, diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 1945184c9..060302448 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -605,7 +605,7 @@ struct bias_add_op_t< using bias_payload_t = mem_payload_t< mem_desc_bias_t, bias_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; coord_t bias_coord(coord.x, 0); mem_desc_bias_t mem_desc_bias(args.base, args.shape, bias_coord); @@ -727,7 +727,7 @@ struct scale_v_offset_v_op_t< using scale_payload_t = mem_payload_t< scale_mem_desc_t, scale_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; coord_t scale_coord(coord.x, 0); scale_mem_desc_t scale_mem_desc( @@ -743,7 +743,7 @@ struct scale_v_offset_v_op_t< using offset_payload_t = mem_payload_t< offset_mem_desc_t, offset_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; coord_t offset_coord(coord.x, 0); offset_mem_desc_t offset_mem_desc( @@ -847,7 +847,7 @@ struct scale_v_op_t< using scale_payload_t = mem_payload_t< scale_mem_desc_t, scale_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; coord_t scale_coord(coord.x, 0); scale_mem_desc_t scale_mem_desc( @@ -957,7 +957,7 @@ struct elemwise_reduce_op_t< using mat_in_payload_t = mem_payload_t< mem_desc_in_t, mat_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; using mat_in_tile_acc_t = tile_t; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); @@ -998,7 +998,7 @@ struct elemwise_reduce_op_t< using mat_tail_in_payload_t = mem_payload_t< mem_desc_in_t, mat_tail_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; using mat_tail_in_tile_acc_t = tile_t; mat_tail_in_tile_t mat_tail_in; @@ -1085,7 +1085,7 @@ struct elemwise_reduce_op_stream_k_t< using mat_in_payload_t = mem_payload_t< mem_desc_in_t, mat_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); mat_in_tile_t mat_in; @@ -1200,7 +1200,7 @@ struct dropout_op_t< using mask_in_payload_t = mem_payload_t< mem_desc_mask_t, mask_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; mem_desc_mask_t mem_desc_mask(args.base, args.shape, coord); mask_in_tile_t mask_in; @@ -1298,7 +1298,7 @@ struct rng_dropout_op_t< using mask_out_payload_t = mem_payload_t< mem_desc_mask_t, mask_out_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; if (args.prob == 0) { return; @@ -1445,7 +1445,7 @@ struct linear_op_t< using mat_in_payload_t = mem_payload_t< mem_desc_in_t, mat_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; using mat_in_tile_acc_t = tile_t; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); @@ -1489,7 +1489,7 @@ struct linear_op_t< using mat_tail_in_payload_t = mem_payload_t< mem_desc_in_t, mat_tail_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; using mat_tail_in_tile_acc_t = tile_t; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index d9ed21e01..71982b8c4 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -47,7 +47,7 @@ class test_col_major_1 { static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; + static constexpr gpu_arch arch = gpu_arch::XeHpc; using data_type_a = fp16; using data_type_b = int4x8; using data_type_c = fp16; diff --git a/third_party/pybind11 b/third_party/pybind11 new file mode 160000 index 000000000..dc9b39596 --- /dev/null +++ b/third_party/pybind11 @@ -0,0 +1 @@ +Subproject commit dc9b39596d986aeb061bd3debe52d30e2467dc48 From 8817f54e052aa3b40455f74d74e8348e25dc36a8 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Mon, 17 Jun 2024 15:27:17 +0000 Subject: [PATCH 28/34] fix XEHPC 2D load --- .../group/gemm/impl/int4_dequantize_xe.hpp | 2 +- include/subgroup/tile/impl/payload_xe.hpp | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 3d3ebf019..4ce0ab693 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -210,7 +210,7 @@ class gemm_t< typename mem_desc_b_t::dtype, mem_layout::row_major, mem_desc_b_t::space>>, - // subgroup::msg_type_v, + // subgroup::msg_type_v, arch_tag>; using matB_prefetch_payload_t = subgroup:: prefetch_payload_t; diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 6d0417f43..37b26cdcd 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -65,15 +65,14 @@ struct mem_payload_t< mem_payload_t; public: - static constexpr bool mem_transpose = - memory_layout == mem_layout::col_major && - !(std::is_same_v || std::is_same_v); + static constexpr bool mem_transpose = memory_layout == mem_layout::col_major; static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose; + static constexpr bool trans = (mem_transpose ^ reg_transpose) && + !(std::is_same_v || std::is_same_v); static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose && (register_layout == reg_layout::vnni_tiled || @@ -1094,7 +1093,7 @@ struct mem_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose && + static constexpr bool trans = (mem_transpose ^ reg_transpose) && !(std::is_same_v || std::is_same_v); static constexpr bool mem_transform = (sizeof(dtype) < 4) && @@ -1657,7 +1656,7 @@ struct prefetch_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose && + static constexpr bool trans = (mem_transpose ^ reg_transpose) && !(std::is_same_v || std::is_same_v); using prefetch_dtype = typename std::conditional< From 957c5a498c89ddc358c5fcb6c238b9e9ed8484fc Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Mon, 17 Jun 2024 16:06:11 +0000 Subject: [PATCH 29/34] fix compile for all UT --- .../multi_layer_perceptron.hpp | 28 +++++++++---------- include/kernel/gemm/impl/default_xe.hpp | 24 ++++++++-------- include/kernel/gemm/impl/stream_k_xe.hpp | 24 ++++++++-------- include/subgroup/tile/impl/load_xe.hpp | 1 - tests/integration/fmha/fmha_forward.hpp | 26 +++++++++++++---- tests/integration/fmha/fmha_utils.h | 10 +++++-- tests/integration/gemv/int4/main.cpp | 14 +++++----- .../integration/sg_dropout_op/kernel_func.hpp | 2 +- .../softmax/softmax_bwd_kernel.hpp | 17 ++++++----- .../softmax/softmax_fwd_kernel.hpp | 9 +++--- tests/unit/epilogue_tile_op/kernel_func.hpp | 6 ++-- tests/unit/tile_mma/kernel_func.hpp | 18 ++++++++---- tests/unit/tile_row_reduction/kernel_func.hpp | 8 ++++-- 13 files changed, 108 insertions(+), 79 deletions(-) diff --git a/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp b/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp index 67a454fe0..f9b4ca27f 100644 --- a/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp +++ b/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp @@ -409,20 +409,20 @@ class multi_layer_perceptron_t { args.matW_base.base, args.matW_ld); } } - if (epilogue_layer1_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_layer1_t::msg_type_c == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matB_base.base), - args.matrix_n_layer1, - args.matrix_m_layer1, - args.matB_ld); - } else { - implementable &= - kernel::general_1d::check_alignment( - args.matB_base.base, args.matB_ld); - } - } + // if (epilogue_layer1_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_layer1_t::msg_type_c == msg_type::block_2d) { + // implementable &= + // kernel::block_2d::check_tensor( + // (uint64_t)(args.matB_base.base), + // args.matrix_n_layer1, + // args.matrix_m_layer1, + // args.matB_ld); + // } else { + // implementable &= + // kernel::general_1d::check_alignment( + // args.matB_base.base, args.matB_ld); + // } + // } if (gemm_layer2_t::msg_type_a != msg_type::unaligned_2d) { if (gemm_layer2_t::msg_type_a == msg_type::block_2d) { implementable &= diff --git a/include/kernel/gemm/impl/default_xe.hpp b/include/kernel/gemm/impl/default_xe.hpp index cb6c5270b..d396d1055 100644 --- a/include/kernel/gemm/impl/default_xe.hpp +++ b/include/kernel/gemm/impl/default_xe.hpp @@ -275,18 +275,18 @@ class gemm_universal_t< args.matB_base.base, args.matB_ld); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); - } else { - implementable &= kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m, + // args.matC_ld); + // } else { + // implementable &= kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/include/kernel/gemm/impl/stream_k_xe.hpp b/include/kernel/gemm/impl/stream_k_xe.hpp index e281e53ae..01558d6ad 100644 --- a/include/kernel/gemm/impl/stream_k_xe.hpp +++ b/include/kernel/gemm/impl/stream_k_xe.hpp @@ -329,18 +329,18 @@ class gemm_universal_t< args.matB_base.base, args.matB_ld); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); - } else { - implementable &= kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m, + // args.matC_ld); + // } else { + // implementable &= kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 786cfac06..08d6b8580 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -213,7 +213,6 @@ tile_load(tile_t& tile, payload_t& payload) { trans, mem_transform, arch_tag>(tdesc); - if constexpr (reg_transpose && trans) { reg_blk.xetla_select(ii * load_elems) .xetla_format>() = diff --git a/tests/integration/fmha/fmha_forward.hpp b/tests/integration/fmha/fmha_forward.hpp index 6231c6e58..623ea6d7f 100644 --- a/tests/integration/fmha/fmha_forward.hpp +++ b/tests/integration/fmha/fmha_forward.hpp @@ -620,8 +620,12 @@ class fmha_forward_t { mem_desc_Dp_Mask_t::layout, mem_desc_Dp_Mask_t::space>, dp_mask_tile_desc_t, - subgroup:: - msg_type_v, + subgroup::msg_type_v< + dp_mask_tile_desc_t, + mem_desc_t< + uint8_t, + mem_desc_Dp_Mask_t::layout, + mem_desc_Dp_Mask_t::space>>, gpu_arch::XeHpc>; load_payload_mask_t load_payload_mask(ctx.mem_desc_Dpij); subgroup::tile_load(mask_in, load_payload_mask); @@ -722,7 +726,12 @@ class fmha_forward_t { using matOi_store_t = subgroup::mem_payload_t< mem_desc_t, matOi_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + matOi_tile_desc_t, + mem_desc_t< + scalar_t, + mem_desc_Oi_t::layout, + mem_desc_Oi_t::space>>, arch_tag>; matOi_store_t matOi_store(mem_desc_Oi); subgroup::tile_store( @@ -762,12 +771,19 @@ class fmha_forward_t { using matQi_load_t = subgroup::mem_payload_t< mem_desc_t, matQi_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + matQi_tile_desc_t, + mem_desc_t>, arch_tag>; using matQi_store_t = subgroup::mem_payload_t< mem_desc_t, matQi_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + matQi_tile_desc_t, + mem_desc_t< + scalar_t, + mem_desc_Qi_L_t::layout, + mem_desc_Qi_L_t::space>>, arch_tag>; int32_t tile_offset_x = ctx.sg_idx * kSgHm; diff --git a/tests/integration/fmha/fmha_utils.h b/tests/integration/fmha/fmha_utils.h index fc1c11909..1aef9a5f4 100644 --- a/tests/integration/fmha/fmha_utils.h +++ b/tests/integration/fmha/fmha_utils.h @@ -156,7 +156,9 @@ struct group_row_reduce_t { using load_payload_t = subgroup::mem_payload_t< mem_desc_t, load_tile_desc, - subgroup::msg_type_v, + subgroup::msg_type_v< + load_tile_desc, + mem_desc_t>, arch_tag>; xetla_nbarrier_t nbarrier; @@ -243,10 +245,12 @@ struct bias_add_op_t { using bias_tile_desc_t = subgroup:: tile_desc_t; using bias_t = subgroup::tile_t; + using mem_desc_bias_t = + mem_desc_t; using bias_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bias_t, bias_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; coord_t bias_coord(coord.x, coord.y); mem_desc_bias_t mem_desc_bias(args.base, args.shape, bias_coord); diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 71982b8c4..7ff5987d2 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -38,7 +38,7 @@ class test_col_major_1 { static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; static constexpr size_t sg_k = 1024 / 1; - static constexpr size_t dequant_s = 131072; + static constexpr size_t dequant_s = 128; // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; @@ -374,7 +374,7 @@ void dequantize_gemv_run(int iter) { for (unsigned i = 0; i < size_a; ++i) { A_h[i] = random_float(); #ifdef UT_DEBUG - A_h[i] = i; + A_h[i] = 1; // A_h[i] = layout_a == mem_layout::row_major // ? (i % matrix_k + i / matrix_k * 100) // : (i % matrix_m + i / matrix_m * 100); @@ -512,11 +512,11 @@ void dequantize_gemv_run(int iter) { epilogue_args); } cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - if (!gemm_op_t::can_implement(gemm_arg)) { - std::cout << "The arguments cannot be supported, aborting ... " - << std::endl; - FAIL(); - } + // if (!gemm_op_t::can_implement(gemm_arg)) { + // std::cout << "The arguments cannot be supported, aborting ... " + // << std::endl; + // FAIL(); + // } size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; profiling_helper prof("dequantize_gemm", ops, "gflops"); diff --git a/tests/integration/sg_dropout_op/kernel_func.hpp b/tests/integration/sg_dropout_op/kernel_func.hpp index 36f4628a9..b566195d7 100644 --- a/tests/integration/sg_dropout_op/kernel_func.hpp +++ b/tests/integration/sg_dropout_op/kernel_func.hpp @@ -66,7 +66,7 @@ struct dropout_func_t { using mat_in_payload_t = subgroup::mem_payload_t< mem_desc_in_t, tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using tile_op_t = typename std::conditional< diff --git a/tests/integration/softmax/softmax_bwd_kernel.hpp b/tests/integration/softmax/softmax_bwd_kernel.hpp index fa6fddb56..e556c67c4 100644 --- a/tests/integration/softmax/softmax_bwd_kernel.hpp +++ b/tests/integration/softmax/softmax_bwd_kernel.hpp @@ -30,11 +30,6 @@ template < uint32_t sg_n, uint32_t sg_m> struct softmax_bwd_test_func { - using mem_desc_in_t = - mem_desc_t; - using mem_desc_out_t = - mem_desc_t; - using tile_shape = group::tile_shape_t; using work_group_t = typename tile_shape::work_group_t; static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; @@ -61,17 +56,21 @@ struct softmax_bwd_test_func { reg_layout::tiled>; using matAcc_t = subgroup::tile_t; using mat_in_t = subgroup::tile_t; + using mem_desc_in_t = + mem_desc_t; using mat_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using mat_out_t = subgroup::tile_t; + using mem_desc_out_t = + mem_desc_t; using mat_out_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_out_t, tile_desc_t, - (tile_size_y > 1) ? msg_type::block_2d : msg_type::block_1d, + subgroup::msg_type_v, gpu_arch::XeHpc>; using softmax_bwd_t = group::softmax_t< diff --git a/tests/integration/softmax/softmax_fwd_kernel.hpp b/tests/integration/softmax/softmax_fwd_kernel.hpp index 5237d47e4..abbd14696 100644 --- a/tests/integration/softmax/softmax_fwd_kernel.hpp +++ b/tests/integration/softmax/softmax_fwd_kernel.hpp @@ -60,16 +60,17 @@ struct softmax_fwd_test_func { reg_layout::tiled>; using matAcc_t = subgroup::tile_t; using mat_in_t = subgroup::tile_t; + using mat_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using mat_out_t = subgroup::tile_t; using mat_out_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, tile_desc_t, - (tile_size_y > 1) ? msg_type::block_2d : msg_type::block_1d, + subgroup::msg_type_v, gpu_arch::XeHpc>; using softmax_fwd_t = group::softmax_t< diff --git a/tests/unit/epilogue_tile_op/kernel_func.hpp b/tests/unit/epilogue_tile_op/kernel_func.hpp index 1d273857e..a13481044 100644 --- a/tests/unit/epilogue_tile_op/kernel_func.hpp +++ b/tests/unit/epilogue_tile_op/kernel_func.hpp @@ -41,7 +41,7 @@ struct tile_elemwise_op_func { using matA_payload_t = mem_payload_t< mem_desc_c_t, matA_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; using tile_shape = tile_shape_t; @@ -95,7 +95,7 @@ struct tile_elemwise_op_func< using matA_payload_t = mem_payload_t< mem_desc_b_t, matA_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; using tile_shape = tile_shape_t; using epilogue_policy = epilogue_policy_tile_op< @@ -150,7 +150,7 @@ struct tile_elemwise_op_func< using matA_payload_t = mem_payload_t< mem_desc_c_t, matA_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; using tile_shape = tile_shape_t; using epilogue_policy = epilogue_policy_tile_op< diff --git a/tests/unit/tile_mma/kernel_func.hpp b/tests/unit/tile_mma/kernel_func.hpp index 8822c8ca2..d8130c4a5 100644 --- a/tests/unit/tile_mma/kernel_func.hpp +++ b/tests/unit/tile_mma/kernel_func.hpp @@ -56,20 +56,26 @@ struct tile_mma_func { using matA_t = tile_t; using matB_t = tile_t; using matC_t = tile_t; + using mem_desc_a_t = + mem_desc_t; using matA_payload_t = mem_payload_t< - mem_desc_t, + mem_desc_a_t, matA_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; + using mem_desc_b_t = + mem_desc_t; using matB_payload_t = mem_payload_t< - mem_desc_t, + mem_desc_b_t, matB_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; + using mem_desc_c_t = + mem_desc_t; using matC_payload_t = mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_tile_desc_t, - msg_type::block_2d, + msg_type_v, gpu_arch::XeHpc>; using matAcc_t = tile_t>; diff --git a/tests/unit/tile_row_reduction/kernel_func.hpp b/tests/unit/tile_row_reduction/kernel_func.hpp index e4c1b5c3e..710143de5 100644 --- a/tests/unit/tile_row_reduction/kernel_func.hpp +++ b/tests/unit/tile_row_reduction/kernel_func.hpp @@ -48,12 +48,16 @@ struct tile_row_reduction_func { using matA_payload_t = mem_payload_t< mem_desc_t, matA_tile_desc_t, - msg_type_v, + msg_type_v< + matA_tile_desc_t, + mem_desc_t>, gpu_arch::XeHpc>; using matC_payload_t = mem_payload_t< mem_desc_t, matC_tile_desc_t, - msg_type_v, + msg_type_v< + matC_tile_desc_t, + mem_desc_t>, gpu_arch::XeHpc>; matA_t matA; matC_t matC; From 5456fc0bc689802475c73d2a6773bcc89fe6a239 Mon Sep 17 00:00:00 2001 From: "Ding, Yi1" Date: Wed, 19 Jun 2024 14:16:24 +0000 Subject: [PATCH 30/34] sync ipex 20240618 --- include/experimental/common/base_types.hpp | 49 +++++++++++++++++++ include/experimental/common/common.hpp | 23 +++++++++ .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 3 +- .../kernel/layer_norm/layer_norm_fwd_xe.hpp | 2 +- include/kernel/gemm/impl/default_xe.hpp | 3 +- include/kernel/gemm/impl/stream_k_xe.hpp | 3 +- include/subgroup/tile/impl/payload_xe.hpp | 9 +--- 7 files changed, 81 insertions(+), 11 deletions(-) create mode 100644 include/experimental/common/base_types.hpp create mode 100644 include/experimental/common/common.hpp diff --git a/include/experimental/common/base_types.hpp b/include/experimental/common/base_types.hpp new file mode 100644 index 000000000..8755a8ce0 --- /dev/null +++ b/include/experimental/common/base_types.hpp @@ -0,0 +1,49 @@ +/******************************************************************************* + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +namespace gpu::xetla { + +/// @brief xetla 4bits data packed as 8bits data type. +/// 2 4bit data pack to one byte +struct int4x2 { + uint8_t data; + + operator uint8_t() const { + return data; + } + int4x2(uint8_t val) { + data = val; + } +}; + +/// @brief Used to check if the type is xetla internal data type +template <> +struct is_internal_type { + static constexpr bool value = true; +}; + +/// @brief Set uint8_t as the native data type of int4x2. +template <> +struct native_type { + using type = uint8_t; +}; + +} // namespace gpu::xetla diff --git a/include/experimental/common/common.hpp b/include/experimental/common/common.hpp new file mode 100644 index 000000000..b1cc9d38d --- /dev/null +++ b/include/experimental/common/common.hpp @@ -0,0 +1,23 @@ +/******************************************************************************* + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +/// @file +/// C++ API + +#pragma once + +#include +#include diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 862cd30f8..b0a0a9f75 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -558,7 +558,8 @@ class gemm_universal_t< // args.matrix_m, // args.matC_ld); // } else { - // implementable &= kernel::general_1d::check_alignment( + // implementable &= kernel::general_1d::check_alignment( // args.matC_base.base, args.matC_ld); // } // } diff --git a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp index 1726ea0bc..ecd6bc25b 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp @@ -326,7 +326,7 @@ struct layer_norm_fwd_t< itr_count += 1; nbarrier.wait(); - xetla_vector mu_m2_vec = + xetla_vector mu_m2_vec = xetla_load_local(slm_load_base); xetla_vector mu_vec = mu_m2_vec.xetla_select(0); diff --git a/include/kernel/gemm/impl/default_xe.hpp b/include/kernel/gemm/impl/default_xe.hpp index d396d1055..644189db2 100644 --- a/include/kernel/gemm/impl/default_xe.hpp +++ b/include/kernel/gemm/impl/default_xe.hpp @@ -283,7 +283,8 @@ class gemm_universal_t< // args.matrix_m, // args.matC_ld); // } else { - // implementable &= kernel::general_1d::check_alignment( + // implementable &= kernel::general_1d::check_alignment( // args.matC_base.base, args.matC_ld); // } // } diff --git a/include/kernel/gemm/impl/stream_k_xe.hpp b/include/kernel/gemm/impl/stream_k_xe.hpp index 01558d6ad..0a23344bf 100644 --- a/include/kernel/gemm/impl/stream_k_xe.hpp +++ b/include/kernel/gemm/impl/stream_k_xe.hpp @@ -337,7 +337,8 @@ class gemm_universal_t< // args.matrix_m, // args.matC_ld); // } else { - // implementable &= kernel::general_1d::check_alignment( + // implementable &= kernel::general_1d::check_alignment( // args.matC_base.base, args.matC_ld); // } // } diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 37b26cdcd..14427b5f1 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -86,7 +86,7 @@ struct mem_payload_t< xetla_vector payloads; inline mem_payload_t(const this_payload_t& rhs) { - this->payload = rhs.payload; + this->payloads = rhs.payloads; } inline mem_payload_t(mem_desc_t& mem_desc) { @@ -159,7 +159,7 @@ struct mem_payload_t< // ~mem_payload_t(){} inline this_payload_t& operator=(const this_payload_t& rhs) { - this->payload = rhs.payload; + this->payloads = rhs.payloads; return *this; } @@ -1739,9 +1739,6 @@ struct prefetch_payload_t< this->width_in_elems = rhs.width_in_elems; this->height_in_elems = rhs.height_in_elems; - this->step_x = rhs.step_x; - this->step_y = rhs.step_y; - this->channel_offset = rhs.channel_offset; } @@ -1756,8 +1753,6 @@ struct prefetch_payload_t< this->width_in_elems = rhs.width_in_elems; this->height_in_elems = rhs.height_in_elems; - this->step_x = rhs.step_x; - this->step_y = rhs.step_y; this->channel_offset = rhs.channel_offset; return *this; } From 9185409cc3dadad6c2390a417069aacb6b2636d3 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Wed, 19 Jun 2024 08:13:04 +0000 Subject: [PATCH 31/34] opt PVC arch --- examples/05_batch_gemm/batch_gemm.hpp | 28 ++--- .../multi_layer_perceptron.hpp | 28 ++--- include/common/core/arch_config.hpp | 4 +- include/common/core/memory.hpp | 116 +++++++++--------- include/group/cooperative_reduction.hpp | 2 +- include/group/epilogue/impl/default_xe.hpp | 29 +++-- include/group/epilogue/impl/tile_op_xe.hpp | 1 + include/kernel/gemm/impl/kslicing_xe.hpp | 25 ++-- include/subgroup/tile/common.hpp | 34 ++--- include/subgroup/tile/impl/load_xe.hpp | 57 ++++----- include/subgroup/tile/impl/store_xe.hpp | 42 +++---- tests/integration/gemm/fp16/common.hpp | 18 +-- tests/integration/gemm/fp16/kernel_func.hpp | 4 +- tests/integration/gemm/fp16/main.cpp | 2 +- tests/integration/gemv/int4/main.cpp | 4 +- 15 files changed, 187 insertions(+), 207 deletions(-) diff --git a/examples/05_batch_gemm/batch_gemm.hpp b/examples/05_batch_gemm/batch_gemm.hpp index 7c16208f4..528477d03 100644 --- a/examples/05_batch_gemm/batch_gemm.hpp +++ b/examples/05_batch_gemm/batch_gemm.hpp @@ -276,20 +276,20 @@ class batch_gemm_t { args.matB_base.base, args.matB_ld); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m * args.batch_size, - args.matC_ld); - } else { - implementable &= - kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= + // kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m * args.batch_size, + // args.matC_ld); + // } else { + // implementable &= + // kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp b/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp index f9b4ca27f..ed2fc2489 100644 --- a/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp +++ b/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp @@ -451,20 +451,20 @@ class multi_layer_perceptron_t { args.matV_base.base, args.matV_ld); } } - if (epilogue_layer2_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_layer2_t::msg_type_c == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n_layer2, - args.matrix_m_layer2, - args.matC_ld); - } else { - implementable &= - kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_layer2_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_layer2_t::msg_type_c == msg_type::block_2d) { + // implementable &= + // kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n_layer2, + // args.matrix_m_layer2, + // args.matC_ld); + // } else { + // implementable &= + // kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 3f3153090..7a1315ab9 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -100,8 +100,8 @@ struct load_store_attr_t { template <> struct load_store_attr_t { - static constexpr uint32_t max_load_vec_len = 64; - static constexpr uint32_t max_store_vec_len = 64; + static constexpr uint32_t max_load_vec_len = 512; + static constexpr uint32_t max_store_vec_len = 512; static constexpr uint32_t max_prefetch_vec_len = 64; }; diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index 1927ce97f..cd51832c6 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -256,7 +256,7 @@ constexpr __ESIMD_NS::atomic_op get_atomic_op(gpu::xetla::atomic_op ao) { /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, cache_hint L1H = cache_hint::cached, cache_hint L2H = cache_hint::cached, @@ -293,7 +293,7 @@ __XETLA_API void xetla_prefetch_global( /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, cache_hint L1H = cache_hint::cached, cache_hint L2H = cache_hint::cached> @@ -385,7 +385,7 @@ __XETLA_API xetla_vector xetla_load_global( /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, cache_hint L1H = cache_hint::none, cache_hint L2H = cache_hint::none, @@ -431,7 +431,7 @@ __XETLA_API xetla_vector xetla_load_global( /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, cache_hint L1H = cache_hint::none, cache_hint L2H = cache_hint::none, @@ -653,7 +653,7 @@ __XETLA_API void xetla_local_init() { /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, int N> __XETLA_API xetla_vector xetla_load_local( @@ -670,35 +670,31 @@ __XETLA_API xetla_vector xetla_load_local( xetla_cvt(offsets), pred); } -/// @brief SLM block load. (transposed gather with 1 channel). -/// Collects elements located at slm and returns them as a single \ref -/// xetla_vector object. -/// -/// Supported platforms: DG2, PVC -/// -/// VISA instruction: lsc_load.slm -/// -/// @tparam Ty is element type. -/// @tparam NElts is the number of elements to load per address (i.e. -/// vector_size per SIMD channel). -/// @tparam DS is the data size. -/// @param offset [in] is the zero-based offset for SLM buffer in bytes. -/// @return is a xetla_vector of type T and size NElts. -/// -template < - typename Ty, - uint8_t NElts = 1, - data_size DS = data_size::default_size> +/// Loads a contiguous block of SLM memory referenced by the given byte-offset +/// \p offset, then returns the loaded data as a simd object. +/// The generated code depends on the combination {T, N, Flags}. +/// Providing flags specifying the alignment of 16-bytes or more produces more +/// efficient code. If the alignment is smaller than 16-bytes, then less +/// efficient gather is generated. If the loaded vector is too long +/// for 1 flat-load GPU instruction, then a series of flat-loads and/or gathers +/// may be generated. +/// @tparam T Element type. +/// @tparam N Number of elements to load. +/// @tparam Flags The alignment specifier type tag. +/// @param byte_offset The byte-offset to load from. +/// @param Flags Specifies the alignment. +/// @return A vector of loaded elements. +/// +template __XETLA_API xetla_vector xetla_load_local(uint32_t offset) { using T = native_type_t; - DEBUG_INVOKE( - dbg_level::core, - core::general_1d::template check_restriction( - (uint64_t)offset)); + // DEBUG_INVOKE( + // dbg_level::core, + // core::general_1d::template + // check_restriction( + // (uint64_t)offset)); - return __ESIMD_ENS:: - lsc_slm_block_load( - offset); + return __ESIMD_NS::slm_block_load(offset); } /// @brief SLM scattered store. @@ -719,7 +715,7 @@ __XETLA_API xetla_vector xetla_load_local(uint32_t offset) { /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, int N> __XETLA_API void xetla_store_local( @@ -737,36 +733,38 @@ __XETLA_API void xetla_store_local( offsets, vals, pred); } -/// @brief SLM block store (transposed SLM scatter with 1 channel). -/// Scatters elements located to slm. -/// -/// Supported platforms: DG2, PVC -/// -/// VISA instruction: lsc_store.slm -/// -/// @tparam Ty is element type. -/// @tparam NElts is the number of elements to store per address (i.e. -/// vector_size per SIMD channel). -/// @tparam DS is the data size. -/// @param offset [in] is the zero-based offset for SLM buffer in bytes. -/// @param vals [in] is values to store. -/// -template < - typename Ty, - uint8_t NElts = 1, - data_size DS = data_size::default_size> +/// Stores elements of the vector \p vals to a contiguous block of SLM memory +/// at the given byte-offset \p offset. +/// The generated code depends on the combination {T, N, Flags}. +/// Providing flags specifying the alignment of 16-bytes or more produces more +/// efficient code. If the alignment is smaller than 16-bytes, then less +/// efficient scatter is generated. If the stored vector is too long +/// for 1 flat-store GPU instruction, then a series of flat-store and/or +/// scatters may be generated. +/// @tparam T Element type. +/// @tparam N Number of elements to store. +/// @tparam Flags The alignment specifier type tag. +/// @param offset The byte-offset to store at. +/// @param vals The vector to store. +/// @param Flags Specifies the alignment. +/// +template __XETLA_API void xetla_store_local( uint32_t offset, xetla_vector vals) { - using T = native_type_t; - DEBUG_INVOKE( - dbg_level::core, - core::general_1d::template check_restriction( - offset)); - - __ESIMD_ENS:: - lsc_slm_block_store( - offset, vals); + // using T = native_type_t; + // DEBUG_INVOKE( + // dbg_level::core, + // core::general_1d::template + // check_restriction( + // offset)); + + // __ESIMD_ENS:: + // lsc_slm_block_store( + // offset, vals); + // __ESIMD_NS::properties props{}; + + __ESIMD_NS::slm_block_store(offset, vals); } /// @brief SLM scattered atomic (0 src). diff --git a/include/group/cooperative_reduction.hpp b/include/group/cooperative_reduction.hpp index ecdc11c62..25682c0e2 100644 --- a/include/group/cooperative_reduction.hpp +++ b/include/group/cooperative_reduction.hpp @@ -95,7 +95,7 @@ class cooperative_reduce_t< static constexpr uint32_t block_size_x = gpu::xetla::subgroup::detail::gcd::value; static constexpr uint32_t block_size_y = - (tile_size_y > src_block_size_y) ? src_block_size_y : tile_size_y; + std::min(src_block_size_y, tile_size_y); using local_st_tile_desc_t = subgroup::tile_desc_t< sg_tile_n, diff --git a/include/group/epilogue/impl/default_xe.hpp b/include/group/epilogue/impl/default_xe.hpp index ab149396a..8d4061f49 100644 --- a/include/group/epilogue/impl/default_xe.hpp +++ b/include/group/epilogue/impl/default_xe.hpp @@ -70,9 +70,9 @@ class epilogue_t< } public: - static constexpr msg_type msg_type_c = - (mem_space_c == mem_space::global ? msg_type::block_2d - : msg_type::scatter); + // static constexpr msg_type msg_type_c = + // (mem_space_c == mem_space::global ? msg_type::block_2d + // : msg_type::scatter); /// @brief Default epilogue. /// 1) Convert dtype_acc to dtype_c 2) Overwrite to memory. @@ -94,6 +94,11 @@ class epilogue_t< [[maybe_unused]] uint32_t nbarrier_base = 0) { using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; + + static constexpr msg_type msg_type_c = + subgroup::msg_type_v; + using matC_payload_t = subgroup:: + mem_payload_t; using matC_payload_t = subgroup:: mem_payload_t; update_sg_tile_tdesc(g, mem_desc_c); @@ -143,9 +148,7 @@ class epilogue_t< using dtype_c = typename mem_desc_c_t::dtype; static constexpr mem_layout mem_layout_c = mem_desc_c_t::layout; static constexpr mem_space mem_space_c = mem_desc_c_t::space; - static constexpr msg_type msg_type_c = - (mem_space_c == mem_space::global ? msg_type::block_2d - : msg_type::scatter); + /// @brief Updates tile base descriptor based on the tid. __XETLA_API static void update_sg_tile_tdesc( work_group_t& g, @@ -165,8 +168,6 @@ class epilogue_t< } public: - static constexpr bool is_2d_block_c = (msg_type_c == msg_type::block_2d); - /// @brief Default epilogue. /// 1) Convert dtype_acc to dtype_c 2) Overwrite to memory. /// @tparam matAcc_t Is the type of the input tile. @@ -190,11 +191,13 @@ class epilogue_t< [[maybe_unused]] uint32_t nbarrier_base = 0) { using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; - using matC_payload_t = subgroup::mem_payload_t< - mem_desc_t, - mat_tile_desc, - msg_type_c, - arch_tag>; + + // static constexpr msg_type msg_type_c = msg_type::block_2d; + static constexpr msg_type msg_type_c = + subgroup::msg_type_v; + + using matC_payload_t = subgroup:: + mem_payload_t; update_sg_tile_tdesc(g, mem_desc_c); diff --git a/include/group/epilogue/impl/tile_op_xe.hpp b/include/group/epilogue/impl/tile_op_xe.hpp index d9afac465..3afaec107 100644 --- a/include/group/epilogue/impl/tile_op_xe.hpp +++ b/include/group/epilogue/impl/tile_op_xe.hpp @@ -127,6 +127,7 @@ class epilogue_t< uint32_t nbarrier_base = 0) { using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; + // static constexpr msg_type msg_type_c = msg_type::block_2d; static constexpr msg_type msg_type_c = subgroup::msg_type_v; using matC_payload_t = subgroup:: diff --git a/include/kernel/gemm/impl/kslicing_xe.hpp b/include/kernel/gemm/impl/kslicing_xe.hpp index 7b74226e5..f4efca3a6 100644 --- a/include/kernel/gemm/impl/kslicing_xe.hpp +++ b/include/kernel/gemm/impl/kslicing_xe.hpp @@ -387,18 +387,19 @@ class gemm_universal_t< args.matB_base.base, args.matB_ld); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); - } else { - implementable &= kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m, + // args.matC_ld); + // } else { + // implementable &= kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/include/subgroup/tile/common.hpp b/include/subgroup/tile/common.hpp index 91c3c3c3c..0a30a1416 100644 --- a/include/subgroup/tile/common.hpp +++ b/include/subgroup/tile/common.hpp @@ -163,23 +163,21 @@ template < typename tile_t> __XETLA_API typename std::enable_if_t< base_len != 0 && payload_t::memory_space == mem_space::local> -process_1d_tail( - tile_t& tile, - payload_t& payload, - uint32_t offset, - uint32_t address_offset) { - using mem_dtype = typename payload_t::mem_dtype; +process_1d_tail(tile_t& tile, payload_t& payload, uint32_t offset) { + using dtype = typename payload_t::dtype; if constexpr (remained_len >= base_len) { - auto reg_sub = - tile.reg.xetla_select(offset); + uint32_t address_offset = offset * sizeof(dtype); + auto reg_sub = tile.reg.xetla_select(offset); if constexpr (flag == process_flag::load) { - reg_sub.xetla_format() = - xetla_load_local( - payload.base_address + payload.address + address_offset); + reg_sub.xetla_format() = xetla_load_local< + dtype, + base_len / sizeof(dtype), + data_size::default_size>( + payload.base_address + payload.address + address_offset); } else { - xetla_store_local( + xetla_store_local( payload.base_address + payload.address + address_offset, - reg_sub.xetla_format()); + reg_sub.xetla_format()); } process_1d_tail< remained_len - base_len, @@ -188,13 +186,7 @@ process_1d_tail( L1, L2, payload_t, - tile_t>( - tile, - payload, - offset + base_len * payload_t::scale_factor, - address_offset + - base_len * payload_t::scale_factor * - sizeof(typename tile_t::dtype)); + tile_t>(tile, payload, offset + base_len); } else { process_1d_tail< remained_len, @@ -203,7 +195,7 @@ process_1d_tail( L1, L2, payload_t, - tile_t>(tile, payload, offset, address_offset); + tile_t>(tile, payload, offset); } } diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 08d6b8580..c3c127bd5 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -857,46 +857,35 @@ __XETLA_API typename std::enable_if_t< detail::check_load_type::is_local_block_1d_xe> tile_load(tile_t& tile, payload_t& payload) { using dtype = typename tile_t::dtype; - using tile_desc = typename tile_t::tile_desc; - using load_dtype = typename payload_t::mem_dtype; + static constexpr uint32_t load_len = tile_t::tile_elems; + static constexpr gpu_arch arch_tag = payload_t::arch_tag; - constexpr uint32_t scale_factor = payload_t::scale_factor; - constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor; - constexpr gpu_arch arch_tag = payload_t::arch_tag; using load_store_attr = load_store_attr_t; - constexpr uint32_t max_load_vec_len = load_store_attr::max_load_vec_len; + static constexpr uint32_t max_load_vec_len = + load_store_attr::max_load_vec_len; + static constexpr uint32_t max_load_vec_elems = + max_load_vec_len / sizeof(dtype); - constexpr uint32_t load_iter_steps = load_len / max_load_vec_len; -#pragma unroll - for (uint32_t i = 0; i < tile_desc::tile_size_y; i++) { - uint32_t offset_y = i * tile_desc::tile_size_x; - uint32_t address_offset_y = i * payload.pitch_in_bytes; - if constexpr (load_len >= max_load_vec_len) { + static constexpr uint32_t load_iter_steps = load_len / max_load_vec_elems; + if constexpr (load_len >= max_load_vec_elems) { #pragma unroll - for (uint32_t j = 0; j < load_iter_steps; j++) { - uint32_t offset_x = j * max_load_vec_len * scale_factor; - auto reg_sub = - tile.reg.xetla_select( - offset_x + offset_y); - uint32_t address_offset = address_offset_y + offset_x * sizeof(dtype); - reg_sub.xetla_format() = xetla_load_local< - load_dtype, - max_load_vec_len, - data_size::default_size>( - payload.base_address + payload.address + address_offset); - } + for (uint32_t j = 0; j < load_iter_steps; j++) { + uint32_t offset_x = j * max_load_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset_x); + uint32_t address_offset = offset_x * sizeof(dtype); + reg_sub.xetla_format() = + xetla_load_local( + payload.base_address + payload.address + address_offset); } - uint32_t tail_offset = - offset_y + load_iter_steps * max_load_vec_len * scale_factor; - uint32_t tail_address_offset = address_offset_y + - load_iter_steps * max_load_vec_len * scale_factor * sizeof(dtype); - detail::process_1d_tail< - load_len % max_load_vec_len, - (max_load_vec_len >> 1), - detail::process_flag::load, - L1, - L2>(tile, payload, tail_offset, tail_address_offset); } + constexpr uint32_t tail_len = load_len % max_load_vec_elems * sizeof(dtype); + uint32_t tail_offset = load_iter_steps * max_load_vec_len; + detail::process_1d_tail< + tail_len, + (max_load_vec_len >> 1), + detail::process_flag::load, + L1, + L2>(tile, payload, tail_offset); } } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 886a8b595..3c69d81a3 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -293,9 +293,9 @@ tile_store(tile_t& tile, payload_t& payload) { if constexpr (store_len >= max_store_vec_elems) { #pragma unroll for (uint32_t i = 0; i < store_iter_steps; i++) { - uint32_t offset_x = i * max_store_vec_elems; - auto reg_sub = tile.reg.xetla_select(offset_x); - uint32_t address_offset = offset_x * sizeof(dtype); + uint32_t offset = i * max_store_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset); + uint32_t address_offset = offset * sizeof(dtype); xetla_store_global( payload.base_ptr, @@ -997,40 +997,36 @@ __XETLA_API typename std::enable_if_t< tile_t::tile_size_y == 1 && tile_t::block_size_y == 1> tile_store(tile_t& tile, payload_t& payload) { using dtype = typename tile_t::dtype; - using tile_desc = typename payload_t::tile_desc; - using store_dtype = typename payload_t::mem_dtype; - - constexpr uint32_t scale_factor = payload_t::scale_factor; - static constexpr uint32_t store_len = tile_desc::tile_size_x / scale_factor; - + static constexpr uint32_t store_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_store_vec_len = load_store_attr::max_store_vec_len; + static constexpr uint32_t max_store_vec_elems = + max_store_vec_len / sizeof(dtype); - static constexpr uint32_t store_iter_steps = store_len / max_store_vec_len; + static constexpr uint32_t store_iter_steps = store_len / max_store_vec_elems; - if constexpr (store_len >= max_store_vec_len) { + if constexpr (store_len >= max_store_vec_elems) { #pragma unroll - for (uint32_t j = 0; j < store_iter_steps; j++) { - uint32_t offset_x = j * max_store_vec_len * scale_factor; - auto reg_sub = tile.reg.xetla_select<64 * scale_factor, 1>(offset_x); + for (uint32_t i = 0; i < store_iter_steps; i++) { + uint32_t offset_x = i * max_store_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); - xetla_store_local( - payload.address + address_offset, - reg_sub.xetla_format()); + xetla_store_local( + payload.base_address + payload.address + address_offset, + reg_sub.xetla_format()); } } + constexpr uint32_t tail_len = store_len % max_store_vec_elems * sizeof(dtype); + uint32_t tail_offset = store_iter_steps * max_store_vec_len; detail::process_1d_tail< - store_len % max_store_vec_len, + tail_len, (max_store_vec_len >> 1), detail::process_flag::store, L1, - L2>( - tile, - payload, - store_iter_steps * max_store_vec_len * scale_factor, - store_iter_steps * max_store_vec_len * scale_factor * sizeof(dtype)); + L2>(tile, payload, tail_offset); } } // namespace gpu::xetla::subgroup diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 7e7896f3e..7d11ab590 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -45,24 +45,24 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } - static constexpr mma_engine engine = mma_engine::fpu; - static constexpr gpu_arch gpu_arch = gpu_arch::XeHpg; + static constexpr mma_engine engine = mma_engine::xmx; + static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; }; class Test0 : public TestBase { public: static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 1; + static constexpr size_t mat_n = 1280; + static constexpr size_t mat_k = 8192; + static constexpr size_t wg_m = 8; static constexpr size_t wg_n = 32; - static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 16; static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; + static constexpr uint32_t local_kslicing = 8; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = fp16; using data_type_b = fp16; using data_type_c = fp16; diff --git a/tests/integration/gemm/fp16/kernel_func.hpp b/tests/integration/gemm/fp16/kernel_func.hpp index 26a78ff91..191213794 100644 --- a/tests/integration/gemm/fp16/kernel_func.hpp +++ b/tests/integration/gemm/fp16/kernel_func.hpp @@ -41,8 +41,8 @@ template < gpu_arch gpu_arch> struct fp16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 0 ; //8; - static constexpr uint32_t prefetch_distance = 0 ;//256 / (sg_k * sizeof(dtype_a)); + static constexpr uint32_t periodic_sync_interval = 1 ; //8; + static constexpr uint32_t prefetch_distance = 3 ;//256 / (sg_k * sizeof(dtype_a)); using compute_attr = typename std::conditional< (engine == mma_engine::fpu), diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index 400d13276..eabc8915d 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -33,7 +33,7 @@ TYPED_TEST_P(fp16_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); using tests = ::testing::Types< - Test4>; + Test0>; // Test1, // Test2, // Test3>; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 7ff5987d2..55035b20d 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -32,12 +32,12 @@ class test_col_major_1 { // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; + static constexpr size_t mat_k = 11008; static constexpr size_t wg_m = 1; static constexpr size_t wg_n = 1; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; - static constexpr size_t sg_k = 1024 / 1; + static constexpr size_t sg_k = 256 / 1; static constexpr size_t dequant_s = 128; // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; From 93c8ad1eca8787f3048398abab04b8dd6a650e46 Mon Sep 17 00:00:00 2001 From: "Swift.Sun" Date: Thu, 20 Jun 2024 04:46:16 +0800 Subject: [PATCH 32/34] fix group_qkv --- .../gemm/impl/int4_dequantize_kslicing_xe.hpp | 8 +-- include/group/epilogue/impl/default_xe.hpp | 7 +-- include/kernel/gemm/impl/kslicing_xe.hpp | 8 +-- include/subgroup/tile/impl/payload_xe.hpp | 29 +++++++++-- include/subgroup/tile/impl/store_xe.hpp | 51 ++++++++++--------- tests/integration/gemm/fp16/common.hpp | 2 +- tests/integration/gemv/int4/main.cpp | 6 +-- 7 files changed, 64 insertions(+), 47 deletions(-) diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index b0a0a9f75..79d37d517 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -598,12 +598,8 @@ class gemm_universal_t< int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n; int start_k = 0; uint32_t wg_tile_k = args.matrix_k; - uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n - ? args.matrix_n - : (start_n + wg_tile_n); - uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m - ? args.matrix_m - : (start_m + wg_tile_m); + uint32_t boundary_n = std::min(start_n + wg_tile_n, args.matrix_n); + uint32_t boundary_m = std::min(start_m + wg_tile_m, args.matrix_m); uint32_t boundary_k = wg_tile_k; if constexpr (num_global_kslicing > 1) { wg_tile_k = (wg_tile_k + num_global_kslicing - 1) / num_global_kslicing; diff --git a/include/group/epilogue/impl/default_xe.hpp b/include/group/epilogue/impl/default_xe.hpp index 8d4061f49..d8dadc8a1 100644 --- a/include/group/epilogue/impl/default_xe.hpp +++ b/include/group/epilogue/impl/default_xe.hpp @@ -95,6 +95,7 @@ class epilogue_t< using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; + // static constexpr msg_type msg_type_c = msg_type::unaligned_2d; static constexpr msg_type msg_type_c = subgroup::msg_type_v; using matC_payload_t = subgroup:: @@ -192,9 +193,9 @@ class epilogue_t< using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; - // static constexpr msg_type msg_type_c = msg_type::block_2d; - static constexpr msg_type msg_type_c = - subgroup::msg_type_v; + static constexpr msg_type msg_type_c = msg_type::block_2d; + // static constexpr msg_type msg_type_c = + // subgroup::msg_type_v; using matC_payload_t = subgroup:: mem_payload_t; diff --git a/include/kernel/gemm/impl/kslicing_xe.hpp b/include/kernel/gemm/impl/kslicing_xe.hpp index f4efca3a6..64d15edb6 100644 --- a/include/kernel/gemm/impl/kslicing_xe.hpp +++ b/include/kernel/gemm/impl/kslicing_xe.hpp @@ -427,12 +427,8 @@ class gemm_universal_t< int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n; int start_k = 0; uint32_t wg_tile_k = args.matrix_k; - uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n - ? args.matrix_n - : (start_n + wg_tile_n); - uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m - ? args.matrix_m - : (start_m + wg_tile_m); + uint32_t boundary_n = std::min(start_n + wg_tile_n, args.matrix_n); + uint32_t boundary_m = std::min(start_m + wg_tile_m, args.matrix_m); uint32_t boundary_k = wg_tile_k; if constexpr (num_global_kslicing > 1) { wg_tile_k = (wg_tile_k + num_global_kslicing - 1) / num_global_kslicing; diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 14427b5f1..b5c789bd5 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -435,9 +435,15 @@ struct mem_payload_t< uint64_t base_offset; dtype* base_ptr; uint32_t pitch_in_bytes; + uint32_t height_in_elems; + uint32_t width_in_elems; + uint32_t payload_bytes; inline mem_payload_t(mem_desc_t& mem_tdesc) { pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); + width_in_elems = mem_tdesc.shape.x; + height_in_elems = mem_tdesc.shape.y; + payload_bytes = width_in_elems * height_in_elems * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; base_offset = mem_transpose @@ -448,14 +454,17 @@ struct mem_payload_t< inline mem_payload_t( dtype* p, - [[maybe_unused]] int surface_width, - [[maybe_unused]] int surface_height, + int surface_width, + int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y) { pitch_in_bytes = surface_pitch * sizeof(dtype); uint32_t offset_x = surface_offset_x; uint32_t offset_y = surface_offset_y; + width_in_elems = surface_width; + height_in_elems = surface_height; + payload_bytes = width_in_elems * height_in_elems * sizeof(dtype); base_offset = mem_transpose ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); @@ -466,6 +475,9 @@ struct mem_payload_t< pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; + width_in_elems = mem_tdesc.shape.x; + height_in_elems = mem_tdesc.shape.y; + payload_bytes = width_in_elems * height_in_elems * sizeof(dtype); base_offset = mem_transpose ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); @@ -474,14 +486,17 @@ struct mem_payload_t< __XETLA_API void init( dtype* p, - [[maybe_unused]] int surface_width, - [[maybe_unused]] int surface_height, + int surface_width, + int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y) { pitch_in_bytes = surface_pitch * sizeof(dtype); uint32_t offset_x = surface_offset_x; uint32_t offset_y = surface_offset_y; + width_in_elems = surface_width; + height_in_elems = surface_height; + payload_bytes = width_in_elems * height_in_elems * sizeof(dtype); base_offset = mem_transpose ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); @@ -492,6 +507,9 @@ struct mem_payload_t< this->base_offset = rhs.base_offset; this->base_ptr = rhs.base_ptr; this->pitch_in_bytes = rhs.pitch_in_bytes; + this->width_in_elems = rhs.width_in_elems; + this->height_in_elems = rhs.height_in_elems; + this->payload_bytes = rhs.payload_bytes; } inline mem_payload_t() = default; @@ -499,6 +517,9 @@ struct mem_payload_t< this->base_offset = rhs.base_offset; this->base_ptr = rhs.base_ptr; this->pitch_in_bytes = rhs.pitch_in_bytes; + this->width_in_elems = rhs.width_in_elems; + this->height_in_elems = rhs.height_in_elems; + this->payload_bytes = rhs.payload_bytes; return *this; } diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 3c69d81a3..701198a54 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -283,34 +283,38 @@ tile_store(tile_t& tile, payload_t& payload) { static constexpr uint32_t store_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; - using load_store_attr = load_store_attr_t; - static constexpr uint32_t max_store_vec_len = - load_store_attr::max_store_vec_len; - static constexpr uint32_t max_store_vec_elems = - max_store_vec_len / sizeof(dtype); + if (payload.base_offset <= payload.payload_bytes) { + using load_store_attr = load_store_attr_t; + static constexpr uint32_t max_store_vec_len = + load_store_attr::max_store_vec_len; + static constexpr uint32_t max_store_vec_elems = + max_store_vec_len / sizeof(dtype); + static constexpr uint32_t store_iter_steps = + store_len / max_store_vec_elems; - static constexpr uint32_t store_iter_steps = store_len / max_store_vec_elems; - if constexpr (store_len >= max_store_vec_elems) { + if constexpr (store_len >= max_store_vec_elems) { #pragma unroll - for (uint32_t i = 0; i < store_iter_steps; i++) { - uint32_t offset = i * max_store_vec_elems; - auto reg_sub = tile.reg.xetla_select(offset); - uint32_t address_offset = offset * sizeof(dtype); + for (uint32_t i = 0; i < store_iter_steps; i++) { + uint32_t offset = i * max_store_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset); + uint32_t address_offset = offset * sizeof(dtype); - xetla_store_global( - payload.base_ptr, - payload.base_offset + address_offset, - reg_sub.xetla_format()); + xetla_store_global( + payload.base_ptr, + payload.base_offset + address_offset, + reg_sub.xetla_format()); + } } + constexpr uint32_t tail_len = + store_len % max_store_vec_elems * sizeof(dtype); + uint32_t tail_offset = store_iter_steps * max_store_vec_len; + detail::process_1d_tail< + tail_len, + (max_store_vec_len >> 1), + detail::process_flag::store, + L1, + L2>(tile, payload, tail_offset); } - constexpr uint32_t tail_len = store_len % max_store_vec_elems * sizeof(dtype); - uint32_t tail_offset = store_iter_steps * max_store_vec_len; - detail::process_1d_tail< - tail_len, - (max_store_vec_len >> 1), - detail::process_flag::store, - L1, - L2>(tile, payload, tail_offset); } /// @brief Is the func storing data from register file to unaligned global @@ -348,7 +352,6 @@ tile_store( constexpr uint32_t num_channel_y = payload_t::num_channel_y; constexpr uint32_t store_elems = num_channel_y * payload_t::num_channel_x; constexpr uint32_t scale_factor = payload_t::scale_factor; - #pragma unroll for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; i++) { diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 7d11ab590..a675ab0b9 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -52,7 +52,7 @@ class TestBase { class Test0 : public TestBase { public: static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 1280; + static constexpr size_t mat_n = 64; static constexpr size_t mat_k = 8192; static constexpr size_t wg_m = 8; static constexpr size_t wg_n = 32; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 55035b20d..3be90c381 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -32,12 +32,12 @@ class test_col_major_1 { // Extract the parameters required by different test cases static constexpr size_t mat_m = 1; static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 11008; + static constexpr size_t mat_k = 4096; static constexpr size_t wg_m = 1; static constexpr size_t wg_n = 1; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 1; - static constexpr size_t sg_k = 256 / 1; + static constexpr size_t sg_k = 1024 / 1; static constexpr size_t dequant_s = 128; // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; @@ -47,7 +47,7 @@ class test_col_major_1 { static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeHpc; + static constexpr gpu_arch arch = gpu_arch::XeLpg; using data_type_a = fp16; using data_type_b = int4x8; using data_type_c = fp16; From 8f0abc49ee2219ae6529098d6838b8c1de222915 Mon Sep 17 00:00:00 2001 From: "Sun, Jiwei1" Date: Thu, 20 Jun 2024 02:39:00 +0000 Subject: [PATCH 33/34] fix group_qkv --- include/group/epilogue/impl/default_xe.hpp | 2 -- include/subgroup/tile/impl/payload_xe.hpp | 20 +++++++++++++---- include/subgroup/tile/impl/store_xe.hpp | 4 ++-- tests/integration/gemv/int4/main.cpp | 26 +++++++++++----------- 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/include/group/epilogue/impl/default_xe.hpp b/include/group/epilogue/impl/default_xe.hpp index d8dadc8a1..b25397dcd 100644 --- a/include/group/epilogue/impl/default_xe.hpp +++ b/include/group/epilogue/impl/default_xe.hpp @@ -100,8 +100,6 @@ class epilogue_t< subgroup::msg_type_v; using matC_payload_t = subgroup:: mem_payload_t; - using matC_payload_t = subgroup:: - mem_payload_t; update_sg_tile_tdesc(g, mem_desc_c); matC_t matC; matC_payload_t matC_payload(mem_desc_c); diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index b5c789bd5..ca4b1751d 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -443,7 +443,10 @@ struct mem_payload_t< pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); width_in_elems = mem_tdesc.shape.x; height_in_elems = mem_tdesc.shape.y; - payload_bytes = width_in_elems * height_in_elems * sizeof(dtype); + payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes + + mem_tdesc.shape.y * sizeof(dtype) + : (mem_tdesc.shape.y - 1) * pitch_in_bytes + + mem_tdesc.shape.x * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; base_offset = mem_transpose @@ -464,7 +467,10 @@ struct mem_payload_t< uint32_t offset_y = surface_offset_y; width_in_elems = surface_width; height_in_elems = surface_height; - payload_bytes = width_in_elems * height_in_elems * sizeof(dtype); + payload_bytes = mem_transpose ? (surface_offset_x - 1) * pitch_in_bytes + + surface_offset_y * sizeof(dtype) + : (surface_offset_y - 1) * pitch_in_bytes + + surface_offset_x * sizeof(dtype); base_offset = mem_transpose ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); @@ -477,7 +483,11 @@ struct mem_payload_t< uint32_t offset_y = mem_tdesc.coord.y; width_in_elems = mem_tdesc.shape.x; height_in_elems = mem_tdesc.shape.y; - payload_bytes = width_in_elems * height_in_elems * sizeof(dtype); + payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes + + mem_tdesc.shape.y * sizeof(dtype) + : (mem_tdesc.shape.y - 1) * pitch_in_bytes + + mem_tdesc.shape.x * sizeof(dtype); + base_offset = mem_transpose ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); @@ -496,7 +506,9 @@ struct mem_payload_t< uint32_t offset_y = surface_offset_y; width_in_elems = surface_width; height_in_elems = surface_height; - payload_bytes = width_in_elems * height_in_elems * sizeof(dtype); + payload_bytes = mem_transpose + ? (surface_width - 1) * pitch_in_bytes + surface_height * sizeof(dtype) + : (surface_height - 1) * pitch_in_bytes + surface_width * sizeof(dtype); base_offset = mem_transpose ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 701198a54..d5345b06b 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -282,8 +282,8 @@ tile_store(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; static constexpr uint32_t store_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; - - if (payload.base_offset <= payload.payload_bytes) { + if (payload.base_offset + store_len * sizeof(dtype) <= + payload.payload_bytes) { using load_store_attr = load_store_attr_t; static constexpr uint32_t max_store_vec_len = load_store_attr::max_store_vec_len; diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 3be90c381..622d6cf1b 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -47,7 +47,7 @@ class test_col_major_1 { static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; + static constexpr gpu_arch arch = gpu_arch::XeHpc; using data_type_a = fp16; using data_type_b = int4x8; using data_type_c = fp16; @@ -108,12 +108,12 @@ int gemm_result_validate( bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation"); #ifdef UT_DEBUG - for (uint32_t i = 0; i < m; i++) { - for (uint32_t j = 0; j < n; j++) { - std::cout << float(sycl::half(C[i * n + j])) << " "; - } - std::cout << std::endl; - } + // for (uint32_t i = 0; i < m; i++) { + // for (uint32_t j = 0; j < n; j++) { + // std::cout << float(sycl::half(C[i * n + j])) << " "; + // } + // std::cout << std::endl; + // } #endif std::cout << (!result ? "FAILED\n" : "PASSED\n"); return result ? 0 : 1; @@ -185,12 +185,12 @@ std::vector dequantize_weight( } } #ifdef UT_DEBUG - for (uint32_t i = 0; i < matrix_n; i++) { - for (uint32_t j = 0; j < matrix_k; j++) { - std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; - } - std::cout << std::endl; - } + // for (uint32_t i = 0; i < matrix_n; i++) { + // for (uint32_t j = 0; j < matrix_k; j++) { + // std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; + // } + // std::cout << std::endl; + // } #endif return b_out; } From dc7d8123a8d17d00afc1c499614ab3f481e39943 Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Thu, 20 Jun 2024 16:58:09 +0800 Subject: [PATCH 34/34] init arch Xe2 enable 2d payload on Xe2 --- include/common/core/arch_config.hpp | 55 +++++++++++++++++++--- include/common/core/common_types.hpp | 2 +- include/group/gemm/compute_policy.hpp | 15 +++--- include/subgroup/tile/impl/payload_xe.hpp | 6 +-- include/subgroup/tile/impl/prefetch_xe.hpp | 2 +- 5 files changed, 60 insertions(+), 20 deletions(-) diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 7a1315ab9..8de2a70a3 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -31,9 +31,8 @@ struct load_store_attr_t { static constexpr bool has_hw_block_2d = false; }; -template <> -struct load_store_attr_t { - /// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490 +template +struct xe_plus_load_store_attr_t { static constexpr bool has_hw_block_2d = true; static constexpr uint32_t max_load_height_in_elem = 32; static constexpr uint32_t max_load_width_in_bytes = 64; @@ -55,10 +54,9 @@ struct load_store_attr_t { template struct client_load_store_attr_base_t { - /// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490 static constexpr bool has_hw_block_2d = false; - static constexpr uint32_t max_load_height_in_elem = 32; - static constexpr uint32_t max_load_width_in_bytes = 64; + static constexpr uint32_t max_load_height_in_elem = 0; + static constexpr uint32_t max_load_width_in_bytes = 0; static constexpr uint32_t max_trans_load_width_in_bytes = 32; static constexpr uint32_t max_vnni_load_width_in_elems = 16; static constexpr uint32_t min_vnni_load_height_in_bytes = 4; @@ -87,6 +85,18 @@ struct load_store_attr_t msg_type::block_2d, gpu_arch::XeLpg> {}; +template <> +struct load_store_attr_t + : public xe_plus_load_store_attr_base_t< + msg_type::block_2d, + gpu_arch::XeHpc> {}; + +template <> +struct load_store_attr_t + : public xe_plus_load_store_attr_base_t< + msg_type::block_2d, + gpu_arch::Xe2> {}; + template inline constexpr bool arch_has_2d_load_store = load_store_attr_t::has_hw_block_2d; @@ -105,6 +115,13 @@ struct load_store_attr_t { static constexpr uint32_t max_prefetch_vec_len = 64; }; +template <> +struct load_store_attr_t { + static constexpr uint32_t max_load_vec_len = 512; + static constexpr uint32_t max_store_vec_len = 512; + static constexpr uint32_t max_prefetch_vec_len = 64; +}; + struct dpas_attr_base_t { static constexpr bool has_xmx = true; static constexpr uint32_t systolic_depth = 8; @@ -129,6 +146,11 @@ struct dpas_attr_t : public dpas_attr_base_t { static constexpr uint32_t n_fixed_limit = 8; }; +template <> +struct dpas_attr_t : public dpas_attr_t { + static constexpr uint32_t systolic_depth = 4; +}; + template inline constexpr bool arch_has_xmx = dpas_attr_t::has_xmx; @@ -162,6 +184,10 @@ template <> struct register_bytes_t { static constexpr uint32_t reg_in_bytes = 32; }; +template <> +struct register_bytes_t { + static constexpr uint32_t reg_in_bytes = 64; +}; template struct register_attr_t { @@ -236,10 +262,25 @@ struct arch_attr_t { using dpas_attr = dpas_attr_t; - static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t max_wg_num = 16; static constexpr uint32_t local_mem_size = 64 * 1024; }; +template <> +struct arch_attr_t { + template + using load_store_attr = load_store_attr_t; + + template + using register_attr = register_attr_t; + + using dpas_attr = dpas_attr_t; + + static constexpr uint32_t max_wg_num = 16; + static constexpr uint32_t local_mem_size = 128 * 1024; +}; + + /// @} xetla_core_arch_config } // namespace gpu::xetla diff --git a/include/common/core/common_types.hpp b/include/common/core/common_types.hpp index cbd174462..3c3ca4037 100644 --- a/include/common/core/common_types.hpp +++ b/include/common/core/common_types.hpp @@ -21,7 +21,7 @@ #include namespace gpu::xetla { -enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 }; +enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2, Xe2 = 3 }; enum class grf_mode : uint8_t { normal = 0, double_grf = 1 }; diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index 0a0cd1c91..11033e907 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -118,16 +118,15 @@ struct compute_policy_default_fpu< static constexpr int sync_freq = perf_tuning_knob::sync_freq; static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr uint32_t block_size_y_a = - arch_tag_ == gpu_arch::XeLpg ? 8 : 16; - static constexpr uint32_t block_bytes_x_a = 32; + static constexpr uint32_t block_size_y_a = 16; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_bytes_x_b = - arch_attr_t::template register_attr<>::reg_in_bytes; - static constexpr uint32_t block_size_x_b = - block_bytes_x_b / sizeof(dtype_mma_b); - static constexpr uint32_t block_size_y_b = block_size_x_a; + static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem; + static constexpr uint32_t block_bytes_y_b = mma_attr::mma_k_in_bytes; + static constexpr uint32_t block_size_y_b = + block_bytes_y_b / sizeof(dtype_mma_b); }; /// @} xetla_gemm diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index ca4b1751d..7a98b03ab 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -1101,7 +1101,7 @@ struct mem_payload_t< tile_desc_, msg_type::block_2d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpg)>> { + std::enable_if_t<(arch_has_2d_load_store)>> { using dtype = native_type_t; using mem_desc_t = mem_desc_t; @@ -1652,7 +1652,7 @@ struct prefetch_payload_t< num_coop_sg_, arch_tag_, std::enable_if_t< - arch_tag_ <= gpu_arch::XeHpg && + arch_has_2d_load_store && ((block_size_y_ != 1 && mem_layout_ == mem_layout::row_major) || (block_size_x_ != 1 && mem_layout_ == mem_layout::col_major))>> { using dtype = native_type_t; @@ -2305,4 +2305,4 @@ struct prefetch_payload_t< __XETLA_API void update_tdesc([[maybe_unused]] int offset) {} }; -} // namespace gpu::xetla::subgroup \ No newline at end of file +} // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index c821ab25d..d2eedfc77 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -195,4 +195,4 @@ __XETLA_API typename std::enable_if_t::is_local> tile_prefetch([[maybe_unused]] payload_t& payload) {} -} // namespace gpu::xetla::subgroup \ No newline at end of file +} // namespace gpu::xetla::subgroup