Skip to content

Commit 177ec96

Browse files
committed
Code cleanup
Signed-off-by: Aditya Chatterjee <[email protected]>
1 parent a43c1bd commit 177ec96

File tree

5 files changed

+30
-322
lines changed

5 files changed

+30
-322
lines changed

applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
*
3030
**************************************************************************************************/
31-
3231
#pragma once
3332

3433
#include "cutlass/cutlass.h"
@@ -280,8 +279,7 @@ struct FlashChunkPrefillMma<
280279
template <class FragQccum, class TensorQ, class TensorK, class FragSrc>
281280
CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK,
282281
FragSrc const &frag_src, int const &k_tile_count,
283-
Params const &params, bool is_KV_cache,
284-
int const& q_head_coord, int const& kv_head_coord) {
282+
Params const &params, bool is_KV_cache) {
285283

286284
auto &gmem_tiled_copy_k =
287285
is_KV_cache ? params.gmem_tiled_copy_k_cache : params.gmem_tiled_copy_k;
@@ -317,8 +315,9 @@ struct FlashChunkPrefillMma<
317315
Tensor tQgQ = thr_copy_Q.retile_S(tCgQ);
318316
Tensor tKgK = thr_copy_K.retile_S(tCgK);
319317

320-
float q_scale = params.ptr_q_scale[0]; //q_head_coord];
321-
float k_scale = params.ptr_k_scale[0]; //kv_head_coord];
318+
// Currently, supporting per-tensor scaling
319+
float q_scale = params.ptr_q_scale[0];
320+
float k_scale = params.ptr_k_scale[0];
322321

323322
//
324323
// Mainloop
@@ -327,22 +326,15 @@ struct FlashChunkPrefillMma<
327326
copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ);
328327
copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK);
329328

330-
// FP8 path: Convert FP8 fragments to FP16 IN-PLACE to avoid register spilling.
329+
// FP8 path: Convert FP8 fragments to BF16
331330
if constexpr (is_fp8_v<ElementQ> || is_fp8_v<ElementK>) {
332-
// Recast the memory region of the FP8 tensors as FP16 tensors.
333-
// This does NOT allocate new registers. It reuses the existing ones.
334-
//auto tCrQ_fp16 = cute::recast<half_t>(tCrQ);
335-
//auto tCrK_fp16 = cute::recast<half_t>(tCrK);
336-
337331
auto tCrQ_fp16 = make_fragment_like<bfloat16_t>(tCrQ);
338332
auto tCrK_fp16 = make_fragment_like<bfloat16_t>(tCrK);
339333

340-
// Perform the conversion, writing the FP16 results directly into the
341-
// reused register space.
342334
if constexpr (is_fp8_v<ElementQ>) {
343335
convert_and_descale<ElementQ>(tCrQ, tCrQ_fp16, q_scale);
344336
} else {
345-
// If Q is already FP16, just copy it to the correctly-named variable.
337+
// If Q is already FP16, copy it.
346338
copy(tCrQ, tCrQ_fp16);
347339
}
348340

@@ -352,11 +344,10 @@ struct FlashChunkPrefillMma<
352344
copy(tCrK, tCrK_fp16);
353345
}
354346

355-
// Now, gemm is called on the FP16 tensors which occupy the same
356-
// register space as the original FP8 tensors did. Register pressure is not increased.
347+
// Now, gemm is called on the BF16 tensors
357348
cute::gemm(tiled_mma, accum, tCrQ_fp16, tCrK_fp16, frag_src);
358349
} else {
359-
// FP16 path (already fast)
350+
// BF16 path
360351
cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src);
361352
}
362353

