Skip to content

Commit

Permalink
WIP: all qkv_format combinations and merge CM files
Browse files Browse the repository at this point in the history
Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa committed Feb 23, 2025
1 parent 9ec3649 commit 93235dd
Show file tree
Hide file tree
Showing 9 changed files with 570 additions and 611 deletions.
5 changes: 3 additions & 2 deletions tests/pytorch/fused_attn/test_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def step(self, dynamic_fill: bool = True):

@pytest.mark.parametrize("dtype", [torch.float16]) # param_types)
@pytest.mark.parametrize("model", model_configs_infer.keys())
@pytest.mark.parametrize("qkv_format", ["thd"]) # qkv_formats)
@pytest.mark.parametrize("qkv_format", qkv_formats)
@pytest.mark.parametrize("is_paged", [False, True])
@pytest.mark.parametrize("backend", ["FusedAttention"]) # , "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"])
@pytest.mark.parametrize("is_cuda_graph", [False, True])
def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph):
reset_rng_states()
Expand Down Expand Up @@ -319,6 +319,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, is_cuda_graph):
head_dim_q=config.head_dim_qk,
max_ctx_len=config.max_ctx_len,
qkv_format=qkv_format,
allow_query_conversion=backend!="FusedAttention",
)
inference_params.allocate_memory(layer_number, qkv_format)

