Skip to content

Commit

Permalink
Fix FA3 Varlen Performance regression (Dao-AILab#1361)
Browse files Browse the repository at this point in the history
  • Loading branch information
kadeng authored Dec 2, 2024
1 parent ca71144 commit 0823cf7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 117 deletions.
91 changes: 49 additions & 42 deletions hopper/copy_paged_sm90_tma_cutlass35.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ struct PagedCopyArgs {
};

CUTE_HOST_DEVICE
PagedCopyArgs(int64_t const block_table_batch_stride_, int const page_block_size_, int32_t *block_table_) : block_table_batch_stride{block_table_batch_stride_}, page_block_size(page_block_size_), block_table(block_table_) {
PagedCopyArgs(int64_t const block_table_batch_stride_, int const page_block_size_, const int32_t *const block_table_) : block_table_batch_stride{block_table_batch_stride_}, page_block_size(page_block_size_), block_table(block_table_) {
};

int64_t block_table_batch_stride; // The stride between block tables for different batches
int page_block_size; // The size of a page block in number of elements
int32_t* block_table; // The block table, must be properly sized or a nullptr
const int64_t block_table_batch_stride; // The stride between block tables for different batches
const int page_block_size; // The size of a page block in number of elements
const int32_t *const block_table; // The block table, must be properly sized or a nullptr
};

namespace cute {
Expand All @@ -38,26 +38,27 @@ namespace cute {
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr,
PagedCopyArgs const& pca,
PagedCopyArgs const* pca,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 2D");
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr,
PagedCopyArgs const& pca,
copy(void const* desc_ptr, uint64_t* mbar_ptr,
PagedCopyArgs const* pca,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
if (pca.block_table == nullptr) {
return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2);
}
CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 3D");
// WARNING: Do not place anything else here, or a performance regression will occur
// look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized"
// asserts that pca==nullptr, but even an assert would kill performance
return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, smem_ptr, crd0, crd1, crd2);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr,
PagedCopyArgs const& pca,

CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr,
PagedCopyArgs const* pca,
void * smem_ptr,
// Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout()
// via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis )
Expand All @@ -73,21 +74,24 @@ namespace cute {
//auto log = pca.debug_log->nextline();
//log.append_threadinfo();
//log.snprintf("SM_90_TMA_LOAD_PAGED::copy(%d, %d, %d, %d) ", (int)crdM, (int)crdK, (int)crdH, (int)crdB);
if (pca.block_table == nullptr) {
if (pca == nullptr) {
return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, smem_ptr, crdK, crdM, crdH, crdB);
}
int32_t const page_idx_offset = crdM / pca.page_block_size; // page index within the batch entry
int32_t const seq_pos_offset = crdM - page_idx_offset*pca.page_block_size; // == crd1 % page_block_size_ -> sequence position within the page
int32_t const page_idx = pca.block_table[page_idx_offset + crdB*pca.block_table_batch_stride]; // The page index for the given batch and sequence position
auto const page_block_size = pca->page_block_size;
int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry
int32_t const seq_pos_offset = crdM - page_idx_offset * page_block_size; // == crd1 % page_block_size_ -> sequence position within the page
int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position
//if (cute::thread0()) {
// printf("SM90_TMA_LOAD_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr);
//}

return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, smem_ptr, crdK, seq_pos_offset, crdH, page_idx);

}


CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr,
copy(void const* desc_ptr, uint64_t* mbar_ptr,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
Expand All @@ -107,28 +111,28 @@ struct SM90_TMA_LOAD_MULTICAST_PAGED
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
PagedCopyArgs const& pca,
PagedCopyArgs const* pca,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
CUTE_INVALID_CONTROL_PATH("not implemented");
}
CUTE_HOST_DEVICE static void
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
PagedCopyArgs const& pca,
PagedCopyArgs const* pca,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
if (pca.block_table == nullptr) {
return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2);
}
CUTE_INVALID_CONTROL_PATH("PAGED_COPY_OP not implemented for 3D");
// WARNING: Do not place anything else here, or a performance regression will occur
// look out for ptxas build warnings like "Potential Performance Loss: wgmma.mma_async instructions are serialized"
// asserts that pca==nullptr, but even an assert would kill performance
return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crd0, crd1, crd2);
}


CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
PagedCopyArgs const& pca,
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask,
PagedCopyArgs const* pca,
void * smem_ptr,
// Index order reordered for TMA from PagedSeqLenTraits::get_kv_gmem_layout()
// via cute::make_tma_copy_atom ( see detail::construct_tma_gbasis )
Expand All @@ -141,17 +145,20 @@ struct SM90_TMA_LOAD_MULTICAST_PAGED
int32_t const& crdH, // head dim
int32_t const& crdB) // batch dim
{
if (pca.block_table == nullptr) {
if (pca == nullptr) {
return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crdK, crdM, crdH, crdB);
}
int32_t const page_idx_offset = crdM / pca.page_block_size; // page index within the batch entry
int32_t const seq_pos_offset = crdM - page_idx_offset*pca.page_block_size; // == crd1 % page_block_size_ -> sequence position within the page
int32_t const page_idx = pca.block_table[page_idx_offset + crdB*pca.block_table_batch_stride]; // The page index for the given batch and sequence position
auto const page_block_size = pca->page_block_size;
int32_t const page_idx_offset = crdM / page_block_size; // page index within the batch entry
int32_t const seq_pos_offset = crdM - page_idx_offset*page_block_size; // == crd1 % page_block_size_ -> sequence position within the page
int32_t const page_idx = pca->block_table[page_idx_offset + crdB*pca->block_table_batch_stride]; // The page index for the given batch and sequence position
//if (cute::thread0()) {
// printf("SM90_TMA_LOAD_MULTICAST_PAGED::copy crdM=%d, crdB=%d, crdK=%d, crdH=%d, page_idx=%d, seq_pos_offset=%d, ptr=%p\n", (int)crdM, (int)crdB, (int) crdK, (int) crdH, (int)page_idx, (int)seq_pos_offset, (void*)desc_ptr);
//}
return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, smem_ptr, crdK, seq_pos_offset, crdH, page_idx);

}

};


Expand Down Expand Up @@ -194,30 +201,30 @@ struct Copy_Traits<SM90_TMA_LOAD_PAGED, NumBitsPerTMA, AuxParams_>
Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {&tma_desc_, &tma_mbar, PagedCopyArgs{} }};
return {{}, {&tma_desc_, &tma_mbar, nullptr }};
}

// Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {new_tma_desc, &tma_mbar, PagedCopyArgs{} }};
return {{}, {new_tma_desc, &tma_mbar, nullptr }};
}

CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask, PagedCopyArgs const & paged_copy_args ) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {&tma_desc_, &tma_mbar, paged_copy_args }};
return {{}, {&tma_desc_, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
}

// Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask,PagedCopyArgs const &paged_copy_args ) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {{}, {new_tma_desc, &tma_mbar, paged_copy_args }};
return {{}, {new_tma_desc, &tma_mbar, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
}

// Generate the TMA coord tensor
Expand Down Expand Up @@ -255,7 +262,7 @@ struct Copy_Traits<SM90_TMA_LOAD_PAGED_OP, NumBitsPerTMA>
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
PagedCopyArgs
PagedCopyArgs const*
> const opargs_;
};

Expand Down Expand Up @@ -295,28 +302,28 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED, NumBitsPerTMA, AuxParams_>
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, PagedCopyArgs{} }};
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, nullptr }};
}

// Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const {
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, PagedCopyArgs{} }};
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, nullptr }};
}

// Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const& paged_copy_args) const {
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, paged_copy_args }};
return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
}

// Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
with(TmaDescriptor const* new_tma_desc, uint64_t& tma_load_mbar, uint16_t const& multicast_mask, PagedCopyArgs const& paged_copy_args) const {
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, paged_copy_args }};
return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, (paged_copy_args.block_table==nullptr) ? nullptr : &paged_copy_args }};
}

// Generate the TMA coord tensor
Expand Down Expand Up @@ -355,7 +362,7 @@ struct Copy_Traits<SM90_TMA_LOAD_MULTICAST_PAGED_OP, NumBitsPerTMA>
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint16_t, // multicast mask
PagedCopyArgs,
PagedCopyArgs const*
> const opargs_;
};

Expand Down
Loading

0 comments on commit 0823cf7

Please sign in to comment.