Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
configs = ["cp_1_0", "cp_2_0", "cp_2_1", "cp_3_2", "cp_3_3"]
model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
dtypes = ["bf16"]
qkv_formats = ["sbhd", "thd"]
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_formats = ["bshd", "sbhd", "thd"]
cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
if test_essential:
configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
configs = ["cp_1_0", "cp_1_1", "cp_1_4", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
dtypes = ["bf16", "fp8"]
qkv_formats = ["sbhd", "thd"]
Expand Down Expand Up @@ -224,10 +224,14 @@ def test_cp_with_fused_attention(

if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
" yet!"
)
Comment on lines +229 to +234
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The skip at line 229 contradicts the PR's stated goal of enabling THD+CP with A2A. This should be removed or updated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't, since A2A+P2P is different than A2A only

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right – A2A and A2A+P2P are different communication patterns. The skip at line 229 is specifically for all_gather, not for A2A, so there's no contradiction with the PR's goal of enabling THD+CP with A2A.

Looking at the code:

  • Line 229: Skips when cp_comm_type == "all_gather"
  • Line 230-234: Skips when cp_comm_type == "a2a+p2p"

These are two separate conditions. The PR enables THD with A2A (plain "a2a"), which is not blocked by either of these skips. So the logic is actually correct and consistent with the PR's objective.

My apologies for the confusion in my original comment!

if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
Expand Down Expand Up @@ -281,6 +285,13 @@ def test_cp_with_fused_attention(
)

dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}

if qkv_format == "thd":
if "causal" in config.attn_mask_type:
config.attn_mask_type = "padding_causal"
else:
config.attn_mask_type = "padding"

fp8_meta = {}
fp8_meta["recipe"] = None
fp8_meta["local_recipes"] = []
Expand Down
Loading