Skip to content

Commit

Permalink
Revert "WIP: thd_bshd_bshd"
Browse files Browse the repository at this point in the history
This reverts commit 1c31b68.

Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Feb 21, 2025
1 parent 1c31b68 commit f5b91c6
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 521 deletions.
13 changes: 5 additions & 8 deletions tests/pytorch/fused_attn/test_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ def step(self, dynamic_fill: bool = True):

@pytest.mark.parametrize("dtype", [torch.float16])#param_types)
@pytest.mark.parametrize("model", model_configs_infer.keys())
@pytest.mark.parametrize("qkv_format", ["thd"])#qkv_formats)
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("is_paged", [False, True])
@pytest.mark.parametrize("backend", ["FusedAttention"])#, "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph):
reset_rng_states()
Expand All @@ -211,7 +211,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph):
# figure out supported backends
inference_params_qkv_format = "bshd"
if is_paged:
qkv_layout = "paged_kv_" + "_".join([inference_params_qkv_format] * 3)
qkv_layout = "paged_kv_" + inference_params_qkv_format + "_2" + inference_params_qkv_format
else:
qkv_layout = "_".join([inference_params_qkv_format] * 3)
available_backends, fused_attn_backends = _get_attention_backends(
Expand Down Expand Up @@ -356,7 +356,7 @@ def gen_data():
dtype=torch.int32,
)
sample_kwargs["inference_params"] = inference_params
sample_kwargs["attn_mask_type"] = "padding" #_causal"
sample_kwargs["attn_mask_type"] = "padding_causal"
sample_kwargs["max_seqlen_q"] = config.max_ctx_len
sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv
sample_kwargs["qkv_format"] = qkv_format
Expand Down Expand Up @@ -485,7 +485,7 @@ def gen_data():
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
inference_params=inference_params,
attn_mask_type="padding", #_causal",
attn_mask_type="padding_causal",
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
qkv_format=qkv_format,
Expand Down Expand Up @@ -525,9 +525,6 @@ def gen_data():
rtol=tols[dtype],
)
if qkv_format == "thd":
print('i ', i, seq, cu_seqlens_q)
print(full_output[seq, sim.t_total_lens[i] - 1, :4])
print(line_output[cu_seqlens_q[i + 1] - 1, :4])
torch.testing.assert_close(
#full_output[seq, sim.t_total_lens[i] - sim.step_lens[i]:sim.t_total_lens[i] - 1, :],
#line_output[cu_seqlens_q[i]:cu_seqlens_q[i + 1] - 1, :],
Expand Down
174 changes: 40 additions & 134 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,15 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD;
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD:
return NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD;
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD:
return NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
Expand All @@ -60,68 +59,24 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_2BSHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Layout::NVTE_BS3HD:
case NVTE_QKV_Layout::NVTE_BSH3D:
case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_2SBHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Layout::NVTE_T3HD:
case NVTE_QKV_Layout::NVTE_TH3D:
case NVTE_QKV_Layout::NVTE_THD_T2HD:
case NVTE_QKV_Layout::NVTE_THD_TH2D:
case NVTE_QKV_Layout::NVTE_THD_THD_THD:
return NVTE_QKV_Format::NVTE_THD;
case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_SBHD_2BSHD;
case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_BSHD_2SBHD;
case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD:
return NVTE_QKV_Format::NVTE_THD_2BSHD;
case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_THD_2SBHD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}

// map NVTE_QKV_Layout to NVTE_QKV_Format for Q
NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
switch (qkv_format) {
case NVTE_QKV_Format::NVTE_SBHD:
case NVTE_QKV_Format::NVTE_SBHD_2BSHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Format::NVTE_BSHD:
case NVTE_QKV_Format::NVTE_BSHD_2SBHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Format::NVTE_THD:
case NVTE_QKV_Format::NVTE_THD_2BSHD:
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
}
}