@@ -404,12 +395,12 @@ struct FlashChunkPrefillMma<
404395
class FragSrc>
405396
CUTLASS_DEVICE void mmaPV(FragQccum &accum, FragS const &tSr, TensorV gV,
406397
FragSrc const &frag_src, Params const &params,
407-
bool is_KV_cache, int const& kv_head_coord) {
398+
bool is_KV_cache) {
408399

409400
auto &gmem_tiled_copy_v =
410401
is_KV_cache ? params.gmem_tiled_copy_v_cache : params.gmem_tiled_copy_v;
411402

412-
float v_scale = params.ptr_v_scale[0]; //kv_head_coord];
403+
float v_scale = params.ptr_v_scale[0];
413404

414405
int thread_idx = static_cast<int>(ThreadIdxX());
415406
// Instantiate the MMA object
@@ -461,20 +452,11 @@ struct FlashChunkPrefillMma<
461452
copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV);
462453

463454
if constexpr (is_fp8_v<ElementV>) {
464-
// Correctly reuse the registers of tCrV for the new FP16 tensor.
465-
// This avoids doubling the register pressure.
466-
//auto tCrV_fp16 = cute::recast<half_t>(tCrV);
467455
auto tCrV_fp16 = make_fragment_like<bfloat16_t>(tCrV);
468-
469-
// Perform the conversion in-place, overwriting the old FP8 data
470-
// with the new FP16 data in the same register space.
471456
convert_and_descale<ElementV>(tCrV, tCrV_fp16, v_scale);
472457

473-
// The GEMM now operates on an FP16 tensor that is in registers,
474-
// preventing a catastrophic performance drop from register spilling.
475458
cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV_fp16, frag_src(_,_,_,i));
476459
} else {
477-
// Native FP16 path (already fast)
478460
cute::gemm(tiled_mma, accum(_,_,_,i), tPr, tCrV, frag_src(_,_,_,i));
479461
}
480462
}

