Skip to content
Draft

GQA #623

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions applications/flash_attention_v2/collective/copy_block_slm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@
#pragma once

namespace cute {
namespace detail {
template <class T>
struct slm_scalar_type {
using type = T;
};

template <>
struct slm_scalar_type<cutlass::bfloat16_t> {
using type = sycl::ext::oneapi::bfloat16;
};

template <class T>
using slm_scalar_type_t = typename slm_scalar_type<T>::type;

} // namespace detail

/* Flat copies */
template <class SrcEngine, class SrcLayout,
Expand All @@ -42,8 +57,10 @@ copy_block_r2s(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
static_assert(is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, "Expected rmem->smem copy");
using dtype = typename SrcEngine::value_type;
using slm_dtype = detail::slm_scalar_type_t<dtype>;

auto atom_r2s = Copy_Atom<XE_1D_STSM<float>, float>{}; // TODO: larger block messages
auto atom_r2s = Copy_Atom<XE_1D_STSM<slm_dtype>, slm_dtype>{}; // TODO: larger block messages

auto atom_shape = make_shape(_1{}, size(src));
auto src_v = src.compose(make_layout(atom_shape));
Expand All @@ -60,8 +77,10 @@ copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,
Tensor<DstEngine, DstLayout> & dst)
{
static_assert(is_smem_v<SrcEngine> && is_rmem_v<DstEngine>, "Expected smem->rmem copy");
using dtype = typename SrcEngine::value_type;
using slm_dtype = detail::slm_scalar_type_t<dtype>;

auto atom_s2r = Copy_Atom<XE_1D_LDSM<float>, float>{};
auto atom_s2r = Copy_Atom<XE_1D_LDSM<slm_dtype>, slm_dtype>{};

auto atom_shape = make_shape(_1{}, size(dst));
auto src_v = src.compose(make_layout(atom_shape, Stride<_0, _16>{}));
Expand All @@ -83,8 +102,10 @@ copy_block_r2s(SubgroupTensor<SrcEngine, SrcLayout, SrcCoordLayout> const& src,

static_assert(is_rmem_v<SrcEngine> && is_smem_v<DstEngine>, "Expected rmem->smem copy");
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == 32, "Only 32-bit data supported");
using dtype = typename SrcEngine::value_type;
using slm_dtype = detail::slm_scalar_type_t<dtype>;

auto atom_r2s = Copy_Atom<XE_1D_STSM<float>, float>{}; // TODO: larger block messages
auto atom_r2s = Copy_Atom<XE_1D_STSM<slm_dtype>, slm_dtype>{}; // TODO: larger block messages

auto atom_shape = make_shape(_1{}, size(SrcLayout{}));

Expand All @@ -109,8 +130,10 @@ copy_block_s2r(Tensor<SrcEngine, SrcLayout> const& src,

static_assert(is_smem_v<SrcEngine> && is_rmem_v<DstEngine>, "Expected smem->rmem copy");
static_assert(sizeof_bits_v<typename SrcEngine::value_type> == 32, "Only 32-bit data supported");
using dtype = typename SrcEngine::value_type;
using slm_dtype = detail::slm_scalar_type_t<dtype>;

auto atom_s2r = Copy_Atom<XE_1D_LDSM<float>, float>{};
auto atom_s2r = Copy_Atom<XE_1D_LDSM<slm_dtype>, slm_dtype>{};

auto atom_shape = make_shape(_1{}, size(DstLayout{}));

Expand Down
101 changes: 65 additions & 36 deletions applications/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ class FMHAFwdEpilogue {

using TensorO = TensorO_;
using TensorO2D = decltype(TensorO_{}(append<rank_v<TensorO_>>(make_coord(_,_),0)));
using TensorO3D = decltype(TensorO_{}(append<rank_v<TensorO_>>(make_coord(_,_,_),0)));
using ElementO = typename TensorO_::value_type;

using FragA = typename CollectiveMainloop::FragA;
using FragARow = typename CollectiveMainloop::FragARow;
using ElementA = typename FragA::value_type;

static constexpr int QGroupSize = 2;
// Split k-reduced tiles between participating subgroups.
// Assumption: the A tile is contiguous.
using ReduceK = decltype(size<3>(typename TiledMMAPV::ThrLayoutVMNK{}));
Expand Down Expand Up @@ -135,50 +136,78 @@ class FMHAFwdEpilogue {
CUTLASS_HOST_DEVICE
FMHAFwdEpilogue(Params const&, SharedStorage& shared_) : shared(shared_) {}

template <typename QVCoord>
template <typename FragASLM, typename FragAMaxSLM, typename FragARowSLM, typename QVCoord>
CUTLASS_DEVICE
void
operator()(TensorO2D const& O, // Global O tensor: (q,v)
FragA & tArA, // O accumulator: (q,v)
FragARow & tA_max, // Softmax row-wise max accumulator
FragARow & tA_sum, // Softmax row-wise sum accumulator
QVCoord blk_qv, // WG tile indices: (q,v)
int thr_id) { // Work-item ID
operator()(TensorO3D const& O_3D, // Global O tensor: (q,v,h)
// FragA & tArA, // O accumulator: (q,v)
// FragARow & tA_max, // Softmax row-wise max accumulator
// FragARow & tA_sum, // Softmax row-wise sum accumulator
FragASLM &tArA_slm, // Output accumulator (q,v)
FragAMaxSLM &tA_max_slm, // Softmax row-wise max accumulator
FragARowSLM &tA_sum_slm, // Softmax row-wise sum accumulator
QVCoord blk_qv, // WG tile indices: (q,v)
int blk_head_kv, // kv head index
int thr_id) { // Work-item ID

using namespace cute;
using ElementA = typename FragA::element_type;

// Reduce k-blocks of A and A_sum across WG, if needed.
auto [rA, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id);

/* Some subgroups may not have any work to do; if so, quit early. */
if (!active) return;

/* Complete softmax, dividing out sums. */
int blk_head_q_start = blk_head_kv * QGroupSize;
auto sg = compat::get_nd_item<1>().get_sub_group();
int sg_id = sg.get_group_id()[0] == 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < rA_sum.size(); i++)
for (int Q = 0; Q < QGroupSize; Q++) {
TensorO2D O_2D = O_3D(_,_,blk_head_q_start + Q);
FragA tArA;
FragARow tA_max;
FragARow tA_sum;
// Load from SLM
// copy_block_s2r(tArA_slm(_,sg_id,Q), tArA);// all sg index
// copy_block_s2r(tA_max_slm(_,sg_id,Q), tA_max);// all sg index
// copy_block_s2r(tA_sum_slm(_,sg_id,Q), tA_sum);// all sg index
barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory);
for (int i = 0; i < tA_max.size(); i++) {
tA_max(i) = tA_max_slm(i, thr_id % 16, sg_id,Q);
}
for (int i = 0; i < tArA.size(); i++) {
tArA(i) = tArA_slm(i, thr_id, Q);
}
for (int i = 0; i < tA_sum.size(); i++) {
tA_sum(i) = tA_sum_slm(i, thr_id, Q);
}
barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory);
// Reduce k-blocks of A and A_sum across WG, if needed.
auto [rA, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id);

/* Some subgroups may not have any work to do; if so, quit early. */
if (!active) return;

/* Complete softmax, dividing out sums. */
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < rA_sum.size(); i++)
rA_sum(i) = ElementA(1) / rA_sum(i);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < rA.size(); i++)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < rA.size(); i++)
rA(i) *= broadcast<0>(rA_sum, rA, i);

/* Tile output */
Tensor cO = make_identity_tensor(O.shape()); // (q,v)
Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v)

/* Prepare slices */
TiledCopyO copy_o{O};
auto thr_copy_o = copy_o.get_slice(thr_id);

auto tOrO = thr_copy_o.partition_sg_fragment_S(gO);
auto tOgO = thr_copy_o.partition_D(gO);

/* Reorder tile and write out */
reorder(rA, tOrO);
copy(copy_o, tOrO, tOgO);

/* Tile output */
Tensor cO = make_identity_tensor(O_2D.shape()); // (q,v)
Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v)

/* Prepare slices */
TiledCopyO copy_o{O_2D};
auto thr_copy_o = copy_o.get_slice(thr_id);

auto tOrO = thr_copy_o.partition_sg_fragment_S(gO);
auto tOgO = thr_copy_o.partition_D(gO);

/* Reorder tile and write out */
reorder(rA, tOrO);
copy(copy_o, tOrO, tOgO);
}
}

// Reduce k-blocks of A and A_sum across WG, if needed.
// Note that each k block has its own scale factor based on A_max,
// so A/A_sum contributions need to be rescaled to match.
Expand Down
Loading