Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add paged attention support #1355

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
Expand Down
54 changes: 27 additions & 27 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
check_set_window_size,
AttentionParams,
_attention_backends,
InferenceParams,
)
from transformer_engine.pytorch.constants import TE_DType
import transformer_engine.pytorch.cpp_extensions as ext
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(
num_layers: int = 1,
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
total_requests: int = 1,
):
self.batch_size = batch_size
self.num_heads = num_heads
Expand All @@ -107,6 +109,7 @@ def __init__(
self.num_layers = num_layers
self.bias_shape = bias_shape
self.window_size = window_size
self.total_requests = total_requests


@contextmanager
Expand All @@ -129,6 +132,7 @@ def _get_attention_backends(
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
inference_params: Optional[InferenceParams] = None,
) -> Tuple[List, List]:
"""Check if what attention backends support a model configuration"""

Expand Down Expand Up @@ -183,6 +187,7 @@ def test():
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
inference_params=inference_params,
)
_, _, fused_attention_backend, _, available_backends = get_attention_backend(
attention_params
Expand Down Expand Up @@ -2043,21 +2048,18 @@ def forward(
qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :],
fp8_dtype_forward,
FusedAttnBackend["FP8"],
None,
None,
None,
fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv
META_QKV, # d_scale_qkv_offset
fp8_meta["scaling_fwd"].scale_inv, # d_scale_s
META_S, # d_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_s
META_S, # q_scale_s_offset
fp8_meta["scaling_fwd"].scale, # q_scale_o
META_O, # q_scale_o_offset
fp8_meta["scaling_fwd"].amax_history, # amax_s
META_S, # amax_s_offset
fp8_meta["scaling_fwd"].amax_history, # amax_o
META_O, # amax_o_offset
d_scale_qkv=fp8_meta["scaling_fwd"].scale_inv,
d_scale_qkv_offset=META_QKV,
d_scale_s=fp8_meta["scaling_fwd"].scale_inv,
d_scale_s_offset=META_S,
q_scale_s=fp8_meta["scaling_fwd"].scale,
q_scale_s_offset=META_S,
q_scale_o=fp8_meta["scaling_fwd"].scale,
q_scale_o_offset=META_O,
amax_s=fp8_meta["scaling_fwd"].amax_history,
amax_s_offset=META_S,
amax_o=fp8_meta["scaling_fwd"].amax_history,
amax_o_offset=META_O,
attn_scale=None,
dropout=p_dropout,
fast_zero_fill=fast_zero_fill,
Expand Down Expand Up @@ -2129,18 +2131,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
fp8_dtype_backward,
ctx.aux_ctx_tensors,
FusedAttnBackend["FP8"],
None,
None,
fwd_scale_inverses[META_QKV], # d_scale_qkv,
fwd_scale_inverses[META_S], # d_scale_s,
fwd_scale_inverses[META_O], # d_scale_o,
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do
ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp
fwd_scales[META_S], # q_scale_s
ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp
ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv
ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp
ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv
d_scale_qkv=fwd_scale_inverses[META_QKV],
d_scale_s=fwd_scale_inverses[META_S],
d_scale_o=fwd_scale_inverses[META_O],
d_scale_do=ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO],
d_scale_dp=ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP],
q_scale_s=fwd_scales[META_S],
q_scale_dp=ctx.fp8_meta["scaling_bwd"].scale[META_DP],
q_scale_dqkv=ctx.fp8_meta["scaling_bwd"].scale[META_DQKV],
amax_dp=ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP],
amax_dqkv=ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV],
attn_scale=None,
dropout=ctx.p_dropout,
fast_zero_fill=ctx.fast_zero_fill,
Expand Down
Loading
Loading