// map NVTE_QKV_Layout to NVTE_QKV_Format for KV
NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
switch (qkv_format) {
case NVTE_QKV_Format::NVTE_SBHD:
case NVTE_QKV_Format::NVTE_BSHD_2SBHD:
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_SBHD;
case NVTE_QKV_Format::NVTE_BSHD:
case NVTE_QKV_Format::NVTE_SBHD_2BSHD:
case NVTE_QKV_Format::NVTE_THD_2BSHD:
return NVTE_QKV_Format::NVTE_BSHD;
case NVTE_QKV_Format::NVTE_THD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2SBHD:
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_2BSHD:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
Expand All @@ -140,8 +95,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
const int sm_arch_ = cuda::sm_arch(device_id);
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();

Expand Down Expand Up @@ -265,11 +218,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right}
(cudnn_runtime_version >= 90500 &&
layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD &&
(layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD ||
layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD) &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) &&
//max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv &&
bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) ||
// 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv)
(cudnn_runtime_version >= 90600 &&
Expand All @@ -285,10 +239,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD ||
(qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 &&
((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) ||
cudnn_runtime_version >= 90600)) ||
((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) ||
kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) &&
cudnn_runtime_version >= 90700)) &&
cudnn_runtime_version >= 90600))) &&
// sliding window
// pre-9.2: full attn, causal
((cudnn_runtime_version < 90200 && window_size_left == -1 &&
Expand Down Expand Up @@ -324,7 +275,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(supported_ragged_offset_size)) {
flag_arb = true;
}
flag_arb = true;
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
Expand Down Expand Up @@ -542,7 +492,7 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
Expand All @@ -554,8 +504,6 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
const Tensor *input_cu_seqlens_kv = reinterpret_cast<const Tensor *>(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = reinterpret_cast<const Tensor *>(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast<const Tensor *>(cu_seqlens_kv_padded);
const Tensor *input_page_table_k = reinterpret_cast<const Tensor *>(page_table_k);
const Tensor *input_page_table_v = reinterpret_cast<const Tensor *>(page_table_v);
const Tensor *input_rng_state = reinterpret_cast<const Tensor *>(rng_state);
const Tensor *input_Q = reinterpret_cast<const Tensor *>(Q);
const Tensor *input_KV = reinterpret_cast<const Tensor *>(KV);
Expand All @@ -580,40 +528,11 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
}
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_KV->data.shape[0];
}
int64_t num_pages_k = 0;
int64_t num_pages_v = 0;
int64_t page_size_k = 0;
int64_t page_size_v = 0;
int64_t max_pages_per_seq_k = 0;
int64_t max_pages_per_seq_v = 0;
if (input_page_table_k->data.dptr != nullptr) {
max_pages_per_seq_k = input_page_table_k->data.shape[1];
}
if (input_page_table_v->data.dptr != nullptr) {
max_pages_per_seq_v = input_page_table_v->data.shape[1];
}
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) {
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (kv_format == NVTE_QKV_Format::NVTE_BSHD) {
num_pages_k = input_KV->data.shape[0];
page_size_k = input_KV->data.shape[1];
num_pages_v = num_pages_v;
page_size_v = page_size_v;
} else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) {
num_pages_k = input_KV->data.shape[1];
page_size_k = input_KV->data.shape[0];
num_pages_v = num_pages_v;
page_size_v = page_size_v;
}
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
Expand All @@ -635,11 +554,10 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q,
input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
handle);
#else
NVTE_ERROR(
Expand Down Expand Up @@ -701,12 +619,9 @@ void nvte_fused_attn_bwd_kvpacked(
}
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_KV->data.shape[0];
}

Expand Down Expand Up @@ -805,12 +720,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t d_v = input_V->data.shape[ndim_kv - 1];
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_K->data.shape[0];
}
int64_t num_pages_k = 0;
Expand All @@ -826,19 +738,16 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
max_pages_per_seq_v = input_page_table_v->data.shape[1];
}
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) {
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (kv_format == NVTE_QKV_Format::NVTE_BSHD) {
num_pages_k = input_K->data.shape[0];
page_size_k = input_K->data.shape[1];
num_pages_v = input_V->data.shape[0];
page_size_v = input_V->data.shape[1];
} else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) {
num_pages_k = input_K->data.shape[1];
page_size_k = input_K->data.shape[0];
num_pages_v = input_V->data.shape[1];
page_size_v = input_V->data.shape[0];
}
if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2BSHD) {
num_pages_k = input_K->data.shape[0];
page_size_k = input_K->data.shape[1];
num_pages_v = input_V->data.shape[0];
page_size_v = input_V->data.shape[1];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_2SBHD) {
num_pages_k = input_K->data.shape[1];
page_size_k = input_K->data.shape[0];
num_pages_v = input_V->data.shape[1];
page_size_v = input_V->data.shape[0];
}

auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
Expand Down Expand Up @@ -924,12 +833,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t d_v = input_V->data.shape[ndim_kv - 1];
size_t t_q = 0;
size_t t_kv = 0;
NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout);
if (q_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
t_q = input_Q->data.shape[0];
}
if (kv_format == NVTE_QKV_Format::NVTE_THD) {
t_kv = input_K->data.shape[0];
}

Expand Down
Loading

0 comments on commit f5b91c6

Please sign in to comment.