Expand Down
50 changes: 10 additions & 40 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,11 @@ def get_attention_backend(
"Disabling FlashAttention as KV caching requires flash-attn 2.2+, or 3.0"
" (Hopper only)"
)
if use_fused_attention and pad_between_seqs:
use_fused_attention = False
logger.debug(
"Disabling FusedAttention for pad_between_seqs = True and KV caching"
)
if inference_params.is_paged:
if use_fused_attention and cudnn_version < (9, 5, 0):
logger.debug("Disabling FusedAttention as paged attention requires cuDNN 9.5+")
Expand Down Expand Up @@ -5527,7 +5532,6 @@ def get_qkv_layout(
q_format = qkv_format
kv_format = qkv_format
is_same_q_kv_format = True
print("qkv format", qkv_format, is_same_q_kv_format, q_format, kv_format)

def run_iteratively(q, k, v):
# check data pointers
Expand Down Expand Up @@ -5616,7 +5620,6 @@ def run_iteratively(q, k, v):
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
qkv_layout = "_".join(list([qkv_format]) * 3)
print("xxxxx0")
elif (
check_strides_kv
and check_shapes_kv
Expand All @@ -5628,7 +5631,6 @@ def run_iteratively(q, k, v):
# when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or
# check_ptrs_qk=True or check_ptrs_kv=True
qkv_layout = q_format + "_" + kv_format + "_" + kv_format
print("xxxxx1")
else:
qkv_layout = "not_supported"

Expand Down Expand Up @@ -5932,9 +5934,8 @@ def forward(
if inference_params is not None:
func = flash_attn_with_kvcache
fa_optional_forward_kwargs_kvcache = {}
fa_optional_forward_kwargs_kvcache["cache_seqlens"] = (
cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
)
cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
fa_optional_forward_kwargs_kvcache["cache_seqlens"] = cache_seqlens
fa_optional_forward_kwargs_kvcache["softmax_scale"] = self.softmax_scale
fa_optional_forward_kwargs_kvcache["causal"] = "causal" in attn_mask_type
if inference_params.is_paged:
Expand Down Expand Up @@ -7506,8 +7507,6 @@ def forward(

# convert causal to causal_bottom_right in inference when KV-caching is in use
# so users can run with the same attn_mask_type for training and inference
# if "padding" not in attn_mask_type:
# attn_mask_type = "padding_" + attn_mask_type
if attn_mask_type in ["causal", "padding_causal"]:
attn_mask_type = attn_mask_type + "_bottom_right"

Expand All @@ -7523,19 +7522,7 @@ def forward(
for x in [query_layer, key_layer, value_layer]
]

# reshape the query tensor
# cuDNN paged attention supports bshd_2bshd and sbhd_2bshd, but
# flash-attention and unfused attention will need q/k/v in the
# same qkv_format
# target_qkv_format = inference_params.qkv_format
# query_layer = inference_params.reshape_and_copy_q(
# query_layer, qkv_format, target_qkv_format, self.layer_number
# )

# update KV cache and return the full key/value tensors
# full key/value tensors are in inference_params.qkv_format format
# print('query_layer',query_layer.shape, query_layer.dtype)
# print('query_layer', query_layer[8,0,:4])
(
query_layer,
key_layer,
Expand All @@ -7553,17 +7540,8 @@ def forward(
value_layer,
qkv_format,
)
# print('ssss0 ',query_layer.shape, key_layer.shape, value_layer.shape)
# print('cu_seqlens_q',cu_seqlens_q)
# print('cu_seqlens_kv',cu_seqlens_kv)
# print('maxxxxx ',max_seqlen_q, max_seqlen_kv)

# update cu_seqlens tensors
# if inference_params.is_cuda_graph:
# cu_seqlens_q = inference_params.cu_seqlens_q_buffer
# cu_seqlens_kv = inference_params.cu_seqlens_kv_buffer
# max_seqlen_q = inference_params.max_seqlen_q
# max_seqlen_kv = inference_params.max_seqlen_kv
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None

if (
isinstance(query_layer, Float8Tensor)
Expand All @@ -7586,8 +7564,6 @@ def forward(
)
# convert qkv layout to its corresponding paged attention layout
if inference_params is not None and inference_params.is_paged:
# qkv_layout = "paged_kv_" + qkv_format + "_2" + qkv_format
# qkv_layout = "paged_kv_thd_2bshd"# + qkv_format + "_2" + qkv_format
qkv_layout = "paged_kv_" + qkv_layout

cp_size = 1
Expand Down Expand Up @@ -7632,10 +7608,6 @@ def forward(
max_seqlen_kv,
key_layer.device,
)
# print('max_seqlen_q ', max_seqlen_q)
# print('max_seqlen_kv ', max_seqlen_kv)
# print('cu_seqlens_q ', cu_seqlens_q)
# print('cu_seqlens_kv ', cu_seqlens_kv)

global _alibi_cache
if alibi_slopes is not None:
Expand Down Expand Up @@ -7860,9 +7832,6 @@ def forward(
fp8_meta=self.fp8_meta,
quantizers=self.quantizers,
)
# print('ooooooooooo ',output.shape)
# print(output[1,9,:4])
# print(output[1,10,:4])

from .cpu_offload import CPUOffloadEnabled

Expand Down Expand Up @@ -8623,6 +8592,7 @@ def forward(

# pylint: disable=fixme
# TODO: consider cases where sequences have different seqlens
# sequence_start = inference_params.get_seqlens_pre_step()
sequence_start = inference_params.seqlens[0]
sequence_end = sequence_start + sequence_length

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens,
int h_q, int d_q, int b, int max_ctx_len, int max_seq_len);
int h_q, int d_q, int b, int max_seq_len);
void reshape_o(torch::Tensor output, torch::Tensor output_buffer, torch::Tensor cu_new_lens, int h_o,
int d_o, int b, int max_seq_len, bool is_output_right_aligned);
void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache,
Expand Down
16 changes: 8 additions & 8 deletions transformer_engine/pytorch/csrc/extensions/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1035,36 +1035,36 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t

template <typename scalar_t>
void reshape_q_launcher(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens,
int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) {
int h_q, int d_q, int b, int max_seq_len) {
transformer_engine::fused_attn::reshape_q_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<scalar_t *>(new_q.data_ptr<scalar_t>()),
reinterpret_cast<scalar_t *>(q_buffer.data_ptr<scalar_t>()), cu_new_lens.data_ptr<int>(),
h_q, d_q, b, max_ctx_len, max_seq_len);
h_q, d_q, b, max_seq_len);
}

void reshape_q(torch::Tensor new_q, torch::Tensor q_buffer, torch::Tensor cu_new_lens,
int h_q, int d_q, int b, int max_ctx_len, int max_seq_len) {
int h_q, int d_q, int b, int max_seq_len) {
NVTE_CHECK(new_q.scalar_type() == q_buffer.scalar_type(),
"new_q and q_buffer must be of the same data type.");
if (q_buffer.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b,
max_seq_len);
} else if (q_buffer.scalar_type() == at::ScalarType::BFloat16) {
using dtype = at::BFloat16;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b,
max_seq_len);
} else if (q_buffer.scalar_type() == at::ScalarType::Float) {
using dtype = float;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b,
max_seq_len);
} else if (q_buffer.scalar_type() == at::ScalarType::Float8_e4m3fn) {
using dtype = at::Float8_e4m3fn;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b,
max_seq_len);
} else if (q_buffer.scalar_type() == at::ScalarType::Float8_e5m2) {
using dtype = at::Float8_e5m2;
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b, max_ctx_len,
reshape_q_launcher<dtype>(new_q, q_buffer, cu_new_lens, h_q, d_q, b,
max_seq_len);
} else {
NVTE_ERROR("Unsupported dtype for KV cache.\n");
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/csrc/kv_cache.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ namespace transformer_engine {
namespace fused_attn {
template <typename scalar_t>
__global__ void reshape_q_kernel(scalar_t *new_q, scalar_t *q_buffer, int *cu_new_lens,
int h_q, int d_q, int b,
int max_ctx_len, int max_seq_len) {
int h_q, int d_q, int b, int max_seq_len) {
// new_q: thd; q_buffer: bshd;
// cu_new_lens: [b + 1]
for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) {
Expand Down
Loading

0 comments on commit 93235dd

Please sign in to comment.