diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c07ef0fc..e2213a9b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,7 +66,7 @@ else() set(XETLA_KERNEL_FLAGS ${XETLA_KERNEL_FLAGS} -Xs "${XETLA_OFFLINE_OPTIONS}") endif() -add_compile_options(-fsycl -fsycl-device-code-split=per_kernel) +add_compile_options(-fsycl -fsycl-device-code-split=per_kernel -ftemplate-backtrace-limit=0) add_compile_options(-Wall -Wextra -Werror) include(ProcessorCount) diff --git a/examples/05_batch_gemm/batch_gemm.hpp b/examples/05_batch_gemm/batch_gemm.hpp index 7c16208f4..528477d03 100644 --- a/examples/05_batch_gemm/batch_gemm.hpp +++ b/examples/05_batch_gemm/batch_gemm.hpp @@ -276,20 +276,20 @@ class batch_gemm_t { args.matB_base.base, args.matB_ld); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m * args.batch_size, - args.matC_ld); - } else { - implementable &= - kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= + // kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m * args.batch_size, + // args.matC_ld); + // } else { + // implementable &= + // kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp b/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp index 67a454fe0..ed2fc2489 100644 --- a/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp +++ b/examples/07_multi_layer_perceptron/multi_layer_perceptron.hpp @@ -409,20 +409,20 @@ class multi_layer_perceptron_t { args.matW_base.base, args.matW_ld); } } - if (epilogue_layer1_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_layer1_t::msg_type_c == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matB_base.base), - args.matrix_n_layer1, - args.matrix_m_layer1, - args.matB_ld); - } else { - implementable &= - kernel::general_1d::check_alignment( - args.matB_base.base, args.matB_ld); - } - } + // if (epilogue_layer1_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_layer1_t::msg_type_c == msg_type::block_2d) { + // implementable &= + // kernel::block_2d::check_tensor( + // (uint64_t)(args.matB_base.base), + // args.matrix_n_layer1, + // args.matrix_m_layer1, + // args.matB_ld); + // } else { + // implementable &= + // kernel::general_1d::check_alignment( + // args.matB_base.base, args.matB_ld); + // } + // } if (gemm_layer2_t::msg_type_a != msg_type::unaligned_2d) { if (gemm_layer2_t::msg_type_a == msg_type::block_2d) { implementable &= @@ -451,20 +451,20 @@ class multi_layer_perceptron_t { args.matV_base.base, args.matV_ld); } } - if (epilogue_layer2_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_layer2_t::msg_type_c == msg_type::block_2d) { - implementable &= - kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n_layer2, - args.matrix_m_layer2, - args.matC_ld); - } else { - implementable &= - kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_layer2_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_layer2_t::msg_type_c == msg_type::block_2d) { + // implementable &= + // kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n_layer2, + // args.matrix_m_layer2, + // args.matC_ld); + // } else { + // implementable &= + // kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/examples/08_scaled_dot_product_attention/softmax.hpp b/examples/08_scaled_dot_product_attention/softmax.hpp index a14386efe..40f986a70 100644 --- a/examples/08_scaled_dot_product_attention/softmax.hpp +++ b/examples/08_scaled_dot_product_attention/softmax.hpp @@ -60,18 +60,21 @@ struct xetla_softmax_fwd_t { using softmax_tile_desc_t = subgroup:: tile_desc_t; using softmax_load_t = subgroup::tile_t; + using mem_desc_in_t = mem_desc_t; using softmax_load_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, softmax_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; // this tile will store the softmax result to global memory using softmax_store_t = subgroup::tile_t; + using mem_desc_out_t = + mem_desc_t; using softmax_store_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_out_t, softmax_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; struct arguments_t { diff --git a/examples/09_gate_recurrent_unit/kernel_func.hpp b/examples/09_gate_recurrent_unit/kernel_func.hpp index b2e3994c3..0dd76188a 100644 --- a/examples/09_gate_recurrent_unit/kernel_func.hpp +++ b/examples/09_gate_recurrent_unit/kernel_func.hpp @@ -156,7 +156,7 @@ struct gru_layer { using mat_hidden_payload_t = mem_payload_t< mem_desc_a_t, matC_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; using matC_payload_t = mem_payload_t< mem_desc_c_t, diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 8c7c56463..8de2a70a3 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -31,9 +31,8 @@ struct load_store_attr_t { static constexpr bool has_hw_block_2d = false; }; -template <> -struct load_store_attr_t { - /// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490 +template +struct xe_plus_load_store_attr_t { static constexpr bool has_hw_block_2d = true; static constexpr uint32_t max_load_height_in_elem = 32; static constexpr uint32_t max_load_width_in_bytes = 64; @@ -55,10 +54,9 @@ struct load_store_attr_t { template struct client_load_store_attr_base_t { - /// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490 static constexpr bool has_hw_block_2d = false; - static constexpr uint32_t max_load_height_in_elem = 32; - static constexpr uint32_t max_load_width_in_bytes = 64; + static constexpr uint32_t max_load_height_in_elem = 0; + static constexpr uint32_t max_load_width_in_bytes = 0; static constexpr uint32_t max_trans_load_width_in_bytes = 32; static constexpr uint32_t max_vnni_load_width_in_elems = 16; static constexpr uint32_t min_vnni_load_height_in_bytes = 4; @@ -87,21 +85,40 @@ struct load_store_attr_t msg_type::block_2d, gpu_arch::XeLpg> {}; +template <> +struct load_store_attr_t + : public xe_plus_load_store_attr_base_t< + msg_type::block_2d, + gpu_arch::XeHpc> {}; + +template <> +struct load_store_attr_t + : public xe_plus_load_store_attr_base_t< + msg_type::block_2d, + gpu_arch::Xe2> {}; + template inline constexpr bool arch_has_2d_load_store = load_store_attr_t::has_hw_block_2d; template struct load_store_attr_t { - static constexpr uint32_t max_load_vec_len = 32; - static constexpr uint32_t max_store_vec_len = 32; + static constexpr uint32_t max_load_vec_len = 256; + static constexpr uint32_t max_store_vec_len = 256; static constexpr uint32_t max_prefetch_vec_len = 32; }; template <> struct load_store_attr_t { - static constexpr uint32_t max_load_vec_len = 64; - static constexpr uint32_t max_store_vec_len = 64; + static constexpr uint32_t max_load_vec_len = 512; + static constexpr uint32_t max_store_vec_len = 512; + static constexpr uint32_t max_prefetch_vec_len = 64; +}; + +template <> +struct load_store_attr_t { + static constexpr uint32_t max_load_vec_len = 512; + static constexpr uint32_t max_store_vec_len = 512; static constexpr uint32_t max_prefetch_vec_len = 64; }; @@ -129,6 +146,11 @@ struct dpas_attr_t : public dpas_attr_base_t { static constexpr uint32_t n_fixed_limit = 8; }; +template <> +struct dpas_attr_t : public dpas_attr_t { + static constexpr uint32_t systolic_depth = 4; +}; + template inline constexpr bool arch_has_xmx = dpas_attr_t::has_xmx; @@ -162,6 +184,10 @@ template <> struct register_bytes_t { static constexpr uint32_t reg_in_bytes = 32; }; +template <> +struct register_bytes_t { + static constexpr uint32_t reg_in_bytes = 64; +}; template struct register_attr_t { @@ -236,10 +262,25 @@ struct arch_attr_t { using dpas_attr = dpas_attr_t; - static constexpr uint32_t max_wg_num = 64; + static constexpr uint32_t max_wg_num = 16; static constexpr uint32_t local_mem_size = 64 * 1024; }; +template <> +struct arch_attr_t { + template + using load_store_attr = load_store_attr_t; + + template + using register_attr = register_attr_t; + + using dpas_attr = dpas_attr_t; + + static constexpr uint32_t max_wg_num = 16; + static constexpr uint32_t local_mem_size = 128 * 1024; +}; + + /// @} xetla_core_arch_config } // namespace gpu::xetla diff --git a/include/common/core/base_consts.hpp b/include/common/core/base_consts.hpp index 2f8bbd489..67bcd7e92 100644 --- a/include/common/core/base_consts.hpp +++ b/include/common/core/base_consts.hpp @@ -23,9 +23,8 @@ namespace gpu::xetla { -/// @addtogroup xetla_core_base_types +/// @addtogroup xetla_core_base_consts /// @{ - -/// @} xetla_core_base_types +/// @} xetla_core_base_consts } // namespace gpu::xetla diff --git a/include/common/core/base_types.hpp b/include/common/core/base_types.hpp index 33ed26b74..ee2a77a81 100644 --- a/include/common/core/base_types.hpp +++ b/include/common/core/base_types.hpp @@ -55,6 +55,32 @@ using fp16 = sycl::half; /// using tf32 = sycl::ext::intel::experimental::esimd::tfloat32; +/// @brief xetla 4bits data packed as 8bits data type. +/// 2 4bit data pack to one byte +struct int4x2 { + uint8_t data; + + operator uint8_t() const { + return data; + } + int4x2(uint8_t val) { + data = val; + } +}; + +/// @brief xetla 4bits data packed as 32bits data type. +/// 8 4bit data pack to 4 bytes +struct int4x8 { + uint32_t data; + + operator uint32_t() const { + return data; + } + int4x8(uint32_t val) { + data = val; + } +}; + /// @brief mx_fp4(E2M1) data packed as 8bits data type. struct mx_fp4 { uint8_t data; @@ -89,6 +115,8 @@ template struct is_internal_type { static constexpr bool value = std::is_same, bf16>::value || std::is_same, tf32>::value || + std::is_same, int4x2>::value || + std::is_same, int4x8>::value || std::is_same, mx_fp4>::value; }; template @@ -137,6 +165,18 @@ struct native_type { using type = uint8_t; }; +/// @brief Set uint8_t as the native data type of int4x2. +template <> +struct native_type { + using type = uint8_t; +}; + +/// @brief Set uint8_t as the native data type of int4x8. +template <> +struct native_type { + using type = uint32_t; +}; + /// @brief Return the native data type of T template using native_type_t = typename native_type::type; diff --git a/include/common/core/common_types.hpp b/include/common/core/common_types.hpp index 2a23a9e5e..3c3ca4037 100644 --- a/include/common/core/common_types.hpp +++ b/include/common/core/common_types.hpp @@ -21,9 +21,18 @@ #include namespace gpu::xetla { -enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2 }; +enum class gpu_arch : uint8_t { XeLpg = 0, XeHpg = 1, XeHpc = 2, Xe2 = 3 }; enum class grf_mode : uint8_t { normal = 0, double_grf = 1 }; enum class mem_layout : uint8_t { row_major = 0, col_major = 1 }; + +enum class quant_mode : uint8_t { S4_ASYM = 0, S4_FULLRANGE_NO_ZP = 1 }; + +struct quant_info { + quant_mode quant_mode; + uint32_t dequant_s; + mem_layout weight_mem_layout; +}; + } // namespace gpu::xetla diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index 0bc360d6e..cd51832c6 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -256,7 +256,7 @@ constexpr __ESIMD_NS::atomic_op get_atomic_op(gpu::xetla::atomic_op ao) { /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, cache_hint L1H = cache_hint::cached, cache_hint L2H = cache_hint::cached, @@ -293,7 +293,7 @@ __XETLA_API void xetla_prefetch_global( /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, cache_hint L1H = cache_hint::cached, cache_hint L2H = cache_hint::cached> @@ -355,7 +355,12 @@ __XETLA_API xetla_vector xetla_load_global( __ESIMD_NS::cache_hint_L1, __ESIMD_NS::cache_hint_L2, __ESIMD_NS::alignment}; - return __ESIMD_NS::block_load(ptr, byte_offset, props); + if constexpr (sizeof(T) * N < sizeof(uint32_t)) { + xetla_vector offsets(byte_offset, sizeof(T)); + return __ESIMD_NS::gather(ptr, offsets); + } else { + return __ESIMD_NS::block_load(ptr, byte_offset, props); + } } /// @brief Stateless scattered load. @@ -380,7 +385,7 @@ __XETLA_API xetla_vector xetla_load_global( /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, cache_hint L1H = cache_hint::none, cache_hint L2H = cache_hint::none, @@ -426,7 +431,7 @@ __XETLA_API xetla_vector xetla_load_global( /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, cache_hint L1H = cache_hint::none, cache_hint L2H = cache_hint::none, @@ -495,7 +500,13 @@ __XETLA_API void xetla_store_global( __ESIMD_NS::cache_hint_L1, __ESIMD_NS::cache_hint_L2, __ESIMD_NS::alignment}; - __ESIMD_NS::block_store(ptr, byte_offset, vals, props); + + if constexpr (sizeof(T) * N < sizeof(uint32_t)) { + xetla_vector offsets(byte_offset, sizeof(T)); + return __ESIMD_NS::scatter(ptr, offsets, vals); + } else { + __ESIMD_NS::block_store(ptr, byte_offset, vals, props); + } } /// @brief Stateless scattered atomic (0 src). @@ -642,7 +653,7 @@ __XETLA_API void xetla_local_init() { /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, int N> __XETLA_API xetla_vector xetla_load_local( @@ -659,35 +670,31 @@ __XETLA_API xetla_vector xetla_load_local( xetla_cvt(offsets), pred); } -/// @brief SLM block load. (transposed gather with 1 channel). -/// Collects elements located at slm and returns them as a single \ref -/// xetla_vector object. -/// -/// Supported platforms: DG2, PVC -/// -/// VISA instruction: lsc_load.slm -/// -/// @tparam Ty is element type. -/// @tparam NElts is the number of elements to load per address (i.e. -/// vector_size per SIMD channel). -/// @tparam DS is the data size. -/// @param offset [in] is the zero-based offset for SLM buffer in bytes. -/// @return is a xetla_vector of type T and size NElts. -/// -template < - typename Ty, - uint8_t NElts = 1, - data_size DS = data_size::default_size> +/// Loads a contiguous block of SLM memory referenced by the given byte-offset +/// \p offset, then returns the loaded data as a simd object. +/// The generated code depends on the combination {T, N, Flags}. +/// Providing flags specifying the alignment of 16-bytes or more produces more +/// efficient code. If the alignment is smaller than 16-bytes, then less +/// efficient gather is generated. If the loaded vector is too long +/// for 1 flat-load GPU instruction, then a series of flat-loads and/or gathers +/// may be generated. +/// @tparam T Element type. +/// @tparam N Number of elements to load. +/// @tparam Flags The alignment specifier type tag. +/// @param byte_offset The byte-offset to load from. +/// @param Flags Specifies the alignment. +/// @return A vector of loaded elements. +/// +template __XETLA_API xetla_vector xetla_load_local(uint32_t offset) { using T = native_type_t; - DEBUG_INVOKE( - dbg_level::core, - core::general_1d::template check_restriction( - (uint64_t)offset)); + // DEBUG_INVOKE( + // dbg_level::core, + // core::general_1d::template + // check_restriction( + // (uint64_t)offset)); - return __ESIMD_ENS:: - lsc_slm_block_load( - offset); + return __ESIMD_NS::slm_block_load(offset); } /// @brief SLM scattered store. @@ -708,7 +715,7 @@ __XETLA_API xetla_vector xetla_load_local(uint32_t offset) { /// template < typename Ty, - uint8_t NElts = 1, + int NElts = 1, data_size DS = data_size::default_size, int N> __XETLA_API void xetla_store_local( @@ -726,36 +733,38 @@ __XETLA_API void xetla_store_local( offsets, vals, pred); } -/// @brief SLM block store (transposed SLM scatter with 1 channel). -/// Scatters elements located to slm. -/// -/// Supported platforms: DG2, PVC -/// -/// VISA instruction: lsc_store.slm -/// -/// @tparam Ty is element type. -/// @tparam NElts is the number of elements to store per address (i.e. -/// vector_size per SIMD channel). -/// @tparam DS is the data size. -/// @param offset [in] is the zero-based offset for SLM buffer in bytes. -/// @param vals [in] is values to store. -/// -template < - typename Ty, - uint8_t NElts = 1, - data_size DS = data_size::default_size> +/// Stores elements of the vector \p vals to a contiguous block of SLM memory +/// at the given byte-offset \p offset. +/// The generated code depends on the combination {T, N, Flags}. +/// Providing flags specifying the alignment of 16-bytes or more produces more +/// efficient code. If the alignment is smaller than 16-bytes, then less +/// efficient scatter is generated. If the stored vector is too long +/// for 1 flat-store GPU instruction, then a series of flat-store and/or +/// scatters may be generated. +/// @tparam T Element type. +/// @tparam N Number of elements to store. +/// @tparam Flags The alignment specifier type tag. +/// @param offset The byte-offset to store at. +/// @param vals The vector to store. +/// @param Flags Specifies the alignment. +/// +template __XETLA_API void xetla_store_local( uint32_t offset, xetla_vector vals) { - using T = native_type_t; - DEBUG_INVOKE( - dbg_level::core, - core::general_1d::template check_restriction( - offset)); - - __ESIMD_ENS:: - lsc_slm_block_store( - offset, vals); + // using T = native_type_t; + // DEBUG_INVOKE( + // dbg_level::core, + // core::general_1d::template + // check_restriction( + // offset)); + + // __ESIMD_ENS:: + // lsc_slm_block_store( + // offset, vals); + // __ESIMD_NS::properties props{}; + + __ESIMD_NS::slm_block_store(offset, vals); } /// @brief SLM scattered atomic (0 src). diff --git a/include/experimental/experimental.hpp b/include/experimental/experimental.hpp index 206768c7d..87f9883e0 100644 --- a/include/experimental/experimental.hpp +++ b/include/experimental/experimental.hpp @@ -19,7 +19,6 @@ #pragma once -#include #include #include #include \ No newline at end of file diff --git a/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp b/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp index 611a09b8c..142c55e18 100644 --- a/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp +++ b/include/experimental/group/fused_op/row_reduction_fused_op_xe.hpp @@ -139,10 +139,12 @@ struct row_reduction_fused_op_t< block_size_y, reg_layout::tiled>; using dgelu_w_in_t = subgroup::tile_t; + using mem_desc_in_t = + mem_desc_t; using dgelu_w_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, dgelu_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using dgelu_x_out_t = subgroup::tile_t; using dgelu_x_out_payload_t = subgroup::mem_payload_t< @@ -234,17 +236,21 @@ struct row_reduction_fused_op_t< block_size_y, reg_layout::tiled>; using mask_in_t = subgroup::tile_t; + using mem_desc_mask_t = + mem_desc_t; using mask_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_mask_t, reduction_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using dropout_bwd_out_t = subgroup::tile_t; + using mem_desc_out_t = + mem_desc_t; using dropout_bwd_out_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_out_t, reduction_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; if (dropout_prob != 0) { mask_in_t mask_in; diff --git a/include/experimental/group/gemm/common.hpp b/include/experimental/group/gemm/common.hpp index 1c48e9e7c..73f8daff4 100644 --- a/include/experimental/group/gemm/common.hpp +++ b/include/experimental/group/gemm/common.hpp @@ -19,7 +19,6 @@ #pragma once -#include #include #include #include diff --git a/include/experimental/group/gemm/compute_policy.hpp b/include/experimental/group/gemm/compute_policy.hpp index 8d7ffe33d..1de706ccb 100644 --- a/include/experimental/group/gemm/compute_policy.hpp +++ b/include/experimental/group/gemm/compute_policy.hpp @@ -22,10 +22,7 @@ #include namespace gpu::xetla::group { - -enum quant_mode : uint8_t { S4_ASYM, S4_FULLRANGE_NO_ZP }; - -/// @brief Compute policy for unaligned shape and xmx engine. +/// @brief Compute policy for int4 dequant gemm. /// @tparam compute_attr_ Is compute-related attributes. /// @tparam perf_tuning_knob_ Is performance-related knobs. /// @tparam arch_tag_ Is the HW architecture. @@ -34,21 +31,19 @@ template < typename perf_tuning_knob_, typename dtype_scale_, typename dtype_zero_pt_, - quant_mode quant_type_, - int dequant_s_, + quant_info quant_info_, mma_engine mma_engine_ = mma_engine::xmx, gpu_arch arch_tag_ = gpu_arch::XeHpc, typename enable = void> struct compute_policy_int4_dequantize {}; -/// @brief Specialized for XeHpc and XeHpg architecture. +/// @brief Specialized for xmx engine. template < typename compute_attr_, typename perf_tuning_knob_, typename dtype_scale_, typename dtype_zero_pt_, - quant_mode quant_type_, - int dequant_s_, + quant_info quant_info_, mma_engine mma_engine_, gpu_arch arch_tag_> struct compute_policy_int4_dequantize< @@ -56,11 +51,10 @@ struct compute_policy_int4_dequantize< perf_tuning_knob_, dtype_scale_, dtype_zero_pt_, - quant_type_, - dequant_s_, + quant_info_, mma_engine_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t> { using compute_attr = compute_attr_; using dtype_mma_acc = typename compute_attr::dtype_acc; using dtype_mma_a = typename compute_attr::dtype_a; @@ -73,30 +67,81 @@ struct compute_policy_int4_dequantize< static constexpr mma_engine mma_engine = mma_engine_; static constexpr gpu_arch arch_tag = arch_tag_; - static_assert( - !(mma_engine == mma_engine::xmx && arch_tag == gpu_arch::XeLpg), - "XeLpg does not support xmx"); + static_assert(arch_has_xmx, "XeLpg does not support xmx"); static constexpr bool is_int4_matB_policy = true; - static constexpr uint32_t dequant_s = dequant_s_; + static constexpr uint32_t dequant_s = quant_info_.dequant_s; static_assert( (dequant_s % (32 / sizeof(dtype_mma_b))) == 0, "dequant_s should be a multiply of 32B"); using dtype_scale = dtype_scale_; using dtype_zero_pt = dtype_zero_pt_; - static constexpr quant_mode quant_type = quant_type_; + static constexpr quant_mode quant_mode = quant_info_.quant_mode; static constexpr uint32_t block_size_y_a = 16; using mma_attr = mma_attr_t; - static constexpr uint32_t block_bytes_x_a = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 32; + static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; + static constexpr uint32_t block_size_x_a = + block_bytes_x_a / sizeof(dtype_mma_a); + static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem; + static constexpr uint32_t block_bytes_y_b = mma_attr::mma_k_in_bytes; + static constexpr uint32_t block_size_y_b = + block_bytes_y_b / sizeof(dtype_mma_b); + + static_assert( + block_bytes_x_a == block_bytes_y_b, + "mat_a x need to match with mat_b y"); +}; + +/// @brief Specialized for fpu engine. +template < + typename compute_attr_, + typename perf_tuning_knob_, + typename dtype_scale_, + typename dtype_zero_pt_, + quant_info quant_info_, + mma_engine mma_engine_, + gpu_arch arch_tag_> +struct compute_policy_int4_dequantize< + compute_attr_, + perf_tuning_knob_, + dtype_scale_, + dtype_zero_pt_, + quant_info_, + mma_engine_, + arch_tag_, + std::enable_if_t> { + using compute_attr = compute_attr_; + using dtype_mma_acc = typename compute_attr::dtype_acc; + using dtype_mma_a = typename compute_attr::dtype_a; + using dtype_mma_b = typename compute_attr::dtype_b; + + using perf_tuning_knob = perf_tuning_knob_; + static constexpr int stages = perf_tuning_knob::stages; + static constexpr int sync_freq = perf_tuning_knob::sync_freq; + static constexpr int k_stride = perf_tuning_knob::k_stride; + static constexpr mma_engine mma_engine = mma_engine_; + static constexpr gpu_arch arch_tag = arch_tag_; + + static constexpr bool is_int4_matB_policy = true; + + static constexpr uint32_t dequant_s = quant_info_.dequant_s; + static_assert( + (dequant_s % (32 / sizeof(dtype_mma_b))) == 0, + "dequant_s should be a multiply of 32B"); + using dtype_scale = dtype_scale_; + using dtype_zero_pt = dtype_zero_pt_; + static constexpr quant_mode quant_mode = quant_info_.quant_mode; + static constexpr bool is_col_major_b = + quant_info_.weight_mem_layout == mem_layout::col_major; + + static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16; + static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_size_x_b = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_n_in_elem : 32; - static constexpr uint32_t block_bytes_y_b = - (mma_engine == mma_engine::xmx) ? mma_attr::mma_k_in_bytes : 32; + static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32; + static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32; static constexpr uint32_t block_size_y_b = block_bytes_y_b / sizeof(dtype_mma_b); diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 528911fcf..4ce0ab693 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -36,8 +36,7 @@ template < typename mem_desc_b_t_, typename dtype_scale_, typename dtype_zero_pt_, - int dequant_s_, - quant_mode quant_type_, + quant_info quant_info_, mma_engine mma_engine_, typename pre_processing_t_, gpu_arch arch_tag_> @@ -47,8 +46,7 @@ class gemm_t< perf_tuning_knob_, dtype_scale_, dtype_zero_pt_, - quant_type_, - dequant_s_, + quant_info_, mma_engine_, arch_tag_>, tile_shape_, // tile shape of workgroup-level gemm @@ -66,8 +64,7 @@ class gemm_t< perf_tuning_knob_, dtype_scale_, dtype_zero_pt_, - quant_type_, - dequant_s_, + quant_info_, mma_engine_, arch_tag_>; static constexpr uint32_t k_stride = compute_policy::k_stride; @@ -80,6 +77,7 @@ class gemm_t< constexpr static gpu_arch arch_tag = compute_policy::arch_tag; static constexpr uint32_t dequant_s = compute_policy::dequant_s; + static constexpr quant_mode quant_mode = compute_policy::quant_mode; using dtype_b = typename mem_desc_b_t::dtype; using dtype_zero_pt = typename compute_policy::dtype_zero_pt; static constexpr uint32_t pack_ratio = sizeof(dtype_b) * 2; @@ -88,6 +86,8 @@ class gemm_t< static constexpr mem_layout mem_layout_b = mem_desc_b_t::layout; static constexpr bool is_col_major_a = mem_layout_a == mem_layout::col_major; static constexpr bool is_col_major_b = mem_layout_b == mem_layout::col_major; + static constexpr bool is_gemv = is_col_major_b && + compute_policy::mma_engine == mma_engine::fpu && sg_tile_m <= 4; private: /******** set data type **********/ @@ -98,11 +98,14 @@ class gemm_t< using dtype_scale = typename compute_policy::dtype_scale; static_assert( - std::is_same, remove_const_t>::value, + std::is_same, remove_const_t>::value || + std::is_same, remove_const_t>::value, "this is for 4bit matB "); static_assert( std::is_same, remove_const_t>:: - value, + value || + std::is_same, remove_const_t>:: + value, "this is for 4bit zero_pt "); /******** set memory attribute **********/ @@ -115,7 +118,6 @@ class gemm_t< is_col_major_a ? tdesc_update_dir::y_dir : tdesc_update_dir::x_dir; static constexpr tdesc_update_dir update_dir_b = is_col_major_b ? tdesc_update_dir::x_dir : tdesc_update_dir::y_dir; - static_assert(!is_col_major_b, "only support MatB row-major for now"); static_assert( (!is_local_a) && (!is_local_b), "only support from global memory for now"); @@ -132,31 +134,35 @@ class gemm_t< static constexpr uint32_t tile_size_y_c = sg_tile_m; static constexpr uint32_t block_size_x_a = - (compute_policy::block_size_x_a > tile_size_x_a) - ? tile_size_x_a - : compute_policy::block_size_x_a; + std::min(compute_policy::block_size_x_a, tile_size_x_a); static constexpr uint32_t block_size_y_a = - (compute_policy::block_size_y_a > tile_size_y_a) - ? tile_size_y_a - : compute_policy::block_size_y_a; + std::min(compute_policy::block_size_y_a, tile_size_y_a); static constexpr uint32_t block_size_x_b = - (compute_policy::block_size_x_b > tile_size_x_b) - ? tile_size_x_b - : compute_policy::block_size_x_b; + std::min(compute_policy::block_size_x_b, tile_size_x_b); static constexpr uint32_t block_size_y_b = - (compute_policy::block_size_y_b > tile_size_y_b) - ? tile_size_y_b - : compute_policy::block_size_y_b; + std::min(compute_policy::block_size_y_b, tile_size_y_b); /******** set tile **********/ static constexpr bool is_vnni_tiled_a = compute_policy::mma_engine == mma_engine::xmx ? ((sizeof(dtype_a) < sizeof(uint32_t)) && is_col_major_a) : false; + static constexpr reg_layout reg_layout_a = - compute_policy::mma_engine == mma_engine::xmx - ? (is_vnni_tiled_a ? reg_layout::vnni_tiled : reg_layout::tiled) - : reg_layout::transpose_tiled; + // fpu + compute_policy::mma_engine == mma_engine::fpu + ? (is_gemv ? reg_layout::tiled : reg_layout::transpose_tiled) + // xmx + : is_vnni_tiled_a ? reg_layout::vnni_tiled + : reg_layout::tiled; + + static constexpr reg_layout reg_layout_b = + // fpu + compute_policy::mma_engine == mma_engine::fpu + ? (is_gemv ? reg_layout::transpose_tiled : reg_layout::tiled) + // xmx + : (sizeof(dtype_mma_b) < sizeof(uint32_t)) ? reg_layout::vnni_tiled + : reg_layout::tiled; using matA_tile_desc_t = subgroup::tile_desc_t< tile_size_x_a, @@ -168,7 +174,7 @@ class gemm_t< using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, matA_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matA_acc_t = subgroup::tile_t; using matA_prefetch_payload_t = subgroup:: @@ -176,17 +182,35 @@ class gemm_t< // note: plane format, row-major // note: 4bit x 2, row-major - using matB_tile_desc_t = subgroup::tile_desc_t< - tile_size_x_b / pack_ratio, - tile_size_y_b, - block_size_x_b / pack_ratio, - block_size_y_b, - reg_layout::tiled>; + using matB_tile_desc_t = std::conditional_t< + is_col_major_b, + // compress int4 along K dimensions + subgroup::tile_desc_t< + tile_size_x_b, + tile_size_y_b / pack_ratio, + block_size_x_b, + // block_size_y_b * sizeof(dtype_mma_b) / sizeof(dtype_b), + block_size_y_b / pack_ratio, + reg_layout_b>, + // compress int4 along N dimensions + subgroup::tile_desc_t< + tile_size_x_b / pack_ratio, + tile_size_y_b, + // block_size_x_b * sizeof(dtype_mma_b) / sizeof(dtype_b), + block_size_x_b / pack_ratio, + block_size_y_b, + reg_layout_b>>; using matB_t = subgroup::tile_t; using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, matB_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + matB_tile_desc_t, + mem_desc_t< + typename mem_desc_b_t::dtype, + mem_layout::row_major, + mem_desc_b_t::space>>, + // subgroup::msg_type_v, arch_tag>; using matB_prefetch_payload_t = subgroup:: prefetch_payload_t; @@ -196,17 +220,16 @@ class gemm_t< tile_size_y_b, block_size_x_b, block_size_y_b, - compute_policy::mma_engine == mma_engine::xmx ? reg_layout::vnni_tiled - : reg_layout::tiled>; + reg_layout_b>; using matB_acc_t = subgroup::tile_t; public: static_assert( (k_stride % (block_size_y_b) == 0), - "k_stride%(block_size_y_b) == 0"); + "k_stride % (block_size_y_b) == 0"); static_assert( - (dequant_s % (block_size_y_b) == 0), - "dequant_s%(block_size_y_b) == 0"); + (dequant_s % block_size_y_b == 0 || block_size_y_b % dequant_s == 0), + "dequant_s % block_size_y_b == 0 || block_size_y_b % dequant_s == 0"); static_assert( (k_stride % (dequant_s) == 0) || (dequant_s % (k_stride) == 0), "k_stride should match with dequant_s"); @@ -222,66 +245,92 @@ class gemm_t< static constexpr uint32_t scale_addr_update_freq = (k_stride < dequant_s) ? dequant_s / k_stride : 1; + static constexpr uint32_t zero_pt_addr_update_freq = + (k_stride < dequant_s) ? dequant_s / k_stride : 1; + using mem_desc_scale_t = mem_desc_t< dtype_scale, - mem_layout::row_major, + mem_layout_b, mem_space::global, mem_desc_b_t::alignment>; + using mem_desc_zero_pt_t = mem_desc_t< dtype_zero_pt, mem_layout::row_major, mem_space::global, mem_desc_b_t::alignment>; - using matAcc_tile_desc_t = subgroup::tile_desc_t< + using matC_tile_desc_t = subgroup::tile_desc_t< tile_size_x_c, tile_size_y_c, block_size_x_b, block_size_y_a, reg_layout::tiled>; - using matAcc_t = subgroup::tile_t; + using matC_t = subgroup::tile_t; private: + using matAcc_tile_desc_t = subgroup::tile_desc_t< + block_size_y_b, + tile_size_y_a, + block_size_y_b, + block_size_y_a, + reg_layout::tiled>; + using matAcc_t = subgroup::tile_t; using scale_tile_desc_t = subgroup::tile_desc_t< tile_size_x_b, tile_size_y_scale, block_size_x_b, block_size_y_scale, - reg_layout::tiled>; + is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled>; using scale_t = subgroup::tile_t; using scale_payload_t = subgroup::mem_payload_t< mem_desc_scale_t, scale_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; + + // compress int4 along N dimensions using zero_pt_tile_desc_t = subgroup::tile_desc_t< - tile_size_x_b / pack_ratio, + (tile_size_x_b + pack_ratio - 1) / pack_ratio, tile_size_y_zero_pt, - block_size_x_b / pack_ratio, + (block_size_x_b + pack_ratio - 1) / pack_ratio, block_size_y_zero_pt, reg_layout::tiled>; + using zero_pt_t = subgroup::tile_t; using zero_pt_payload_t = subgroup::mem_payload_t< mem_desc_zero_pt_t, zero_pt_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using scale_prefetch_payload_t = subgroup:: prefetch_payload_t; using zero_pt_prefetch_payload_t = subgroup:: prefetch_payload_t; - using tile_mma = subgroup::tile_mma_t< - matAcc_t, - matAcc_t, + using tile_mma = std::conditional_t< + is_gemv, + subgroup::tile_fma_t, + subgroup::tile_mma_t< + matC_t, + matC_t, + matB_acc_t, + matA_acc_t, + compute_policy::mma_engine, + arch_tag>>; + using dequantize_t = subgroup::dequant_int4_weight_t< matB_acc_t, - matA_acc_t, - compute_policy::mma_engine, - arch_tag>; - + matB_t, + scale_t, + zero_pt_t, + dequant_s, + quant_mode>; static constexpr bool enable_periodic_sync = (sync_freq != 0); static constexpr uint32_t barrier_count_x = wg_size_y > 1 ? wg_size_x : 0; static constexpr uint32_t barrier_count_y = wg_size_x > 1 ? wg_size_y : 0; + uint32_t wg_start_m = 0; + uint32_t wg_start_n = 0; + uint32_t wg_start_k = 0; public: static constexpr uint32_t barrier_count = @@ -294,7 +343,8 @@ class gemm_t< static constexpr msg_type msg_type_b = matB_payload_t::message_type; /// @brief Arguments for gemm. - /// User should prepare matA_base_desc, matB_base_desc, inner_loop_count... + /// User should prepare matA_base_desc, matB_base_desc, + /// inner_loop_start inner_loop_count... struct arguments_t { /// @brief Is the memory description of matA, including base, shape and /// coordinate. @@ -302,6 +352,8 @@ class gemm_t< /// @brief Is the memory description of matB, including base, shape and /// coordinate. mem_desc_b_t matB_base_desc; + /// @brief The tile starting from K-dim + uint32_t inner_loop_start; /// @brief Is the total inner loop count required to compute the entire /// K-dim. uint32_t inner_loop_count; @@ -318,11 +370,13 @@ class gemm_t< inline arguments_t( mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_start, uint32_t loop_count, mem_desc_scale_t scale_desc, mem_desc_zero_pt_t zero_pt_desc) : matA_base_desc(matA_desc), matB_base_desc(matB_desc), + inner_loop_start(loop_start), inner_loop_count(loop_count), scale_base_desc(scale_desc), zero_pt_base_desc(zero_pt_desc) {} @@ -330,24 +384,28 @@ class gemm_t< inline arguments_t( mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_start, uint32_t loop_count, mem_desc_scale_t scale_desc) : matA_base_desc(matA_desc), matB_base_desc(matB_desc), + inner_loop_start(loop_start), inner_loop_count(loop_count), scale_base_desc(scale_desc) {} - // Be aware of the risks: Rule of three (copy constructor, copy assignment, - // destructor) Please check if you need to add self-define destructor inline - // ~arguments_t(){} + // Be aware of the risks: Rule of three (copy constructor, copy + // assignment, destructor) Please check if you need to add self-define + // destructor inline ~arguments_t(){} inline arguments_t(const arguments_t& args) : matA_base_desc(args.matA_base_desc), matB_base_desc(args.matB_base_desc), + inner_loop_start(args.inner_loop_start), inner_loop_count(args.inner_loop_count), scale_base_desc(args.scale_base_desc), zero_pt_base_desc(args.zero_pt_base_desc) {} inline arguments_t& operator=(const arguments_t& args) { this->matA_base_desc = args.matA_base_desc; this->matB_base_desc = args.matB_base_desc; + this->inner_loop_start = args.inner_loop_start; this->inner_loop_count = args.inner_loop_count; this->scale_base_desc = args.scale_base_desc; this->zero_pt_base_desc = args.zero_pt_base_desc; @@ -356,11 +414,13 @@ class gemm_t< inline void init( mem_desc_a_t matA_desc, mem_desc_b_t matB_desc, + uint32_t loop_start, uint32_t loop_count, mem_desc_scale_t scale_desc, mem_desc_zero_pt_t zero_pt_desc) { matA_base_desc = matA_desc; matB_base_desc = matB_desc; + inner_loop_start = loop_start; inner_loop_count = loop_count; scale_base_desc = scale_desc; zero_pt_base_desc = zero_pt_desc; @@ -408,13 +468,13 @@ class gemm_t< /// @brief Main execution function for gemm. /// The basic process is load data -> matrix multiply. /// @param g Is the workgroup of the current tile. - /// @param matAcc Is the reference of the accumulation buffer. + /// @param matC Is the reference of the accumulation buffer. /// @param args Is the gemm::arguments_t. /// @param slm_base Is the slm base address. /// @param nbarrier_base Is the named barrier base. __XETLA_API KERNEL_FUNC void operator()( work_group_t& g, - matAcc_t& matAcc, + matC_t& matC, arguments_t args, [[maybe_unused]] uint32_t slm_base = 0, uint32_t nbarrier_base = 0) { @@ -426,6 +486,8 @@ class gemm_t< matB_t matB; scale_t scale; zero_pt_t zero_pt; + matAcc_t matAcc; + matAcc.reg = 0; matA_payload_t matA_payload(args.matA_base_desc); matB_payload_t matB_payload(args.matB_base_desc); @@ -437,6 +499,13 @@ class gemm_t< zero_pt_prefetch_payload_t zero_pt_prefetch_payload( args.zero_pt_base_desc, 0); + wg_start_m = args.matA_base_desc.coord.y; + wg_start_n = args.scale_base_desc.coord.x; + wg_start_k = args.matA_base_desc.coord.x; + typename dequantize_t::arguments_t dequantize_args{ + wg_start_m, wg_start_n, wg_start_k}; + dequantize_t dequantize; + xetla_nbarrier_t nbarrier_a; nbarrier_a.init_nbarrier( sg_idy + nbarrier_base, nbarrier_role::producer_consumer); @@ -445,8 +514,8 @@ class gemm_t< sg_idx + barrier_count_y + nbarrier_base, nbarrier_role::producer_consumer); - int scale_prefetch_addr_i = args.matB_base_desc.coord.y; - int scale_load_addr_i = args.matB_base_desc.coord.y; + int scale_prefetch_addr_i = args.inner_loop_start; + int tile_k_idx = args.inner_loop_start; SW_BARRIER(); #pragma unroll for (uint32_t i = 0; i < stages; i++) { @@ -458,21 +527,25 @@ class gemm_t< subgroup::tile_prefetch( scale_prefetch_payload); if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( zero_pt_prefetch_payload); } - scale_prefetch_addr_i += dequant_s; + scale_prefetch_addr_i++; matA_prefetch_payload.template update_tdesc( matA_t::tile_size_x); matB_prefetch_payload.template update_tdesc( matB_t::tile_size_y); - if ((scale_prefetch_addr_i % dequant_s) == 0) { - scale_prefetch_payload.template update_tdesc( + if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { + scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); - zero_pt_prefetch_payload.template update_tdesc( - zero_pt_t::tile_size_y); + if constexpr ( + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + zero_pt_prefetch_payload + .template update_tdesc( + zero_pt_t::tile_size_y); + } } } @@ -493,14 +566,16 @@ class gemm_t< matA, matA_payload); subgroup::tile_load( matB, matB_payload); + // subgroup::tile_load( + // matB, matB_payload); subgroup::tile_load( scale, scale_payload); if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { subgroup::tile_load( zero_pt, zero_pt_payload); } - scale_load_addr_i += matB_t::tile_size_y; + tile_k_idx++; SW_BARRIER(); if constexpr (stages != 0) { subgroup::tile_prefetch( @@ -511,21 +586,25 @@ class gemm_t< subgroup::tile_prefetch( scale_prefetch_payload); if constexpr ( - compute_policy::quant_type != quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { // TODO 1D prefetch need pack to U32/U64 subgroup::tile_prefetch( zero_pt_prefetch_payload); } - scale_prefetch_addr_i += dequant_s; + scale_prefetch_addr_i++; } SW_BARRIER(); matA_payload.template update_tdesc(matA_t::tile_size_x); matB_payload.template update_tdesc(matB_t::tile_size_y); - if ((scale_load_addr_i % dequant_s) == 0) { - scale_payload.template update_tdesc( - scale_t::tile_size_y); - zero_pt_payload.template update_tdesc( - zero_pt_t::tile_size_y); + if (tile_k_idx % scale_addr_update_freq == 0) { + scale_payload.template update_tdesc(scale_t::tile_size_y); + } + if constexpr ( + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + if (tile_k_idx % zero_pt_addr_update_freq == 0) { + zero_pt_payload.template update_tdesc( + zero_pt_t::tile_size_y); + } } if constexpr (stages != 0) { matA_prefetch_payload.template update_tdesc( @@ -535,9 +614,12 @@ class gemm_t< if ((scale_prefetch_addr_i % scale_addr_update_freq) == 0) { scale_prefetch_payload.template update_tdesc( scale_t::tile_size_y); - zero_pt_prefetch_payload - .template update_tdesc( - zero_pt_t::tile_size_y); + if constexpr ( + compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { + zero_pt_prefetch_payload + .template update_tdesc( + zero_pt_t::tile_size_y); + } } } SW_BARRIER(); @@ -547,9 +629,23 @@ class gemm_t< subgroup::vnni_reverse(matA); } subgroup::elemwise_cvt(matA_acc, matA); - dequantize(matB_acc, matB, scale, zero_pt); + + dequantize(matB_acc, matB, scale, zero_pt, dequantize_args); SW_BARRIER(); - tile_mma::mma(matAcc, matAcc, matB_acc, matA_acc); + if constexpr (is_gemv) { + tile_mma::mma( + matAcc, + matAcc, + matC, + matB_acc, + matA_acc, + i == args.inner_loop_count - 1); + } else { + if constexpr (is_col_major_b) { + tile_transpose(matB_acc); + } + tile_mma::mma(matC, matC, matB_acc, matA_acc); + } SW_BARRIER(); if constexpr (enable_periodic_sync) { if ((i % sync_freq) == 0) { @@ -568,19 +664,113 @@ class gemm_t< } private: + // inline void dequantize( + // matB_acc_t& matB_acc, + // matB_t& matB, + // scale_t& scale, + // zero_pt_t& zero_pt) { + // // no tail, because this is matB + // constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; + // constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; + // #pragma unroll + // for (uint32_t i = 0; i < num_block_y; ++i) { + // #pragma unroll + // for (uint32_t j = 0; j < num_block_x; ++j) { + // int block_id = (i * num_block_x + j); + // // Must be little-endian + // auto matB_blk = matB.reg.xetla_format() + // .xetla_select( + // block_id * matB_acc_t::block_elems / 2); + + // auto dst_blk = matB_acc.reg.xetla_select( + // block_id * matB_acc_t::block_elems); + + // // int8 includes 2 4bits data. + // xetla_vector cvt_blk_i8; + + // // lowest 4 bit + // { + // cvt_blk_i8.xetla_select(0) = + // matB_blk & 0xf; + // } + // // highest 4 bit + // { + // cvt_blk_i8.xetla_select(1) = + // matB_blk >> 4; + // } + + // // (b_i8 - zero_pt_i8) x scale = fp16 + // constexpr uint32_t step = std::min(block_size_y_b, dequant_s); + // #pragma unroll + // for (uint32_t jj = 0; jj < block_size_x_b; jj++) { + // #pragma unroll + // for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { + // uint32_t offset_y_in_tile = i * block_size_y_b + ii; + // uint32_t offset_x_in_tile = j * block_size_x_b + jj; + + // uint32_t scale_idx = + // (offset_y_in_tile) / dequant_s * scale_t::block_size_x + + // offset_x_in_tile; + + // if constexpr (compute_policy::quant_mode == + // quant_mode::S4_ASYM) { + // uint32_t zero_pt_idx = + // offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + + // offset_x_in_tile / pack_ratio; + // native_type_t zero_pt_pack = + // zero_pt.reg[zero_pt_idx]; + + // int8_t zero_pt_i8 = + // (zero_pt_pack >> + // (4 * ((wg_start_n + offset_x_in_tile) % pack_ratio))) & + // 0xf; + // // sycl::ext::oneapi::experimental::printf( + // // "zero_pt.reg[%d} %x zero_pt_i8 %x + // offset_x_in_tile:%d + // // \n", zero_pt_idx, zero_pt_pack, (int32_t)zero_pt_i8 , + // // offset_x_in_tile); + + // cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + // cvt_blk_i8.xetla_select(jj * block_size_y_b + + // ii) - zero_pt_i8; + // } else if constexpr ( + // compute_policy::quant_mode == + // quant_mode::S4_FULLRANGE_NO_ZP) { + // cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + // cvt_blk_i8.xetla_select(jj * block_size_y_b + + // ii) - int8_t(8); + // } + // dst_blk.xetla_select(jj * block_size_y_b + ii) = + // cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) + // * scale.reg[scale_idx]; + + // // sycl::ext::oneapi::experimental::printf( + // // "scale[%d] %f \n", + // // scale_idx, + // // float(sycl::half(scale.reg.xetla_select<1, + // 1>(scale_idx)))); + // } + // } + // } + // } + // } + + /* inline void dequantize( - matB_acc_t& matB_acc, - matB_t& matB, - scale_t& scale, - zero_pt_t& zero_pt) { + matB_acc_t & matB_acc, + matB_t & matB, + scale_t & scale, + zero_pt_t & zero_pt) { // no tail, because this is matB constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; constexpr uint32_t block_b_y_per_scale = dequant_s / block_size_y_b; -#pragma unroll + constexpr uint32_t block_b_x_per_scale = dequant_s / block_size_x_b; + #pragma unroll for (uint32_t i = 0; i < num_block_y; ++i) { -#pragma unroll + #pragma unroll for (uint32_t j = 0; j < num_block_x; ++j) { int block_id = (i * num_block_x + j); auto matB_blk = matB.reg @@ -590,7 +780,6 @@ class gemm_t< int scale_block_id = (i / block_b_y_per_scale * num_block_x + j); auto scale_vec = scale.reg.xetla_select( scale_block_id * scale_t::block_size_x); - auto dst_blk = matB_acc.reg.xetla_select( block_id * matB_acc_t::block_elems); @@ -598,7 +787,7 @@ class gemm_t< xetla_vector cvt_blk; xetla_vector cvt_blk_i32; - if constexpr (compute_policy::quant_type == quant_mode::S4_ASYM) { + if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) { auto zero_pt_vec = zero_pt.reg .xetla_select( scale_block_id * zero_pt_t::block_size_x) @@ -608,9 +797,10 @@ class gemm_t< xetla_vector zero_pt_sub; zero_pt_sub.xetla_select(0) = zero_pt_vec & 0x0f; - zero_pt_sub.xetla_select(1) = zero_pt_vec >> 4; + zero_pt_sub.xetla_select(1) = + zero_pt_vec >> 4; xetla_vector zero_pt_blk; -#pragma unroll + #pragma unroll for (uint32_t row = 0; row < block_size_y_b; row++) { zero_pt_blk.xetla_select(row * block_size_x_b) .xetla_format() = @@ -621,9 +811,10 @@ class gemm_t< zero_pt_blk.xetla_format()); } if constexpr ( - compute_policy::quant_type == quant_mode::S4_FULLRANGE_NO_ZP) { + compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { xetla_vector cvt_blk_i8; - cvt_blk_i8.xetla_select(0) = matB_blk & 0x0f; + cvt_blk_i8.xetla_select(0) = + matB_blk & 0x0f; cvt_blk_i8.xetla_select(0) = cvt_blk_i8.xetla_select(0) << 4; cvt_blk_i8.xetla_select(0) = @@ -633,15 +824,16 @@ class gemm_t< cvt_blk_i32 = (cvt_blk_i8.xetla_format()); } if constexpr (compute_policy::mma_engine == mma_engine::xmx) { - constexpr uint32_t vnni_rows = sizeof(uint32_t) / sizeof(dtype_mma_b); + constexpr uint32_t vnni_rows = + sizeof(uint32_t) / sizeof(dtype_mma_b); xetla_vector temp_blk; temp_blk.xetla_select(0) = cvt_blk_i32; -#pragma unroll + #pragma unroll for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { -#pragma unroll + #pragma unroll for (uint32_t row = 0; row < vnni_rows; row++) { temp_blk.xetla_select( row + block_size_x_b * k * vnni_rows) = @@ -651,12 +843,13 @@ class gemm_t< } xetla_vector scale_blk; -#pragma unroll + #pragma unroll for (uint32_t row = 0; row < vnni_rows; row++) { - scale_blk.xetla_select(row) = scale_vec; + scale_blk.xetla_select(row) = + scale_vec; } -#pragma unroll + #pragma unroll for (uint32_t k = 0; k < block_size_y_b; k += vnni_rows) { dst_blk.xetla_select( k * block_size_x_b) = @@ -665,7 +858,7 @@ class gemm_t< scale_blk; } } else { -#pragma unroll + #pragma unroll for (uint32_t k = 0; k < block_size_y_b; k++) { dst_blk.xetla_select(k * block_size_x_b) = cvt_blk_i32.xetla_select( @@ -675,7 +868,8 @@ class gemm_t< } } } - } + } */ + /// @brief Updates tile base descriptor based on the tid. __XETLA_API static void update_sg_tile_tdesc( arguments_t& args, @@ -685,7 +879,11 @@ class gemm_t< int32_t tile_offset_m = sg_idy * sg_tile_m; args.matA_base_desc.update_coord_y(tile_offset_m); - args.matB_base_desc.update_coord_x(tile_offset_n / pack_ratio); + if constexpr (is_col_major_b) { + args.matB_base_desc.update_coord_x(tile_offset_n); + } else { + args.matB_base_desc.update_coord_x(tile_offset_n / pack_ratio); + } args.scale_base_desc.update_coord_x(tile_offset_n); args.zero_pt_base_desc.update_coord_x(tile_offset_n / pack_ratio); } diff --git a/include/experimental/group/reduction/row_reduce_store_xe.hpp b/include/experimental/group/reduction/row_reduce_store_xe.hpp index a38d4e5bb..29a2474ef 100644 --- a/include/experimental/group/reduction/row_reduce_store_xe.hpp +++ b/include/experimental/group/reduction/row_reduce_store_xe.hpp @@ -58,10 +58,12 @@ struct group_row_reduce_store_t< using local_st_tile_desc_t = subgroup::tile_desc_t; using local_st_t = subgroup::tile_t; + using mem_desc_acc = + mem_desc_t; using local_st_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_acc, local_st_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using local_ld_tile_desc_t = subgroup::tile_desc_t< local_tile_size_x, @@ -70,10 +72,12 @@ struct group_row_reduce_store_t< wg_size_y, reg_layout::tiled>; using local_ld_t = subgroup::tile_t; + using mem_desc_ld_t = + mem_desc_t; using local_ld_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_ld_t, local_ld_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; // If the local tile size is small, we still can use 2D block store diff --git a/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp b/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp index 084ebdc88..1cd3b7153 100644 --- a/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp +++ b/include/experimental/kernel/col_major_shuf/col_major_shuf.hpp @@ -20,6 +20,6 @@ #pragma once #include +#include #include #include -#include diff --git a/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp b/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp index 0395a77fd..6b860686d 100644 --- a/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp +++ b/include/experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp @@ -83,7 +83,7 @@ struct col_major_shuf_t< using store_tile_payload_t = subgroup::mem_payload_t< mem_desc_store_tile_t, store_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_>; using mem_desc_gidx_t = mem_desc_t< @@ -97,7 +97,7 @@ struct col_major_shuf_t< using gidx_payload_t = subgroup::mem_payload_t< mem_desc_gidx_t, gidx_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_>; struct arguments_t { diff --git a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp index 4cbd1498a..2ccf069d1 100644 --- a/include/experimental/kernel/data_transformer/data_transformer_xe.hpp +++ b/include/experimental/kernel/data_transformer/data_transformer_xe.hpp @@ -122,10 +122,11 @@ struct xetla_data_transformer< block_size_y, in_reg_layout>; using global_ld_t = subgroup::tile_t; + using mem_desc_ld_t = mem_desc_t; using global_ld_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_ld_t, global_ld_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using global_st_tile_desc_t = subgroup::tile_desc_t< diff --git a/include/experimental/kernel/gemm/common.hpp b/include/experimental/kernel/gemm/common.hpp index 367583687..2004ce0cc 100644 --- a/include/experimental/kernel/gemm/common.hpp +++ b/include/experimental/kernel/gemm/common.hpp @@ -19,7 +19,6 @@ #pragma once -#include #include #include #include diff --git a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp index 6c98df456..79d37d517 100644 --- a/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp +++ b/include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp @@ -95,7 +95,7 @@ class gemm_universal_t< using dtype_c = typename mem_desc_c_t::dtype; using dtype_scale = typename mem_desc_scale_t::dtype; using dtype_zero_pt = typename mem_desc_zero_pt_t::dtype; - using matAcc_t = typename gemm_t::matAcc_t; + using matAcc_t = typename gemm_t::matC_t; using dtype_acc = typename matAcc_t::dtype; using mem_desc_acc_t = mem_desc_t; @@ -159,7 +159,7 @@ class gemm_universal_t< /// @brief GEMM arguments. /// This is the interface for users to pass the application-related runtime /// variables. - template + template struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). @@ -295,7 +295,7 @@ class gemm_universal_t< } }; template <> - struct arguments_t { + struct arguments_t { /// @brief Is the size of the m dimension of the matrix multiplication (m x /// k x n). uint32_t matrix_m; @@ -486,7 +486,7 @@ class gemm_universal_t< /// @param args Is the GEMM arguments for application-related runtime /// variables. /// @return Expected nd_range. - template + template static cl::sycl::nd_range<3> get_nd_range(arguments_t& args) { cl::sycl::range<3> local_range = get_local_range(); cl::sycl::range<3> group_range = @@ -523,7 +523,7 @@ class gemm_universal_t< /// @param args Is the GEMM arguments for application-related runtime /// variables. /// @return Check result. - template + template static bool can_implement(arguments_t& args) { bool implementable = true; if (gemm_t::msg_type_a != msg_type::unaligned_2d) { @@ -550,24 +550,24 @@ class gemm_universal_t< args.matB_base.base, args.matB_ld / pack_ratio); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); - } else { - implementable &= kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m, + // args.matC_ld); + // } else { + // implementable &= kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } // check for int4x2 implementable &= ((args.matB_ld % pack_ratio == 0) && (args.matrix_n % pack_ratio == 0)); if constexpr ( - gemm_t::compute_policy::quant_type != - group::quant_mode::S4_FULLRANGE_NO_ZP) { + gemm_t::compute_policy::quant_mode != quant_mode::S4_FULLRANGE_NO_ZP) { implementable &= (args.zero_pt_ld % pack_ratio == 0); } @@ -584,7 +584,7 @@ class gemm_universal_t< /// variables. /// @param slm_base Is the slm base address. /// @param nbarrier_base Is the named barrier base. - template + template __XETLA_API KERNEL_FUNC void operator()( sycl::nd_item<3>& item, const arguments_t& args, @@ -598,12 +598,8 @@ class gemm_universal_t< int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n; int start_k = 0; uint32_t wg_tile_k = args.matrix_k; - uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n - ? args.matrix_n - : (start_n + wg_tile_n); - uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m - ? args.matrix_m - : (start_m + wg_tile_m); + uint32_t boundary_n = std::min(start_n + wg_tile_n, args.matrix_n); + uint32_t boundary_m = std::min(start_m + wg_tile_m, args.matrix_m); uint32_t boundary_k = wg_tile_k; if constexpr (num_global_kslicing > 1) { wg_tile_k = (wg_tile_k + num_global_kslicing - 1) / num_global_kslicing; @@ -647,10 +643,17 @@ class gemm_universal_t< args.matA_base, {boundary_k, boundary_m, args.matA_ld}, {start_k, start_m}); - mem_desc_b.init( - args.matB_base, - {boundary_n / pack_ratio, boundary_k, args.matB_ld / pack_ratio}, - {int(start_n / pack_ratio), start_k}); + if constexpr (gemm_t::is_col_major_b) { + mem_desc_b.init( + args.matB_base, + {boundary_n, boundary_k / pack_ratio, args.matB_ld / pack_ratio}, + {start_n, int(start_k / pack_ratio)}); + } else { + mem_desc_b.init( + args.matB_base, + {boundary_n / pack_ratio, boundary_k, args.matB_ld / pack_ratio}, + {int(start_n / pack_ratio), start_k}); + } uint32_t scale_size_y = ((args.matrix_k + dequant_s - 1) / dequant_s); mem_desc_scale_t mem_desc_scale( @@ -658,23 +661,28 @@ class gemm_universal_t< {args.matrix_n, scale_size_y, args.scale_ld}, {start_x_scale, start_y_scale}); + uint32_t inner_loop_start = (start_k + k_stride - 1) / k_stride; uint32_t inner_loop_count = (wg_tile_k + k_stride - 1) / k_stride; gemm_args_t gemm_args; if constexpr ( - gemm_t::compute_policy::quant_type == - group::quant_mode::S4_FULLRANGE_NO_ZP) { - gemm_args = - gemm_args_t(mem_desc_a, mem_desc_b, inner_loop_count, mem_desc_scale); + gemm_t::compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + gemm_args = gemm_args_t( + mem_desc_a, + mem_desc_b, + inner_loop_start, + inner_loop_count, + mem_desc_scale); } else { mem_desc_zero_pt_t mem_desc_zero_pt( args.zero_pt_base, - {args.matrix_n / pack_ratio, - scale_size_y, + {(args.matrix_n + pack_ratio - 1) / pack_ratio, + ((args.matrix_k + dequant_s - 1) / dequant_s), args.zero_pt_ld / pack_ratio}, {start_x_zero_pt, start_y_zero_pt}); gemm_args = gemm_args_t( mem_desc_a, mem_desc_b, + inner_loop_start, inner_loop_count, mem_desc_scale, mem_desc_zero_pt); diff --git a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp index f4c2fff61..b67d7f997 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_bwd_xe.hpp @@ -92,20 +92,26 @@ struct layer_norm_bwd_t< using gamma_in_t = subgroup::tile_t; using dx_out_t = subgroup::tile_t; + using mem_desc_y_t = + mem_desc_t; using dy_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_y_t, ln_bwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_x_t = + mem_desc_t; using x_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_x_t, ln_bwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_weight_t = + mem_desc_t; using gamma_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_weight_t, ln_bwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using dx_out_payload_t = subgroup::mem_payload_t< mem_desc_t, diff --git a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp index 08ad1eaf3..ecd6bc25b 100644 --- a/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp +++ b/include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp @@ -101,23 +101,29 @@ struct layer_norm_fwd_t< using beta_in_t = subgroup::tile_t; using y_out_t = subgroup::tile_t; + using mem_desc_x_t = + mem_desc_t; using x_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_x_t, ln_fwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_weight_t = + mem_desc_t; using gamma_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_weight_t, ln_fwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using beta_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_weight_t, ln_fwd_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_y_t = + mem_desc_t; using y_out_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_y_t, ln_fwd_tile_desc_t, msg_type::block_1d, gpu_arch::XeHpc>; @@ -320,7 +326,7 @@ struct layer_norm_fwd_t< itr_count += 1; nbarrier.wait(); - xetla_vector mu_m2_vec = + xetla_vector mu_m2_vec = xetla_load_local(slm_load_base); xetla_vector mu_vec = mu_m2_vec.xetla_select(0); diff --git a/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp b/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp index 0dac16a9d..398e6df96 100644 --- a/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp +++ b/include/experimental/kernel/mha_core_attention/mha_attn_reg.hpp @@ -226,54 +226,55 @@ struct xetla_mha_attn_reg_fwd_t { using matC_16x2048_t = subgroup::tile_t; using matC_128x64_t = subgroup::tile_t; + using mem_desc_c_t = mem_desc_t; using matC_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_64x384_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_64x512_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_32x1024_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_16x2048_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, mat_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_128x128_t = @@ -288,40 +289,41 @@ struct xetla_mha_attn_reg_fwd_t { subgroup::tile_t; using matDpotMk_128x64_t = subgroup::tile_t; + using mem_desc_dpot_t = mem_desc_t; using matDpotMk_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_128x128_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_128x256_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_64x384_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_64x384_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_64x512_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_64x512_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_32x1024_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_32x1024_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_16x2048_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_16x2048_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matDpotMk_128x64_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_dpot_t, mat_128x64_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; /// @brief Arguments for xetla_softmax_fwd_t::run. @@ -1780,47 +1782,48 @@ struct xetla_mha_attn_reg_bwd_t { using matC_32x1024_t = subgroup::tile_t; using matC_16x2048_t = subgroup::tile_t; + using mem_desc_c_t = mem_desc_t; using matC_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_64x384_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_64x384_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_64x512_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_64x512_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_32x1024_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_32x1024_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_16x2048_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_16x2048_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_tile_desc_t = subgroup::tile_desc_t< @@ -1862,39 +1865,42 @@ struct xetla_mha_attn_reg_bwd_t { subgroup::tile_t; using matC_256x64_trnp_af_t = subgroup::tile_t; - + using mem_desc_bot_t = mem_desc_t; using matC_128x64_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_256x64_trnp_a_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_256x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matW_128x128_t = subgroup::tile_t; @@ -1904,35 +1910,36 @@ struct xetla_mha_attn_reg_bwd_t { using matW_32x1024_t = subgroup::tile_t; using matW_16x2048_t = subgroup::tile_t; + using mem_desc_w_t = mem_desc_t; using matW_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_128x128_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_128x256_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_64x384_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_64x384_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_64x512_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_64x512_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_32x1024_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_32x1024_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matW_16x2048_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_w_t, matC_16x2048_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; #if 0 diff --git a/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp b/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp index e6764ff3c..41dfa7606 100644 --- a/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp +++ b/include/experimental/kernel/mha_core_attention/mha_core_attn.hpp @@ -188,17 +188,19 @@ struct xetla_mha_core_attn_fwd_t { reg_layout::tiled>; using matElem_ld_t = gpu::xetla::subgroup::tile_t; + using mem_desc_elem_ld_t = + mem_desc_t; using matElem_ld_payload_t = gpu::xetla::subgroup::mem_payload_t< - mem_desc_t, + mem_desc_elem_ld_t, matElem_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matElem_st_t = gpu::xetla::subgroup::tile_t; using matElem_st_payload_t = gpu::xetla::subgroup::mem_payload_t< - mem_desc_t, + mem_desc_elem_ld_t, matElem_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matElem_reg_t = gpu::xetla::subgroup::tile_t< float, @@ -1065,55 +1067,60 @@ struct xetla_mha_core_attn_bwd_t { subgroup::tile_t; using matC_256x64_trnp_af_t = subgroup::tile_t; - + using mem_desc_c_t = mem_desc_t; using matC_128x128_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_128x128_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x256_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_128x256_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_bot_t = mem_desc_t; using matC_128x64_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup::msg_type_v, gpu_arch::XeHpc>; using matC_128x64_trnp_a_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_256x64_trnp_a_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_256x64_trnp_a_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_128x64_trnp_af_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_128x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; using matC_256x64_trnp_af_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bot_t, matC_256x64_trnp_af_tile_desc_t, (global_kslicing > 1) ? msg_type::atomic_add - : subgroup::msg_type_v, + : subgroup:: + msg_type_v, gpu_arch::XeHpc>; // 512 = 16x32 or 8x64 @@ -1127,13 +1134,16 @@ struct xetla_mha_core_attn_bwd_t { gpu::xetla::subgroup::tile_t; using matElem_st_t = gpu::xetla::subgroup::tile_t; + + using mem_desc_elem_ld_t = + mem_desc_t; using matElem_ld_payload_t = gpu::xetla::subgroup::mem_payload_t< - mem_desc_t, + mem_desc_elem_ld_t, matElem_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using matElem_st_payload_t = gpu::xetla::subgroup::mem_payload_t< - mem_desc_t, + mem_desc_elem_ld_t, matElem_tile_desc_t, msg_type::block_2d, gpu_arch::XeHpc>; diff --git a/include/experimental/kernel/reduction/row_reduction_xe.hpp b/include/experimental/kernel/reduction/row_reduction_xe.hpp index 1d24d50b0..c2a4a11c9 100644 --- a/include/experimental/kernel/reduction/row_reduction_xe.hpp +++ b/include/experimental/kernel/reduction/row_reduction_xe.hpp @@ -108,10 +108,12 @@ struct xetla_row_reduction_t< block_size_y, reg_layout::tiled>; using global_ld_t = subgroup::tile_t; + using mem_desc_in_t = + mem_desc_t; using global_ld_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, global_ld_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using mat_buffer_t = subgroup::tile_t< dtype_acc, diff --git a/include/group/cooperative_reduction.hpp b/include/group/cooperative_reduction.hpp index b5ab96d23..25682c0e2 100644 --- a/include/group/cooperative_reduction.hpp +++ b/include/group/cooperative_reduction.hpp @@ -95,7 +95,7 @@ class cooperative_reduce_t< static constexpr uint32_t block_size_x = gpu::xetla::subgroup::detail::gcd::value; static constexpr uint32_t block_size_y = - (tile_size_y > src_block_size_y) ? src_block_size_y : tile_size_y; + std::min(src_block_size_y, tile_size_y); using local_st_tile_desc_t = subgroup::tile_desc_t< sg_tile_n, @@ -104,10 +104,12 @@ class cooperative_reduce_t< src_block_size_y, reg_layout::tiled>; using local_st_tile_t = subgroup::tile_t; + using mem_desc_st_t = + mem_desc_t; using local_st_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_st_t, local_st_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using local_ld_tile_desc_t = subgroup::tile_desc_t< tile_size_x, @@ -116,10 +118,12 @@ class cooperative_reduce_t< block_size_y, reg_layout::tiled>; using local_ld_tile_t = subgroup::tile_t; + using mem_desc_ld_t = + mem_desc_t; using local_ld_payload_t = subgroup::mem_payload_t< mem_desc_t, local_ld_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; public: diff --git a/include/group/epilogue/impl/default_xe.hpp b/include/group/epilogue/impl/default_xe.hpp index ab149396a..b25397dcd 100644 --- a/include/group/epilogue/impl/default_xe.hpp +++ b/include/group/epilogue/impl/default_xe.hpp @@ -70,9 +70,9 @@ class epilogue_t< } public: - static constexpr msg_type msg_type_c = - (mem_space_c == mem_space::global ? msg_type::block_2d - : msg_type::scatter); + // static constexpr msg_type msg_type_c = + // (mem_space_c == mem_space::global ? msg_type::block_2d + // : msg_type::scatter); /// @brief Default epilogue. /// 1) Convert dtype_acc to dtype_c 2) Overwrite to memory. @@ -94,6 +94,10 @@ class epilogue_t< [[maybe_unused]] uint32_t nbarrier_base = 0) { using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; + + // static constexpr msg_type msg_type_c = msg_type::unaligned_2d; + static constexpr msg_type msg_type_c = + subgroup::msg_type_v; using matC_payload_t = subgroup:: mem_payload_t; update_sg_tile_tdesc(g, mem_desc_c); @@ -143,9 +147,7 @@ class epilogue_t< using dtype_c = typename mem_desc_c_t::dtype; static constexpr mem_layout mem_layout_c = mem_desc_c_t::layout; static constexpr mem_space mem_space_c = mem_desc_c_t::space; - static constexpr msg_type msg_type_c = - (mem_space_c == mem_space::global ? msg_type::block_2d - : msg_type::scatter); + /// @brief Updates tile base descriptor based on the tid. __XETLA_API static void update_sg_tile_tdesc( work_group_t& g, @@ -165,8 +167,6 @@ class epilogue_t< } public: - static constexpr bool is_2d_block_c = (msg_type_c == msg_type::block_2d); - /// @brief Default epilogue. /// 1) Convert dtype_acc to dtype_c 2) Overwrite to memory. /// @tparam matAcc_t Is the type of the input tile. @@ -190,11 +190,13 @@ class epilogue_t< [[maybe_unused]] uint32_t nbarrier_base = 0) { using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; - using matC_payload_t = subgroup::mem_payload_t< - mem_desc_t, - mat_tile_desc, - msg_type_c, - arch_tag>; + + static constexpr msg_type msg_type_c = msg_type::block_2d; + // static constexpr msg_type msg_type_c = + // subgroup::msg_type_v; + + using matC_payload_t = subgroup:: + mem_payload_t; update_sg_tile_tdesc(g, mem_desc_c); diff --git a/include/group/epilogue/impl/tile_op_xe.hpp b/include/group/epilogue/impl/tile_op_xe.hpp index 656cdabde..3afaec107 100644 --- a/include/group/epilogue/impl/tile_op_xe.hpp +++ b/include/group/epilogue/impl/tile_op_xe.hpp @@ -106,10 +106,6 @@ class epilogue_t< } public: - static constexpr msg_type msg_type_c = - (mem_space_c == mem_space::global ? msg_type::block_2d - : msg_type::scatter); - /// @brief Default epilogue. /// 1) Call tile_op/chained_tile_op 2) Convert dtype_acc to dtype_c /// 3) Overwrite/reduce_sum to memory. @@ -131,6 +127,9 @@ class epilogue_t< uint32_t nbarrier_base = 0) { using mat_tile_desc = typename matAcc_t::tile_desc; using matC_t = subgroup::tile_t; + // static constexpr msg_type msg_type_c = msg_type::block_2d; + static constexpr msg_type msg_type_c = + subgroup::msg_type_v; using matC_payload_t = subgroup:: mem_payload_t; update_sg_tile_tdesc(g, mem_desc_c); diff --git a/include/group/gemm/compute_policy.hpp b/include/group/gemm/compute_policy.hpp index 0a0cd1c91..11033e907 100644 --- a/include/group/gemm/compute_policy.hpp +++ b/include/group/gemm/compute_policy.hpp @@ -118,16 +118,15 @@ struct compute_policy_default_fpu< static constexpr int sync_freq = perf_tuning_knob::sync_freq; static constexpr int k_stride = perf_tuning_knob::k_stride; - static constexpr uint32_t block_size_y_a = - arch_tag_ == gpu_arch::XeLpg ? 8 : 16; - static constexpr uint32_t block_bytes_x_a = 32; + static constexpr uint32_t block_size_y_a = 16; + using mma_attr = mma_attr_t; + static constexpr uint32_t block_bytes_x_a = mma_attr::mma_k_in_bytes; static constexpr uint32_t block_size_x_a = block_bytes_x_a / sizeof(dtype_mma_a); - static constexpr uint32_t block_bytes_x_b = - arch_attr_t::template register_attr<>::reg_in_bytes; - static constexpr uint32_t block_size_x_b = - block_bytes_x_b / sizeof(dtype_mma_b); - static constexpr uint32_t block_size_y_b = block_size_x_a; + static constexpr uint32_t block_size_x_b = mma_attr::mma_n_in_elem; + static constexpr uint32_t block_bytes_y_b = mma_attr::mma_k_in_bytes; + static constexpr uint32_t block_size_y_b = + block_bytes_y_b / sizeof(dtype_mma_b); }; /// @} xetla_gemm diff --git a/include/group/gemm/impl/default_fpu_xe.hpp b/include/group/gemm/impl/default_fpu_xe.hpp index add8e6790..7345ea322 100644 --- a/include/group/gemm/impl/default_fpu_xe.hpp +++ b/include/group/gemm/impl/default_fpu_xe.hpp @@ -150,7 +150,7 @@ class gemm_t< using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, matA_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; // the tile size in register may bigger than in memory because of the padding using matA_acc_t = subgroup::tile_t; @@ -171,7 +171,7 @@ class gemm_t< using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, matB_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matB_acc_t = subgroup::tile_t; using matB_prefetch_payload_t = subgroup::prefetch_payload_t< diff --git a/include/group/gemm/impl/default_xmx_xe.hpp b/include/group/gemm/impl/default_xmx_xe.hpp index 75a0ef79c..0626f2310 100644 --- a/include/group/gemm/impl/default_xmx_xe.hpp +++ b/include/group/gemm/impl/default_xmx_xe.hpp @@ -136,7 +136,7 @@ class gemm_t< using matA_payload_t = subgroup::mem_payload_t< mem_desc_a_t, matA_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matA_acc_t = subgroup::tile_t; using matA_prefetch_payload_t = subgroup::prefetch_payload_t< @@ -157,7 +157,7 @@ class gemm_t< using matB_payload_t = subgroup::mem_payload_t< mem_desc_b_t, matB_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; using matB_acc_t = subgroup::tile_t; using matB_prefetch_payload_t = subgroup::prefetch_payload_t< diff --git a/include/group/reduction/reduction_xe.hpp b/include/group/reduction/reduction_xe.hpp index ee39acc80..d07873551 100644 --- a/include/group/reduction/reduction_xe.hpp +++ b/include/group/reduction/reduction_xe.hpp @@ -41,14 +41,17 @@ struct group_reduce_t { subgroup::tile_desc_t; using local_ld_t = subgroup::tile_t; using local_st_t = subgroup::tile_t; + using mem_desc_ld_t = mem_desc_t; using local_ld_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_ld_t, local_ld_tile_desc, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; + using mem_desc_st_t = mem_desc_t; using local_st_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_st_t, local_st_tile_desc, + // subgroup::msg_type_v, msg_type::block_1d, gpu_arch::XeHpc>; inline group_reduce_t() = default; diff --git a/include/group/softmax/impl/softmax_bwd_xe.hpp b/include/group/softmax/impl/softmax_bwd_xe.hpp index dd31f9f1e..84d41643e 100644 --- a/include/group/softmax/impl/softmax_bwd_xe.hpp +++ b/include/group/softmax/impl/softmax_bwd_xe.hpp @@ -105,7 +105,7 @@ class softmax_t< using mat_in_payload_t = subgroup::mem_payload_t< mem_desc_in_t, mat_in_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; int32_t sg_idx = g.get_id() % wg_size_x; diff --git a/include/kernel/gemm/impl/default_xe.hpp b/include/kernel/gemm/impl/default_xe.hpp index cb6c5270b..644189db2 100644 --- a/include/kernel/gemm/impl/default_xe.hpp +++ b/include/kernel/gemm/impl/default_xe.hpp @@ -275,18 +275,19 @@ class gemm_universal_t< args.matB_base.base, args.matB_ld); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); - } else { - implementable &= kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m, + // args.matC_ld); + // } else { + // implementable &= kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/include/kernel/gemm/impl/kslicing_xe.hpp b/include/kernel/gemm/impl/kslicing_xe.hpp index 7b74226e5..64d15edb6 100644 --- a/include/kernel/gemm/impl/kslicing_xe.hpp +++ b/include/kernel/gemm/impl/kslicing_xe.hpp @@ -387,18 +387,19 @@ class gemm_universal_t< args.matB_base.base, args.matB_ld); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); - } else { - implementable &= kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m, + // args.matC_ld); + // } else { + // implementable &= kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } @@ -426,12 +427,8 @@ class gemm_universal_t< int start_n = group_swizzle.template get_tile_idx<2>(item) * wg_tile_n; int start_k = 0; uint32_t wg_tile_k = args.matrix_k; - uint32_t boundary_n = (start_n + wg_tile_n) > args.matrix_n - ? args.matrix_n - : (start_n + wg_tile_n); - uint32_t boundary_m = (start_m + wg_tile_m) > args.matrix_m - ? args.matrix_m - : (start_m + wg_tile_m); + uint32_t boundary_n = std::min(start_n + wg_tile_n, args.matrix_n); + uint32_t boundary_m = std::min(start_m + wg_tile_m, args.matrix_m); uint32_t boundary_k = wg_tile_k; if constexpr (num_global_kslicing > 1) { wg_tile_k = (wg_tile_k + num_global_kslicing - 1) / num_global_kslicing; diff --git a/include/kernel/gemm/impl/stream_k_xe.hpp b/include/kernel/gemm/impl/stream_k_xe.hpp index e281e53ae..0a23344bf 100644 --- a/include/kernel/gemm/impl/stream_k_xe.hpp +++ b/include/kernel/gemm/impl/stream_k_xe.hpp @@ -329,18 +329,19 @@ class gemm_universal_t< args.matB_base.base, args.matB_ld); } } - if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { - if (epilogue_t::msg_type_c == msg_type::block_2d) { - implementable &= kernel::block_2d::check_tensor( - (uint64_t)(args.matC_base.base), - args.matrix_n, - args.matrix_m, - args.matC_ld); - } else { - implementable &= kernel::general_1d::check_alignment( - args.matC_base.base, args.matC_ld); - } - } + // if (epilogue_t::msg_type_c != msg_type::unaligned_2d) { + // if (epilogue_t::msg_type_c == msg_type::block_2d) { + // implementable &= kernel::block_2d::check_tensor( + // (uint64_t)(args.matC_base.base), + // args.matrix_n, + // args.matrix_m, + // args.matC_ld); + // } else { + // implementable &= kernel::general_1d::check_alignment( + // args.matC_base.base, args.matC_ld); + // } + // } return implementable; } diff --git a/include/subgroup/tile/chained_tile_op.hpp b/include/subgroup/tile/chained_tile_op.hpp index 3d5698a01..294295082 100644 --- a/include/subgroup/tile/chained_tile_op.hpp +++ b/include/subgroup/tile/chained_tile_op.hpp @@ -71,6 +71,9 @@ struct chained_tile_op_arg_t inline chained_tile_op_arg_t( chained_tile_op_arg_t const& args) = default; + inline chained_tile_op_arg_t& operator=( + chained_tile_op_arg_t const& args) = + default; template inline T get() const { diff --git a/include/subgroup/tile/common.hpp b/include/subgroup/tile/common.hpp index 9385c700f..0a30a1416 100644 --- a/include/subgroup/tile/common.hpp +++ b/include/subgroup/tile/common.hpp @@ -132,23 +132,21 @@ __XETLA_API typename std::enable_if_t< base_len != 0 && payload_t::memory_space == mem_space::global> process_1d_tail(tile_t& tile, payload_t& payload, uint32_t offset) { using dtype = typename payload_t::dtype; - using mem_dtype = typename payload_t::mem_dtype; if constexpr (remained_len >= base_len) { uint32_t address_offset = offset * sizeof(dtype); - auto reg_sub = - tile.reg.xetla_select(offset); + auto reg_sub = tile.reg.xetla_select(offset); if constexpr (flag == process_flag::load) { - reg_sub.xetla_format() = - xetla_load_global( + reg_sub.xetla_format() = + xetla_load_global( payload.base_ptr, payload.base_offset + address_offset); } else { - xetla_store_global( + xetla_store_global( payload.base_ptr, payload.base_offset + address_offset, - reg_sub.xetla_format()); + reg_sub.xetla_format()); } process_1d_tail> 1), flag, L1, L2>( - tile, payload, offset + base_len * payload_t::scale_factor); + tile, payload, offset + base_len); } else { process_1d_tail> 1), flag, L1, L2>( tile, payload, offset); @@ -165,23 +163,21 @@ template < typename tile_t> __XETLA_API typename std::enable_if_t< base_len != 0 && payload_t::memory_space == mem_space::local> -process_1d_tail( - tile_t& tile, - payload_t& payload, - uint32_t offset, - uint32_t address_offset) { - using mem_dtype = typename payload_t::mem_dtype; +process_1d_tail(tile_t& tile, payload_t& payload, uint32_t offset) { + using dtype = typename payload_t::dtype; if constexpr (remained_len >= base_len) { - auto reg_sub = - tile.reg.xetla_select(offset); + uint32_t address_offset = offset * sizeof(dtype); + auto reg_sub = tile.reg.xetla_select(offset); if constexpr (flag == process_flag::load) { - reg_sub.xetla_format() = - xetla_load_local( - payload.base_address + payload.address + address_offset); + reg_sub.xetla_format() = xetla_load_local< + dtype, + base_len / sizeof(dtype), + data_size::default_size>( + payload.base_address + payload.address + address_offset); } else { - xetla_store_local( + xetla_store_local( payload.base_address + payload.address + address_offset, - reg_sub.xetla_format()); + reg_sub.xetla_format()); } process_1d_tail< remained_len - base_len, @@ -190,13 +186,7 @@ process_1d_tail( L1, L2, payload_t, - tile_t>( - tile, - payload, - offset + base_len * payload_t::scale_factor, - address_offset + - base_len * payload_t::scale_factor * - sizeof(typename tile_t::dtype)); + tile_t>(tile, payload, offset + base_len); } else { process_1d_tail< remained_len, @@ -205,7 +195,7 @@ process_1d_tail( L1, L2, payload_t, - tile_t>(tile, payload, offset, address_offset); + tile_t>(tile, payload, offset); } } @@ -315,14 +305,17 @@ struct is_floating_to_integer { is_integral::value; }; -template < - typename tile_desc_, - mem_space memory_space, - mem_layout memory_layout = mem_layout::row_major> +template struct msg_type_query { + using dtype = mem_desc_::dtype; + static constexpr mem_layout memory_layout = mem_desc_::layout; + static constexpr mem_space memory_space = mem_desc_::space; + static constexpr msg_type value = memory_space == mem_space::global - ? (((tile_desc_::tile_size_y == 1) && - (memory_layout == mem_layout::row_major)) + ? (((tile_desc_::tile_size_y == 1 && + memory_layout == mem_layout::row_major) || + (tile_desc_::tile_size_x == 1 && + memory_layout == mem_layout::col_major)) ? msg_type::block_1d : msg_type::block_2d) : (((tile_desc_::tile_size_y == 1) && @@ -331,8 +324,8 @@ struct msg_type_query { : msg_type::scatter); }; -template -constexpr msg_type msg_type_v = msg_type_query::value; +template +constexpr msg_type msg_type_v = msg_type_query::value; template < typename dtype, diff --git a/include/subgroup/tile/impl/fma_xe.hpp b/include/subgroup/tile/impl/fma_xe.hpp index c1ca0c6ff..e81d8d7df 100644 --- a/include/subgroup/tile/impl/fma_xe.hpp +++ b/include/subgroup/tile/impl/fma_xe.hpp @@ -24,6 +24,150 @@ namespace gpu::xetla::subgroup { /// @brief Is the tile mma operation functor, specialized for Xe and fpu engine. +template < + typename matAcc_t_, + typename matC_t_, + typename matB_t_, + typename matA_t_, + gpu_arch arch_tag_> +struct tile_fma_t { + using matA_t = matA_t_; + using matB_t = matB_t_; + using matC_t = matC_t_; + using matAcc_t = matAcc_t_; + using dtype_a = typename matA_t::dtype; + using dtype_b = typename matB_t::dtype; + using dtype_acc = typename matAcc_t_::dtype; + + using register_attr = + typename arch_attr_t::template register_attr<>; + + static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; + static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; + static constexpr uint32_t a_tile_elems = matA_t::tile_elems; + static constexpr uint32_t a_block_size_y = matA_t::block_size_y; + static constexpr uint32_t a_block_size_x = matA_t::block_size_x; + static constexpr uint32_t a_block_elems = matA_t::block_elems; + + static constexpr uint32_t b_tile_size_x = matB_t::tile_size_x; + static constexpr uint32_t b_tile_size_y = matB_t::tile_size_y; + static constexpr uint32_t b_tile_elems = matB_t::tile_elems; + static constexpr uint32_t b_block_size_x = matB_t::block_size_x; + static constexpr uint32_t b_block_size_y = matB_t::block_size_y; + static constexpr uint32_t b_block_elems = matB_t::block_elems; + + static constexpr uint32_t tile_size_m = a_tile_size_y; + static constexpr uint32_t tile_size_k = a_tile_size_x; + static constexpr uint32_t tile_size_n = b_tile_size_x; + static constexpr uint32_t block_size_m = a_block_size_y; + static constexpr uint32_t block_size_k = a_block_size_x; + static constexpr uint32_t block_size_n = b_block_size_x; + + static_assert( + a_tile_size_x == b_tile_size_y, + "matA tile k should match with matB tile k"); + static_assert( + a_block_size_x == b_block_size_y, + "matA block k should match with matB block k"); + static_assert( + b_block_size_y == matAcc_t::block_size_x, + "matA block k should match with matAcc block k"); + static_assert( + a_block_size_y == matAcc_t::block_size_y, + "mata block m should match with matAcc block m"); + + static_assert(tile_size_n == 1, "matB tile n must be 1"); + static_assert(b_block_size_x == 1, "matB block n must be 1"); + __XETLA_API static void mma( + matAcc_t& acc_dst, + matAcc_t& acc_src, + matC_t& c, + matB_t& b, + matA_t& a, + bool reduce) { +#pragma unroll + for (uint32_t k = 0; k < tile_size_k / block_size_k; k++) { +#pragma unroll + for (uint32_t m = 0; m < tile_size_m / block_size_m; m++) { + uint32_t a_block_idx = m * tile_size_k / block_size_k + k; + auto a_block = a.reg.xetla_select( + a_block_idx * matA_t::block_elems); +#pragma unroll + for (uint32_t n = 0; n < tile_size_n / block_size_n; n++) { + uint32_t b_block_idx = n * tile_size_k / block_size_k + k; + auto b_block = b.reg.xetla_select( + b_block_idx * matB_t::block_elems); + + auto src_block = + acc_src.reg.xetla_select( + m * matAcc_t::block_elems); + auto dst_block = + acc_dst.reg.xetla_select( + m * matAcc_t::block_elems); + fma_core( + dst_block, src_block, b_block, a_block); + } + } + } + if (reduce) { + reduce_acc_k(acc_dst, c); + } + } + template + __XETLA_API static void fma_core( + xetla_vector_ref __REF__ dst_block, + xetla_vector_ref __REF__ src_block, + xetla_vector_ref __REF__ b_block, + xetla_vector_ref __REF__ a_block) { + static_assert(blk_n == 1, "block n must be 1"); + auto dst_blk_2d = dst_block.xetla_format(); + auto src_blk_2d = src_block.xetla_format(); + auto b_blk_2d = b_block.xetla_format(); + auto a_blk_2d = a_block.xetla_format(); + +#pragma unroll + for (uint32_t m = 0; m < blk_m; m++) { + auto a_row = a_blk_2d.row(m); +#pragma unroll + for (uint32_t n = 0; n < blk_n; n++) { + auto b_row = b_blk_2d.row(n); + dst_blk_2d.row(m) = b_row * a_row + src_blk_2d.row(m); + } + } + } + __XETLA_API static void reduce_acc_k(matAcc_t& matAcc, matC_t& matC) { + // matC [tx,ty,bx,by](matmul): tile_n, 1, block_n, 1 + // matAcc[tx,ty,bx,by](matmul): tile_n, block_k, block_n, block_k + // matAcc[tx,ty,bx,by](memory): block_k, tile_n, block_k, block_n + + // static_assert( + // matC_t::tile_size_y == 1 && matC_t::block_size_y == 1, + // "matDst_t_ tile m and block m should match be 1"); + static_assert( + matAcc_t::tile_size_y == matC_t::tile_size_y, + "matAcc_t tile m should match with matDst_t_ tile m"); + static_assert( + matAcc_t::block_size_y == matC_t::block_size_y, + "matAcc_t block m should match with matDst_t_ block m"); + static constexpr uint32_t block_k = matAcc_t::block_size_x; + static constexpr uint32_t block_m = matAcc_t::block_size_y; + using dtype = matAcc_t::dtype; + +#pragma unroll + for (uint32_t m = 0; m < a_tile_size_y / a_block_size_y; m++) { + matC.reg.xetla_select(m * block_m) = + recur_col_reduce( + matAcc.reg.xetla_select( + m * block_m * block_k)); + // matC.reg = + // recur_col_reduce(matAcc.reg); + } + } +}; + +/// @brief Is the tile mma operation functor, specialized for Xe and fpu +/// engine. template < typename matAcc_dst_t_, typename matAcc_src_t_, @@ -60,8 +204,8 @@ struct tile_mma_t< static constexpr uint32_t a_tile_size_y = matA_t::tile_size_y; static constexpr uint32_t a_tile_size_x = matA_t::tile_size_x; static constexpr uint32_t a_tile_elems = matA_t::tile_elems; - static constexpr uint32_t a_block_size_w = matA_t::block_size_y; - static constexpr uint32_t a_block_size_h = matA_t::block_size_x; + static constexpr uint32_t a_block_size_y = matA_t::block_size_y; + static constexpr uint32_t a_block_size_x = matA_t::block_size_x; static constexpr uint32_t a_block_elems = matA_t::block_elems; static constexpr uint32_t b_tile_size_x = matB_t::tile_size_x; @@ -76,7 +220,7 @@ struct tile_mma_t< static constexpr uint32_t tile_size_n = matDst_t::tile_size_x; static constexpr uint32_t tile_elems = tile_size_m * tile_size_n; static constexpr uint32_t block_size_n = matDst_t::block_size_x; - static constexpr uint32_t block_size_k = a_block_size_h; + static constexpr uint32_t block_size_k = a_block_size_x; static constexpr uint32_t block_size_m = matDst_t::block_size_y; static constexpr uint32_t block_elems = block_size_m * block_size_n; @@ -90,7 +234,7 @@ struct tile_mma_t< tile_size_n == matB_t::tile_size_x, "matAcc tile n should match with matB tile n"); static_assert( - block_size_m == a_block_size_w, + block_size_m == a_block_size_y, "matAcc block m should match with matA block m"); static_assert( block_size_n == b_block_size_x, @@ -194,7 +338,7 @@ struct tile_mma_t< constexpr uint32_t tail_start_m = tile_size_m / block_size_m * block_size_m; constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; - constexpr uint32_t a_tail_blk_elems = a_block_size_h * a_tail_blk_w; + constexpr uint32_t a_tail_blk_elems = a_block_size_x * a_tail_blk_w; constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; auto a_block = a.reg.xetla_select( @@ -236,7 +380,7 @@ struct tile_mma_t< constexpr uint32_t tail_start_m = tile_size_m / block_size_m * block_size_m; constexpr uint32_t a_tail_blk_w = a_tile_size_y - tail_start_m; - constexpr uint32_t a_tail_blk_elems = a_block_size_h * a_tail_blk_w; + constexpr uint32_t a_tail_blk_elems = a_block_size_x * a_tail_blk_w; constexpr uint32_t tail_size_m = tile_size_m - tail_start_m; constexpr uint32_t acc_tail_blk_elems = tail_size_m * block_size_n; auto a_block = a.reg.xetla_select( diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index 216a57d96..c3c127bd5 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -26,17 +26,16 @@ namespace gpu::xetla::subgroup { namespace detail { -template +template struct check_load_type { - static constexpr bool is_lsc_gather = true; + static constexpr bool is_lsc_gather = is_lsc_gather_; static constexpr bool is_global_block_2d = (payload_t::memory_space == mem_space::global && (payload_t::message_type == msg_type::block_2d)); static constexpr bool is_global_block_1d = - ((payload_t::memory_space == mem_space::global) && - (tile_t::tile_size_y == 1) && (tile_t::block_size_y == 1) && - (payload_t::message_type == msg_type::block_1d)); + (payload_t::memory_space == mem_space::global && + payload_t::message_type == msg_type::block_1d); static constexpr bool is_global_unaligned_2d_xe = ((payload_t::memory_space == mem_space::global) && @@ -101,7 +100,7 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr bool reg_transpose = tile_desc::reg_transpose; static constexpr bool mem_transpose = payload_t::mem_transpose; - static constexpr bool trans = reg_transpose ^ mem_transpose; + static constexpr bool trans = payload_t::trans; static constexpr uint32_t scale_factor = payload_t::scale_factor; static constexpr bool mem_transform = payload_t::mem_transform; @@ -214,13 +213,12 @@ tile_load(tile_t& tile, payload_t& payload) { trans, mem_transform, arch_tag>(tdesc); - if constexpr (reg_transpose && trans) { reg_blk.xetla_select(ii * load_elems) .xetla_format>() = reg_tmp .xetla_format< - load_dtype, + native_type_t, block_size_x / scale_factor, ld_blk_height>() .xetla_select< @@ -395,33 +393,31 @@ template < __XETLA_API typename std::enable_if_t< detail::check_load_type::is_global_block_1d> tile_load(tile_t& tile, payload_t& payload) { - using dtype = typename tile_t::dtype; - using load_dtype = typename payload_t::mem_dtype; - - static constexpr uint32_t tile_size_x = tile_t::tile_size_x; - static constexpr uint32_t scale_factor = payload_t::scale_factor; - static constexpr uint32_t load_len = tile_size_x / scale_factor; + using dtype = typename payload_t::dtype; + static constexpr uint32_t load_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_load_vec_len = load_store_attr::max_load_vec_len; + static constexpr uint32_t max_load_vec_elems = + max_load_vec_len / sizeof(dtype); - static constexpr uint32_t load_iter_steps = load_len / max_load_vec_len; - if constexpr (load_len >= max_load_vec_len) { + static constexpr uint32_t load_iter_steps = load_len / max_load_vec_elems; + if constexpr (load_len >= max_load_vec_elems) { #pragma unroll for (uint32_t i = 0; i < load_iter_steps; i++) { - uint32_t offset_x = i * max_load_vec_len * scale_factor; - auto reg_sub = - tile.reg.xetla_select(offset_x); + uint32_t offset_x = i * max_load_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); - reg_sub.xetla_format() = - xetla_load_global( + reg_sub.xetla_format() = + xetla_load_global( payload.base_ptr, payload.base_offset + address_offset); } } - constexpr uint32_t tail_len = load_len % max_load_vec_len; - uint32_t tail_offset = load_iter_steps * max_load_vec_len * scale_factor; + constexpr uint32_t tail_len = load_len % max_load_vec_elems * sizeof(dtype); + uint32_t tail_offset = load_iter_steps * max_load_vec_len; detail::process_1d_tail< tail_len, (max_load_vec_len >> 1), @@ -469,64 +465,73 @@ tile_load(tile_t& tile, payload_t& payload) { uint32_t offset_x = j * tile_desc::block_size_x; auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); -#pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; - sub_block_y += num_channel) { - xetla_vector reg_tmp = 0; - uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) - : offset_x * sizeof(dtype) + - (offset_y + 0) * payload.pitch_in_bytes; - - const uint32_t sub_block_offset_x = payload.base_x + offset_x + 0; - const uint32_t sub_block_offset_y = - payload.base_y + offset_y + sub_block_y; + // #pragma unroll + // for (uint32_t sub_block_offset = 0; sub_block_offset < + // (payload_t::mem_transpose ? tile_desc::block_size_x + // : tile_desc::block_size_y); + // sub_block_offset += num_channel) { + uint32_t sub_block_offset = 0; + xetla_vector reg_tmp = 0; + uint32_t address_offset = payload_t::mem_transpose + ? (offset_x + sub_block_offset) * payload.pitch_in_bytes + + offset_y * sizeof(dtype) + : offset_x * sizeof(dtype) + + (offset_y + sub_block_offset) * payload.pitch_in_bytes; + xetla_mask pred = 1; + if constexpr (num_channel > 1) { + // For SDP load, need pred + const uint32_t sub_block_offset_x = payload.base_x + offset_x + + (payload_t::mem_transpose ? sub_block_offset : 0); + const uint32_t sub_block_offset_y = payload.base_y + offset_y + + (payload_t::mem_transpose ? 0 : sub_block_offset); const auto offset_ch_dim = payload_t::trans ? sub_block_offset_x : sub_block_offset_y; const auto size_ch_dim = payload_t::trans ? payload.width_in_elems : payload.height_in_elems; - xetla_mask pred = offset_ch_dim + num_channel > size_ch_dim + pred = offset_ch_dim + num_channel > size_ch_dim ? (xetla_vector_gen(offset_ch_dim, 1) < size_ch_dim) : 1; - - reg_tmp = xetla_load_global< - load_dtype, - payload_t::simd_exec_size, - data_size::default_size, - L1, - L2, - payload_t::num_channel>( - payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, - pred); - if constexpr (payload_t::simd_exec_size > 1) { - xetla_vector reg_tmp_trans; + } + reg_tmp = xetla_load_global< + load_dtype, + payload_t::simd_exec_size, + data_size::default_size, + L1, + L2, + payload_t::num_channel>( + payload.base_ptr, + payload.channel_offset + payload.base_offset + address_offset, + pred); + + if constexpr ( + payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) { + xetla_vector reg_tmp_trans; #pragma unroll - for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { - if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix - reg_tmp_trans.xetla_select( - iii * payload_t::simd_exec_size) = - reg_tmp.xetla_select< - payload_t::simd_exec_size, - payload_t::num_channel>(iii); - else // TODO (dingyi): Delete after driver fix - reg_tmp_trans.xetla_select( - iii * payload_t::simd_exec_size) = 0; - } - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp_trans; - } else { - reg_sub - .xetla_select( - sub_block_y * tile_desc::block_size_x) - .xetla_format() = reg_tmp; + for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { + if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix + reg_tmp_trans.xetla_select( + iii * payload_t::simd_exec_size) = + reg_tmp.xetla_select< + payload_t::simd_exec_size, + payload_t::num_channel>(iii); + else // TODO (dingyi): Delete after driver fix + reg_tmp_trans.xetla_select( + iii * payload_t::simd_exec_size) = 0; } + reg_sub + .xetla_select( + sub_block_offset * tile_desc::block_size_x) + .xetla_format() = reg_tmp_trans; + } else { + reg_sub + .xetla_select( + sub_block_offset * tile_desc::block_size_x) + .xetla_format() = reg_tmp; } } + // } } if constexpr (payload_t::trans) { @@ -565,7 +570,9 @@ __XETLA_API typename std::enable_if_t< tile_load(tile_t& tile, payload_t& payload) { using dtype = typename payload_t::dtype; using tile_desc = typename payload_t::tile_desc; - constexpr uint32_t load_elems = tile_desc::block_size_x; + constexpr uint32_t load_elems = payload_t::mem_transpose + ? tile_desc::block_size_y + : tile_desc::block_size_x; #pragma unroll for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { @@ -576,11 +583,13 @@ tile_load(tile_t& tile, payload_t& payload) { auto reg_sub = tile.reg.xetla_select( (i * tile_desc::num_block_x + j) * tile_desc::block_elems); #pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; + for (uint32_t sub_block_y = 0; + sub_block_y < (payload_t::mem_transpose ? tile_desc::block_size_x + : tile_desc::block_size_y); sub_block_y += 1) { uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + - (offset_y + sub_block_y) * sizeof(dtype) + ? (offset_x + sub_block_y) * payload.pitch_in_bytes + + offset_y * sizeof(dtype) : offset_x * sizeof(dtype) + (offset_y + sub_block_y) * payload.pitch_in_bytes; @@ -848,46 +857,35 @@ __XETLA_API typename std::enable_if_t< detail::check_load_type::is_local_block_1d_xe> tile_load(tile_t& tile, payload_t& payload) { using dtype = typename tile_t::dtype; - using tile_desc = typename tile_t::tile_desc; - using load_dtype = typename payload_t::mem_dtype; + static constexpr uint32_t load_len = tile_t::tile_elems; + static constexpr gpu_arch arch_tag = payload_t::arch_tag; - constexpr uint32_t scale_factor = payload_t::scale_factor; - constexpr uint32_t load_len = tile_desc::tile_size_x / scale_factor; - constexpr gpu_arch arch_tag = payload_t::arch_tag; using load_store_attr = load_store_attr_t; - constexpr uint32_t max_load_vec_len = load_store_attr::max_load_vec_len; + static constexpr uint32_t max_load_vec_len = + load_store_attr::max_load_vec_len; + static constexpr uint32_t max_load_vec_elems = + max_load_vec_len / sizeof(dtype); - constexpr uint32_t load_iter_steps = load_len / max_load_vec_len; -#pragma unroll - for (uint32_t i = 0; i < tile_desc::tile_size_y; i++) { - uint32_t offset_y = i * tile_desc::tile_size_x; - uint32_t address_offset_y = i * payload.pitch_in_bytes; - if constexpr (load_len >= max_load_vec_len) { + static constexpr uint32_t load_iter_steps = load_len / max_load_vec_elems; + if constexpr (load_len >= max_load_vec_elems) { #pragma unroll - for (uint32_t j = 0; j < load_iter_steps; j++) { - uint32_t offset_x = j * max_load_vec_len * scale_factor; - auto reg_sub = - tile.reg.xetla_select( - offset_x + offset_y); - uint32_t address_offset = address_offset_y + offset_x * sizeof(dtype); - reg_sub.xetla_format() = xetla_load_local< - load_dtype, - max_load_vec_len, - data_size::default_size>( - payload.base_address + payload.address + address_offset); - } + for (uint32_t j = 0; j < load_iter_steps; j++) { + uint32_t offset_x = j * max_load_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset_x); + uint32_t address_offset = offset_x * sizeof(dtype); + reg_sub.xetla_format() = + xetla_load_local( + payload.base_address + payload.address + address_offset); } - uint32_t tail_offset = - offset_y + load_iter_steps * max_load_vec_len * scale_factor; - uint32_t tail_address_offset = address_offset_y + - load_iter_steps * max_load_vec_len * scale_factor * sizeof(dtype); - detail::process_1d_tail< - load_len % max_load_vec_len, - (max_load_vec_len >> 1), - detail::process_flag::load, - L1, - L2>(tile, payload, tail_offset, tail_address_offset); } + constexpr uint32_t tail_len = load_len % max_load_vec_elems * sizeof(dtype); + uint32_t tail_offset = load_iter_steps * max_load_vec_len; + detail::process_1d_tail< + tail_len, + (max_load_vec_len >> 1), + detail::process_flag::load, + L1, + L2>(tile, payload, tail_offset); } } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 85e83b45b..8c43b5a32 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -706,16 +706,18 @@ layout_convert(T_dst& dst, T_src& src) { template void dump_mat( T mat, - size_t tile_x = T::tile_size_x, - size_t tile_y = T::tile_size_y) { + size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x, + size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) { #pragma unroll for (size_t row = 0; row < tile_y; row++) { #pragma unroll for (size_t col = 0; col < tile_x; col++) { sycl::ext::oneapi::experimental::printf( - "%d ", (int)(sycl::half)mat.reg[row * tile_x + col]); + "%x(%d) ", + int(native_type_t(mat.reg[row * tile_x + col])), + int(native_type_t(mat.reg[row * tile_x + col]))); } - sycl::ext::oneapi::experimental::printf("\n "); + sycl::ext::oneapi::experimental::printf("\n"); } sycl::ext::oneapi::experimental::printf("\n "); } @@ -728,9 +730,9 @@ void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) { sycl::ext::oneapi::experimental::printf( "%d ", (int)(sycl::half)mat[row * tile_x + col]); } - sycl::ext::oneapi::experimental::printf("\n "); + sycl::ext::oneapi::experimental::printf("\n"); } - sycl::ext::oneapi::experimental::printf("\n "); + sycl::ext::oneapi::experimental::printf("\n"); } } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index c895614e0..7a98b03ab 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -70,7 +70,9 @@ struct mem_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose; + + static constexpr bool trans = (mem_transpose ^ reg_transpose) && + !(std::is_same_v || std::is_same_v); static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose && (register_layout == reg_layout::vnni_tiled || @@ -84,7 +86,7 @@ struct mem_payload_t< xetla_vector payloads; inline mem_payload_t(const this_payload_t& rhs) { - this->payload = rhs.payload; + this->payloads = rhs.payloads; } inline mem_payload_t(mem_desc_t& mem_desc) { @@ -157,7 +159,7 @@ struct mem_payload_t< // ~mem_payload_t(){} inline this_payload_t& operator=(const this_payload_t& rhs) { - this->payload = rhs.payload; + this->payloads = rhs.payloads; return *this; } @@ -397,19 +399,20 @@ template < typename dtype_, typename tile_desc_, gpu_arch arch_tag_, - uint32_t alignment_> + uint32_t alignment_, + mem_layout memory_layout_> struct mem_payload_t< - mem_desc_t, + mem_desc_t, tile_desc_, msg_type::block_1d, arch_tag_, std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { using mem_desc_t = - mem_desc_t; - using dtype = dtype_; + mem_desc_t; + using dtype = native_type_t; using tile_desc = tile_desc_; static constexpr mem_space memory_space = mem_space::global; - static constexpr mem_layout memory_layout = mem_layout::row_major; + static constexpr mem_layout memory_layout = memory_layout_; static constexpr msg_type message_type = msg_type::block_1d; static constexpr uint32_t alignment_in_bytes = mem_desc_t::alignment_in_bytes; static constexpr gpu_arch arch_tag = arch_tag_; @@ -421,75 +424,104 @@ struct mem_payload_t< static constexpr uint32_t tile_size_x = tile_desc::tile_size_x; static constexpr uint32_t tile_size_y = tile_desc::tile_size_y; static_assert( - tile_size_y == 1, - "For tile_size_y > 1 case, please use 2d block message! "); + (tile_size_y == 1 && memory_layout == mem_layout::row_major) || + (tile_size_x == 1 && memory_layout == mem_layout::col_major), + "For tile_size_y > 1 or tile_size_x > 1 case, please use 2d block message! "); using this_payload_t = mem_payload_t; + static constexpr bool mem_transpose = memory_layout == mem_layout::col_major; public: - static constexpr uint32_t bytes_per_row = tile_size_x * sizeof(dtype); - using mem_dtype = typename std::conditional< - (bytes_per_row % sizeof(uint64_t) == 0) && - (alignment_in_bytes % sizeof(uint64_t) == 0), - uint64_t, - typename std::conditional< - (bytes_per_row % sizeof(uint32_t) == 0), - uint32_t, - dtype>::type>::type; - static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype); - uint64_t base_offset; - mem_dtype* base_ptr; + dtype* base_ptr; uint32_t pitch_in_bytes; + uint32_t height_in_elems; + uint32_t width_in_elems; + uint32_t payload_bytes; inline mem_payload_t(mem_desc_t& mem_tdesc) { pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); + width_in_elems = mem_tdesc.shape.x; + height_in_elems = mem_tdesc.shape.y; + payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes + + mem_tdesc.shape.y * sizeof(dtype) + : (mem_tdesc.shape.y - 1) * pitch_in_bytes + + mem_tdesc.shape.x * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; - base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - base_ptr = (mem_dtype*)mem_tdesc.base.base; + base_offset = mem_transpose + ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) + : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + base_ptr = (dtype*)mem_tdesc.base.base; } inline mem_payload_t( dtype* p, - [[maybe_unused]] int surface_width, - [[maybe_unused]] int surface_height, + int surface_width, + int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y) { pitch_in_bytes = surface_pitch * sizeof(dtype); uint32_t offset_x = surface_offset_x; uint32_t offset_y = surface_offset_y; - base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - base_ptr = (mem_dtype*)p; + width_in_elems = surface_width; + height_in_elems = surface_height; + payload_bytes = mem_transpose ? (surface_offset_x - 1) * pitch_in_bytes + + surface_offset_y * sizeof(dtype) + : (surface_offset_y - 1) * pitch_in_bytes + + surface_offset_x * sizeof(dtype); + base_offset = mem_transpose + ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) + : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + base_ptr = (dtype*)p; } __XETLA_API void init(mem_desc_t& mem_tdesc) { pitch_in_bytes = mem_tdesc.shape.stride * sizeof(dtype); uint32_t offset_x = mem_tdesc.coord.x; uint32_t offset_y = mem_tdesc.coord.y; - base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - base_ptr = (mem_dtype*)mem_tdesc.base.base; + width_in_elems = mem_tdesc.shape.x; + height_in_elems = mem_tdesc.shape.y; + payload_bytes = mem_transpose ? (mem_tdesc.shape.x - 1) * pitch_in_bytes + + mem_tdesc.shape.y * sizeof(dtype) + : (mem_tdesc.shape.y - 1) * pitch_in_bytes + + mem_tdesc.shape.x * sizeof(dtype); + + base_offset = mem_transpose + ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) + : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + base_ptr = (dtype*)mem_tdesc.base.base; } __XETLA_API void init( dtype* p, - [[maybe_unused]] int surface_width, - [[maybe_unused]] int surface_height, + int surface_width, + int surface_height, int surface_pitch, int surface_offset_x, int surface_offset_y) { pitch_in_bytes = surface_pitch * sizeof(dtype); uint32_t offset_x = surface_offset_x; uint32_t offset_y = surface_offset_y; - base_offset = offset_y * pitch_in_bytes + offset_x * sizeof(dtype); - base_ptr = (mem_dtype*)p; + width_in_elems = surface_width; + height_in_elems = surface_height; + payload_bytes = mem_transpose + ? (surface_width - 1) * pitch_in_bytes + surface_height * sizeof(dtype) + : (surface_height - 1) * pitch_in_bytes + surface_width * sizeof(dtype); + base_offset = mem_transpose + ? offset_x * pitch_in_bytes + offset_y * sizeof(dtype) + : offset_y * pitch_in_bytes + offset_x * sizeof(dtype); + base_ptr = (dtype*)p; } inline mem_payload_t(const this_payload_t& rhs) { this->base_offset = rhs.base_offset; this->base_ptr = rhs.base_ptr; this->pitch_in_bytes = rhs.pitch_in_bytes; + this->width_in_elems = rhs.width_in_elems; + this->height_in_elems = rhs.height_in_elems; + this->payload_bytes = rhs.payload_bytes; } inline mem_payload_t() = default; @@ -497,6 +529,9 @@ struct mem_payload_t< this->base_offset = rhs.base_offset; this->base_ptr = rhs.base_ptr; this->pitch_in_bytes = rhs.pitch_in_bytes; + this->width_in_elems = rhs.width_in_elems; + this->height_in_elems = rhs.height_in_elems; + this->payload_bytes = rhs.payload_bytes; return *this; } @@ -1066,9 +1101,8 @@ struct mem_payload_t< tile_desc_, msg_type::block_2d, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpg)>> { - using dtype = - std::conditional_t, uint8_t, dtype_>; + std::enable_if_t<(arch_has_2d_load_store)>> { + using dtype = native_type_t; using mem_desc_t = mem_desc_t; using tile_desc = tile_desc_; @@ -1084,11 +1118,6 @@ struct mem_payload_t< static constexpr uint32_t block_size_x = tile_desc::block_size_x; static constexpr uint32_t block_size_y = tile_desc::block_size_y; - static constexpr uint32_t block_per_row_bytes = - alignment_in_bytes < block_size_x * sizeof(dtype) - ? alignment_in_bytes - : block_size_x * sizeof(dtype); - using this_payload_t = mem_payload_t; @@ -1097,7 +1126,8 @@ struct mem_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose; + static constexpr bool trans = (mem_transpose ^ reg_transpose) && + !(std::is_same_v || std::is_same_v); static constexpr bool mem_transform = (sizeof(dtype) < 4) && (register_layout == reg_layout::vnni_tiled || @@ -1109,6 +1139,11 @@ struct mem_payload_t< block_size_x * block_size_y * sizeof(dtype); // using mem_dtype = uint32_t; + + static constexpr uint32_t block_per_row_bytes = std::min( + (mem_transpose ? block_size_y : block_size_x) * uint32_t(sizeof(dtype)), + alignment_in_bytes); + using mem_dtype = typename std::conditional< (block_per_row_bytes % sizeof(uint64_t) == 0), uint64_t, @@ -1616,10 +1651,11 @@ struct prefetch_payload_t< reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t<( - arch_tag_ <= gpu_arch::XeHpg && - (tile_size_y_ != 1 || block_size_y_ != 1))>> { - using dtype = dtype_; + std::enable_if_t< + arch_has_2d_load_store && + ((block_size_y_ != 1 && mem_layout_ == mem_layout::row_major) || + (block_size_x_ != 1 && mem_layout_ == mem_layout::col_major))>> { + using dtype = native_type_t; using mem_desc_t = mem_desc_t; using tile_desc = tile_desc_t< @@ -1653,7 +1689,8 @@ struct prefetch_payload_t< static constexpr reg_layout register_layout = tile_desc::register_layout; static constexpr bool reg_transpose = register_layout == reg_layout::transpose_tiled; - static constexpr bool trans = mem_transpose ^ reg_transpose; + static constexpr bool trans = (mem_transpose ^ reg_transpose) && + !(std::is_same_v || std::is_same_v); using prefetch_dtype = typename std::conditional< (alignment_in_bytes % (sizeof(uint64_t)) == 0), @@ -1665,12 +1702,22 @@ struct prefetch_payload_t< static constexpr uint32_t pack_factor = sizeof(prefetch_dtype) / sizeof(dtype); + static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype); + static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype); + static constexpr uint32_t simd_channel = + ((tile_bytes % max_store_bytes) == 0 && + (block_bytes % max_store_bytes) == 0) + ? 32 + : 16; + static constexpr uint32_t num_channel = mem_transpose + ? (simd_channel >= block_size_x) ? block_size_x : simd_channel + : (simd_channel >= block_size_y) ? block_size_y + : simd_channel; + static constexpr uint32_t simd_exec_size = (mem_transpose ? block_size_y : block_size_x) >= pack_factor ? (mem_transpose ? block_size_y : block_size_x) / pack_factor : 1; - static constexpr uint32_t num_channel = - mem_transpose ? block_size_x : block_size_y; static constexpr uint32_t mem_tile_size_w = mem_transpose ? tile_size_y : tile_size_x; @@ -1725,9 +1772,6 @@ struct prefetch_payload_t< this->width_in_elems = rhs.width_in_elems; this->height_in_elems = rhs.height_in_elems; - this->step_x = rhs.step_x; - this->step_y = rhs.step_y; - this->channel_offset = rhs.channel_offset; } @@ -1742,8 +1786,6 @@ struct prefetch_payload_t< this->width_in_elems = rhs.width_in_elems; this->height_in_elems = rhs.height_in_elems; - this->step_x = rhs.step_x; - this->step_y = rhs.step_y; this->channel_offset = rhs.channel_offset; return *this; } @@ -1837,7 +1879,8 @@ struct prefetch_payload_t< arch_tag_, std::enable_if_t< (arch_tag_ == gpu_arch::XeHpc) && - (tile_size_y_ != 1 || block_size_y_ != 1)>> { + (((block_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || + ((block_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; @@ -2108,7 +2151,9 @@ struct prefetch_payload_t< template < typename dtype_, uint32_t tile_size_x_, + uint32_t tile_size_y_, uint32_t block_size_x_, + uint32_t block_size_y_, mem_layout mem_layout_, uint32_t alignment_, uint32_t num_coop_sg_, @@ -2116,15 +2161,23 @@ template < gpu_arch arch_tag_> struct prefetch_payload_t< mem_desc_t, - tile_desc_t, + tile_desc_t< + tile_size_x_, + tile_size_y_, + block_size_x_, + block_size_y_, + reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_tag_ <= gpu_arch::XeHpc)>> { + std::enable_if_t< + ((block_size_y_ == 1) && mem_layout_ == mem_layout::row_major) || + ((block_size_x_ == 1) && mem_layout_ == mem_layout::col_major)>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; // CL aligned, so we can use uint64_t using prefetch_dtype = uint64_t; + static constexpr msg_type message_type = msg_type::block_1d; using tile_desc = tile_desc_t; static constexpr mem_space memory_space = mem_space::global; static constexpr mem_layout memory_layout = mem_layout_; diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index 0f9e2d3bb..d2eedfc77 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -25,23 +25,29 @@ namespace gpu::xetla::subgroup { namespace detail { + +template +struct check_prefetch_type; +template +struct check_prefetch_type< + payload_t, + std::enable_if_t> { + static constexpr bool is_global_2d = false; + static constexpr bool is_global_block_1d = false; + static constexpr bool is_global_unaligned_2d = false; + static constexpr bool is_local = true; +}; template -struct check_prefetch_type { +struct check_prefetch_type< + payload_t, + std::enable_if_t> { static constexpr bool is_global_2d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y != 1)); - + payload_t::message_type == msg_type::block_2d; static constexpr bool is_global_block_1d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y == 1)); - + payload_t::message_type == msg_type::block_1d; static constexpr bool is_global_unaligned_2d = - ((payload_t::memory_space == mem_space::global) && - (payload_t::tile_desc::tile_size_y != 1) && - (payload_t::message_type == msg_type::unaligned_2d)); - - static constexpr bool is_local = - (payload_t::memory_space == mem_space::local); + payload_t::message_type == msg_type::unaligned_2d; + static constexpr bool is_local = false; }; } // namespace detail @@ -104,26 +110,25 @@ tile_prefetch(payload_t& payload) { #pragma unroll for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { uint32_t offset_x = j * tile_desc::block_size_x; -#pragma unroll - for (uint32_t sub_block_y = 0; sub_block_y < tile_desc::block_size_y; - sub_block_y += num_channel) { - uint32_t address_offset = payload_t::mem_transpose - ? offset_x * payload.pitch_in_bytes + - (offset_y + sub_block_y) * sizeof(dtype) - : offset_x * sizeof(dtype) + - (offset_y + sub_block_y) * payload.pitch_in_bytes; + // #pragma unroll + // for (uint32_t sub_block_y = 0; sub_block_y < + // tile_desc::block_size_y; + // sub_block_y += num_channel) { + uint32_t address_offset = payload_t::mem_transpose + ? offset_x * payload.pitch_in_bytes + (offset_y + 0) * sizeof(dtype) + : offset_x * sizeof(dtype) + (offset_y + 0) * payload.pitch_in_bytes; - xetla_prefetch_global< - prefetch_dtype, - payload_t::simd_exec_size, - data_size::default_size, - L1, - L2, - payload_t::num_channel>( - payload.base_ptr, - payload.channel_offset + payload.base_offset + address_offset, - 1); - } + xetla_prefetch_global< + prefetch_dtype, + payload_t::simd_exec_size, + data_size::default_size, + L1, + L2, + num_channel>( + payload.base_ptr, + payload.channel_offset + payload.base_offset + address_offset, + 1); + // } } } } @@ -150,7 +155,9 @@ tile_prefetch(payload_t& payload) { using prefetch_dtype = typename payload_t::prefetch_dtype; constexpr uint32_t prefetch_len = tile_desc::tile_size_x / payload_t::scale_factor; - constexpr uint32_t max_prefetch_in_bytes = load_store_attr_t::max_prefetch_vec_len; + constexpr uint32_t max_prefetch_in_bytes = + load_store_attr_t:: + max_prefetch_vec_len; if constexpr (prefetch_len >= max_prefetch_in_bytes) { #pragma unroll for (uint32_t j = 0; j < prefetch_len / max_prefetch_in_bytes; j++) { @@ -165,10 +172,11 @@ tile_prefetch(payload_t& payload) { } } constexpr uint32_t tail_len = prefetch_len % max_prefetch_in_bytes; - uint32_t tail_offset = - prefetch_len / max_prefetch_in_bytes * max_prefetch_in_bytes * payload_t::scale_factor; - detail::process_1d_tail( - payload, tail_offset); + uint32_t tail_offset = prefetch_len / max_prefetch_in_bytes * + max_prefetch_in_bytes * payload_t::scale_factor; + detail:: + process_1d_tail( + payload, tail_offset); } /// @brief Is prefetch data func. @@ -183,8 +191,8 @@ template < cache_hint L1 = cache_hint::cached, cache_hint L2 = cache_hint::cached, typename payload_t> -__XETLA_API typename std::enable_if_t< - detail::check_prefetch_type::is_local> -tile_prefetch([[maybe_unused]] payload_t& payload) {} +__XETLA_API + typename std::enable_if_t::is_local> + tile_prefetch([[maybe_unused]] payload_t& payload) {} } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index 56196da6d..d5345b06b 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -34,9 +34,8 @@ struct check_store_type { (payload_t::message_type == msg_type::block_2d)); static constexpr bool is_global_block_1d_xe = - ((payload_t::memory_space == mem_space::global) && - (tile_t::tile_size_y == 1) && (tile_t::block_size_y == 1) && - (payload_t::message_type == msg_type::block_1d)); + (payload_t::memory_space == mem_space::global && + payload_t::message_type == msg_type::block_1d); static constexpr bool is_global_unaligned_2d_xe = (payload_t::memory_space == mem_space::global && @@ -280,42 +279,42 @@ template < __XETLA_API typename std::enable_if_t< detail::check_store_type::is_global_block_1d_xe> tile_store(tile_t& tile, payload_t& payload) { - using dtype = typename tile_t::dtype; - using store_dtype = typename payload_t::mem_dtype; - - static constexpr uint32_t tile_size_x = tile_t::tile_size_x; - static constexpr uint32_t scale_factor = payload_t::scale_factor; - - static constexpr uint32_t store_len = tile_size_x / scale_factor; - + using dtype = typename payload_t::dtype; + static constexpr uint32_t store_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; - using load_store_attr = load_store_attr_t; - static constexpr uint32_t max_store_vec_len = - load_store_attr::max_store_vec_len; + if (payload.base_offset + store_len * sizeof(dtype) <= + payload.payload_bytes) { + using load_store_attr = load_store_attr_t; + static constexpr uint32_t max_store_vec_len = + load_store_attr::max_store_vec_len; + static constexpr uint32_t max_store_vec_elems = + max_store_vec_len / sizeof(dtype); + static constexpr uint32_t store_iter_steps = + store_len / max_store_vec_elems; - static constexpr uint32_t store_iter_steps = store_len / max_store_vec_len; - if constexpr (store_len >= max_store_vec_len) { + if constexpr (store_len >= max_store_vec_elems) { #pragma unroll - for (uint32_t i = 0; i < store_iter_steps; i++) { - uint32_t offset_x = i * max_store_vec_len * scale_factor; - auto reg_sub = - tile.reg.xetla_select(offset_x); - uint32_t address_offset = offset_x * sizeof(dtype); + for (uint32_t i = 0; i < store_iter_steps; i++) { + uint32_t offset = i * max_store_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset); + uint32_t address_offset = offset * sizeof(dtype); - xetla_store_global( - payload.base_ptr, - payload.base_offset + address_offset, - reg_sub.xetla_format()); + xetla_store_global( + payload.base_ptr, + payload.base_offset + address_offset, + reg_sub.xetla_format()); + } } + constexpr uint32_t tail_len = + store_len % max_store_vec_elems * sizeof(dtype); + uint32_t tail_offset = store_iter_steps * max_store_vec_len; + detail::process_1d_tail< + tail_len, + (max_store_vec_len >> 1), + detail::process_flag::store, + L1, + L2>(tile, payload, tail_offset); } - constexpr uint32_t tail_len = store_len % max_store_vec_len; - uint32_t tail_offset = store_iter_steps * max_store_vec_len * scale_factor; - detail::process_1d_tail< - tail_len, - (max_store_vec_len >> 1), - detail::process_flag::store, - L1, - L2>(tile, payload, tail_offset); } /// @brief Is the func storing data from register file to unaligned global @@ -353,7 +352,6 @@ tile_store( constexpr uint32_t num_channel_y = payload_t::num_channel_y; constexpr uint32_t store_elems = num_channel_y * payload_t::num_channel_x; constexpr uint32_t scale_factor = payload_t::scale_factor; - #pragma unroll for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; i++) { @@ -515,7 +513,6 @@ tile_store(tile_t& tile, payload_t& payload) { xetla_mask pred_y = channel_index < payload.height_in_elems; - xetla_store_global< store_dtype, payload_t::simd_exec_size, @@ -1003,40 +1000,36 @@ __XETLA_API typename std::enable_if_t< tile_t::tile_size_y == 1 && tile_t::block_size_y == 1> tile_store(tile_t& tile, payload_t& payload) { using dtype = typename tile_t::dtype; - using tile_desc = typename payload_t::tile_desc; - using store_dtype = typename payload_t::mem_dtype; - - constexpr uint32_t scale_factor = payload_t::scale_factor; - static constexpr uint32_t store_len = tile_desc::tile_size_x / scale_factor; - + static constexpr uint32_t store_len = tile_t::tile_elems; static constexpr gpu_arch arch_tag = payload_t::arch_tag; + using load_store_attr = load_store_attr_t; static constexpr uint32_t max_store_vec_len = load_store_attr::max_store_vec_len; + static constexpr uint32_t max_store_vec_elems = + max_store_vec_len / sizeof(dtype); - static constexpr uint32_t store_iter_steps = store_len / max_store_vec_len; + static constexpr uint32_t store_iter_steps = store_len / max_store_vec_elems; - if constexpr (store_len >= max_store_vec_len) { + if constexpr (store_len >= max_store_vec_elems) { #pragma unroll - for (uint32_t j = 0; j < store_iter_steps; j++) { - uint32_t offset_x = j * max_store_vec_len * scale_factor; - auto reg_sub = tile.reg.xetla_select<64 * scale_factor, 1>(offset_x); + for (uint32_t i = 0; i < store_iter_steps; i++) { + uint32_t offset_x = i * max_store_vec_elems; + auto reg_sub = tile.reg.xetla_select(offset_x); uint32_t address_offset = offset_x * sizeof(dtype); - xetla_store_local( - payload.address + address_offset, - reg_sub.xetla_format()); + xetla_store_local( + payload.base_address + payload.address + address_offset, + reg_sub.xetla_format()); } } + constexpr uint32_t tail_len = store_len % max_store_vec_elems * sizeof(dtype); + uint32_t tail_offset = store_iter_steps * max_store_vec_len; detail::process_1d_tail< - store_len % max_store_vec_len, + tail_len, (max_store_vec_len >> 1), detail::process_flag::store, L1, - L2>( - tile, - payload, - store_iter_steps * max_store_vec_len * scale_factor, - store_iter_steps * max_store_vec_len * scale_factor * sizeof(dtype)); + L2>(tile, payload, tail_offset); } } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index 644717df8..060302448 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -51,6 +51,124 @@ struct none_op_t { } }; +template < + typename matB_acc_t, + typename matB_t, + typename scale_t, + typename zero_pt_t, + uint32_t dequant_s, + quant_mode quant_mode> +struct dequant_int4_weight_t { + struct arguments_t { + uint32_t wg_start_m; + uint32_t wg_start_n; + uint32_t wg_start_k; + inline arguments_t() = default; + inline arguments_t( + uint32_t wg_start_m_, + uint32_t wg_start_n_, + uint32_t wg_start_k_) + : wg_start_m(wg_start_m_), + wg_start_n(wg_start_n_), + wg_start_k(wg_start_k_) {} + }; + __XETLA_API KERNEL_FUNC void operator()( + matB_acc_t& matB_acc, + matB_t& matB, + scale_t& scale, + zero_pt_t& zero_pt, + // [[maybe_unused]] const coord_t& coord, + [[maybe_unused]] const arguments_t& args, + [[maybe_unused]] uint32_t slm_base = 0, + [[maybe_unused]] uint32_t nbarrier_base = 0) { + // no tail, because this is matB + constexpr uint32_t tile_size_x_b = matB_acc_t::tile_size_x; + constexpr uint32_t tile_size_y_b = matB_acc_t::tile_size_y; + constexpr uint32_t block_size_x_b = matB_acc_t::block_size_x; + constexpr uint32_t block_size_y_b = matB_acc_t::block_size_y; + static constexpr uint32_t pack_ratio = sizeof(typename matB_t::dtype) * 2; + + constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; + constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; +#pragma unroll + for (uint32_t i = 0; i < num_block_y; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_block_x; ++j) { + int block_id = (i * num_block_x + j); + // Must be little-endian + auto matB_blk = matB.reg.xetla_format() + .xetla_select( + block_id * matB_acc_t::block_elems / 2); + + auto dst_blk = matB_acc.reg.xetla_select( + block_id * matB_acc_t::block_elems); + + // int8 includes 2 4bits data. + xetla_vector cvt_blk_i8; + + // lowest 4 bit + { + cvt_blk_i8.xetla_select(0) = + matB_blk & 0xf; + } + // highest 4 bit + { + cvt_blk_i8.xetla_select(1) = + matB_blk >> 4; + } + + // (b_i8 - zero_pt_i8) x scale = fp16 + constexpr uint32_t step = std::min(block_size_y_b, dequant_s); +#pragma unroll + for (uint32_t jj = 0; jj < block_size_x_b; jj++) { +#pragma unroll + for (uint32_t ii = 0; ii < block_size_y_b; ii += step) { + uint32_t offset_y_in_tile = i * block_size_y_b + ii; + uint32_t offset_x_in_tile = j * block_size_x_b + jj; + + uint32_t scale_idx = + (offset_y_in_tile) / dequant_s * scale_t::block_size_x + + offset_x_in_tile; + + if constexpr (quant_mode == quant_mode::S4_ASYM) { + uint32_t zero_pt_idx = + offset_y_in_tile / dequant_s * zero_pt_t::block_size_x + + offset_x_in_tile / pack_ratio; + native_type_t zero_pt_pack = + zero_pt.reg[zero_pt_idx]; + + int8_t zero_pt_i8 = + (zero_pt_pack >> + (4 * ((args.wg_start_n + offset_x_in_tile) % pack_ratio))) & + 0xf; + // sycl::ext::oneapi::experimental::printf( + // "zero_pt.reg[%d} %x zero_pt_i8 %x offset_x_in_tile:%d + // \n", zero_pt_idx, zero_pt_pack, (int32_t)zero_pt_i8 , + // offset_x_in_tile); + + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - + zero_pt_i8; + } else if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - + int8_t(8); + } + dst_blk.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * + scale.reg[scale_idx]; + + // sycl::ext::oneapi::experimental::printf( + // "scale[%d] %f \n", + // scale_idx, + // float(sycl::half(scale.reg.xetla_select<1, 1>(scale_idx)))); + } + } + } + } + } +}; + /// @brief Is the element-wise relu op functor. /// Get the relu input from matAcc, update the relu output in place, /// Used in epilogue::tile_op or chained_tile_op. @@ -487,7 +605,7 @@ struct bias_add_op_t< using bias_payload_t = mem_payload_t< mem_desc_bias_t, bias_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; coord_t bias_coord(coord.x, 0); mem_desc_bias_t mem_desc_bias(args.base, args.shape, bias_coord); @@ -609,7 +727,7 @@ struct scale_v_offset_v_op_t< using scale_payload_t = mem_payload_t< scale_mem_desc_t, scale_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; coord_t scale_coord(coord.x, 0); scale_mem_desc_t scale_mem_desc( @@ -625,7 +743,7 @@ struct scale_v_offset_v_op_t< using offset_payload_t = mem_payload_t< offset_mem_desc_t, offset_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; coord_t offset_coord(coord.x, 0); offset_mem_desc_t offset_mem_desc( @@ -729,7 +847,7 @@ struct scale_v_op_t< using scale_payload_t = mem_payload_t< scale_mem_desc_t, scale_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; coord_t scale_coord(coord.x, 0); scale_mem_desc_t scale_mem_desc( @@ -839,7 +957,7 @@ struct elemwise_reduce_op_t< using mat_in_payload_t = mem_payload_t< mem_desc_in_t, mat_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; using mat_in_tile_acc_t = tile_t; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); @@ -880,7 +998,7 @@ struct elemwise_reduce_op_t< using mat_tail_in_payload_t = mem_payload_t< mem_desc_in_t, mat_tail_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; using mat_tail_in_tile_acc_t = tile_t; mat_tail_in_tile_t mat_tail_in; @@ -967,7 +1085,7 @@ struct elemwise_reduce_op_stream_k_t< using mat_in_payload_t = mem_payload_t< mem_desc_in_t, mat_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); mat_in_tile_t mat_in; @@ -1082,7 +1200,7 @@ struct dropout_op_t< using mask_in_payload_t = mem_payload_t< mem_desc_mask_t, mask_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; mem_desc_mask_t mem_desc_mask(args.base, args.shape, coord); mask_in_tile_t mask_in; @@ -1180,7 +1298,7 @@ struct rng_dropout_op_t< using mask_out_payload_t = mem_payload_t< mem_desc_mask_t, mask_out_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; if (args.prob == 0) { return; @@ -1327,7 +1445,7 @@ struct linear_op_t< using mat_in_payload_t = mem_payload_t< mem_desc_in_t, mat_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; using mat_in_tile_acc_t = tile_t; mem_desc_in_t mem_desc_in(args.base, args.shape, coord); @@ -1371,7 +1489,7 @@ struct linear_op_t< using mat_tail_in_payload_t = mem_payload_t< mem_desc_in_t, mat_tail_in_tile_desc_t, - msg_type_v, + msg_type_v, arch_tag>; using mat_tail_in_tile_acc_t = tile_t; diff --git a/tests/integration/CMakeLists.txt b/tests/integration/CMakeLists.txt index 8e0e8e64b..fcb4a9fc2 100644 --- a/tests/integration/CMakeLists.txt +++ b/tests/integration/CMakeLists.txt @@ -21,6 +21,7 @@ endfunction() # add_subdirectory(vector_add) add_subdirectory(gemm) +add_subdirectory(gemv) add_subdirectory(row_reduction) add_subdirectory(layer_norm) add_subdirectory(data_transformer) diff --git a/tests/integration/fmha/fmha_forward.hpp b/tests/integration/fmha/fmha_forward.hpp index 6231c6e58..623ea6d7f 100644 --- a/tests/integration/fmha/fmha_forward.hpp +++ b/tests/integration/fmha/fmha_forward.hpp @@ -620,8 +620,12 @@ class fmha_forward_t { mem_desc_Dp_Mask_t::layout, mem_desc_Dp_Mask_t::space>, dp_mask_tile_desc_t, - subgroup:: - msg_type_v, + subgroup::msg_type_v< + dp_mask_tile_desc_t, + mem_desc_t< + uint8_t, + mem_desc_Dp_Mask_t::layout, + mem_desc_Dp_Mask_t::space>>, gpu_arch::XeHpc>; load_payload_mask_t load_payload_mask(ctx.mem_desc_Dpij); subgroup::tile_load(mask_in, load_payload_mask); @@ -722,7 +726,12 @@ class fmha_forward_t { using matOi_store_t = subgroup::mem_payload_t< mem_desc_t, matOi_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + matOi_tile_desc_t, + mem_desc_t< + scalar_t, + mem_desc_Oi_t::layout, + mem_desc_Oi_t::space>>, arch_tag>; matOi_store_t matOi_store(mem_desc_Oi); subgroup::tile_store( @@ -762,12 +771,19 @@ class fmha_forward_t { using matQi_load_t = subgroup::mem_payload_t< mem_desc_t, matQi_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + matQi_tile_desc_t, + mem_desc_t>, arch_tag>; using matQi_store_t = subgroup::mem_payload_t< mem_desc_t, matQi_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v< + matQi_tile_desc_t, + mem_desc_t< + scalar_t, + mem_desc_Qi_L_t::layout, + mem_desc_Qi_L_t::space>>, arch_tag>; int32_t tile_offset_x = ctx.sg_idx * kSgHm; diff --git a/tests/integration/fmha/fmha_utils.h b/tests/integration/fmha/fmha_utils.h index fc1c11909..1aef9a5f4 100644 --- a/tests/integration/fmha/fmha_utils.h +++ b/tests/integration/fmha/fmha_utils.h @@ -156,7 +156,9 @@ struct group_row_reduce_t { using load_payload_t = subgroup::mem_payload_t< mem_desc_t, load_tile_desc, - subgroup::msg_type_v, + subgroup::msg_type_v< + load_tile_desc, + mem_desc_t>, arch_tag>; xetla_nbarrier_t nbarrier; @@ -243,10 +245,12 @@ struct bias_add_op_t { using bias_tile_desc_t = subgroup:: tile_desc_t; using bias_t = subgroup::tile_t; + using mem_desc_bias_t = + mem_desc_t; using bias_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_bias_t, bias_tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, arch_tag>; coord_t bias_coord(coord.x, coord.y); mem_desc_bias_t mem_desc_bias(args.base, args.shape, bias_coord); diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 7e7896f3e..a675ab0b9 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -45,24 +45,24 @@ class TestBase { mem_layout_a_str + "_" + mem_layout_b_str; return name; } - static constexpr mma_engine engine = mma_engine::fpu; - static constexpr gpu_arch gpu_arch = gpu_arch::XeHpg; + static constexpr mma_engine engine = mma_engine::xmx; + static constexpr gpu_arch gpu_arch = gpu_arch::XeHpc; }; class Test0 : public TestBase { public: static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 256; - static constexpr size_t mat_k = 256; - static constexpr size_t wg_m = 1; + static constexpr size_t mat_n = 64; + static constexpr size_t mat_k = 8192; + static constexpr size_t wg_m = 8; static constexpr size_t wg_n = 32; - static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 32; + static constexpr size_t sg_m = 8; + static constexpr size_t sg_n = 16; static constexpr size_t sg_k = 16; static constexpr uint32_t global_kslicing = 1; - static constexpr uint32_t local_kslicing = 1; + static constexpr uint32_t local_kslicing = 8; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = fp16; using data_type_b = fp16; using data_type_c = fp16; diff --git a/tests/integration/gemm/fp16/kernel_func.hpp b/tests/integration/gemm/fp16/kernel_func.hpp index 26a78ff91..191213794 100644 --- a/tests/integration/gemm/fp16/kernel_func.hpp +++ b/tests/integration/gemm/fp16/kernel_func.hpp @@ -41,8 +41,8 @@ template < gpu_arch gpu_arch> struct fp16_gemm_test_func { using tile_shape = tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 0 ; //8; - static constexpr uint32_t prefetch_distance = 0 ;//256 / (sg_k * sizeof(dtype_a)); + static constexpr uint32_t periodic_sync_interval = 1 ; //8; + static constexpr uint32_t prefetch_distance = 3 ;//256 / (sg_k * sizeof(dtype_a)); using compute_attr = typename std::conditional< (engine == mma_engine::fpu), diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index 400d13276..eabc8915d 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -33,7 +33,7 @@ TYPED_TEST_P(fp16_gemm_test, esimd) { } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); using tests = ::testing::Types< - Test4>; + Test0>; // Test1, // Test2, // Test3>; diff --git a/tests/integration/gemm/int4_dequantization/main.cpp b/tests/integration/gemm/int4_dequantization/main.cpp index ab120d187..18e40ded5 100644 --- a/tests/integration/gemm/int4_dequantization/main.cpp +++ b/tests/integration/gemm/int4_dequantization/main.cpp @@ -164,6 +164,8 @@ void dequantize_gemm_run(uint32_t iter) { constexpr size_t matrix_m = Test::mat_m; constexpr size_t matrix_n = Test::mat_n; constexpr size_t matrix_k = Test::mat_k; + + static constexpr mem_layout layout_b = Test::layout_b; constexpr uint32_t global_kslicing = Test::global_kslicing; constexpr uint32_t local_kslicing = Test::local_kslicing; @@ -227,13 +229,15 @@ void dequantize_gemm_run(uint32_t iter) { compute_attr_t; using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; + + static constexpr quant_info quant_info{quant_mode::S4_ASYM, Test::dequant_s, layout_b}; + using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, perf_tuning_knob, data_type_scale, data_type_zero_pt, - gpu::xetla::group::quant_mode::S4_ASYM, - dequant_s, + quant_info, mma_engine::xmx, gpu_arch::XeHpg>; using gemm_t = xetla::group:: @@ -332,7 +336,7 @@ void dequantize_gemm_run(uint32_t iter) { .wait(); // set up gemm arguments - typename gemm_op_t::template arguments_t gemm_arg( + typename gemm_op_t::template arguments_t gemm_arg( matrix_m, matrix_k, matrix_n, diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp index a8e4da602..69fdfc1fe 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_client.cpp @@ -16,7 +16,7 @@ #include #include "xetla.hpp" -#define UT_DEBUG 1 +// #define UT_DEBUG 1 using namespace gpu::xetla; // The number of times the kernel is executed constexpr int ITER = 200; @@ -621,13 +621,15 @@ void dequantize_gemm_run(int iter) { using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; + static constexpr quant_info quant_info{ + quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b}; + using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, perf_tuning_knob, data_type_scale, data_type_zero_pt, - gpu::xetla::group::quant_mode::S4_FULLRANGE_NO_ZP, - dequant_s, + quant_info, Test::mma_eng, Test::arch>; @@ -794,7 +796,7 @@ void dequantize_gemm_run(int iter) { } queue.memcpy((void*)gidx_d, (void*)gidx_h, size_gidx * sizeof(uint32_t)) .wait(); - typename gemm_op_t::template arguments_t + typename gemm_op_t::template arguments_t gemm_arg( matrix_m, matrix_k, @@ -901,7 +903,7 @@ void dequantize_gemm_run(int iter) { free(A_d_shuf, context); } if constexpr (Feature::feature == optional_feature::NONE) { - typename gemm_op_t::template arguments_t + typename gemm_op_t::template arguments_t gemm_arg( matrix_m, matrix_k, @@ -1041,4 +1043,4 @@ REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_act_shuf_test, esimd); INSTANTIATE_TYPED_TEST_SUITE_P( dequantize_gemm_act_shuf_test_suite, dequantize_gemm_act_shuf_test, - tests); + tests); \ No newline at end of file diff --git a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp index b64928d8a..1c42454df 100644 --- a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp +++ b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp @@ -318,7 +318,7 @@ void dequantize_gemm_run(int iter) { constexpr size_t matrix_k = Test::mat_k; constexpr uint32_t global_kslicing = Test::global_kslicing; constexpr uint32_t local_kslicing = Test::local_kslicing; - + static constexpr mem_layout layout_b = Test::layout_b; constexpr size_t wg_tile_m = Test::wg_m; constexpr size_t wg_tile_n = Test::wg_n; constexpr size_t sg_tile_m = Test::sg_m; @@ -368,7 +368,7 @@ void dequantize_gemm_run(int iter) { DEVICE_MEM_ALIGNMENT / sizeof(data_type_a)>; using mem_desc_b_t = xetla::mem_desc_t< data_type_b, - mem_layout::row_major, + layout_b, mem_space::global, DEVICE_MEM_ALIGNMENT / sizeof(data_type_b)>; using mem_desc_c_t = xetla::mem_desc_t< @@ -387,14 +387,15 @@ void dequantize_gemm_run(int iter) { compute_attr_t; using perf_tuning_knob = xetla::group:: perf_tuning_knob_t; + static constexpr quant_info quant_info{ + quant_mode::S4_FULLRANGE_NO_ZP, Test::dequant_s, layout_b}; using compute_policy = xetla::group::compute_policy_int4_dequantize< compute_attr, perf_tuning_knob, data_type_scale, data_type_zero_pt, - gpu::xetla::group::quant_mode::S4_FULLRANGE_NO_ZP, - dequant_s, + quant_info, mma_engine::xmx, gpu_arch::XeHpc>; @@ -531,7 +532,7 @@ void dequantize_gemm_run(int iter) { // It accepts the base pointer to matrix D, and its dimensions {bias_d, bias_add_shape}}); - typename gemm_op_t::template arguments_t gemm_arg( + typename gemm_op_t::template arguments_t gemm_arg( matrix_m, matrix_k, matrix_n, diff --git a/tests/integration/gemv/CMakeLists.txt b/tests/integration/gemv/CMakeLists.txt new file mode 100644 index 000000000..8620dac8b --- /dev/null +++ b/tests/integration/gemv/CMakeLists.txt @@ -0,0 +1,3 @@ +include_directories(${CMAKE_SOURCE_DIR}/tests/integration/gemv) + +add_subdirectory(int4) diff --git a/tests/integration/gemv/int4/CMakeLists.txt b/tests/integration/gemv/int4/CMakeLists.txt new file mode 100644 index 000000000..f38ac28eb --- /dev/null +++ b/tests/integration/gemv/int4/CMakeLists.txt @@ -0,0 +1,9 @@ +get_filename_component(ProjectId ${CMAKE_CURRENT_SOURCE_DIR} NAME) +string(REPLACE " " "_" ProjectId ${ProjectId}) +set(ProjectIdClient ${ProjectId}) +set(ProjectIdXe ${ProjectId}) +string(PREPEND ProjectIdClient "gemv_") + +FILE(GLOB src main.cpp) +add_integration_test(${ProjectIdClient} ${src}) + diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp new file mode 100644 index 000000000..622d6cf1b --- /dev/null +++ b/tests/integration/gemv/int4/main.cpp @@ -0,0 +1,602 @@ +/******************************************************************************* + * Copyright (c) 2022-2023 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include +#include "xetla.hpp" +// #define UT_DEBUG +using namespace gpu::xetla; +using namespace gpu::xetla::group; +// The number of times the kernel is executed +#ifdef UT_DEBUG +constexpr int ITER = 1; +#else +constexpr int ITER = 200; +#endif +constexpr size_t UNDEFINED_DATA_SIZE = 1024; + +class test_col_major_1 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 1; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 1; + static constexpr size_t wg_n = 1; + static constexpr size_t sg_m = 1; + static constexpr size_t sg_n = 1; + static constexpr size_t sg_k = 1024 / 1; + static constexpr size_t dequant_s = 128; + // static constexpr quant_mode quant_mode = quant_mode::S4_ASYM; + static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP; + + static constexpr size_t local_kslicing = 1; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeHpc; + using data_type_a = fp16; + using data_type_b = int4x8; + using data_type_c = fp16; +}; +class test_col_major_2 { + public: + // Extract the parameters required by different test cases + static constexpr size_t mat_m = 4; + static constexpr size_t mat_n = 4096; + static constexpr size_t mat_k = 4096; + static constexpr size_t wg_m = 4; + static constexpr size_t wg_n = 1; + static constexpr size_t sg_m = 4; + static constexpr size_t sg_n = 1; + static constexpr size_t sg_k = 1024; + static constexpr size_t dequant_s = 4096; + + static constexpr size_t local_kslicing = 1; + static constexpr size_t global_kslicing = 1; + static constexpr mem_layout layout_a = mem_layout::row_major; + static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mma_engine mma_eng = mma_engine::fpu; + static constexpr gpu_arch arch = gpu_arch::XeLpg; + using data_type_a = fp16; + using data_type_b = int4x8; + using data_type_c = fp16; +}; + +template < + typename data_type_a, + typename data_type_b, + typename data_type_c, + typename data_type_bias, + typename data_type_acc = float> +int gemm_result_validate( + data_type_a* A, + data_type_b* B, + data_type_c* C, + data_type_bias* bias, + uint32_t m, + uint32_t k, + uint32_t n, + mem_layout mem_layout_a_ = mem_layout::row_major, + mem_layout mem_layout_b_ = mem_layout::row_major) { + buff_cmp::buff_vals data(C, m, n, n); + std::vector gold_C(m * n, 0); + get_gemm_gold( + m, n, k, mem_layout_a_, mem_layout_b_, A, B, gold_C.data()); + + // BiasAdd + for (uint32_t i = 0; i < gold_C.size(); ++i) { + uint32_t col = i % n; + gold_C[i] += bias[col]; + } + + buff_cmp::buff_vals other(gold_C.data(), m, n, n); + + bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation"); + +#ifdef UT_DEBUG + // for (uint32_t i = 0; i < m; i++) { + // for (uint32_t j = 0; j < n; j++) { + // std::cout << float(sycl::half(C[i * n + j])) << " "; + // } + // std::cout << std::endl; + // } +#endif + std::cout << (!result ? "FAILED\n" : "PASSED\n"); + return result ? 0 : 1; +} + +template < + quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, + typename data_type_acc_in = fp16, + typename data_type_b, + typename data_type_scale, + typename data_type_zero_pt> +std::vector convert_int4( + data_type_b data_b, + data_type_scale scale, + data_type_zero_pt zero_pt) { + std::vector dequant_fp16(sizeof(data_type_b) * 2); + + int8_t zero_pt_i8 = zero_pt & 0xf; + for (uint32_t i = 0; i < dequant_fp16.size(); i++) { + int8_t dequant_8bit = data_b & 0xf; + if constexpr (quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + dequant_fp16[i] = scale * (dequant_8bit - 8); + } else { + dequant_fp16[i] = scale * (dequant_8bit - zero_pt_i8); + } + data_b = data_b >> 4; + } + return dequant_fp16; +} + +template < + size_t dequant_s, + mem_layout layout_b = mem_layout::col_major, + quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP, + typename data_type_acc_in = fp16, + typename data_type_b, + typename data_type_scale, + typename data_type_zero_pt> +std::vector dequantize_weight( + size_t matrix_k, + size_t matrix_n, + data_type_b* b, + data_type_scale* scale, + data_type_zero_pt* zero_pt) { + std::vector b_out(matrix_k * matrix_n, 0); + constexpr size_t pack_radio = 2 * sizeof(data_type_b); + size_t width = layout_b == mem_layout::row_major ? matrix_n / pack_radio + : matrix_k / pack_radio; + size_t height = layout_b == mem_layout::row_major ? matrix_k : matrix_n; + size_t step = layout_b == mem_layout::row_major ? 1 : dequant_s / pack_radio; + + for (uint32_t i = 0; i < height; i++) { + for (uint32_t j = 0; j < width; j += step) { + int start_b_in = i * width + j; + int start_scale_in = start_b_in / step; + int start_zero_pt_in = + (j / step) * (matrix_n / pack_radio) + i / pack_radio; + int start_out = + layout_b == mem_layout::row_major ? 0 : i * matrix_k + j * pack_radio; + for (uint32_t jj = 0; jj < step; jj++) { + std::vector dequant_fp16 = convert_int4( + b[start_b_in + jj], + scale[start_scale_in], + zero_pt[start_zero_pt_in] >> (4 * (i % pack_radio))); + for (uint32_t jjj = 0; jjj < dequant_fp16.size(); jjj++) { + b_out[start_out + pack_radio * jj + jjj] = dequant_fp16[jjj]; + } + } + } + } +#ifdef UT_DEBUG + // for (uint32_t i = 0; i < matrix_n; i++) { + // for (uint32_t j = 0; j < matrix_k; j++) { + // std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; + // } + // std::cout << std::endl; + // } +#endif + return b_out; +} + +template +void dequantize_gemv_run(int iter) { + using namespace gpu; + // Accept incoming parameters + constexpr size_t matrix_m = Test::mat_m; + constexpr size_t matrix_n = Test::mat_n; + constexpr size_t matrix_k = Test::mat_k; + constexpr uint32_t global_kslicing = Test::global_kslicing; + constexpr uint32_t local_kslicing = Test::local_kslicing; + + constexpr size_t wg_tile_m = Test::wg_m; + constexpr size_t wg_tile_n = Test::wg_n; + constexpr size_t sg_tile_m = Test::sg_m; + constexpr size_t sg_tile_n = Test::sg_n; + constexpr size_t sg_tile_k = Test::sg_k; + constexpr size_t dequant_s = std::min(Test::dequant_s, matrix_k); + constexpr quant_mode quant_mode = Test::quant_mode; + using data_type_a = typename Test::data_type_a; + using data_type_b = typename Test::data_type_b; + using data_type_c = typename Test::data_type_c; + using data_type_zero_pt = data_type_b; + using data_type_scale = fp16; + using data_type_acc_in = fp16; + using data_type_acc = float; + using data_type_bias = data_type_a; + + constexpr mem_layout layout_a = Test::layout_a; + constexpr mem_layout layout_b = Test::layout_b; + + constexpr size_t size_a = matrix_m * matrix_k; + constexpr size_t size_b = matrix_k * matrix_n / (2 * sizeof(data_type_b)); + + constexpr size_t size_scale_k = matrix_k / dequant_s; + constexpr size_t size_scale_n = matrix_n; + constexpr size_t size_scale = size_scale_k * size_scale_n; + + constexpr size_t size_zero_pt_k = matrix_k / dequant_s; + constexpr size_t size_zero_pt_n = matrix_n; + constexpr size_t size_zero_pt = + size_zero_pt_k * size_zero_pt_n / (2 * sizeof(data_type_b)); + + constexpr size_t size_c = matrix_m * matrix_n; + constexpr size_t size_bias = matrix_n; + + uint32_t lda = layout_a == mem_layout::row_major ? matrix_k : matrix_m; + uint32_t ldb = layout_b == mem_layout::row_major ? matrix_n : matrix_k; + uint32_t ldc = matrix_n; + uint32_t ld_scale = + layout_b == mem_layout::row_major ? size_scale_n : size_scale_k; + uint32_t ld_zero_pt = size_zero_pt_n; + + // Turn on the enable_profiling property to facilitate subsequent profiling + sycl::property_list properties{ + sycl::property::queue::enable_profiling(), + sycl::property::queue::in_order()}; + auto queue = sycl::queue(properties); + auto context = queue.get_info(); + auto device = queue.get_info(); + + std::cout << "Running on " << device.get_info() << "\n"; + + using tile_shape = + xetla::group::tile_shape_t; + static constexpr uint32_t periodic_sync_interval = 0; + static constexpr uint32_t prefetch_distance = 0; + + using mem_desc_a_t = xetla::mem_desc_t< + data_type_a, + layout_a, + mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(data_type_a)>; + using mem_desc_b_t = xetla::mem_desc_t< + data_type_b, + layout_b, + mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(data_type_b)>; + using mem_desc_c_t = xetla::mem_desc_t< + data_type_c, + mem_layout::row_major, + mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(data_type_c)>; + + using mem_desc_bias_t = xetla::mem_desc_t< + data_type_bias, + mem_layout::row_major, + mem_space::global, + DEVICE_MEM_ALIGNMENT / sizeof(data_type_bias)>; + + using compute_attr = xetla::group:: + compute_attr_t; + using perf_tuning_knob = xetla::group:: + perf_tuning_knob_t; + static constexpr quant_info quant_info{quant_mode, Test::dequant_s, layout_b}; + using compute_policy = xetla::group::compute_policy_int4_dequantize< + compute_attr, + perf_tuning_knob, + data_type_scale, + data_type_zero_pt, + quant_info, + Test::mma_eng, + Test::arch>; + + using gemm_t = xetla::group:: + gemm_t; + + using bias_op_t = + gpu::xetla::subgroup::bias_add_op_t; + + using tile_op_t = gpu::xetla::subgroup::chained_tile_op_t; + + using epilogue_t = xetla::group::epilogue_t< + xetla::group::epilogue_policy_tile_op, + tile_shape, + mem_desc_c_t>; + + using group_swizzle = xetla::kernel::group_swizzle_default; + + using gemm_op_t = xetla::kernel::gemm_universal_t< + gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing< + group_swizzle, + global_kslicing, + local_kslicing>, + gemm_t, + epilogue_t>; + + size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n); + size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n); + + // Define and initialize the data required for the calculation + auto* A_h = static_cast( + malloc_host(size_a * sizeof(data_type_a), context)); + auto* B_h = static_cast(malloc_host( + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b), context)); + auto* C_h = static_cast( + malloc_host(size_c * sizeof(data_type_c), context)); + auto* Acc_h = static_cast( + malloc_host(size_acc * sizeof(data_type_acc), context)); + auto* Cnt_h = + static_cast(malloc_host(size_cnt * sizeof(uint32_t), context)); + auto* scale_h = static_cast(malloc_host( + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale), context)); + auto* zero_pt_h = static_cast(malloc_host( + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt), + context)); + auto* bias_h = static_cast( + malloc_host(size_bias * sizeof(data_type_bias), context)); + + auto* A_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_a * sizeof(data_type_a), device, context)); + auto* B_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b), + device, + context)); + auto* C_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_c * sizeof(data_type_c), device, context)); + auto* Acc_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_acc * sizeof(data_type_acc), device, context)); + auto* Cnt_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, size_cnt * sizeof(uint32_t), device, context)); + auto* scale_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale), + device, + context)); + auto* zero_pt_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt), + device, + context)); + auto* bias_d = static_cast(aligned_alloc_device( + DEVICE_MEM_ALIGNMENT, + size_bias * sizeof(data_type_bias), + device, + context)); + + for (unsigned i = 0; i < size_a; ++i) { + A_h[i] = random_float(); +#ifdef UT_DEBUG + A_h[i] = 1; + // A_h[i] = layout_a == mem_layout::row_major + // ? (i % matrix_k + i / matrix_k * 100) + // : (i % matrix_m + i / matrix_m * 100); +#endif + } + + for (unsigned i = 0; i < size_b + UNDEFINED_DATA_SIZE; ++i) { + if constexpr (std::is_same_v) { + B_h[i] = random_uint8(); +#ifdef UT_DEBUG + B_h[i] = 0x77; +#endif + } else if constexpr (std::is_same_v) { + B_h[i] = random_uint32(); +#ifdef UT_DEBUG + B_h[i] = 0x77777777; +#endif + } + } + + for (unsigned i = 0; i < size_scale; ++i) { + scale_h[i] = random_float(); +#ifdef UT_DEBUG + scale_h[i] = 1.f; +#endif + } + for (unsigned i = size_scale; i < size_scale + UNDEFINED_DATA_SIZE; ++i) { + scale_h[i] = INFINITY; + } + for (unsigned i = 0; i < size_zero_pt + UNDEFINED_DATA_SIZE; ++i) { + if constexpr (std::is_same_v) { + zero_pt_h[i] = random_uint8(); +#ifdef UT_DEBUG + zero_pt_h[i] = 0x12 << i; +#endif + } else if constexpr (std::is_same_v) { + zero_pt_h[i] = random_uint32(); +#ifdef UT_DEBUG + zero_pt_h[i] = 0x33333333; +#endif + } + } + + for (unsigned i = 0; i < size_c; ++i) { + C_h[i] = random_float(); + } + + for (unsigned i = 0; i < size_acc; ++i) { + Acc_h[i] = random_float(); + } + + for (unsigned i = 0; i < size_cnt; ++i) { + Cnt_h[i] = random_uint8(); + } + + for (unsigned i = 0; i < size_bias; ++i) { + bias_h[i] = random_float(); +#ifdef UT_DEBUG + bias_h[i] = 0.f; +#endif + } + + queue.memcpy((void*)A_d, (void*)A_h, size_a * sizeof(data_type_a)).wait(); + queue + .memcpy( + (void*)B_d, + (void*)B_h, + (size_b + UNDEFINED_DATA_SIZE) * sizeof(data_type_b)) + .wait(); + queue.memcpy((void*)C_d, (void*)C_h, size_c * sizeof(data_type_c)).wait(); + queue.memcpy((void*)Acc_d, (void*)Acc_h, size_acc * sizeof(data_type_acc)) + .wait(); + queue.memcpy((void*)Cnt_d, (void*)Cnt_h, size_cnt * sizeof(uint32_t)).wait(); + queue + .memcpy( + (void*)scale_d, + (void*)scale_h, + (size_scale + UNDEFINED_DATA_SIZE) * sizeof(data_type_scale)) + .wait(); + queue + .memcpy( + (void*)zero_pt_d, + (void*)zero_pt_h, + (size_zero_pt + UNDEFINED_DATA_SIZE) * sizeof(data_type_zero_pt)) + .wait(); + queue.memcpy((void*)bias_d, (void*)bias_h, size_bias * sizeof(data_type_bias)) + .wait(); + + queue.memset(Cnt_d, 0, size_cnt * sizeof(uint32_t)).wait(); + queue.memset(Acc_d, 0, size_acc * sizeof(data_type_acc)).wait(); + // set up gemm arguments + typename bias_op_t::shape_t bias_add_shape(matrix_n, 1, matrix_n); + using epilogue_args_t = epilogue_t::arguments_t; + + epilogue_args_t epilogue_args( + {// epilogue_args init list + // It accepts the base pointer to matrix D, and its dimensions + {bias_d, bias_add_shape}}); + typename gemm_op_t::template arguments_t gemm_arg; + if constexpr (compute_policy::quant_mode == quant_mode::S4_FULLRANGE_NO_ZP) { + gemm_arg = + typename gemm_op_t::template arguments_t( + matrix_m, + matrix_k, + matrix_n, + A_d, + lda, + B_d, + ldb, + C_d, + ldc, + scale_d, + ld_scale, + Acc_d, + Cnt_d, + epilogue_args); + } else if constexpr (compute_policy::quant_mode == quant_mode::S4_ASYM) { + gemm_arg = + typename gemm_op_t::template arguments_t( + matrix_m, + matrix_k, + matrix_n, + A_d, + lda, + B_d, + ldb, + C_d, + ldc, + scale_d, + ld_scale, + zero_pt_d, + ld_zero_pt, + Acc_d, + Cnt_d, + epilogue_args); + } + cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); + // if (!gemm_op_t::can_implement(gemm_arg)) { + // std::cout << "The arguments cannot be supported, aborting ... " + // << std::endl; + // FAIL(); + // } + + size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; + profiling_helper prof("dequantize_gemm", ops, "gflops"); +#ifdef UT_DEBUG + int constexpr warm = 0; +#else + int constexpr warm = 100; +#endif + try { + for (int i = 0; i < iter + warm; i++) { + if (i >= warm) + prof.cpu_start(); + auto e_esimd = queue.submit([&](handler& cgh) { + cgh.parallel_for(nd_range, [=](nd_item<3> item) SYCL_ESIMD_KERNEL { + // allocate slm and nbarrier resource + slm_barrier_init(); + gemm_op_t gemm_op; + gemm_op(item, gemm_arg); + }); + }); + if (i >= warm) { + e_esimd.wait(); + prof.cpu_end(); + prof.add_gpu_event(e_esimd); + } + } + } catch (cl::sycl::exception const& e) { + std::cout << "SYCL exception caught: " << e.what() << '\n'; + FAIL(); + } + + // performance + prof.print_profiling_result(profiling_selector::GPU); + // check result + std::vector dequantize_b = + dequantize_weight( + matrix_k, matrix_n, B_h, scale_h, zero_pt_h); + + queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); + ASSERT_EQ( + 0, + gemm_result_validate( + A_h, + dequantize_b.data(), + C_h, + bias_h, + matrix_m, + matrix_k, + matrix_n, + layout_a, + layout_b)); + + free(A_h, context); + free(B_h, context); + free(C_h, context); + free(scale_h, context); + free(zero_pt_h, context); + free(A_d, context); + free(B_d, context); + free(C_d, context); + free(scale_d, context); + free(zero_pt_d, context); + free(Acc_h, context); + free(Cnt_h, context); + free(Acc_d, context); + free(Cnt_d, context); +} + +template +class dequantize_gemv_test : public ::testing::Test {}; +TYPED_TEST_SUITE_P(dequantize_gemv_test); + +TYPED_TEST_P(dequantize_gemv_test, esimd) { + dequantize_gemv_run(ITER); +} + +REGISTER_TYPED_TEST_SUITE_P(dequantize_gemv_test, esimd); +using tests = ::testing::Types; + +INSTANTIATE_TYPED_TEST_SUITE_P( + dequantize_gemv_test_suite, + dequantize_gemv_test, + tests); diff --git a/tests/integration/sg_dropout_op/kernel_func.hpp b/tests/integration/sg_dropout_op/kernel_func.hpp index 36f4628a9..b566195d7 100644 --- a/tests/integration/sg_dropout_op/kernel_func.hpp +++ b/tests/integration/sg_dropout_op/kernel_func.hpp @@ -66,7 +66,7 @@ struct dropout_func_t { using mat_in_payload_t = subgroup::mem_payload_t< mem_desc_in_t, tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using tile_op_t = typename std::conditional< diff --git a/tests/integration/softmax/softmax_bwd_kernel.hpp b/tests/integration/softmax/softmax_bwd_kernel.hpp index fa6fddb56..e556c67c4 100644 --- a/tests/integration/softmax/softmax_bwd_kernel.hpp +++ b/tests/integration/softmax/softmax_bwd_kernel.hpp @@ -30,11 +30,6 @@ template < uint32_t sg_n, uint32_t sg_m> struct softmax_bwd_test_func { - using mem_desc_in_t = - mem_desc_t; - using mem_desc_out_t = - mem_desc_t; - using tile_shape = group::tile_shape_t; using work_group_t = typename tile_shape::work_group_t; static constexpr uint32_t wg_size_x = tile_shape::wg_size_x; @@ -61,17 +56,21 @@ struct softmax_bwd_test_func { reg_layout::tiled>; using matAcc_t = subgroup::tile_t; using mat_in_t = subgroup::tile_t; + using mem_desc_in_t = + mem_desc_t; using mat_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using mat_out_t = subgroup::tile_t; + using mem_desc_out_t = + mem_desc_t; using mat_out_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_out_t, tile_desc_t, - (tile_size_y > 1) ? msg_type::block_2d : msg_type::block_1d, + subgroup::msg_type_v, gpu_arch::XeHpc>; using softmax_bwd_t = group::softmax_t< diff --git a/tests/integration/softmax/softmax_fwd_kernel.hpp b/tests/integration/softmax/softmax_fwd_kernel.hpp index 5237d47e4..abbd14696 100644 --- a/tests/integration/softmax/softmax_fwd_kernel.hpp +++ b/tests/integration/softmax/softmax_fwd_kernel.hpp @@ -60,16 +60,17 @@ struct softmax_fwd_test_func { reg_layout::tiled>; using matAcc_t = subgroup::tile_t; using mat_in_t = subgroup::tile_t; + using mat_in_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, tile_desc_t, - subgroup::msg_type_v, + subgroup::msg_type_v, gpu_arch::XeHpc>; using mat_out_t = subgroup::tile_t; using mat_out_payload_t = subgroup::mem_payload_t< - mem_desc_t, + mem_desc_in_t, tile_desc_t, - (tile_size_y > 1) ? msg_type::block_2d : msg_type::block_1d, + subgroup::msg_type_v, gpu_arch::XeHpc>; using softmax_fwd_t = group::softmax_t< diff --git a/tests/unit/epilogue_tile_op/kernel_func.hpp b/tests/unit/epilogue_tile_op/kernel_func.hpp index 1d273857e..a13481044 100644 --- a/tests/unit/epilogue_tile_op/kernel_func.hpp +++ b/tests/unit/epilogue_tile_op/kernel_func.hpp @@ -41,7 +41,7 @@ struct tile_elemwise_op_func { using matA_payload_t = mem_payload_t< mem_desc_c_t, matA_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; using tile_shape = tile_shape_t; @@ -95,7 +95,7 @@ struct tile_elemwise_op_func< using matA_payload_t = mem_payload_t< mem_desc_b_t, matA_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; using tile_shape = tile_shape_t; using epilogue_policy = epilogue_policy_tile_op< @@ -150,7 +150,7 @@ struct tile_elemwise_op_func< using matA_payload_t = mem_payload_t< mem_desc_c_t, matA_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; using tile_shape = tile_shape_t; using epilogue_policy = epilogue_policy_tile_op< diff --git a/tests/unit/tile_mma/kernel_func.hpp b/tests/unit/tile_mma/kernel_func.hpp index 8822c8ca2..d8130c4a5 100644 --- a/tests/unit/tile_mma/kernel_func.hpp +++ b/tests/unit/tile_mma/kernel_func.hpp @@ -56,20 +56,26 @@ struct tile_mma_func { using matA_t = tile_t; using matB_t = tile_t; using matC_t = tile_t; + using mem_desc_a_t = + mem_desc_t; using matA_payload_t = mem_payload_t< - mem_desc_t, + mem_desc_a_t, matA_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; + using mem_desc_b_t = + mem_desc_t; using matB_payload_t = mem_payload_t< - mem_desc_t, + mem_desc_b_t, matB_tile_desc_t, - msg_type_v, + msg_type_v, gpu_arch::XeHpc>; + using mem_desc_c_t = + mem_desc_t; using matC_payload_t = mem_payload_t< - mem_desc_t, + mem_desc_c_t, matC_tile_desc_t, - msg_type::block_2d, + msg_type_v, gpu_arch::XeHpc>; using matAcc_t = tile_t>; diff --git a/tests/unit/tile_row_reduction/kernel_func.hpp b/tests/unit/tile_row_reduction/kernel_func.hpp index e4c1b5c3e..710143de5 100644 --- a/tests/unit/tile_row_reduction/kernel_func.hpp +++ b/tests/unit/tile_row_reduction/kernel_func.hpp @@ -48,12 +48,16 @@ struct tile_row_reduction_func { using matA_payload_t = mem_payload_t< mem_desc_t, matA_tile_desc_t, - msg_type_v, + msg_type_v< + matA_tile_desc_t, + mem_desc_t>, gpu_arch::XeHpc>; using matC_payload_t = mem_payload_t< mem_desc_t, matC_tile_desc_t, - msg_type_v, + msg_type_v< + matC_tile_desc_t, + mem_desc_t>, gpu_arch::XeHpc>; matA_t matA; matC_t matC; diff --git a/tests/utils/common.hpp b/tests/utils/common.hpp index 432727252..4a79de7c4 100644 --- a/tests/utils/common.hpp +++ b/tests/utils/common.hpp @@ -36,6 +36,9 @@ using namespace gpu::xetla; #define random_float() (generate_real_random()) #define random_uint8() (generate_int_random(0, 255)) +#define random_uint32() \ + (generate_int_random(0, std::numeric_limits::max())) + template inline auto getTypeName() { fprintf(stderr, "FAIL: Not implemented specialization\n"); diff --git a/third_party/pybind11 b/third_party/pybind11 new file mode 160000 index 000000000..dc9b39596 --- /dev/null +++ b/third_party/pybind11 @@ -0,0 +1 @@ +Subproject commit dc9b39596d986aeb061bd3debe52d30e2467dc48