2828 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929 *
3030 **************************************************************************************************/
31-
3231#pragma once
3332
3433#include " cutlass/epilogue/collective/default_epilogue.hpp"
@@ -220,64 +219,6 @@ template <class FMHAChunkPrefillKernel, bool isVarLen> struct ExampleRunner {
220219 // Methods
221220 //
222221
223- /*
224- template <typename T>
225- void initialize_block_random(cutlass::DeviceAllocation<T>& block) {
226- if (block.size() == 0) {
227- return;
228- }
229- std::vector<T> host_tensor(block.size());
230- std::mt19937 gen(seed);
231- std::uniform_real_distribution<float> dis(-1.f, 1.f);
232-
233- for (size_t i = 0; i < host_tensor.size(); ++i) {
234- host_tensor[i] = static_cast<T>(dis(gen));
235- }
236- block.copy_from_host(host_tensor.data(), host_tensor.size());
237- }
238- */
239-
240- template <typename T>
241- void initialize_block_random (cutlass::DeviceAllocation<T>& block) {
242- if (block.size () == 0 ) {
243- return ;
244- }
245- std::vector<T> host_tensor (block.size ());
246- std::mt19937 gen (seed);
247- std::uniform_int_distribution<> dis (1 , 9 );
248-
249- for (size_t i = 0 ; i < host_tensor.size (); ++i) {
250- host_tensor[i] = static_cast <T>(dis (gen));
251- }
252- block.copy_from_host (host_tensor.data (), host_tensor.size ());
253- }
254-
255- template <typename T>
256- void initialize_block_identity (cutlass::DeviceAllocation<T>& block, int rows, int cols) {
257- if (block.size () == 0 ) {
258- return ;
259- }
260- std::vector<T> host_tensor (block.size (), T (0 .f ));
261- for (int i = 0 ; i < rows; ++i) {
262- if (i < cols) {
263- host_tensor[i * cols + i] = T (1 .f );
264- }
265- }
266- block.copy_from_host (host_tensor.data (), host_tensor.size ());
267- }
268-
269- template <typename T>
270- void initialize_block_iota (cutlass::DeviceAllocation<T>& block) {
271- if (block.size () == 0 ) {
272- return ;
273- }
274- std::vector<T> host_tensor (block.size ());
275- for (size_t i = 0 ; i < host_tensor.size (); ++i) {
276- host_tensor[i] = static_cast <T>(static_cast <float >(1.0 ));
277- }
278- block.copy_from_host (host_tensor.data (), host_tensor.size ());
279- }
280-
281222template <typename SrcType, typename DstType, typename Encoding>
282223void run_conversion_kernel (SrcType* src_ptr_in, DstType* dst_ptr_in, int64_t num_elements, float scale) {
283224 sycl::queue queue = compat::get_default_queue ();
@@ -300,28 +241,6 @@ void run_conversion_kernel(SrcType* src_ptr_in, DstType* dst_ptr_in, int64_t num
300241 });
301242}
302243
303- template <typename T>
304- void print_device_tensor (const char * name, T* ptr, size_t size, int max_elements_to_print = 1153 ) {
305- std::cout << " --- " << name << " ---" << std::endl;
306- if (ptr == nullptr || size == 0 ) {
307- std::cout << " (null)" << std::endl;
308- return ;
309- }
310- std::vector<T> host_tensor (size);
311- compat::memcpy (host_tensor.data (), ptr, size * sizeof (T));
312- compat::wait ();
313-
314- int count = 0 ;
315- for (const auto & val : host_tensor) {
316- if (count++ >= max_elements_to_print) {
317- std::cout << " ..." << std::endl;
318- break ;
319- }
320- std::cout << static_cast <float >(val) << " " ;
321- }
322- std::cout << std::endl << " --- End " << name << " ---" << std::endl;
323- }
324-
325244bool verify (ProblemShapeType problem_size, Options options, const float * q_scale, const float * k_scale, const float * v_scale) {
326245 std::vector<ElementOutput> host_O (block_ref_O.size ());
327246
@@ -351,7 +270,7 @@ bool verify(ProblemShapeType problem_size, Options options, const float* q_scale
351270 int offset_o = 0 ;
352271
353272 using namespace cutlass ;
354- using RefElement = bfloat16_t ; // half_t;
273+ using RefElement = bfloat16_t ;
355274 DeviceAllocation<RefElement> block_Q_ref, block_K_ref, block_V_ref;
356275
357276 // loop over the batch dimension to compute the output
@@ -479,22 +398,6 @@ bool verify(ProblemShapeType problem_size, Options options, const float* q_scale
479398 }
480399 compat::wait ();
481400
482- // Print inputs for the first batch item
483- if (b == 0 ) {
484- if constexpr (is_fp8_v<ElementQ>) {
485- std::cout << " \n ========= FP8 Kernel Inputs (Batch 0) =========\n " ;
486- print_device_tensor (" FP8 Input Q" , q_ptr_orig, seq_len_qo * num_heads_q * head_size_qk);
487- print_device_tensor (" FP8 Input K" , k_ptr_orig, seq_len_kv_total * num_heads_kv * head_size_qk);
488- print_device_tensor (" FP8 Input V" , v_ptr_orig, seq_len_kv_total * num_heads_kv * head_size_vo);
489- std::cout << " \n ========= Reference Kernel Inputs (Batch 0, Descaled) =========\n " ;
490- } else {
491- std::cout << " \n ========= FP16 Kernel and Reference Kernel Inputs (Batch 0) =========\n " ;
492- }
493- print_device_tensor (" Input Q" , reinterpret_cast <RefElement*>(q_ptr), seq_len_qo * num_heads_q * head_size_qk);
494- print_device_tensor (" Input K" , reinterpret_cast <RefElement*>(k_ptr), seq_len_kv_total * num_heads_kv * head_size_qk);
495- print_device_tensor (" Input V" , reinterpret_cast <RefElement*>(v_ptr), seq_len_kv_total * num_heads_kv * head_size_vo);
496- }
497-
498401 for (int q_group = 0 ; q_group < num_heads_q / q_group_size; q_group++) {
499402 for (int q_head = 0 ; q_head < q_group_size; q_head++) {
500403 cutlass::DeviceAllocation<ElementAccumulator> block_S;
@@ -646,11 +549,6 @@ bool verify(ProblemShapeType problem_size, Options options, const float* q_scale
646549 compat::wait ();
647550 compat::memcpy<ElementOutput>(block_ref_O.get (), host_O.data (), host_O.size ());
648551
649- std::cout << " \n ========= Kernel Outputs =========\n " ;
650- print_device_tensor (" Actual Kernel Output (block_O)" , block_O.get (), block_O.size ());
651- print_device_tensor (" Reference Kernel Output (block_ref_O)" , block_ref_O.get (), block_ref_O.size ());
652- std::cout << " \n ==================================\n " ;
653-
654552 // Check if output from CUTLASS kernel and reference kernel are equal or not
655553 bool passed = cutlass::reference::device::BlockCompareRelativelyEqual (block_ref_O.get (), block_O.get (),
656554 block_O.size (), ElementOutput{0.5 }, ElementOutput{0.5 });
@@ -806,18 +704,6 @@ bool verify(ProblemShapeType problem_size, Options options, const float* q_scale
806704 block_V_cache.reset (num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_vo);
807705 }
808706
809- /* initialize_block_iota(block_Q);
810- initialize_block_iota(block_K);
811- initialize_block_iota(block_V); //, seq_len_kv, head_size_vo);
812- initialize_block_iota(block_K_cache);
813- initialize_block_iota(block_V_cache); //, seq_len_kv_cache, head_size_vo);*/
814- //
815- /* initialize_block_random(block_Q);
816- initialize_block_random(block_K);
817- initialize_block_random(block_V);
818- initialize_block_random(block_K_cache);
819- initialize_block_random(block_V_cache);*/
820-
821707 initialize_block (block_Q, seed + 2023 );
822708 initialize_block (block_K, seed + 2022 );
823709 initialize_block (block_V, seed + 2021 );
0 commit comments