Skip to content

Commit

Permalink
WIP: add bshd_2sbhd, sbhd_2bshd
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Feb 22, 2025
1 parent 0341de7 commit 6bd61a7
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 36 deletions.
4 changes: 4 additions & 0 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
Expand Down Expand Up @@ -75,8 +77,10 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
return NVTE_QKV_Format::NVTE_THD;
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_SBHD_2BSHD;
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_BSHD_2SBHD;
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/fused_attn/utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
strideA[hidden_transpose_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) ||
(matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) {
Expand All @@ -268,6 +269,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6
strideA[hidden_dim_idx] = 1;
}
break;
case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
Expand Down
22 changes: 12 additions & 10 deletions transformer_engine/common/include/transformer_engine/fused_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,16 @@ enum NVTE_QKV_Layout {
NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */
NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */
NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */
NVTE_THD_BSHD_BSHD = 15, /*!< THD_BSHD_BSHD layout */
NVTE_THD_SBHD_SBHD = 16, /*!< THD_SBHD_SBHD layout */
NVTE_Paged_KV_BSHD_BSHD_BSHD = 17, /*!< Paged_KV_BSHD_BSHD_BSHD layout */
NVTE_Paged_KV_BSHD_SBHD_SBHD = 18, /*!< Paged_KV_BSHD_SBHD_SBHD layout */
NVTE_Paged_KV_SBHD_BSHD_BSHD = 19, /*!< Paged_KV_SBHD_BSHD_BSHD layout */
NVTE_Paged_KV_SBHD_SBHD_SBHD = 20, /*!< Paged_KV_SBHD_SBHD_SBHD layout */
NVTE_Paged_KV_THD_BSHD_BSHD = 21, /*!< Paged_KV_THD_BSHD_BSHD layout */
NVTE_Paged_KV_THD_SBHD_SBHD = 22, /*!< Paged_KV_THD_SBHD_SBHD layout */
NVTE_SBHD_BSHD_BSHD = 15, /*!< SBHD_BSHD_BSHD layout */
NVTE_BSHD_SBHD_SBHD = 16, /*!< BSHD_SBHD_SBHD layout */
NVTE_THD_BSHD_BSHD = 17, /*!< THD_BSHD_BSHD layout */
NVTE_THD_SBHD_SBHD = 18, /*!< THD_SBHD_SBHD layout */
NVTE_Paged_KV_BSHD_BSHD_BSHD = 19, /*!< Paged_KV_BSHD_BSHD_BSHD layout */
NVTE_Paged_KV_BSHD_SBHD_SBHD = 20, /*!< Paged_KV_BSHD_SBHD_SBHD layout */
NVTE_Paged_KV_SBHD_BSHD_BSHD = 21, /*!< Paged_KV_SBHD_BSHD_BSHD layout */
NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */
NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */
NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */
};

/*! \enum NVTE_QKV_Layout_Group
Expand Down Expand Up @@ -81,9 +83,9 @@ enum NVTE_QKV_Format {
NVTE_BSHD = 1,
/*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */
NVTE_THD = 2,
/*! BSHD format for Q and SBHD format for KV, i.e. Paged_KV_BSHD_SBHD_SBHD */
/*! BSHD format for Q and SBHD format for KV, i.e. BSHD_SBHD_SBHD, Paged_KV_BSHD_SBHD_SBHD */
NVTE_BSHD_2SBHD = 3,
/*! SBHD format for Q and BSHD format for KV, i.e. Paged_KV_SBHD_BSHD_BSHD */
/*! SBHD format for Q and BSHD format for KV, i.e. SBHD_BSHD_BSHD, Paged_KV_SBHD_BSHD_BSHD */
NVTE_SBHD_2BSHD = 4,
/*! THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD */
NVTE_THD_2BSHD = 5,
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \
.value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \
.value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \
.value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \
.value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \
.value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
"thd_t2hd",
"thd_th2d",
"thd_thd_thd",
"sbhd_bshd_bshd",
"bshd_sbhd_sbhd",
"thd_bshd_bshd",
"thd_sbhd_sbhd",
"paged_kv_bshd_bshd_bshd",
Expand Down
54 changes: 28 additions & 26 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
"thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD,
"thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D,
"thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD,
"sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_SBHD_BSHD_BSHD,
"bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_BSHD_SBHD_SBHD,
"thd_bshd_bshd": NVTE_QKV_Layout.NVTE_THD_BSHD_BSHD,
"thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_THD_SBHD_SBHD,
"paged_kv_bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_BSHD_BSHD,
Expand Down Expand Up @@ -274,32 +276,32 @@ def fused_attn_fwd(

# execute kernel

print(max_seqlen_q,
max_seqlen_kv,
is_training,
attn_scale,
dropout,
fast_zero_fill,
QKVLayout[qkv_layout],
AttnBiasType[attn_bias_type],
AttnMaskType[attn_mask_type],
window_size,
cu_seqlens_q,
cu_seqlens_kv,
q.shape,
k.shape,
v.shape,
fake_dtype,
cu_seqlens_q_padded,
cu_seqlens_kv_padded,
page_table_k,
page_table_v,
s_quantizer,
o_quantizer,
attn_bias,
rng_gen,
rng_elts_per_thread,
)
#print(max_seqlen_q,
# max_seqlen_kv,
# is_training,
# attn_scale,
# dropout,
# fast_zero_fill,
# QKVLayout[qkv_layout],
# AttnBiasType[attn_bias_type],
# AttnMaskType[attn_mask_type],
# window_size,
# cu_seqlens_q,
# cu_seqlens_kv,
# q.shape,
# k.shape,
# v.shape,
# fake_dtype,
# cu_seqlens_q_padded,
# cu_seqlens_kv_padded,
# page_table_k,
# page_table_v,
# s_quantizer,
# o_quantizer,
# attn_bias,
# rng_gen,
# rng_elts_per_thread,
#)
output_tensors = tex.fused_attn_fwd(
max_seqlen_q,
max_seqlen_kv,
Expand Down

0 comments on commit 6bd61a7

Please sign in to comment.