applications/flash_attention_v2/kernel/xe_chunk_prefill.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,8 @@ class FMHAPrefillChunk {
454454
// Perform the collective scoped MMA
455455
CollectiveMainloop collective_mma;
456456

457-
auto q_group_size = num_heads_q / num_heads_kv;
458-
auto kv_head_coord = q_head_coord / q_group_size;
457+
// auto q_group_size = num_heads_q / num_heads_kv;
458+
// auto kv_head_coord = q_head_coord / q_group_size;
459459

460460
// when causal mask is true. It is not possible to set the scope
461461
// of the barrier to workgroup level as the number n block is
@@ -483,7 +483,7 @@ class FMHAPrefillChunk {
483483

484484
collective_mma.mmaQK(tSr, gQ, gK_, tSr,
485485
ceil_div(head_size_qk, QK_BLK_K), mainloop_params,
486-
is_KV_cache, q_head_coord, kv_head_coord);
486+
is_KV_cache);
487487

488488
if constexpr (LocalMask) {
489489
// Sliding windows
@@ -577,7 +577,7 @@ class FMHAPrefillChunk {
577577

578578
// 5) Perform GEMM O = S*V
579579
collective_mma.template mmaPV<VSlicer>(out_reg, tSr, gV_, out_reg,
580-
mainloop_params, is_KV_cache, kv_head_coord);
580+
mainloop_params, is_KV_cache);
581581

582582
// ... prefetch next tile ...
583583
// Prefetch the next Q tile
@@ -624,7 +624,7 @@ class FMHAPrefillChunk {
624624
// 3) Perform GEMM S = Q*K
625625
collective_mma.mmaQK(tSr, gQ, gK(_, _, kv_splits_new - 1, _), tSr,
626626
ceil_div(head_size_qk, QK_BLK_K), mainloop_params,
627-
false, q_head_coord, kv_head_coord);
627+
false);
628628

629629
// we only need one block ahead, there is enough gap to prefetch it
630630
// while doing softmax. because the gap between the two MMA is big,
@@ -655,7 +655,7 @@ class FMHAPrefillChunk {
655655

656656
collective_mma.template mmaPV<VSlicer>(out_reg, tSr,
657657
gV(_, _, kv_splits_new - 1),
658-
out_reg, mainloop_params, false, kv_head_coord);
658+
out_reg, mainloop_params, false);
659659
}
660660

661661

examples/06_bmg_flash_attention/06_bmg_chunk_prefill_fp8.cpp

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,38 +29,28 @@
2929
*
3030
**************************************************************************************************/
3131
/*! \file
32-
\brief Flash Attention V2 Prefill for Intel BMG
32+
\brief fp8 Chunk Prefill for Intel BMG
3333
34-
This example constructs and executes a Flash Attention Prefill with KV cache on Intel BMG. The
34+
This example constructs and executes a FP8 Flash Attention Chunk Prefill on Intel BMG. The
3535
definition of the GEMM, options etc for this example are defined in the associated
36-
bmg_flash_attn_cachedKV_runner.hpp header file.
36+
bmg_flash_chunk_prefill_runner.hpp header file.
3737
3838
See https://arxiv.org/pdf/2307.08691 for details of Flash Attention V2 algorithm
3939
40-
To run this example:
41-
$ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV --seq_len_qo=512
42-
--seq_len_kv=512 --seq_len_kv_cache=512 --head_size_vo=128 --head_size_qk=128
43-
44-
Causal masking of the first matrix multiplication is supported (`--is_causal`)
45-
4640
To build & run this example (from your build dir):
4741
48-
$ ninja 06_bmg_prefill_attention_cachedKV
49-
$ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV
42+
$ ninja 06_bmg_chunk_prefill_fp8_hdim128
43+
$ ./examples/06_bmg_flash_attention/06_bmg_chunk_prefill_fp8_hdim128
5044
5145
Call with `--help` for information about available options
5246
*/
5347

5448
#include "bmg_flash_chunk_prefill_runner.hpp"
5549

5650
int main(int argc, const char **argv) {
57-
//
5851
// Parse options
59-
//
6052

6153
Options options;
62-
// Override the default data type for this test
63-
// options.dtype = "fp8";
6454
options.parse(argc, argv);
6555

6656
if (options.help) {
@@ -118,12 +108,12 @@ int main(int argc, const char **argv) {
118108
// =================================================================================================
119109
// FP8 Type Definitions
120110
// =================================================================================================
121-
using ElementInputQ = cutlass::float_e5m2_t; // <- data type of elements in input matrix A
122-
using ElementInputKV = cutlass::float_e5m2_t; // <- data type of elements in input matrix B
123-
using MMAOperation = XE_8x16x16_F32F16F16F32_TT; //XE_8x16x16_F32BF16BF16F32_TT;
124-
using GmemTiledCopyQ = XE_2D_U8x8x32_LD_N; // XE_2D_U8x8x32_LD_N;
125-
using GmemTiledCopyK = XE_2D_U8x16x16_LD_T; // _T designates a transposed block load operation
126-
using GmemTiledCopyV = XE_2D_U8x32x32_LD_V;
111+
using ElementInputQ = cutlass::float_e5m2_t; // data type of elements in input matrix A
112+
using ElementInputKV = cutlass::float_e5m2_t; // data type of elements in input matrix B
113+
using MMAOperation = XE_8x16x16_F32F16F16F32_TT; //XE_8x16x16_F32BF16BF16F32_TT;
114+
using GmemTiledCopyQ = XE_2D_U8x8x32_LD_N; // XE_2D_U8x8x32_LD_N;
115+
using GmemTiledCopyK = XE_2D_U8x16x16_LD_T; // _T designates a transposed block load operation
116+
using GmemTiledCopyV = XE_2D_U8x32x32_LD_V;
127117

128118
constexpr int PipelineStages = 2;
129119

examples/06_bmg_flash_attention/bmg_flash_chunk_prefill_runner.hpp

Lines changed: 1 addition & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929
*
3030
**************************************************************************************************/
31-
3231
#pragma once
3332

3433
#include "cutlass/epilogue/collective/default_epilogue.hpp"
@@ -220,64 +219,6 @@ template <class FMHAChunkPrefillKernel, bool isVarLen> struct ExampleRunner {
220219
// Methods
221220
//
222221

223-
/*
224-
template <typename T>
225-
void initialize_block_random(cutlass::DeviceAllocation<T>& block) {
226-
if (block.size() == 0) {
227-
return;
228-
}
229-
std::vector<T> host_tensor(block.size());
230-
std::mt19937 gen(seed);
231-
std::uniform_real_distribution<float> dis(-1.f, 1.f);
232-
233-
for (size_t i = 0; i < host_tensor.size(); ++i) {
234-
host_tensor[i] = static_cast<T>(dis(gen));
235-
}
236-
block.copy_from_host(host_tensor.data(), host_tensor.size());
237-
}
238-
*/
239-
240-
template <typename T>
241-
void initialize_block_random(cutlass::DeviceAllocation<T>& block) {
242-
if (block.size() == 0) {
243-
return;
244-
}
245-
std::vector<T> host_tensor(block.size());
246-
std::mt19937 gen(seed);
247-
std::uniform_int_distribution<> dis(1, 9);
248-
249-
for (size_t i = 0; i < host_tensor.size(); ++i) {
250-
host_tensor[i] = static_cast<T>(dis(gen));
251-
}
252-
block.copy_from_host(host_tensor.data(), host_tensor.size());
253-
}
254-
255-
template <typename T>
256-
void initialize_block_identity(cutlass::DeviceAllocation<T>& block, int rows, int cols) {
257-
if (block.size() == 0) {
258-
return;
259-
}
260-
std::vector<T> host_tensor(block.size(), T(0.f));
261-
for (int i = 0; i < rows; ++i) {
262-
if (i < cols) {
263-
host_tensor[i * cols + i] = T(1.f);
264-
}
265-
}
266-
block.copy_from_host(host_tensor.data(), host_tensor.size());
267-
}
268-
269-
template <typename T>
270-
void initialize_block_iota(cutlass::DeviceAllocation<T>& block) {
271-
if (block.size() == 0) {
272-
return;
273-
}
274-
std::vector<T> host_tensor(block.size());
275-
for (size_t i = 0; i < host_tensor.size(); ++i) {
276-
host_tensor[i] = static_cast<T>(static_cast<float>(1.0));
277-
}
278-
block.copy_from_host(host_tensor.data(), host_tensor.size());
279-
}
280-
281222
template <typename SrcType, typename DstType, typename Encoding>
282223
void run_conversion_kernel(SrcType* src_ptr_in, DstType* dst_ptr_in, int64_t num_elements, float scale) {
283224
sycl::queue queue = compat::get_default_queue();
@@ -300,28 +241,6 @@ void run_conversion_kernel(SrcType* src_ptr_in, DstType* dst_ptr_in, int64_t num
300241
});
301242
}
302243

303-
template<typename T>
304-
void print_device_tensor(const char* name, T* ptr, size_t size, int max_elements_to_print = 1153) {
305-
std::cout << "--- " << name << " ---" << std::endl;
306-
if (ptr == nullptr || size == 0) {
307-
std::cout << "(null)" << std::endl;
308-
return;
309-
}
310-
std::vector<T> host_tensor(size);
311-
compat::memcpy(host_tensor.data(), ptr, size * sizeof(T));
312-
compat::wait();
313-
314-
int count = 0;
315-
for (const auto& val : host_tensor) {
316-
if (count++ >= max_elements_to_print) {
317-
std::cout << "..." << std::endl;
318-
break;
319-
}
320-
std::cout << static_cast<float>(val) << " ";
321-
}
322-
std::cout << std::endl << "--- End " << name << " ---" << std::endl;
323-
}
324-
325244
bool verify(ProblemShapeType problem_size, Options options, const float* q_scale, const float* k_scale, const float* v_scale) {
326245
std::vector<ElementOutput> host_O(block_ref_O.size());
327246

@@ -351,7 +270,7 @@ bool verify(ProblemShapeType problem_size, Options options, const float* q_scale
351270
int offset_o = 0;
352271

353272
using namespace cutlass;
354-
using RefElement = bfloat16_t; //half_t;
273+
using RefElement = bfloat16_t;
355274
DeviceAllocation<RefElement> block_Q_ref, block_K_ref, block_V_ref;
356275

357276
// loop over the batch dimension to compute the output
@@ -479,22 +398,6 @@ bool verify(ProblemShapeType problem_size, Options options, const float* q_scale
479398
}
480399
compat::wait();
481400

482-
// Print inputs for the first batch item
483-
if (b == 0) {
484-
if constexpr (is_fp8_v<ElementQ>) {
485-
std::cout << "\n========= FP8 Kernel Inputs (Batch 0) =========\n";
486-
print_device_tensor("FP8 Input Q", q_ptr_orig, seq_len_qo * num_heads_q * head_size_qk);
487-
print_device_tensor("FP8 Input K", k_ptr_orig, seq_len_kv_total * num_heads_kv * head_size_qk);
488-
print_device_tensor("FP8 Input V", v_ptr_orig, seq_len_kv_total * num_heads_kv * head_size_vo);
489-
std::cout << "\n========= Reference Kernel Inputs (Batch 0, Descaled) =========\n";
490-
} else {
491-
std::cout << "\n========= FP16 Kernel and Reference Kernel Inputs (Batch 0) =========\n";
492-
}
493-
print_device_tensor("Input Q", reinterpret_cast<RefElement*>(q_ptr), seq_len_qo * num_heads_q * head_size_qk);
494-
print_device_tensor("Input K", reinterpret_cast<RefElement*>(k_ptr), seq_len_kv_total * num_heads_kv * head_size_qk);
495-
print_device_tensor("Input V", reinterpret_cast<RefElement*>(v_ptr), seq_len_kv_total * num_heads_kv * head_size_vo);
496-
}
497-
498401
for (int q_group = 0; q_group < num_heads_q / q_group_size; q_group++) {
499402
for (int q_head = 0; q_head < q_group_size; q_head++) {
500403
cutlass::DeviceAllocation<ElementAccumulator> block_S;
@@ -646,11 +549,6 @@ bool verify(ProblemShapeType problem_size, Options options, const float* q_scale
646549
compat::wait();
647550
compat::memcpy<ElementOutput>(block_ref_O.get(), host_O.data(), host_O.size());
648551

649-
std::cout << "\n========= Kernel Outputs =========\n";
650-
print_device_tensor("Actual Kernel Output (block_O)", block_O.get(), block_O.size());
651-
print_device_tensor("Reference Kernel Output (block_ref_O)", block_ref_O.get(), block_ref_O.size());
652-
std::cout << "\n==================================\n";
653-
654552
// Check if output from CUTLASS kernel and reference kernel are equal or not
655553
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
656554
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});
@@ -806,18 +704,6 @@ bool verify(ProblemShapeType problem_size, Options options, const float* q_scale
806704
block_V_cache.reset(num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_vo);
807705
}
808706

809-
/*initialize_block_iota(block_Q);
810-
initialize_block_iota(block_K);
811-
initialize_block_iota(block_V); //, seq_len_kv, head_size_vo);
812-
initialize_block_iota(block_K_cache);
813-
initialize_block_iota(block_V_cache); //, seq_len_kv_cache, head_size_vo);*/
814-
//
815-
/*initialize_block_random(block_Q);
816-
initialize_block_random(block_K);
817-
initialize_block_random(block_V);
818-
initialize_block_random(block_K_cache);
819-
initialize_block_random(block_V_cache);*/
820-
821707
initialize_block(block_Q, seed + 2023);
822708
initialize_block(block_K, seed + 2022);
823709
initialize_block(block_V, seed + 2021);

0 commit comments

Comments
 (0)